use crate::models::{Class, Function};
use crate::parsers::{ImportInfo, ParseResult};
use anyhow::{Context, Result};
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::OnceLock;
use tree_sitter::{Node, Parser, Query, QueryCursor, StreamingIterator};
thread_local! {
static PY_PARSER: RefCell<Parser> = RefCell::new({
let mut p = Parser::new();
p.set_language(&tree_sitter_python::LANGUAGE.into()).expect("Python language");
p
});
}
const PY_FUNC_QUERY_STR: &str = r#"
(module
(function_definition
name: (identifier) @func_name
parameters: (parameters) @params
return_type: (type)? @return_type
) @func
)
(module
(decorated_definition
(function_definition
name: (identifier) @func_name
parameters: (parameters) @params
return_type: (type)? @return_type
) @func
)
)
"#;
static PY_FUNC_QUERY: OnceLock<Query> = OnceLock::new();
fn extract_decorators(node: &Node, source: &[u8]) -> Vec<String> {
let mut decorators = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "decorator" {
let mut inner_cursor = child.walk();
for inner in child.children(&mut inner_cursor) {
if inner.kind() != "@" && inner.kind() != "comment" {
let text = inner.utf8_text(source).unwrap_or("");
let name = text.split('(').next().unwrap_or(text).trim();
if !name.is_empty() {
decorators.push(name.to_string());
}
break;
}
}
}
}
decorators
}
#[allow(dead_code)]
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
parse_source_with_tree(source, path).map(|(r, _)| r)
}
pub fn parse_source_with_tree(source: &str, path: &Path) -> Result<(ParseResult, tree_sitter::Tree)> {
let tree = PY_PARSER.with(|cell| {
cell.borrow_mut().parse(source, None)
}).context("Failed to parse Python source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_functions(&root, source_bytes, path, &mut result)?;
let class_nodes = extract_classes(&root, source_bytes, path, &mut result)?;
extract_class_methods(&class_nodes, source_bytes, path, &mut result)?;
extract_imports(&root, source_bytes, &mut result)?;
extract_calls(&root, source_bytes, path, &mut result)?;
annotate_exports(&root, source_bytes, &mut result);
Ok((result, tree))
}
fn annotate_exports(root: &Node, source: &[u8], result: &mut ParseResult) {
let all_names = extract_dunder_all(root, source);
for func in &mut result.functions {
let name_part = func.qualified_name.rsplit("::").next().unwrap_or("");
if name_part.contains('.') {
continue;
}
let is_exported = if let Some(ref names) = all_names {
names.contains(&func.name)
} else {
!func.name.starts_with('_')
};
if is_exported && !func.annotations.contains(&"exported".to_string()) {
func.annotations.push("exported".to_string());
}
}
for class in &mut result.classes {
let is_exported = if let Some(ref names) = all_names {
names.contains(&class.name)
} else {
!class.name.starts_with('_')
};
if is_exported && !class.annotations.contains(&"exported".to_string()) {
class.annotations.push("exported".to_string());
}
}
}
fn extract_dunder_all(root: &Node, source: &[u8]) -> Option<HashSet<String>> {
let mut cursor = root.walk();
for node in root.children(&mut cursor) {
if node.kind() == "expression_statement" {
for child in node.children(&mut node.walk()) {
if child.kind() == "assignment" {
let left = child.child_by_field_name("left");
let right = child.child_by_field_name("right");
if let (Some(left_node), Some(right_node)) = (left, right) {
let left_text = left_node.utf8_text(source).unwrap_or("");
if left_text == "__all__" {
return Some(extract_string_list(&right_node, source));
}
}
}
}
}
}
None
}
fn extract_string_list(node: &Node, source: &[u8]) -> HashSet<String> {
let mut names = HashSet::new();
if node.kind() == "list" || node.kind() == "tuple" {
for child in node.children(&mut node.walk()) {
if child.kind() == "string" {
let text = child.utf8_text(source).unwrap_or("");
let stripped = text
.trim_start_matches(['\'', '"'])
.trim_end_matches(['\'', '"']);
if !stripped.is_empty() {
names.insert(stripped.to_string());
}
}
}
}
names
}
fn extract_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query = PY_FUNC_QUERY.get_or_init(|| {
Query::new(&tree_sitter_python::LANGUAGE.into(), PY_FUNC_QUERY_STR)
.expect("valid Python function query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, source);
while let Some(m) = matches.next() {
let mut func_node = None;
let mut name = String::new();
let mut params_node = None;
let mut return_type_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"func" => func_node = Some(capture.node),
"func_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
"params" => params_node = Some(capture.node),
"return_type" => return_type_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = func_node {
let line_text = {
let start = node.start_byte();
let line_start = source[..start]
.iter()
.rposition(|&b| b == b'\n')
.map_or(0, |i| i + 1);
std::str::from_utf8(&source[line_start..start + 10.min(source.len() - start)])
.unwrap_or("")
};
let is_async = line_text.trim_start().starts_with("async");
let parameters = extract_parameters(params_node, source);
let return_type =
return_type_node.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
let annotations = if let Some(parent) = node.parent() {
if parent.kind() == "decorated_definition" {
extract_decorators(&parent, source)
} else {
vec![]
}
} else {
vec![]
};
result.functions.push(Function {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(&node, source)),
max_nesting: None,
doc_comment: None,
annotations,
});
}
}
extract_async_functions(root, source, path, result)?;
Ok(())
}
fn extract_async_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let mut cursor = root.walk();
for node in root.children(&mut cursor) {
if node.kind() == "async_function_definition" {
if let Some(func) = parse_function_node(&node, source, path, true) {
if !result
.functions
.iter()
.any(|f| f.qualified_name == func.qualified_name)
{
result.functions.push(func);
}
}
} else if node.kind() == "decorated_definition" {
for child in node.children(&mut node.walk()) {
if child.kind() == "async_function_definition" {
if let Some(func) = parse_function_node(&child, source, path, true) {
if !result
.functions
.iter()
.any(|f| f.qualified_name == func.qualified_name)
{
result.functions.push(func);
}
}
}
}
}
}
Ok(())
}
fn parse_function_node(
node: &Node,
source: &[u8],
path: &Path,
is_async: bool,
) -> Option<Function> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let params_node = node.child_by_field_name("parameters");
let parameters = extract_parameters(params_node, source);
let return_type = node
.child_by_field_name("return_type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
let annotations = if let Some(parent) = node.parent() {
if parent.kind() == "decorated_definition" {
extract_decorators(&parent, source)
} else {
vec![]
}
} else {
vec![]
};
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(node, source)),
max_nesting: None,
doc_comment: None,
annotations,
})
}
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(node) = params_node else {
return vec![];
};
let mut params = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" => {
if let Ok(text) = child.utf8_text(source) {
params.push(text.to_string());
}
}
"typed_parameter" | "default_parameter" | "typed_default_parameter" => {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(text) = name_node.utf8_text(source) {
params.push(text.to_string());
}
} else {
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "identifier" {
if let Ok(text) = grandchild.utf8_text(source) {
params.push(text.to_string());
break;
}
}
}
}
}
"list_splat_pattern" | "dictionary_splat_pattern" => {
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "identifier" {
if let Ok(text) = grandchild.utf8_text(source) {
let prefix = if child.kind() == "list_splat_pattern" {
"*"
} else {
"**"
};
params.push(format!("{}{}", prefix, text));
break;
}
}
}
}
_ => {}
}
}
params
}
fn extract_classes<'a>(
root: &Node<'a>,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<Vec<(String, Node<'a>)>> {
let mut cursor = root.walk();
let mut class_nodes = Vec::new();
for node in root.children(&mut cursor) {
let class_node = if node.kind() == "class_definition" {
Some(node)
} else if node.kind() == "decorated_definition" {
node.children(&mut node.walk())
.find(|c| c.kind() == "class_definition")
} else {
None
};
if let Some(class_node) = class_node {
if let Some(class) = parse_class_node(&class_node, source, path) {
let name = class.name.clone();
result.classes.push(class);
class_nodes.push((name, class_node));
}
}
}
Ok(class_nodes)
}
fn parse_class_node(node: &Node, source: &[u8], path: &Path) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
let bases = extract_bases(node, source);
let methods = extract_methods(node, source);
let annotations = if let Some(parent) = node.parent() {
if parent.kind() == "decorated_definition" {
extract_decorators(&parent, source)
} else {
vec![]
}
} else {
vec![]
};
Some(Class {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
field_count: 0,
bases,
doc_comment: None,
annotations,
})
}
fn extract_bases(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut bases = Vec::new();
for child in class_node.children(&mut class_node.walk()) {
if child.kind() == "argument_list" {
for arg in child.children(&mut child.walk()) {
if let Some(base_name) = extract_base_name(&arg, source) {
bases.push(base_name);
}
}
}
}
bases
}
fn extract_base_name(node: &Node, source: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => node.utf8_text(source).ok().map(|s| s.to_string()),
"attribute" => {
node.utf8_text(source).ok().map(|s| s.to_string())
}
"subscript" => {
node.child_by_field_name("value")
.and_then(|n| extract_base_name(&n, source))
}
"keyword_argument" => None, "(" | ")" | "," => None, _ => None,
}
}
fn extract_methods(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
let body = class_node.child_by_field_name("body").or_else(|| {
class_node
.children(&mut class_node.walk())
.find(|c| c.kind() == "block")
});
if let Some(body) = body {
for child in body.children(&mut body.walk()) {
let func_node = if child.kind() == "function_definition"
|| child.kind() == "async_function_definition"
{
Some(child)
} else if child.kind() == "decorated_definition" {
child.children(&mut child.walk()).find(|c| {
c.kind() == "function_definition" || c.kind() == "async_function_definition"
})
} else {
None
};
if let Some(func) = func_node {
if let Some(name_node) = func.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(name.to_string());
}
}
}
}
}
methods
}
fn extract_class_methods(
class_nodes: &[(String, Node)],
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
for (class_name, class_node) in class_nodes {
let Some(body) = class_node.child_by_field_name("body") else {
continue;
};
for child in body.children(&mut body.walk()) {
let func_node = if child.kind() == "function_definition"
|| child.kind() == "async_function_definition"
{
Some(child)
} else if child.kind() == "decorated_definition" {
child.children(&mut child.walk()).find(|c| {
c.kind() == "function_definition" || c.kind() == "async_function_definition"
})
} else {
None
};
let Some(func) = func_node else {
continue;
};
let is_async = func.kind() == "async_function_definition"
|| func
.utf8_text(source)
.is_ok_and(|t| t.trim_start().starts_with("async"));
let Some(name_node) = func.child_by_field_name("name") else {
continue;
};
let Ok(name) = name_node.utf8_text(source) else {
continue;
};
let params_node = func.child_by_field_name("parameters");
let parameters = extract_parameters(params_node, source);
let return_type = func
.child_by_field_name("return_type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = func.start_position().row as u32 + 1;
let line_end = func.end_position().row as u32 + 1;
let qualified_name =
format!("{}::{}.{}:{}", path.display(), class_name, name, line_start);
let annotations = if let Some(parent) = func.parent() {
if parent.kind() == "decorated_definition" {
extract_decorators(&parent, source)
} else {
vec![]
}
} else {
vec![]
};
result.functions.push(Function {
name: name.to_string(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(&func, source)),
max_nesting: None,
doc_comment: None,
annotations,
});
}
}
Ok(())
}
fn extract_imports(root: &Node, source: &[u8], result: &mut ParseResult) -> Result<()> {
let mut cursor = root.walk();
for node in root.children(&mut cursor) {
match node.kind() {
"import_statement" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "dotted_name" {
if let Ok(text) = child.utf8_text(source) {
result.imports.push(ImportInfo::runtime(text.to_string()));
}
} else if child.kind() == "aliased_import" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(text) = name_node.utf8_text(source) {
result.imports.push(ImportInfo::runtime(text.to_string()));
}
}
}
}
}
"import_from_statement" => {
if let Some(module_node) = node.child_by_field_name("module_name") {
if let Ok(module) = module_node.utf8_text(source) {
result.imports.push(ImportInfo::runtime(module.to_string()));
}
} else {
for child in node.children(&mut node.walk()) {
if child.kind() == "dotted_name" || child.kind() == "relative_import" {
if let Ok(text) = child.utf8_text(source) {
result.imports.push(ImportInfo::runtime(text.to_string()));
}
break;
}
}
}
}
_ => {}
}
}
Ok(())
}
fn extract_calls(root: &Node, source: &[u8], path: &Path, result: &mut ParseResult) -> Result<()> {
let mut scope_map: HashMap<(u32, u32), String> = HashMap::new();
for func in &result.functions {
scope_map.insert(
(func.line_start, func.line_end),
func.qualified_name.clone(),
);
}
extract_calls_recursive(root, source, path, &scope_map, result);
Ok(())
}
fn extract_calls_recursive(
node: &Node,
source: &[u8],
path: &Path,
scope_map: &HashMap<(u32, u32), String>,
result: &mut ParseResult,
) {
if node.kind() == "call" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map)
.unwrap_or_else(|| path.display().to_string());
if let Some(func_node) = node.child_by_field_name("function") {
if let Some(callee) = extract_call_target(&func_node, source) {
if !callee.starts_with("self.") || !caller.contains(&callee.replace("self.", "")) {
result.calls.push((caller, callee));
}
}
}
}
for child in node.children(&mut node.walk()) {
extract_calls_recursive(&child, source, path, scope_map, result);
}
}
fn find_containing_scope(line: u32, scope_map: &HashMap<(u32, u32), String>) -> Option<String> {
super::find_containing_scope(line, scope_map)
}
fn extract_call_target(node: &Node, source: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => node.utf8_text(source).ok().map(|s| s.to_string()),
"attribute" => {
node.utf8_text(source).ok().map(|s| s.to_string())
}
"subscript" => {
node.child_by_field_name("value")
.and_then(|n| extract_call_target(&n, source))
}
"call" => {
node.child_by_field_name("function")
.and_then(|n| extract_call_target(&n, source))
}
_ => None,
}
}
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_statement" | "elif_clause" | "while_statement" | "for_statement" => {
*complexity += 1;
}
"except_clause" => {
*complexity += 1;
}
"boolean_operator" => {
*complexity += 1;
}
"conditional_expression" => {
*complexity += 1;
}
"list_comprehension" | "dictionary_comprehension" | "set_comprehension" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "if_clause" {
*complexity += 1;
}
}
}
"match_statement" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "case_clause" {
*complexity += 1;
}
}
}
"try_statement" => {}
"with_statement" => {
*complexity += 1;
}
"assert_statement" => {
*complexity += 1;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
def hello(name: str) -> str:
"""Greet someone."""
return f"Hello, {name}!"
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse simple function");
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "hello");
assert_eq!(func.parameters, vec!["name"]);
assert!(!func.is_async);
assert_eq!(func.line_start, 2);
}
#[test]
fn test_parse_async_function() {
let source = r#"
async def fetch_data(url: str) -> bytes:
return await http.get(url)
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse async function");
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "fetch_data");
assert!(func.is_async);
}
#[test]
fn test_parse_class() {
let source = r#"
class MyClass(BaseClass, Mixin):
def __init__(self):
pass
def method(self, x):
return x * 2
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse class");
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "MyClass");
assert_eq!(class.bases, vec!["BaseClass", "Mixin"]);
assert_eq!(class.methods, vec!["__init__", "method"]);
}
#[test]
fn test_parse_imports() {
let source = r#"
import os
import sys
from pathlib import Path
from typing import List, Optional
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse imports");
assert!(result.imports.iter().any(|i| i.path == "os"));
assert!(result.imports.iter().any(|i| i.path == "sys"));
assert!(result.imports.iter().any(|i| i.path == "pathlib"));
assert!(result.imports.iter().any(|i| i.path == "typing"));
}
#[test]
fn test_parse_calls() {
let source = r#"
def caller():
result = some_function()
other_function(result)
return result
def some_function():
return 42
def other_function(x):
print(x)
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse calls");
assert!(!result.calls.is_empty());
let call_targets: Vec<&str> = result.calls.iter().map(|(_, t)| t.as_str()).collect();
assert!(call_targets.contains(&"some_function"));
assert!(call_targets.contains(&"other_function"));
}
#[test]
fn test_complexity_calculation() {
let source = r#"
def complex_function(x):
if x > 0:
if x > 10:
return "big"
else:
return "small positive"
elif x < 0:
return "negative"
else:
return "zero"
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse complex function");
let func = &result.functions[0];
assert!(func.complexity.expect("should have complexity") >= 4);
}
#[test]
fn test_parse_decorated_function() {
let source = r#"
@decorator
def decorated():
pass
@property
def prop(self):
return self._value
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse decorated function");
assert_eq!(result.functions.len(), 2);
}
#[test]
fn test_parse_star_args() {
let source = r#"
def varargs(*args, **kwargs):
pass
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse star args");
let func = &result.functions[0];
assert!(func.parameters.contains(&"*args".to_string()));
assert!(func.parameters.contains(&"**kwargs".to_string()));
}
#[test]
fn test_method_count_excludes_nested() {
let source = r#"
class DataProcessor:
def __init__(self):
self.handlers = []
def process(self, items):
# These should NOT be counted as methods:
inner_helper = lambda x: x * 2
results = list(map(lambda item: item.strip(), items))
def local_transform(val):
return val.upper()
return [local_transform(r) for r in results]
def register(self, handler):
self.handlers.append(handler)
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse nested methods");
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "DataProcessor");
assert_eq!(
class.methods.len(),
3,
"Expected 3 methods (__init__, process, register), got {:?}",
class.methods
);
assert!(class.methods.contains(&"__init__".to_string()));
assert!(class.methods.contains(&"process".to_string()));
assert!(class.methods.contains(&"register".to_string()));
}
#[test]
fn test_decorated_methods_counted_correctly() {
let source = r#"
class MyClass:
@property
def value(self):
return self._value
@staticmethod
def create():
return MyClass()
@classmethod
def from_string(cls, s):
return cls()
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse decorated methods");
let class = &result.classes[0];
assert_eq!(
class.methods.len(),
3,
"Expected 3 methods (value, create, from_string), got {:?}",
class.methods
);
}
#[test]
fn test_export_detection_python() {
let code = r#"
__all__ = ['public_func', 'PublicClass']
def public_func():
pass
def _private_func():
pass
class PublicClass:
pass
"#;
let path = PathBuf::from("test.py");
let result = parse_source(code, &path).expect("should parse exports");
let public = result
.functions
.iter()
.find(|f| f.name == "public_func")
.unwrap();
assert!(
public.annotations.iter().any(|a| a == "exported"),
"public_func should be exported, annotations: {:?}",
public.annotations
);
let private = result
.functions
.iter()
.find(|f| f.name == "_private_func")
.unwrap();
assert!(
!private.annotations.iter().any(|a| a == "exported"),
"_private_func should NOT be exported"
);
let public_class = result
.classes
.iter()
.find(|c| c.name == "PublicClass")
.unwrap();
assert!(
public_class.annotations.iter().any(|a| a == "exported"),
"PublicClass should be exported, annotations: {:?}",
public_class.annotations
);
}
#[test]
fn test_export_detection_python_no_all() {
let code = r#"
def public_func():
pass
def _private_func():
pass
class PublicClass:
pass
class _PrivateClass:
pass
"#;
let path = PathBuf::from("test.py");
let result = parse_source(code, &path).expect("should parse exports without __all__");
let public = result
.functions
.iter()
.find(|f| f.name == "public_func")
.unwrap();
assert!(
public.annotations.iter().any(|a| a == "exported"),
"public_func should be exported (no __all__), annotations: {:?}",
public.annotations
);
let private = result
.functions
.iter()
.find(|f| f.name == "_private_func")
.unwrap();
assert!(
!private.annotations.iter().any(|a| a == "exported"),
"_private_func should NOT be exported"
);
let public_class = result
.classes
.iter()
.find(|c| c.name == "PublicClass")
.unwrap();
assert!(
public_class.annotations.iter().any(|a| a == "exported"),
"PublicClass should be exported (no __all__), annotations: {:?}",
public_class.annotations
);
let private_class = result
.classes
.iter()
.find(|c| c.name == "_PrivateClass")
.unwrap();
assert!(
!private_class.annotations.iter().any(|a| a == "exported"),
"_PrivateClass should NOT be exported"
);
}
#[test]
fn test_decorator_extraction() {
let code = r#"
@app.route('/users')
def get_users():
return []
@login_required
@cache(timeout=300)
def admin_page():
pass
class MyModel:
pass
@dataclass
class UserDTO:
name: str
"#;
let path = PathBuf::from("test.py");
let result = parse_source(code, &path).expect("should parse decorators");
let get_users = result
.functions
.iter()
.find(|f| f.name == "get_users")
.unwrap();
assert!(
get_users
.annotations
.iter()
.any(|a| a.contains("app.route")),
"get_users should have app.route annotation, got: {:?}",
get_users.annotations
);
let admin = result
.functions
.iter()
.find(|f| f.name == "admin_page")
.unwrap();
assert!(
admin.annotations.len() >= 2,
"admin_page should have 2+ annotations, got: {:?}",
admin.annotations
);
let my_model = result
.classes
.iter()
.find(|c| c.name == "MyModel")
.unwrap();
assert!(
my_model
.annotations
.iter()
.all(|a| a == "exported"),
"MyModel should only have 'exported' annotation (no decorators), got: {:?}",
my_model.annotations
);
let user_dto = result
.classes
.iter()
.find(|c| c.name == "UserDTO")
.unwrap();
assert!(
user_dto
.annotations
.iter()
.any(|a| a.contains("dataclass")),
"UserDTO should have dataclass annotation, got: {:?}",
user_dto.annotations
);
}
#[test]
fn test_class_methods_as_function_entries() {
let source = r#"
class MyClass:
def __init__(self, value):
self._value = value
def process(self, data):
return data * self._value
async def fetch(self, url):
return await http.get(url)
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse class methods");
let method_names: Vec<&str> = result.functions.iter().map(|f| f.name.as_str()).collect();
assert!(
method_names.contains(&"__init__"),
"missing __init__, got: {:?}",
method_names
);
assert!(
method_names.contains(&"process"),
"missing process, got: {:?}",
method_names
);
assert!(
method_names.contains(&"fetch"),
"missing fetch, got: {:?}",
method_names
);
let init = result
.functions
.iter()
.find(|f| f.name == "__init__")
.unwrap();
assert!(
init.qualified_name.contains("MyClass.__init__"),
"expected ClassName.method format, got: {}",
init.qualified_name
);
let fetch = result
.functions
.iter()
.find(|f| f.name == "fetch")
.unwrap();
assert!(fetch.is_async, "fetch should be async");
assert!(init.parameters.contains(&"self".to_string()));
assert!(init.parameters.contains(&"value".to_string()));
}
#[test]
fn test_decorated_class_methods_as_function_entries() {
let source = r#"
class MyView:
@property
def value(self):
return self._value
@staticmethod
def create():
return MyView()
@app.route('/api')
async def handle(self):
pass
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse decorated methods");
let value = result
.functions
.iter()
.find(|f| f.name == "value")
.unwrap();
assert!(
value.annotations.iter().any(|a| a == "property"),
"value should have @property, got: {:?}",
value.annotations
);
let create = result
.functions
.iter()
.find(|f| f.name == "create")
.unwrap();
assert!(
create.annotations.iter().any(|a| a == "staticmethod"),
"create should have @staticmethod, got: {:?}",
create.annotations
);
let handle = result
.functions
.iter()
.find(|f| f.name == "handle")
.unwrap();
assert!(handle.is_async, "handle should be async");
assert!(
handle.annotations.iter().any(|a| a.contains("app.route")),
"handle should have @app.route, got: {:?}",
handle.annotations
);
}
#[test]
fn test_class_methods_not_individually_exported() {
let source = r#"
__all__ = ['MyClass']
class MyClass:
def process(self):
pass
def _internal(self):
pass
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse");
let process = result
.functions
.iter()
.find(|f| f.name == "process")
.unwrap();
assert!(
!process.annotations.iter().any(|a| a == "exported"),
"class method 'process' should not be individually exported, got: {:?}",
process.annotations
);
}
#[test]
fn test_class_methods_have_complexity() {
let source = r#"
class Handler:
def handle(self, request):
if request.method == "GET":
if request.user:
return self.get(request)
else:
return self.unauthorized()
elif request.method == "POST":
return self.post(request)
return self.not_found()
"#;
let path = PathBuf::from("test.py");
let result = parse_source(source, &path).expect("should parse");
let handle = result
.functions
.iter()
.find(|f| f.name == "handle")
.unwrap();
let complexity = handle.complexity.expect("method should have complexity");
assert!(
complexity >= 4,
"handle should have complexity >= 4, got: {}",
complexity
);
}
}