use crate::config::ParserConfig;
use crate::visitor::{extract_decorators, extract_docstring};
use codegraph_parser_api::{
CallRelation, ClassEntity, CodeIR, ComplexityBuilder, ComplexityMetrics, FunctionEntity,
ImportRelation, InheritanceRelation, ModuleEntity, Parameter,
};
use std::path::Path;
use tree_sitter::{Node, Parser};
pub fn extract(source: &str, file_path: &Path, config: &ParserConfig) -> Result<CodeIR, String> {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_python::language())
.map_err(|e| format!("Failed to set language: {e}"))?;
let tree = parser
.parse(source, None)
.ok_or_else(|| "Failed to parse".to_string())?;
let root_node = tree.root_node();
if root_node.has_error() {
let mut cursor = root_node.walk();
for child in root_node.children(&mut cursor) {
if child.is_error() || child.has_error() {
return Err(format!(
"Syntax error at line {}, column {}: {}",
child.start_position().row + 1,
child.start_position().column,
file_path.display()
));
}
}
return Err(format!("Syntax error in {}", file_path.display()));
}
let source_bytes = source.as_bytes();
let mut ir = CodeIR::new(file_path.to_path_buf());
let module_name = file_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("module")
.to_string();
let line_count = source.lines().count();
let module = ModuleEntity::new(
module_name.clone(),
file_path.display().to_string(),
"python",
)
.with_line_count(line_count);
ir.set_module(module);
let mut cursor = root_node.walk();
for child in root_node.children(&mut cursor) {
match child.kind() {
"function_definition" => {
if let Some(func) = extract_function(source_bytes, child, config, None) {
let calls =
extract_calls_from_node(source_bytes, child, &func.name, func.line_start);
for call in calls {
ir.add_call(call);
}
ir.add_function(func);
}
}
"decorated_definition" => {
if let Some(definition) = find_definition_in_decorated(child) {
match definition.kind() {
"function_definition" => {
if let Some(func) =
extract_function(source_bytes, definition, config, None)
{
let calls = extract_calls_from_node(
source_bytes,
definition,
&func.name,
func.line_start,
);
for call in calls {
ir.add_call(call);
}
ir.add_function(func);
}
}
"class_definition" => {
if let Some((class, methods, calls, inheritance)) =
extract_class(source_bytes, definition, config)
{
for method in methods {
ir.add_function(method);
}
for call in calls {
ir.add_call(call);
}
for inh in inheritance {
ir.add_inheritance(inh);
}
ir.add_class(class);
}
}
_ => {}
}
}
}
"class_definition" => {
if let Some((class, methods, calls, inheritance)) =
extract_class(source_bytes, child, config)
{
for method in methods {
ir.add_function(method);
}
for call in calls {
ir.add_call(call);
}
for inh in inheritance {
ir.add_inheritance(inh);
}
ir.add_class(class);
}
}
"import_statement" => {
let imports = extract_import(source_bytes, child, &module_name);
for import in imports {
ir.add_import(import);
}
}
"import_from_statement" => {
let imports = extract_import_from(source_bytes, child, &module_name);
for import in imports {
ir.add_import(import);
}
}
_ => {}
}
}
Ok(ir)
}
fn find_definition_in_decorated(node: Node) -> Option<Node> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"function_definition" | "class_definition" => return Some(child),
_ => {}
}
}
None
}
fn extract_function(
source: &[u8],
node: Node,
config: &ParserConfig,
parent_class: Option<&str>,
) -> Option<FunctionEntity> {
let name = node
.child_by_field_name("name")
.map(|n| n.utf8_text(source).unwrap_or("unknown").to_string())?;
let is_dunder = name.starts_with("__") && name.ends_with("__") && name.len() > 4;
if !config.include_private && name.starts_with('_') && !is_dunder {
return None;
}
if !config.include_tests && (name.starts_with("test_") || name.starts_with("Test")) {
return None;
}
let line_start = node.start_position().row + 1;
let line_end = node.end_position().row + 1;
let is_async = node
.parent()
.map(|p| p.kind() == "decorated_definition")
.unwrap_or(false)
|| has_async_keyword(source, node);
let parameters = extract_parameters(source, node);
let return_type = node
.child_by_field_name("return_type")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let doc_comment = node
.child_by_field_name("body")
.and_then(|body| extract_docstring(source, body));
let decorators = if let Some(parent) = node.parent() {
if parent.kind() == "decorated_definition" {
extract_decorators(source, parent)
} else {
Vec::new()
}
} else {
Vec::new()
};
let is_static = decorators.iter().any(|d| d.contains("staticmethod"));
let is_test = decorators
.iter()
.any(|d| d.contains("test") || d.contains("pytest"));
let complexity = node
.child_by_field_name("body")
.map(|body| calculate_complexity_from_node(source, body));
let mut func = FunctionEntity::new(&name, line_start, line_end);
func.parameters = parameters;
func.return_type = return_type;
func.doc_comment = doc_comment;
func.is_async = is_async;
func.is_static = is_static;
func.is_test = is_test;
func.complexity = complexity;
if let Some(class_name) = parent_class {
func.parent_class = Some(class_name.to_string());
}
Some(func)
}
fn has_async_keyword(source: &[u8], node: Node) -> bool {
if let Some(first_child) = node.child(0) {
let text = first_child.utf8_text(source).unwrap_or("");
return text == "async";
}
false
}
fn extract_parameters(source: &[u8], node: Node) -> Vec<Parameter> {
let mut params = Vec::new();
if let Some(params_node) = node.child_by_field_name("parameters") {
let mut cursor = params_node.walk();
for child in params_node.children(&mut cursor) {
match child.kind() {
"identifier" => {
let name = child.utf8_text(source).unwrap_or("unknown").to_string();
params.push(Parameter {
name,
type_annotation: None,
default_value: None,
is_variadic: false,
});
}
"typed_parameter" => {
let name = child
.child_by_field_name("name")
.or_else(|| child.child(0))
.map(|n| n.utf8_text(source).unwrap_or("unknown").to_string())
.unwrap_or_else(|| "unknown".to_string());
let type_annotation = child
.child_by_field_name("type")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
params.push(Parameter {
name,
type_annotation,
default_value: None,
is_variadic: false,
});
}
"default_parameter" => {
let name = child
.child_by_field_name("name")
.or_else(|| child.child(0))
.map(|n| n.utf8_text(source).unwrap_or("unknown").to_string())
.unwrap_or_else(|| "unknown".to_string());
let type_annotation = child
.child_by_field_name("type")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let default_value = child
.child_by_field_name("value")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
params.push(Parameter {
name,
type_annotation,
default_value,
is_variadic: false,
});
}
"typed_default_parameter" => {
let name = child
.child_by_field_name("name")
.or_else(|| child.child(0))
.map(|n| n.utf8_text(source).unwrap_or("unknown").to_string())
.unwrap_or_else(|| "unknown".to_string());
let type_annotation = child
.child_by_field_name("type")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let default_value = child
.child_by_field_name("value")
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
params.push(Parameter {
name,
type_annotation,
default_value,
is_variadic: false,
});
}
"list_splat_pattern" | "dictionary_splat_pattern" => {
let name = child
.child(1)
.map(|n| n.utf8_text(source).unwrap_or("unknown").to_string())
.unwrap_or_else(|| "args".to_string());
params.push(Parameter {
name,
type_annotation: None,
default_value: None,
is_variadic: true,
});
}
_ => {}
}
}
}
params
}
fn extract_class(
source: &[u8],
node: Node,
config: &ParserConfig,
) -> Option<(
ClassEntity,
Vec<FunctionEntity>,
Vec<CallRelation>,
Vec<InheritanceRelation>,
)> {
let name = node
.child_by_field_name("name")
.map(|n| n.utf8_text(source).unwrap_or("Class").to_string())?;
let line_start = node.start_position().row + 1;
let line_end = node.end_position().row + 1;
let mut inheritance = Vec::new();
if let Some(bases) = node.child_by_field_name("superclasses") {
let mut cursor = bases.walk();
for child in bases.children(&mut cursor) {
if let Some(base_name) = extract_base_class_name(source, child) {
inheritance.push(InheritanceRelation::new(&name, base_name));
}
}
}
let doc_comment = node
.child_by_field_name("body")
.and_then(|body| extract_docstring(source, body));
let mut methods = Vec::new();
let mut calls = Vec::new();
if let Some(body) = node.child_by_field_name("body") {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
match child.kind() {
"function_definition" => {
if let Some(method) = extract_function(source, child, config, Some(&name)) {
let method_qualified_name = format!("{}.{}", name, method.name);
let method_calls = extract_calls_from_node(
source,
child,
&method_qualified_name,
method.line_start,
);
calls.extend(method_calls);
methods.push(method);
}
}
"decorated_definition" => {
if let Some(definition) = find_definition_in_decorated(child) {
if definition.kind() == "function_definition" {
if let Some(method) =
extract_function(source, definition, config, Some(&name))
{
let method_qualified_name = format!("{}.{}", name, method.name);
let method_calls = extract_calls_from_node(
source,
definition,
&method_qualified_name,
method.line_start,
);
calls.extend(method_calls);
methods.push(method);
}
}
}
}
_ => {}
}
}
}
let mut class = ClassEntity::new(&name, line_start, line_end);
class.doc_comment = doc_comment;
class.methods = methods.clone();
Some((class, methods, calls, inheritance))
}
fn extract_base_class_name(source: &[u8], node: Node) -> Option<String> {
match node.kind() {
"identifier" => Some(node.utf8_text(source).unwrap_or("").to_string()),
"attribute" => {
Some(node.utf8_text(source).unwrap_or("").to_string())
}
_ => None,
}
}
fn extract_calls_from_node(
source: &[u8],
node: Node,
caller_name: &str,
line_offset: usize,
) -> Vec<CallRelation> {
let mut calls = Vec::new();
extract_calls_recursive(source, node, caller_name, line_offset, &mut calls);
calls
}
fn extract_calls_recursive(
source: &[u8],
node: Node,
caller_name: &str,
line_offset: usize,
calls: &mut Vec<CallRelation>,
) {
if node.kind() == "call" {
if let Some(func_node) = node.child_by_field_name("function") {
let callee_name = extract_callee_name(source, func_node);
if !callee_name.is_empty() {
let call_line = node.start_position().row + 1;
calls.push(CallRelation::new(caller_name, &callee_name, call_line));
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_calls_recursive(source, child, caller_name, line_offset, calls);
}
}
fn extract_callee_name(source: &[u8], node: Node) -> String {
match node.kind() {
"identifier" => node.utf8_text(source).unwrap_or("").to_string(),
"attribute" => {
node.utf8_text(source).unwrap_or("").to_string()
}
_ => String::new(),
}
}
fn extract_import(source: &[u8], node: Node, importer: &str) -> Vec<ImportRelation> {
let mut imports = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
let module_name = if child.kind() == "aliased_import" {
child
.child_by_field_name("name")
.map(|n| n.utf8_text(source).unwrap_or("").to_string())
} else {
Some(child.utf8_text(source).unwrap_or("").to_string())
};
let alias = if child.kind() == "aliased_import" {
child
.child_by_field_name("alias")
.map(|n| n.utf8_text(source).unwrap_or("").to_string())
} else {
None
};
if let Some(module) = module_name {
let mut import_rel = ImportRelation::new(importer, &module);
if let Some(a) = alias {
import_rel = import_rel.with_alias(&a);
}
imports.push(import_rel);
}
}
}
imports
}
fn extract_import_from(source: &[u8], node: Node, importer: &str) -> Vec<ImportRelation> {
let from_module = node
.child_by_field_name("module_name")
.map(|n| n.utf8_text(source).unwrap_or(".").to_string())
.unwrap_or_else(|| ".".to_string());
let mut symbols = Vec::new();
let mut is_wildcard = false;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"wildcard_import" => {
is_wildcard = true;
}
"dotted_name" | "identifier" => {
if child.start_byte()
> node
.child_by_field_name("module_name")
.map_or(0, |n| n.end_byte())
{
symbols.push(child.utf8_text(source).unwrap_or("").to_string());
}
}
"aliased_import" => {
if let Some(name_node) = child.child_by_field_name("name") {
symbols.push(name_node.utf8_text(source).unwrap_or("").to_string());
}
}
_ => {}
}
}
if is_wildcard {
vec![ImportRelation::new(importer, &from_module).wildcard()]
} else if !symbols.is_empty() {
vec![ImportRelation::new(importer, &from_module).with_symbols(symbols)]
} else {
vec![ImportRelation::new(importer, &from_module)]
}
}
fn calculate_complexity_from_node(source: &[u8], node: Node) -> ComplexityMetrics {
let mut builder = ComplexityBuilder::new();
calculate_complexity_recursive(source, node, &mut builder);
builder.build()
}
fn calculate_complexity_recursive(source: &[u8], node: Node, builder: &mut ComplexityBuilder) {
match node.kind() {
"if_statement" => {
builder.add_branch();
builder.enter_scope();
if let Some(body) = node.child_by_field_name("consequence") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"elif_clause" => {
builder.add_branch();
builder.enter_scope();
if let Some(body) = child.child_by_field_name("consequence") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
}
"else_clause" => {
builder.add_branch();
builder.enter_scope();
if let Some(body) = child.child_by_field_name("body") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
}
_ => {}
}
}
if let Some(condition) = node.child_by_field_name("condition") {
count_logical_operators(source, condition, builder);
}
}
"while_statement" => {
builder.add_loop();
builder.enter_scope();
if let Some(body) = node.child_by_field_name("body") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
if let Some(condition) = node.child_by_field_name("condition") {
count_logical_operators(source, condition, builder);
}
}
"for_statement" => {
builder.add_loop();
builder.enter_scope();
if let Some(body) = node.child_by_field_name("body") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
}
"with_statement" => {
builder.enter_scope();
if let Some(body) = node.child_by_field_name("body") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
}
"try_statement" => {
builder.enter_scope();
if let Some(body) = node.child_by_field_name("body") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "except_clause" {
builder.add_exception_handler();
builder.enter_scope();
let mut except_cursor = child.walk();
for except_child in child.children(&mut except_cursor) {
if except_child.kind() == "block" {
calculate_complexity_recursive(source, except_child, builder);
}
}
builder.exit_scope();
} else if child.kind() == "finally_clause" {
builder.enter_scope();
let mut finally_cursor = child.walk();
for finally_child in child.children(&mut finally_cursor) {
if finally_child.kind() == "block" {
calculate_complexity_recursive(source, finally_child, builder);
}
}
builder.exit_scope();
}
}
}
"match_statement" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "case_clause" {
builder.add_branch();
builder.enter_scope();
if let Some(body) = child.child_by_field_name("consequence") {
calculate_complexity_recursive(source, body, builder);
}
builder.exit_scope();
}
}
}
"boolean_operator" => {
builder.add_logical_operator();
}
"conditional_expression" => {
builder.add_branch();
}
"list_comprehension"
| "set_comprehension"
| "dictionary_comprehension"
| "generator_expression" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "for_in_clause" {
builder.add_loop();
}
if child.kind() == "if_clause" {
builder.add_branch();
}
}
}
_ => {}
}
if !matches!(
node.kind(),
"if_statement" | "while_statement" | "for_statement" | "try_statement" | "match_statement"
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
calculate_complexity_recursive(source, child, builder);
}
}
}
fn count_logical_operators(source: &[u8], node: Node, builder: &mut ComplexityBuilder) {
if node.kind() == "boolean_operator" {
builder.add_logical_operator();
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
count_logical_operators(source, child, builder);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_ir_new() {
let path = Path::new("test.py");
let ir = CodeIR::new(path.to_path_buf());
assert_eq!(ir.entity_count(), 0);
assert_eq!(ir.relationship_count(), 0);
}
#[test]
fn test_extract_simple_function() {
let source = r#"
def greet(name):
print(f"Hello, {name}")
return name.upper()
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
assert_eq!(ir.functions[0].name, "greet");
assert_eq!(ir.functions[0].line_start, 2);
}
#[test]
fn test_extract_class_with_methods() {
let source = r#"
class Calculator:
def add(self, a, b):
return a + b
def multiply(self, a, b):
return a * b
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.classes.len(), 1);
assert_eq!(ir.classes[0].name, "Calculator");
assert_eq!(ir.classes[0].line_start, 2);
}
#[test]
fn test_extract_calls() {
let source = r#"
def main():
greet("World")
result = greet("Alice")
def greet(name):
print(f"Hello, {name}")
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 2);
assert!(ir.calls.len() >= 2, "Should find at least 2 calls");
}
#[test]
fn test_extract_imports() {
let source = r#"
import os
import sys
from pathlib import Path
from typing import List, Dict
from collections import *
def main():
pass
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert!(
ir.imports.len() >= 4,
"Should find at least 4 import statements"
);
}
#[test]
fn test_extract_inheritance() {
let source = r#"
class Animal:
def move(self):
pass
class Dog(Animal):
def bark(self):
pass
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.classes.len(), 2);
assert_eq!(ir.inheritance.len(), 1);
assert_eq!(ir.inheritance[0].child, "Dog");
assert_eq!(ir.inheritance[0].parent, "Animal");
}
#[test]
fn test_complexity_simple_function() {
let source = r#"
def simple():
return 1
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
let func = &ir.functions[0];
assert!(func.complexity.is_some());
let complexity = func.complexity.as_ref().unwrap();
assert_eq!(complexity.cyclomatic_complexity, 1);
}
#[test]
fn test_complexity_with_branches() {
let source = r#"
def branching(x):
if x > 0:
return 1
elif x < 0:
return -1
else:
return 0
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
let func = &ir.functions[0];
let complexity = func.complexity.as_ref().unwrap();
assert!(complexity.branches >= 3);
}
#[test]
fn test_complexity_with_loops() {
let source = r#"
def loopy(items):
total = 0
for item in items:
while item > 0:
total += 1
item -= 1
return total
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
let func = &ir.functions[0];
let complexity = func.complexity.as_ref().unwrap();
assert_eq!(complexity.loops, 2);
}
#[test]
fn test_complexity_with_logical_operators() {
let source = r#"
def complex_condition(a, b, c):
if a > 0 and b > 0 or c > 0:
return True
return False
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
let func = &ir.functions[0];
let complexity = func.complexity.as_ref().unwrap();
assert!(complexity.logical_operators >= 2);
}
#[test]
fn test_complexity_with_try_except() {
let source = r#"
def risky():
try:
result = dangerous_operation()
except ValueError:
result = 0
except TypeError:
result = -1
return result
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
let func = &ir.functions[0];
let complexity = func.complexity.as_ref().unwrap();
assert_eq!(complexity.exception_handlers, 2);
}
#[test]
fn test_accurate_line_numbers() {
let source = "def first():\n pass\n\ndef second():\n pass";
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 2);
assert_eq!(ir.functions[0].name, "first");
assert_eq!(ir.functions[0].line_start, 1);
assert_eq!(ir.functions[1].name, "second");
assert_eq!(ir.functions[1].line_start, 4);
}
#[test]
fn test_async_function() {
let source = r#"
async def fetch_data():
return "data"
"#;
let path = Path::new("test.py");
let config = ParserConfig::default();
let ir = extract(source, path, &config).unwrap();
assert_eq!(ir.functions.len(), 1);
}
}