#![cfg_attr(coverage_nightly, coverage(off))]
#[cfg(feature = "python-ast")]
use crate::services::context::AstItem;
#[cfg(feature = "python-ast")]
use std::path::{Path, PathBuf};
#[cfg(feature = "python-ast")]
use tree_sitter::{Node, Tree};
#[cfg(feature = "python-ast")]
pub struct EnhancedPythonVisitor {
items: Vec<AstItem>,
_file_path: PathBuf,
module_path: Vec<String>,
class_stack: Vec<String>,
source: String,
}
#[cfg(feature = "python-ast")]
impl EnhancedPythonVisitor {
#[must_use]
pub fn new(file_path: &Path, source: &str) -> Self {
Self {
items: Vec::new(),
_file_path: file_path.to_path_buf(),
module_path: Vec::new(),
class_stack: Vec::new(),
source: source.to_string(),
}
}
#[must_use]
pub fn extract_items(mut self, tree: &Tree) -> Vec<AstItem> {
let root = tree.root_node();
self.visit_node(&root);
self.items
}
fn get_qualified_name(&self, name: &str) -> String {
let mut parts = Vec::new();
parts.extend(self.module_path.iter().cloned());
parts.extend(self.class_stack.iter().cloned());
parts.push(name.to_string());
parts.join("::")
}
fn get_line(&self, node: &Node) -> usize {
node.start_position().row + 1
}
fn visit_node(&mut self, node: &Node) {
match node.kind() {
"function_definition" => self.visit_function_def(node),
"class_definition" => self.visit_class_def(node),
_ => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.visit_node(&child);
}
}
}
}
fn visit_function_def(&mut self, node: &Node) {
if let Some(name_node) = node.child_by_field_name("name") {
let name = &self.source[name_node.byte_range()];
let qualified_name = self.get_qualified_name(name);
let line = self.get_line(node);
let is_async = node.parent().is_some_and(|p| p.kind() == "module");
self.items.push(AstItem::Function {
name: qualified_name,
visibility: "public".to_string(), is_async,
line,
});
}
if let Some(body) = node.child_by_field_name("body") {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
self.visit_node(&child);
}
}
}
fn visit_class_def(&mut self, node: &Node) {
if let Some(name_node) = node.child_by_field_name("name") {
let name = &self.source[name_node.byte_range()];
let qualified_name = self.get_qualified_name(name);
let line = self.get_line(node);
let fields_count = if let Some(body) = node.child_by_field_name("body") {
let mut count = 0;
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
if child.kind() == "function_definition" {
count += 1;
}
}
count
} else {
0
};
self.items.push(AstItem::Struct {
name: qualified_name,
visibility: "public".to_string(),
fields_count,
derives: vec![], line,
});
self.class_stack.push(name.to_string());
if let Some(body) = node.child_by_field_name("body") {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
self.visit_node(&child);
}
}
self.class_stack.pop();
}
}
}
#[cfg(all(test, feature = "python-ast"))]
mod tests {
use super::*;
use std::path::Path;
use tree_sitter::Parser as TsParser;
fn parse_python(code: &str) -> Tree {
let mut parser = TsParser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.expect("Failed to set Python language");
parser
.parse(code, None)
.expect("Failed to parse Python code")
}
#[test]
fn test_simple_function() {
let code = r#"
def hello_world():
print("Hello, World!")
"#;
let tree = parse_python(code);
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), code);
let items = visitor.extract_items(&tree);
assert_eq!(items.len(), 1);
if let AstItem::Function { name, .. } = &items[0] {
assert_eq!(name, "hello_world");
} else {
panic!("Expected function item");
}
}
#[test]
fn test_async_function() {
let code = r#"
async def async_hello():
await some_task()
"#;
let tree = parse_python(code);
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), code);
let items = visitor.extract_items(&tree);
assert_eq!(items.len(), 1);
if let AstItem::Function { name, .. } = &items[0] {
assert_eq!(name, "async_hello");
} else {
panic!("Expected async function item");
}
}
#[test]
fn test_class_with_methods() {
let code = r#"
class Calculator:
def add(self, a, b):
return a + b
async def multiply_async(self, a, b):
return a * b
"#;
let tree = parse_python(code);
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), code);
let items = visitor.extract_items(&tree);
assert_eq!(items.len(), 3);
if let AstItem::Struct {
name, fields_count, ..
} = &items[0]
{
assert_eq!(name, "Calculator");
assert_eq!(*fields_count, 2); } else {
panic!("Expected class item");
}
if let AstItem::Function { name, .. } = &items[1] {
assert_eq!(name, "Calculator::add");
} else {
panic!("Expected method item");
}
if let AstItem::Function { name, .. } = &items[2] {
assert_eq!(name, "Calculator::multiply_async");
} else {
panic!("Expected async method item");
}
}
#[test]
fn test_nested_functions() {
let code = r#"
def outer_function():
def inner_function():
pass
inner_function()
"#;
let tree = parse_python(code);
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), code);
let items = visitor.extract_items(&tree);
assert_eq!(items.len(), 2);
if let AstItem::Function { name, .. } = &items[0] {
assert_eq!(name, "outer_function");
} else {
panic!("Expected outer function");
}
if let AstItem::Function { name, .. } = &items[1] {
assert_eq!(name, "inner_function");
} else {
panic!("Expected inner function");
}
}
#[test]
fn test_complex_qualified_names() {
let code = r#"
class Database:
class Connection:
def connect(self):
pass
async def disconnect(self):
pass
"#;
let tree = parse_python(code);
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), code);
let items = visitor.extract_items(&tree);
assert_eq!(items.len(), 4);
let names: Vec<String> = items
.iter()
.map(|item| match item {
AstItem::Function { name, .. } => name.clone(),
AstItem::Struct { name, .. } => name.clone(),
_ => "unknown".to_string(),
})
.collect();
assert!(names.contains(&"Database".to_string()));
assert!(names.contains(&"Database::Connection".to_string()));
assert!(names.contains(&"Database::Connection::connect".to_string()));
assert!(names.contains(&"Database::Connection::disconnect".to_string()));
}
}
#[cfg(all(test, feature = "python-ast"))]
mod property_tests {
use super::*;
use proptest::prelude::*;
use tree_sitter::Parser as TsParser;
fn try_parse_python(code: &str) -> Option<Tree> {
let mut parser = TsParser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.ok()?;
parser.parse(code, None)
}
proptest! {
#[test]
fn test_visitor_handles_any_valid_python(
func_name in "[a-zA-Z_][a-zA-Z0-9_]*",
class_name in "[a-zA-Z_][a-zA-Z0-9_]*"
) {
let code = format!(r#"
class {}:
def {}(self):
pass
"#, class_name, func_name);
if let Some(tree) = try_parse_python(&code) {
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), &code);
let items = visitor.extract_items(&tree);
prop_assert!(items.len() >= 2);
let has_real_names = items.iter().any(|item| match item {
AstItem::Function { name, .. } => !name.starts_with("function_"),
AstItem::Struct { name, .. } => !name.starts_with("class_"),
_ => true,
});
prop_assert!(has_real_names);
}
}
#[test]
fn test_visitor_complexity_bounds(
function_count in 1usize..10,
) {
let mut code = String::new();
for i in 0..function_count {
code.push_str(&format!("def function_{}(): pass\n", i));
}
if let Some(tree) = try_parse_python(&code) {
let visitor = EnhancedPythonVisitor::new(Path::new("test.py"), &code);
let items = visitor.extract_items(&tree);
prop_assert_eq!(items.len(), function_count);
for (i, item) in items.iter().enumerate() {
if let AstItem::Function { name, .. } = item {
prop_assert_eq!(name, &format!("function_{}", i));
} else {
prop_assert!(false, "Expected function item");
}
}
}
}
}
}