use std::collections::HashSet;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator, Tree};
use crate::graph::node::{DecoratorInfo, SymbolInfo, SymbolKind, SymbolVisibility};
const PYTHON_SYMBOL_QUERY: &str = r#"
(function_definition name: (identifier) @name) @symbol
(class_definition name: (identifier) @name) @symbol
(decorated_definition
(function_definition name: (identifier) @name)) @symbol
(decorated_definition
(class_definition name: (identifier) @name)) @symbol
(module
(expression_statement
(assignment left: (identifier) @name))) @symbol
"#;
static PY_SYMBOL_QUERY: std::sync::OnceLock<Query> = std::sync::OnceLock::new();
fn py_symbol_query(language: &Language) -> &'static Query {
PY_SYMBOL_QUERY.get_or_init(|| {
Query::new(language, PYTHON_SYMBOL_QUERY).expect("invalid Python symbol query")
})
}
fn node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
fn python_visibility(name: &str) -> SymbolVisibility {
if name.starts_with('_') {
SymbolVisibility::Private
} else {
SymbolVisibility::Pub
}
}
fn parse_python_decorator(decorator_node: Node, source: &[u8]) -> DecoratorInfo {
let inner = decorator_node.named_child(0);
match inner.map(|n| n.kind()) {
Some("identifier") => {
let name = node_text(inner.unwrap(), source).to_owned();
DecoratorInfo {
name,
object: None,
attribute: None,
args_raw: None,
framework: None,
}
}
Some("attribute") => {
let attr_node = inner.unwrap();
let obj = attr_node
.child_by_field_name("object")
.map(|n| node_text(n, source).to_owned());
let attr = attr_node
.child_by_field_name("attribute")
.map(|n| node_text(n, source).to_owned());
let name = format!(
"{}.{}",
obj.as_deref().unwrap_or(""),
attr.as_deref().unwrap_or("")
);
DecoratorInfo {
name,
object: obj,
attribute: attr,
args_raw: None,
framework: None,
}
}
Some("call") => {
let call = inner.unwrap();
let func = call.child_by_field_name("function");
let args = call
.child_by_field_name("arguments")
.map(|n| node_text(n, source).to_owned());
let (name, obj, attr) = match func.map(|f| f.kind()) {
Some("identifier") => {
let n = node_text(func.unwrap(), source).to_owned();
(n, None, None)
}
Some("attribute") => {
let f = func.unwrap();
let o = f
.child_by_field_name("object")
.map(|n| node_text(n, source).to_owned());
let a = f
.child_by_field_name("attribute")
.map(|n| node_text(n, source).to_owned());
let n = format!(
"{}.{}",
o.as_deref().unwrap_or(""),
a.as_deref().unwrap_or("")
);
(n, o, a)
}
_ => (node_text(call, source).to_owned(), None, None),
};
DecoratorInfo {
name,
object: obj,
attribute: attr,
args_raw: args,
framework: None,
}
}
_ => {
DecoratorInfo {
name: node_text(decorator_node, source).to_owned(),
object: None,
attribute: None,
args_raw: None,
framework: None,
}
}
}
}
fn extract_python_decorators(decorated_node: Node, source: &[u8]) -> Vec<DecoratorInfo> {
let mut decorators = Vec::new();
let mut cursor = decorated_node.walk();
for child in decorated_node.children(&mut cursor) {
if child.kind() == "decorator" {
decorators.push(parse_python_decorator(child, source));
}
}
decorators
}
fn extract_all_exports(root: Node, source: &[u8]) -> Option<HashSet<String>> {
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() != "expression_statement" {
continue;
}
let mut expr_cursor = child.walk();
for expr_child in child.children(&mut expr_cursor) {
if expr_child.kind() == "assignment" {
let left = expr_child.child_by_field_name("left")?;
if left.kind() == "identifier" && node_text(left, source) == "__all__" {
let right = expr_child.child_by_field_name("right")?;
return Some(collect_string_list(right, source));
}
}
}
}
None
}
fn collect_string_list(list_node: Node, source: &[u8]) -> HashSet<String> {
let mut names = HashSet::new();
let mut cursor = list_node.walk();
for child in list_node.children(&mut cursor) {
if child.kind() == "string" {
let text = node_text(child, source);
let stripped = text
.trim_start_matches('"')
.trim_start_matches('\'')
.trim_end_matches('"')
.trim_end_matches('\'');
names.insert(stripped.to_owned());
}
}
names
}
fn extract_python_class_members(class_node: Node, source: &[u8]) -> Vec<SymbolInfo> {
let mut children = Vec::new();
let block = {
let mut found = None;
let mut cursor = class_node.walk();
for child in class_node.children(&mut cursor) {
if child.kind() == "block" {
found = Some(child);
break;
}
}
match found {
Some(b) => b,
None => return children,
}
};
let mut cursor = block.walk();
for child in block.children(&mut cursor) {
match child.kind() {
"function_definition" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
children.push(SymbolInfo {
name: name.clone(),
kind: SymbolKind::Method,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
visibility: python_visibility(&name),
..Default::default()
});
}
}
"decorated_definition" => {
let mut inner_cursor = child.walk();
for inner_child in child.children(&mut inner_cursor) {
match inner_child.kind() {
"function_definition" => {
if let Some(name_node) = inner_child.child_by_field_name("name") {
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
let decorators = extract_python_decorators(child, source);
children.push(SymbolInfo {
name: name.clone(),
kind: SymbolKind::Method,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
visibility: python_visibility(&name),
decorators,
..Default::default()
});
}
}
"class_definition" => {
if let Some(name_node) = inner_child.child_by_field_name("name") {
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
let decorators = extract_python_decorators(child, source);
children.push(SymbolInfo {
name: name.clone(),
kind: SymbolKind::Class,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
visibility: python_visibility(&name),
decorators,
..Default::default()
});
}
}
_ => {}
}
}
}
"class_definition" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
children.push(SymbolInfo {
name: name.clone(),
kind: SymbolKind::Class,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
visibility: python_visibility(&name),
..Default::default()
});
}
}
_ => {}
}
}
children
}
pub fn extract_python_symbols(
tree: &Tree,
source: &[u8],
language: &Language,
) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let root = tree.root_node();
let all_exports_opt = extract_all_exports(root, source);
let query = py_symbol_query(language);
let name_idx = query
.capture_index_for_name("name")
.expect("python symbol query must have @name");
let symbol_idx = query
.capture_index_for_name("symbol")
.expect("python symbol query must have @symbol");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, root, source);
let mut results: Vec<(SymbolInfo, Vec<SymbolInfo>)> = Vec::new();
while let Some(m) = matches.next() {
let mut name_node: Option<Node> = None;
let mut symbol_node: Option<Node> = None;
for capture in m.captures {
if capture.index == name_idx {
name_node = Some(capture.node);
} else if capture.index == symbol_idx {
symbol_node = Some(capture.node);
}
}
let (name_n, sym_n) = match (name_node, symbol_node) {
(Some(n), Some(s)) => (n, s),
_ => continue,
};
let name = node_text(name_n, source).to_owned();
let sym_kind = sym_n.kind();
if (sym_kind == "function_definition" || sym_kind == "class_definition")
&& let Some(parent) = sym_n.parent()
&& parent.kind() == "decorated_definition"
{
continue;
}
let kind = match sym_kind {
"function_definition" => SymbolKind::Function,
"class_definition" => SymbolKind::Class,
"decorated_definition" => {
let mut inner_kind = SymbolKind::Function;
let mut inner_cursor = sym_n.walk();
for child in sym_n.children(&mut inner_cursor) {
match child.kind() {
"function_definition" => {
inner_kind = SymbolKind::Function;
break;
}
"class_definition" => {
inner_kind = SymbolKind::Class;
break;
}
_ => {}
}
}
inner_kind
}
"module" => {
SymbolKind::Variable
}
"expression_statement" => {
SymbolKind::Variable
}
_ => continue,
};
let def_node = if sym_kind == "module" {
find_assignment_node(root, name_n)
} else {
Some(sym_n)
};
let def_node = match def_node {
Some(n) => n,
None => continue,
};
let decorators = if sym_kind == "decorated_definition" {
extract_python_decorators(sym_n, source)
} else {
Vec::new()
};
let visibility = python_visibility(&name);
let is_exported = match &all_exports_opt {
Some(all_exports) => all_exports.contains(&name),
None => !name.starts_with('_'),
};
let pos = name_n.start_position();
let line = pos.row + 1;
let col = pos.column;
let line_end = def_node.end_position().row + 1;
let symbol = SymbolInfo {
name: name.clone(),
kind: kind.clone(),
line,
col,
line_end,
is_exported,
is_default: false,
visibility,
trait_impl: None,
decorators,
};
let children = if kind == SymbolKind::Class {
let class_node = if sym_kind == "decorated_definition" {
let mut found = None;
let mut c = sym_n.walk();
for child in sym_n.children(&mut c) {
if child.kind() == "class_definition" {
found = Some(child);
break;
}
}
found.unwrap_or(sym_n)
} else {
sym_n
};
extract_python_class_members(class_node, source)
} else {
Vec::new()
};
results.push((symbol, children));
}
results.extend(extract_type_aliases(root, source, &all_exports_opt));
results
}
fn find_assignment_node<'a>(root: Node<'a>, name_node: Node<'a>) -> Option<Node<'a>> {
let assignment = name_node.parent()?;
if assignment.kind() == "assignment" {
let expr_stmt = assignment.parent()?;
if expr_stmt.kind() == "expression_statement"
&& let Some(p) = expr_stmt.parent()
&& p.id() == root.id()
{
return Some(expr_stmt);
}
}
None
}
fn extract_type_aliases(
root: Node,
source: &[u8],
all_exports_opt: &Option<HashSet<String>>,
) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let mut results = Vec::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() == "type_alias_statement"
&& let Some(name) = extract_type_alias_name(child, source)
{
let pos_node = child.child_by_field_name("left").unwrap_or(child);
let pos = pos_node.start_position();
let visibility = python_visibility(&name);
let is_exported = match all_exports_opt {
Some(all_exports) => all_exports.contains(&name),
None => !name.starts_with('_'),
};
results.push((
SymbolInfo {
name: name.clone(),
kind: SymbolKind::TypeAlias,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
is_exported,
is_default: false,
visibility,
trait_impl: None,
decorators: Vec::new(),
},
Vec::new(),
));
}
}
results
}
fn extract_type_alias_name<'a>(node: Node<'a>, source: &'a [u8]) -> Option<String> {
let left = node.child_by_field_name("left")?;
let inner = left.named_child(0)?;
let text = node_text(inner, source);
if !text.is_empty() {
Some(text.to_owned())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::languages::language_for_extension;
fn parse_py(source: &str) -> (Tree, Language) {
let lang = language_for_extension("py").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
fn extract(source: &str) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let (tree, lang) = parse_py(source);
extract_python_symbols(&tree, source.as_bytes(), &lang)
}
#[test]
fn test_python_function() {
let src = "def hello():\n pass\n";
let syms = extract(src);
assert_eq!(
syms.len(),
1,
"expected 1 symbol, got {}: {:?}",
syms.len(),
syms.iter().map(|(s, _)| &s.name).collect::<Vec<_>>()
);
let (sym, children) = &syms[0];
assert_eq!(sym.name, "hello");
assert_eq!(sym.kind, SymbolKind::Function);
assert_eq!(sym.visibility, SymbolVisibility::Pub);
assert!(sym.is_exported);
assert!(children.is_empty());
}
#[test]
fn test_python_async_function() {
let src = "async def fetch():\n pass\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.name, "fetch");
assert_eq!(sym.kind, SymbolKind::Function);
assert_eq!(sym.visibility, SymbolVisibility::Pub);
assert!(sym.is_exported);
}
#[test]
fn test_python_class() {
let src = "class MyClass:\n pass\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.name, "MyClass");
assert_eq!(sym.kind, SymbolKind::Class);
assert_eq!(sym.visibility, SymbolVisibility::Pub);
assert!(sym.is_exported);
}
#[test]
fn test_python_assignment() {
let src = "MAX_SIZE = 100\n";
let syms = extract(src);
assert_eq!(syms.len(), 1, "expected 1 symbol");
let (sym, _) = &syms[0];
assert_eq!(sym.name, "MAX_SIZE");
assert_eq!(sym.kind, SymbolKind::Variable);
assert_eq!(sym.visibility, SymbolVisibility::Pub);
assert!(sym.is_exported);
}
#[test]
fn test_python_type_alias() {
let src = "type Alias = int\n";
let syms = extract(src);
assert_eq!(syms.len(), 1, "expected 1 symbol");
let (sym, _) = &syms[0];
assert_eq!(sym.name, "Alias");
assert_eq!(sym.kind, SymbolKind::TypeAlias);
}
#[test]
fn test_python_visibility_private() {
let src = "_helper = 1\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.name, "_helper");
assert_eq!(sym.visibility, SymbolVisibility::Private);
assert!(!sym.is_exported);
}
#[test]
fn test_python_visibility_dunder() {
let src = "__secret = 1\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.name, "__secret");
assert_eq!(sym.visibility, SymbolVisibility::Private);
assert!(!sym.is_exported);
}
#[test]
fn test_python_dunder_method() {
let src = "class MyClass:\n def __init__(self):\n pass\n";
let syms = extract(src);
let (class_sym, children) = &syms[0];
assert_eq!(class_sym.name, "MyClass");
assert_eq!(children.len(), 1);
let method = &children[0];
assert_eq!(method.name, "__init__");
assert_eq!(method.kind, SymbolKind::Method);
assert_eq!(method.visibility, SymbolVisibility::Private);
}
#[test]
fn test_python_all_exports() {
let src = "__all__ = [\"Foo\"]\n\nclass Foo:\n pass\n\nclass Bar:\n pass\n";
let syms = extract(src);
let foo = syms.iter().find(|(s, _)| s.name == "Foo").unwrap();
let bar = syms.iter().find(|(s, _)| s.name == "Bar").unwrap();
let all = syms.iter().find(|(s, _)| s.name == "__all__");
assert!(foo.0.is_exported, "Foo should be exported");
assert!(!bar.0.is_exported, "Bar should NOT be exported");
if let Some((all_sym, _)) = all {
assert_eq!(all_sym.kind, SymbolKind::Variable);
}
}
#[test]
fn test_python_all_exports_private() {
let src = "__all__ = [\"_helper\"]\n\n_helper = 1\n";
let syms = extract(src);
let helper = syms.iter().find(|(s, _)| s.name == "_helper").unwrap();
assert!(
helper.0.is_exported,
"is_exported should be true (in __all__)"
);
assert_eq!(
helper.0.visibility,
SymbolVisibility::Private,
"visibility should be Private"
);
}
#[test]
fn test_python_decorated_function() {
let src = "@decorator\ndef foo():\n pass\n";
let syms = extract(src);
assert_eq!(syms.len(), 1, "expected 1 symbol (no duplicate)");
let (sym, _) = &syms[0];
assert_eq!(sym.name, "foo");
assert_eq!(sym.kind, SymbolKind::Function);
assert_eq!(sym.decorators.len(), 1);
assert_eq!(sym.decorators[0].name, "decorator");
assert!(sym.decorators[0].object.is_none());
assert!(sym.decorators[0].attribute.is_none());
}
#[test]
fn test_python_stacked_decorators() {
let src = "@first\n@second\ndef foo():\n pass\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.decorators.len(), 2);
assert_eq!(sym.decorators[0].name, "first");
assert_eq!(sym.decorators[1].name, "second");
}
#[test]
fn test_python_attribute_decorator() {
let src = "@app.route(\"/api\")\ndef handler():\n pass\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert_eq!(sym.decorators.len(), 1);
let dec = &sym.decorators[0];
assert_eq!(dec.object.as_deref(), Some("app"));
assert_eq!(dec.attribute.as_deref(), Some("route"));
assert!(
dec.args_raw.is_some(),
"args_raw should be present for call decorator"
);
}
#[test]
fn test_python_line_end() {
let src = "def long_func():\n x = 1\n y = 2\n return x + y\n";
let syms = extract(src);
assert_eq!(syms.len(), 1);
let (sym, _) = &syms[0];
assert!(
sym.line_end > sym.line,
"line_end ({}) should be > line ({})",
sym.line_end,
sym.line
);
}
#[test]
fn test_python_no_duplicate_decorated() {
let src = "@my_decorator\ndef process():\n pass\n";
let syms = extract(src);
let count = syms.iter().filter(|(s, _)| s.name == "process").count();
assert_eq!(
count, 1,
"decorated function should appear exactly once, got {}",
count
);
}
#[test]
fn test_python_nested_class_methods() {
let src = "class Animal:\n def speak(self):\n pass\n def move(self):\n pass\n";
let syms = extract(src);
let (class_sym, children) = &syms[0];
assert_eq!(class_sym.name, "Animal");
assert_eq!(class_sym.kind, SymbolKind::Class);
assert_eq!(children.len(), 2, "expected 2 methods");
let names: Vec<_> = children.iter().map(|c| c.name.as_str()).collect();
assert!(names.contains(&"speak"));
assert!(names.contains(&"move"));
for child in children {
assert_eq!(child.kind, SymbolKind::Method);
}
}
#[test]
fn test_python_module_level_only_assignments() {
let src = "def my_func():\n local_var = 42\n return local_var\n\ntop_level = 1\n";
let syms = extract(src);
let names: Vec<_> = syms.iter().map(|(s, _)| s.name.as_str()).collect();
assert!(names.contains(&"my_func"), "should have my_func");
assert!(names.contains(&"top_level"), "should have top_level");
assert!(
!names.contains(&"local_var"),
"should NOT have local_var (function-body assignment)"
);
}
}