use tree_sitter::Node;
use crate::model::{Statement, StatementKind};
pub fn cyclomatic_complexity(node: Node) -> u32 {
let mut count = 1u32;
count_branches(node, &mut count);
count
}
pub fn extract_statements<'a>(
node: Node<'a>,
source: &'a [u8],
parent_symbol: &str,
file: &str,
) -> Vec<Statement> {
let mut stmts = Vec::new();
let mut counter = 0u32;
collect_statements(
node,
source,
parent_symbol,
file,
0,
&mut stmts,
&mut counter,
);
stmts
}
fn node_condition_text<'a>(node: Node<'a>, source: &'a [u8], kind: &str) -> String {
let cond = node
.child_by_field_name("condition")
.or_else(|| node.child_by_field_name("value"))
.or_else(|| {
if kind.starts_with("for") {
node.child_by_field_name("left")
.or_else(|| node.child_by_field_name("pattern"))
} else {
None
}
});
match cond {
Some(c) => {
let text = c.utf8_text(source).unwrap_or("");
truncate_condition(text)
}
None => String::new(),
}
}
fn truncate_condition(text: &str) -> String {
let cleaned: String = text.chars().filter(|c| !c.is_control()).collect();
if cleaned.len() > 120 {
format!("{}...", &cleaned[..117])
} else {
cleaned
}
}
#[allow(clippy::only_used_in_recursion)]
fn collect_statements<'a>(
node: Node<'a>,
source: &'a [u8],
parent_symbol: &str,
file: &str,
depth: u32,
stmts: &mut Vec<Statement>,
counter: &mut u32,
) {
let kind = node.kind();
let stmt_kind = match kind {
"if_expression" | "if_statement" => Some(StatementKind::If),
"elif_clause" | "else_if_clause" => Some(StatementKind::ElseIf),
"else_clause" => Some(StatementKind::Else),
"for_statement" | "for_expression" | "for_in_statement" => Some(StatementKind::For),
"while_statement" | "while_expression" => Some(StatementKind::While),
"do_statement" => Some(StatementKind::DoWhile),
"loop_expression" => Some(StatementKind::Loop),
"match_expression" | "switch_expression" | "switch_statement" | "when_expression" => {
Some(StatementKind::Match)
}
"match_arm" | "case_clause" | "switch_case" | "arm" | "when_clause" => {
Some(StatementKind::Case)
}
"try_statement" | "try_expression" => Some(StatementKind::Try),
"catch_clause" | "except_clause" | "rescue_clause" => Some(StatementKind::Catch),
"ternary_expression" | "conditional_expression" => Some(StatementKind::Ternary),
"return_statement" => {
if depth == 1
&& node.start_position().row
< node
.parent()
.map(|p| p.end_position().row.saturating_sub(2))
.unwrap_or(0)
{
Some(StatementKind::Guard)
} else {
None
}
}
_ => None,
};
let is_branch = stmt_kind.is_some();
if let Some(sk) = stmt_kind {
let condition = node_condition_text(node, source, kind);
let id = format!("{}::stmt_{}", parent_symbol, counter);
*counter += 1;
stmts.push(Statement {
id,
kind: sk,
condition,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
depth,
parent_symbol: parent_symbol.to_string(),
});
}
let next_depth = if is_branch { depth + 1 } else { depth };
for i in 0..node.child_count() {
if let Some(child) = node.child(i as u32) {
let ck = child.kind();
if ck == "function_definition"
|| ck == "function_item"
|| ck == "function_declaration"
|| ck == "method_declaration"
|| ck == "method_definition"
|| ck == "closure_expression"
|| ck == "lambda_expression"
|| ck == "arrow_function"
|| ck == "anonymous_function_creation_expression"
{
continue;
}
collect_statements(
child,
source,
parent_symbol,
file,
next_depth,
stmts,
counter,
);
}
}
}
fn count_branches(node: Node, count: &mut u32) {
let kind = node.kind();
match kind {
"if_expression" | "if_statement" | "elif_clause" | "else_if_clause" |
"else_clause" | "when_clause" |
"for_statement" | "for_expression" | "for_in_statement" |
"while_statement" | "while_expression" |
"do_statement" | "loop_expression" |
"match_arm" | "case_clause" | "switch_case" | "arm" |
"catch_clause" | "except_clause" | "rescue_clause" |
"ternary_expression" | "conditional_expression" |
"binary_expression" => {
if kind == "binary_expression" {
let op = node.child_by_field_name("operator")
.map(|n| n.kind())
.unwrap_or("");
if op == "&&" || op == "||" || op == "and" || op == "or" || op == "??" {
*count += 1;
}
} else {
*count += 1;
}
}
"try_expression" | "propagation_expression" => *count += 1,
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i as u32) {
count_branches(child, count);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_python(source: &str) -> Vec<Statement> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
let root = tree.root_node();
let fn_node = find_first_function(root).unwrap_or(root);
extract_statements(fn_node, source.as_bytes(), "test::func", "test.py")
}
fn find_first_function(node: Node) -> Option<Node> {
if node.kind() == "function_definition" {
return Some(node);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i as u32) {
if let Some(found) = find_first_function(child) {
return Some(found);
}
}
}
None
}
#[test]
fn test_python_if_else() {
let source =
"def check(x):\n if x > 0:\n return True\n else:\n return False\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(kinds.contains(&"If"), "expected If, got {:?}", kinds);
assert!(kinds.contains(&"Else"), "expected Else, got {:?}", kinds);
}
#[test]
fn test_python_elif() {
let source = "def classify(x):\n if x > 100:\n return 'high'\n elif x > 50:\n return 'medium'\n else:\n return 'low'\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(kinds.contains(&"If"), "expected If, got {:?}", kinds);
assert!(
kinds.contains(&"ElseIf"),
"expected ElseIf, got {:?}",
kinds
);
assert!(kinds.contains(&"Else"), "expected Else, got {:?}", kinds);
}
#[test]
fn test_python_for_while() {
let source = "def process(items):\n for item in items:\n if item > 0:\n print(item)\n i = 0\n while i < 10:\n i += 1\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(kinds.contains(&"For"), "expected For, got {:?}", kinds);
assert!(kinds.contains(&"While"), "expected While, got {:?}", kinds);
assert!(kinds.contains(&"If"), "expected If, got {:?}", kinds);
}
#[test]
fn test_python_try_except() {
let source = "def process():\n try:\n do_work()\n except Exception as e:\n handle_error(e)\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(kinds.contains(&"Try"), "expected Try, got {:?}", kinds);
assert!(kinds.contains(&"Catch"), "expected Catch, got {:?}", kinds);
}
#[test]
fn test_python_ternary() {
let source = "def pick(x):\n return 'yes' if x > 0 else 'no'\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(
kinds.contains(&"Ternary"),
"expected Ternary, got {:?}",
kinds
);
}
#[test]
fn test_depth_tracking() {
let source = "def nested(x):\n if x > 0:\n if x > 10:\n if x > 100:\n do_thing()\n";
let stmts = parse_python(source);
let ifs: Vec<&Statement> = stmts
.iter()
.filter(|s| s.kind == StatementKind::If)
.collect();
assert_eq!(ifs.len(), 3, "expected 3 nested ifs, got {}", ifs.len());
assert_eq!(ifs[0].depth, 0);
assert_eq!(ifs[1].depth, 1);
assert_eq!(ifs[2].depth, 2);
}
#[test]
fn test_condition_text_extracted() {
let source = "def check(x):\n if x > 42:\n print('big')\n";
let stmts = parse_python(source);
let if_stmt = stmts
.iter()
.find(|s| s.kind == StatementKind::If)
.expect("expected If");
assert!(
if_stmt.condition.contains("x > 42"),
"expected 'x > 42', got '{}'",
if_stmt.condition
);
}
#[test]
fn test_no_statements_simple_function() {
let source = "def add(a, b):\n return a + b\n";
let stmts = parse_python(source);
assert!(
stmts.is_empty(),
"expected no statements, got {}",
stmts.len()
);
}
#[test]
fn test_statement_ids_unique() {
let source = "def multi(x):\n if x > 0:\n pass\n if x > 1:\n pass\n for i in range(x):\n pass\n";
let stmts = parse_python(source);
let ids: Vec<&str> = stmts.iter().map(|s| s.id.as_str()).collect();
let unique: std::collections::HashSet<&str> = ids.iter().cloned().collect();
assert_eq!(ids.len(), unique.len(), "IDs not unique: {:?}", ids);
}
#[test]
fn test_python_guard_early_return() {
let source = "def process(x):\n if x is None:\n return\n do_work(x)\n do_more(x)\n return x\n";
let stmts = parse_python(source);
let kinds: Vec<&str> = stmts.iter().map(|s| s.kind.as_str()).collect();
assert!(kinds.contains(&"If"), "expected If, got {:?}", kinds);
}
}