use std::collections::HashMap;
use tree_sitter::Node;
pub(super) struct PythonAliases<'a> {
pub imports: &'a HashMap<String, String>,
pub modules: &'a HashMap<String, String>,
}
impl<'a> PythonAliases<'a> {
pub(super) fn new(
imports: &'a HashMap<String, String>,
modules: &'a HashMap<String, String>,
) -> Self {
Self { imports, modules }
}
}
pub(super) fn collect_python_from_imports(
root: Node<'_>,
source: &[u8],
) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
if top.kind() != "import_from_statement" {
continue;
}
let module = top
.child_by_field_name("module_name")
.and_then(|m| node_text(m, source))
.unwrap_or("")
.to_string();
if module.is_empty() {
continue;
}
let module_name_id = top.child_by_field_name("module_name").map(|n| n.id());
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
if Some(child.id()) == module_name_id {
continue;
}
match child.kind() {
"dotted_name" => {
if let Some(name) = node_text(child, source) {
map.insert(name.to_string(), module.clone());
}
}
"aliased_import" => {
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let Some(local) = alias {
map.insert(local.to_string(), module.clone());
}
}
_ => {}
}
}
}
map
}
pub(super) fn collect_python_module_aliases(
root: Node<'_>,
source: &[u8],
) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
if top.kind() != "import_statement" {
continue;
}
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
match child.kind() {
"dotted_name" => {
if let Some(name) = node_text(child, source) {
map.insert(name.to_string(), name.to_string());
}
}
"aliased_import" => {
let module = child
.child_by_field_name("name")
.and_then(|n| node_text(n, source));
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let (Some(module), Some(alias)) = (module, alias) {
map.insert(alias.to_string(), module.to_string());
}
}
_ => {}
}
}
}
map
}
fn node_text<'a>(node: Node<'_>, source: &'a [u8]) -> Option<&'a str> {
let start = node.start_byte();
let end = node.end_byte().min(source.len());
std::str::from_utf8(&source[start..end]).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_python(src: &str) -> tree_sitter::Tree {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.expect("load python grammar");
parser.parse(src, None).expect("parse")
}
#[test]
fn from_imports_simple() {
let src = "from hashlib import md5\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("md5"), Some(&"hashlib".to_string()));
}
#[test]
fn from_imports_with_alias() {
let src = "from hashlib import md5 as m\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("m"), Some(&"hashlib".to_string()));
assert!(!map.contains_key("md5"));
}
#[test]
fn from_imports_multi_with_alias() {
let src = "from hashlib import md5, sha1 as sha_one\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("md5"), Some(&"hashlib".to_string()));
assert_eq!(map.get("sha_one"), Some(&"hashlib".to_string()));
}
#[test]
fn module_aliases_unaliased_identity() {
let src = "import hashlib\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("hashlib"), Some(&"hashlib".to_string()));
}
#[test]
fn module_aliases_aliased() {
let src = "import hashlib as hl\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("hl"), Some(&"hashlib".to_string()));
assert!(!map.contains_key("hashlib"));
}
#[test]
fn module_aliases_multi_in_one_statement() {
let src = "import os, sys\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os"), Some(&"os".to_string()));
assert_eq!(map.get("sys"), Some(&"sys".to_string()));
}
#[test]
fn module_aliases_multi_aliased() {
let src = "import os as o, sys as s\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("o"), Some(&"os".to_string()));
assert_eq!(map.get("s"), Some(&"sys".to_string()));
}
#[test]
fn module_aliases_dotted() {
let src = "import os.path\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os.path"), Some(&"os.path".to_string()));
}
#[test]
fn module_aliases_dotted_aliased() {
let src = "import os.path as op\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("op"), Some(&"os.path".to_string()));
}
#[test]
fn module_aliases_ignores_function_local_imports() {
let src = "def f():\n import hashlib as hl\n return hl\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert!(map.is_empty(), "got: {:?}", map);
}
#[test]
fn module_aliases_ignores_from_imports() {
let src = "from os import system\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert!(map.is_empty());
}
#[test]
fn from_imports_ignores_plain_import() {
let src = "import hashlib as hl\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert!(map.is_empty());
}
#[test]
fn module_aliases_alias_shadows_real_module() {
let src = "import safelib as os\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os"), Some(&"safelib".to_string()));
}
}