use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub(super) struct PythonAliases<'a> {
pub imports: &'a HashMap<String, String>,
pub modules: &'a HashMap<String, String>,
}
impl<'a> PythonAliases<'a> {
pub(super) fn new(
imports: &'a HashMap<String, String>,
modules: &'a HashMap<String, String>,
) -> Self {
Self { imports, modules }
}
}
pub(super) fn collect_python_from_imports(
root: Node<'_>,
source: &[u8],
) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
if top.kind() != "import_from_statement" {
continue;
}
let module = top
.child_by_field_name("module_name")
.and_then(|m| node_text(m, source))
.unwrap_or("")
.to_string();
if module.is_empty() {
continue;
}
let module_name_id = top.child_by_field_name("module_name").map(|n| n.id());
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
if Some(child.id()) == module_name_id {
continue;
}
match child.kind() {
"dotted_name" => {
if let Some(name) = node_text(child, source) {
map.insert(name.to_string(), module.clone());
}
}
"aliased_import" => {
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let Some(local) = alias {
map.insert(local.to_string(), module.clone());
}
}
_ => {}
}
}
}
map
}
pub(super) fn collect_python_module_aliases(
root: Node<'_>,
source: &[u8],
) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
if top.kind() != "import_statement" {
continue;
}
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
match child.kind() {
"dotted_name" => {
if let Some(name) = node_text(child, source) {
map.insert(name.to_string(), name.to_string());
}
}
"aliased_import" => {
let module = child
.child_by_field_name("name")
.and_then(|n| node_text(n, source));
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let (Some(module), Some(alias)) = (module, alias) {
map.insert(alias.to_string(), module.to_string());
}
}
_ => {}
}
}
}
map
}
fn node_text<'a>(node: Node<'_>, source: &'a [u8]) -> Option<&'a str> {
let start = node.start_byte();
let end = node.end_byte().min(source.len());
std::str::from_utf8(&source[start..end]).ok()
}
pub(super) fn class_attr_constructors_of(
root: Node<'_>,
source: &[u8],
aliases: &PythonAliases<'_>,
ctor_names: &HashSet<&str>,
) -> HashMap<String, HashSet<String>> {
let mut out: HashMap<String, HashSet<String>> = HashMap::new();
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if node.kind() == "class_definition" {
if let (Some(name_node), Some(body)) = (
node.child_by_field_name("name"),
node.child_by_field_name("body"),
) {
if let Some(class_name) = node_text(name_node, source) {
let class_name = class_name.to_string();
collect_self_attr_ctors(
body,
source,
aliases,
ctor_names,
&class_name,
&mut out,
);
}
}
}
let mut c = node.walk();
for child in node.children(&mut c) {
stack.push(child);
}
}
out
}
fn collect_self_attr_ctors(
node: Node<'_>,
source: &[u8],
aliases: &PythonAliases<'_>,
ctor_names: &HashSet<&str>,
class_name: &str,
out: &mut HashMap<String, HashSet<String>>,
) {
if node.kind() == "assignment" {
if let Some(attr_name) = self_attribute_name(node.child_by_field_name("left"), source) {
if let Some(rhs) = node.child_by_field_name("right") {
if rhs.kind() == "call"
&& call_resolves_to_known_ctor(rhs, source, aliases, ctor_names)
{
out.entry(class_name.to_string())
.or_default()
.insert(attr_name);
}
}
}
}
let mut c = node.walk();
for child in node.children(&mut c) {
collect_self_attr_ctors(child, source, aliases, ctor_names, class_name, out);
}
}
fn self_attribute_name(node: Option<Node<'_>>, source: &[u8]) -> Option<String> {
let n = node?;
if n.kind() != "attribute" {
return None;
}
let obj = n.child_by_field_name("object")?;
let obj_text = node_text(obj, source)?;
if obj_text != "self" {
return None;
}
let attr = n.child_by_field_name("attribute")?;
Some(node_text(attr, source)?.to_string())
}
fn call_resolves_to_known_ctor(
call: Node<'_>,
source: &[u8],
aliases: &PythonAliases<'_>,
ctor_names: &HashSet<&str>,
) -> bool {
let func = match call.child_by_field_name("function") {
Some(f) => f,
None => return false,
};
match func.kind() {
"identifier" => {
let name = match node_text(func, source) {
Some(s) => s,
None => return false,
};
if let Some(module) = aliases.imports.get(name) {
let dotted = format!("{module}.{name}");
if ctor_names.contains(dotted.as_str()) {
return true;
}
}
ctor_names.contains(name)
}
"attribute" => {
let attr = match func
.child_by_field_name("attribute")
.and_then(|n| node_text(n, source))
{
Some(s) => s,
None => return false,
};
let obj = match func.child_by_field_name("object") {
Some(o) => o,
None => return false,
};
let obj_text = node_text(obj, source).unwrap_or("");
if let Some(resolved_module) = aliases.modules.get(obj_text) {
let dotted = format!("{resolved_module}.{attr}");
if ctor_names.contains(dotted.as_str()) {
return true;
}
}
if let Some(full) = node_text(func, source) {
if ctor_names.contains(full) {
return true;
}
}
ctor_names.contains(attr)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_python(src: &str) -> tree_sitter::Tree {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.expect("load python grammar");
parser.parse(src, None).expect("parse")
}
#[test]
fn from_imports_simple() {
let src = "from hashlib import md5\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("md5"), Some(&"hashlib".to_string()));
}
#[test]
fn from_imports_with_alias() {
let src = "from hashlib import md5 as m\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("m"), Some(&"hashlib".to_string()));
assert!(!map.contains_key("md5"));
}
#[test]
fn from_imports_multi_with_alias() {
let src = "from hashlib import md5, sha1 as sha_one\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert_eq!(map.get("md5"), Some(&"hashlib".to_string()));
assert_eq!(map.get("sha_one"), Some(&"hashlib".to_string()));
}
#[test]
fn module_aliases_unaliased_identity() {
let src = "import hashlib\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("hashlib"), Some(&"hashlib".to_string()));
}
#[test]
fn module_aliases_aliased() {
let src = "import hashlib as hl\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("hl"), Some(&"hashlib".to_string()));
assert!(!map.contains_key("hashlib"));
}
#[test]
fn module_aliases_multi_in_one_statement() {
let src = "import os, sys\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os"), Some(&"os".to_string()));
assert_eq!(map.get("sys"), Some(&"sys".to_string()));
}
#[test]
fn module_aliases_multi_aliased() {
let src = "import os as o, sys as s\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("o"), Some(&"os".to_string()));
assert_eq!(map.get("s"), Some(&"sys".to_string()));
}
#[test]
fn module_aliases_dotted() {
let src = "import os.path\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os.path"), Some(&"os.path".to_string()));
}
#[test]
fn module_aliases_dotted_aliased() {
let src = "import os.path as op\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("op"), Some(&"os.path".to_string()));
}
#[test]
fn module_aliases_ignores_function_local_imports() {
let src = "def f():\n import hashlib as hl\n return hl\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert!(map.is_empty(), "got: {:?}", map);
}
#[test]
fn module_aliases_ignores_from_imports() {
let src = "from os import system\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert!(map.is_empty());
}
#[test]
fn from_imports_ignores_plain_import() {
let src = "import hashlib as hl\n";
let tree = parse_python(src);
let map = collect_python_from_imports(tree.root_node(), src.as_bytes());
assert!(map.is_empty());
}
#[test]
fn module_aliases_alias_shadows_real_module() {
let src = "import safelib as os\n";
let tree = parse_python(src);
let map = collect_python_module_aliases(tree.root_node(), src.as_bytes());
assert_eq!(map.get("os"), Some(&"safelib".to_string()));
}
fn cookie_ctor_names() -> HashSet<&'static str> {
HashSet::from([
"CookieJar",
"RequestsCookieJar",
"http.cookiejar.CookieJar",
"requests.cookies.RequestsCookieJar",
])
}
#[test]
fn class_attr_ctors_bare_name_from_import() {
let src = "from http.cookiejar import CookieJar\nclass C:\n def __init__(self):\n self.jar = CookieJar()\n";
let tree = parse_python(src);
let imports = collect_python_from_imports(tree.root_node(), src.as_bytes());
let modules = collect_python_module_aliases(tree.root_node(), src.as_bytes());
let aliases = PythonAliases::new(&imports, &modules);
let ctors = cookie_ctor_names();
let map = class_attr_constructors_of(tree.root_node(), src.as_bytes(), &aliases, &ctors);
let attrs = map.get("C").expect("C should be present");
assert!(attrs.contains("jar"), "expected `jar` in {attrs:?}");
}
#[test]
fn class_attr_ctors_module_alias() {
let src = "import http.cookiejar as cl\nclass C:\n def __init__(self):\n self.jar = cl.CookieJar()\n";
let tree = parse_python(src);
let imports = collect_python_from_imports(tree.root_node(), src.as_bytes());
let modules = collect_python_module_aliases(tree.root_node(), src.as_bytes());
let aliases = PythonAliases::new(&imports, &modules);
let ctors = cookie_ctor_names();
let map = class_attr_constructors_of(tree.root_node(), src.as_bytes(), &aliases, &ctors);
let attrs = map.get("C").expect("C should be present");
assert!(attrs.contains("jar"), "expected `jar` in {attrs:?}");
}
#[test]
fn class_attr_ctors_only_some_match() {
let src = "from http.cookiejar import CookieJar\nclass C:\n def __init__(self):\n self.jar = CookieJar()\n self.timeout = 30\n";
let tree = parse_python(src);
let imports = collect_python_from_imports(tree.root_node(), src.as_bytes());
let modules = collect_python_module_aliases(tree.root_node(), src.as_bytes());
let aliases = PythonAliases::new(&imports, &modules);
let ctors = cookie_ctor_names();
let map = class_attr_constructors_of(tree.root_node(), src.as_bytes(), &aliases, &ctors);
let attrs = map.get("C").expect("C should be present");
assert!(attrs.contains("jar"));
assert!(!attrs.contains("timeout"));
assert_eq!(attrs.len(), 1);
}
#[test]
fn class_attr_ctors_unknown_ctor_empty_map() {
let src = "from collections import OrderedDict\nclass C:\n def __init__(self):\n self.data = OrderedDict()\n";
let tree = parse_python(src);
let imports = collect_python_from_imports(tree.root_node(), src.as_bytes());
let modules = collect_python_module_aliases(tree.root_node(), src.as_bytes());
let aliases = PythonAliases::new(&imports, &modules);
let ctors = cookie_ctor_names();
let map = class_attr_constructors_of(tree.root_node(), src.as_bytes(), &aliases, &ctors);
assert!(map.is_empty(), "expected empty map, got: {map:?}");
}
}