use crate::types::complexity::{CodeSmell, ComplexityGrade, ComplexityMetrics};
use tree_sitter::{Node, Parser};
const LONG_FUNCTION_THRESHOLD: usize = 50;
const DEEP_NESTING_THRESHOLD: u8 = 4;
const TOO_MANY_PARAMS_THRESHOLD: usize = 5;
pub fn compute_complexity_rust(content: &str) -> Option<ComplexityMetrics> {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.ok()?;
let tree = parser.parse(content, None)?;
let root = tree.root_node();
let src = content.as_bytes();
let mut state = WalkState::default();
walk_rust(root, src, 0, &mut state);
let cyclomatic = state.cyclomatic.saturating_add(1);
let cognitive = state.cognitive;
let grade = ComplexityGrade::from_cyclomatic(cyclomatic);
let smells = detect_smells_rust(root, src, &state);
tracing::debug!(
cyclomatic,
cognitive,
?grade,
max_nesting = state.max_nesting,
"compute_complexity_rust"
);
Some(ComplexityMetrics {
cyclomatic,
cognitive,
grade,
smells,
})
}
pub fn compute_complexity_typescript(content: &str) -> Option<ComplexityMetrics> {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
.ok()?;
let tree = parser.parse(content, None)?;
let root = tree.root_node();
let src = content.as_bytes();
let mut state = WalkState::default();
walk_ts(root, src, 0, &mut state);
let cyclomatic = state.cyclomatic.saturating_add(1);
let cognitive = state.cognitive;
let grade = ComplexityGrade::from_cyclomatic(cyclomatic);
let smells = detect_smells_ts(root, src, &state);
tracing::debug!(
cyclomatic,
cognitive,
?grade,
max_nesting = state.max_nesting,
"compute_complexity_typescript"
);
Some(ComplexityMetrics {
cyclomatic,
cognitive,
grade,
smells,
})
}
#[derive(Default)]
struct WalkState {
cyclomatic: u32,
cognitive: u32,
max_nesting: u8,
}
impl WalkState {
fn note_branch(&mut self, depth: u8) {
self.cyclomatic = self.cyclomatic.saturating_add(1);
let weight = (depth as u32).saturating_add(1);
self.cognitive = self.cognitive.saturating_add(weight);
}
}
fn walk_rust(node: Node, src: &[u8], depth: u8, state: &mut WalkState) {
state.max_nesting = state.max_nesting.max(depth);
let kind = node.kind();
let mut nest_inc: u8 = 0;
match kind {
"if_expression" => {
state.note_branch(depth);
nest_inc = 1;
}
"else_clause" if has_child_kind(node, "if_expression") => {
state.note_branch(depth);
}
"else_clause" => {}
"match_arm" if !is_first_match_arm(node) => {
state.note_branch(depth);
}
"match_arm" => {}
"match_expression" => {
nest_inc = 1;
}
"while_expression" | "loop_expression" | "for_expression" => {
state.note_branch(depth);
nest_inc = 1;
}
"binary_expression" if is_short_circuit_op(node, src) => {
state.note_branch(depth);
}
"binary_expression" => {}
"try_expression" => {
state.note_branch(depth);
}
"closure_expression" => {
state.note_branch(depth);
nest_inc = 1;
}
_ => {}
}
let new_depth = depth.saturating_add(nest_inc);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_rust(child, src, new_depth, state);
}
}
fn walk_ts(node: Node, src: &[u8], depth: u8, state: &mut WalkState) {
state.max_nesting = state.max_nesting.max(depth);
let kind = node.kind();
let mut nest_inc: u8 = 0;
match kind {
"if_statement" => {
state.note_branch(depth);
nest_inc = 1;
}
"else_clause" if has_child_kind(node, "if_statement") => {
state.note_branch(depth);
}
"else_clause" => {}
"switch_case" if !is_first_switch_case(node) => {
state.note_branch(depth);
}
"switch_case" => {}
"switch_statement" => {
nest_inc = 1;
}
"while_statement" | "do_statement" | "for_statement" | "for_in_statement"
| "for_of_statement" => {
state.note_branch(depth);
nest_inc = 1;
}
"binary_expression" if is_short_circuit_op(node, src) => {
state.note_branch(depth);
}
"binary_expression" => {}
"ternary_expression" => {
state.note_branch(depth);
}
"arrow_function" | "function_expression" => {
state.note_branch(depth);
nest_inc = 1;
}
"catch_clause" => {
state.note_branch(depth);
nest_inc = 1;
}
_ => {}
}
let new_depth = depth.saturating_add(nest_inc);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_ts(child, src, new_depth, state);
}
}
fn is_short_circuit_op(node: Node, src: &[u8]) -> bool {
if let Some(op) = node.child_by_field_name("operator") {
let txt = op.utf8_text(src).unwrap_or("");
return txt == "&&" || txt == "||";
}
false
}
fn is_first_match_arm(node: Node) -> bool {
let Some(parent) = node.parent() else {
return true;
};
let mut cursor = parent.walk();
for child in parent.children(&mut cursor) {
if child.kind() == "match_arm" {
return child.id() == node.id();
}
}
true
}
fn is_first_switch_case(node: Node) -> bool {
let Some(parent) = node.parent() else {
return true;
};
let mut cursor = parent.walk();
for child in parent.children(&mut cursor) {
if child.kind() == "switch_case" {
return child.id() == node.id();
}
}
true
}
fn has_child_kind(node: Node, kind: &str) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == kind {
return true;
}
}
false
}
fn detect_smells_rust(root: Node, src: &[u8], state: &WalkState) -> Vec<CodeSmell> {
let mut smells = Vec::new();
let fn_node = find_first_kind(root, "function_item");
let lines = if let Some(n) = fn_node {
n.end_position().row.saturating_sub(n.start_position().row) + 1
} else {
line_count(src)
};
if lines > LONG_FUNCTION_THRESHOLD {
smells.push(CodeSmell::LongFunction { lines });
}
if state.max_nesting > DEEP_NESTING_THRESHOLD {
smells.push(CodeSmell::DeepNesting {
max_depth: state.max_nesting,
});
}
if let Some(fn_n) = fn_node {
let params = fn_n
.child_by_field_name("parameters")
.map(|p| count_named_children_kind(p, "parameter"))
.unwrap_or(0);
if params > TOO_MANY_PARAMS_THRESHOLD {
smells.push(CodeSmell::TooManyParams { count: params });
}
if !has_rust_doc(fn_n, src) {
smells.push(CodeSmell::MissingDocstring);
}
} else if !contains_doc_marker(src) {
smells.push(CodeSmell::MissingDocstring);
}
smells
}
fn detect_smells_ts(root: Node, src: &[u8], state: &WalkState) -> Vec<CodeSmell> {
let mut smells = Vec::new();
let fn_node = find_first_kind(root, "function_declaration")
.or_else(|| find_first_kind(root, "method_definition"))
.or_else(|| find_first_kind(root, "arrow_function"));
let lines = if let Some(n) = fn_node {
n.end_position().row.saturating_sub(n.start_position().row) + 1
} else {
line_count(src)
};
if lines > LONG_FUNCTION_THRESHOLD {
smells.push(CodeSmell::LongFunction { lines });
}
if state.max_nesting > DEEP_NESTING_THRESHOLD {
smells.push(CodeSmell::DeepNesting {
max_depth: state.max_nesting,
});
}
if let Some(fn_n) = fn_node {
let params = fn_n
.child_by_field_name("parameters")
.map(count_param_children)
.unwrap_or(0);
if params > TOO_MANY_PARAMS_THRESHOLD {
smells.push(CodeSmell::TooManyParams { count: params });
}
if !has_jsdoc(fn_n, src) {
smells.push(CodeSmell::MissingDocstring);
}
} else if !contains_doc_marker(src) {
smells.push(CodeSmell::MissingDocstring);
}
smells
}
fn count_param_children(params: Node) -> usize {
let mut count = 0;
let mut cursor = params.walk();
for child in params.children(&mut cursor) {
match child.kind() {
"required_parameter" | "optional_parameter" | "rest_pattern" | "identifier"
| "assignment_pattern" | "object_pattern" | "array_pattern" => count += 1,
_ => {}
}
}
count
}
fn count_named_children_kind(node: Node, kind: &str) -> usize {
let mut count = 0;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == kind {
count += 1;
}
}
count
}
fn find_first_kind<'a>(root: Node<'a>, kind: &str) -> Option<Node<'a>> {
let mut stack = vec![root];
while let Some(n) = stack.pop() {
if n.kind() == kind {
return Some(n);
}
let mut cursor = n.walk();
for child in n.children(&mut cursor) {
stack.push(child);
}
}
None
}
fn has_rust_doc(fn_node: Node, src: &[u8]) -> bool {
let mut sib = fn_node.prev_sibling();
while let Some(s) = sib {
match s.kind() {
"line_comment" => {
let txt = s.utf8_text(src).unwrap_or("");
if txt.starts_with("///") || txt.starts_with("//!") {
return true;
}
sib = s.prev_sibling();
}
"block_comment" => {
let txt = s.utf8_text(src).unwrap_or("");
if txt.starts_with("/**") || txt.starts_with("/*!") {
return true;
}
sib = s.prev_sibling();
}
"attribute_item" | "inner_attribute_item" => {
sib = s.prev_sibling();
}
_ => break,
}
}
false
}
fn has_jsdoc(fn_node: Node, src: &[u8]) -> bool {
let mut sib = fn_node.prev_sibling();
while let Some(s) = sib {
if s.kind() == "comment" {
let txt = s.utf8_text(src).unwrap_or("");
if txt.starts_with("/**") {
return true;
}
sib = s.prev_sibling();
} else {
break;
}
}
false
}
fn contains_doc_marker(src: &[u8]) -> bool {
let s = std::str::from_utf8(src).unwrap_or("");
s.contains("///") || s.contains("/**") || s.contains("\"\"\"") || s.contains("'''")
}
fn line_count(src: &[u8]) -> usize {
let s = std::str::from_utf8(src).unwrap_or("");
s.lines().count().max(1)
}
fn generic_language(lang: &str) -> Option<(tree_sitter::Language, &'static [&'static str])> {
match lang {
"python" => Some((
tree_sitter_python::LANGUAGE.into(),
&[
"if_statement",
"elif_clause",
"for_statement",
"while_statement",
"except_clause",
"with_statement",
"boolean_operator",
"conditional_expression",
],
)),
"java" => Some((
tree_sitter_java::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"enhanced_for_statement",
"while_statement",
"do_statement",
"catch_clause",
"switch_label",
"binary_expression",
"ternary_expression",
],
)),
"kotlin" => Some((
tree_sitter_kotlin_ng::LANGUAGE.into(),
&[
"if_expression",
"for_statement",
"while_statement",
"do_while_statement",
"catch_block",
"when_entry",
"conjunction_expression",
"disjunction_expression",
],
)),
"go" => Some((
tree_sitter_go::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"type_switch_statement",
"expression_switch_statement",
"select_statement",
"expression_case",
"type_case",
"communication_case",
"binary_expression",
],
)),
"c" => Some((
tree_sitter_c::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"while_statement",
"do_statement",
"case_statement",
"binary_expression",
"conditional_expression",
],
)),
"cpp" => Some((
tree_sitter_cpp::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"for_range_loop",
"while_statement",
"do_statement",
"case_statement",
"catch_clause",
"binary_expression",
"conditional_expression",
],
)),
"ruby" => Some((
tree_sitter_ruby::LANGUAGE.into(),
&[
"if",
"unless",
"while",
"until",
"for",
"rescue",
"when",
"elsif",
"if_modifier",
"unless_modifier",
"while_modifier",
"until_modifier",
"binary",
],
)),
"php" => Some((
tree_sitter_php::LANGUAGE_PHP.into(),
&[
"if_statement",
"else_if_clause",
"foreach_statement",
"for_statement",
"while_statement",
"do_statement",
"catch_clause",
"match_expression",
"case_statement",
"binary_expression",
"conditional_expression",
],
)),
"csharp" => Some((
tree_sitter_c_sharp::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"for_each_statement",
"while_statement",
"do_statement",
"catch_clause",
"switch_section",
"case_switch_label",
"binary_expression",
"conditional_expression",
],
)),
"scala" => Some((
tree_sitter_scala::LANGUAGE.into(),
&[
"if_expression",
"for_expression",
"while_expression",
"do_while_expression",
"catch_clause",
"case_clause",
"infix_expression",
],
)),
"swift" => Some((
tree_sitter_swift::LANGUAGE.into(),
&[
"if_statement",
"for_statement",
"while_statement",
"repeat_while_statement",
"guard_statement",
"catch_block",
"switch_entry",
"ternary_expression",
],
)),
_ => None,
}
}
pub fn compute_complexity_generic(content: &str, lang: &str) -> Option<ComplexityMetrics> {
let (language, branch_kinds) = generic_language(lang)?;
let mut parser = Parser::new();
parser.set_language(&language).ok()?;
let tree = parser.parse(content, None)?;
let root = tree.root_node();
let src = content.as_bytes();
let mut state = WalkState::default();
walk_generic(root, src, 0, branch_kinds, &mut state);
let cyclomatic = state.cyclomatic.saturating_add(1);
let cognitive = state.cognitive;
let grade = ComplexityGrade::from_cyclomatic(cyclomatic);
let smells = detect_smells_generic(src, &state);
tracing::debug!(
lang,
cyclomatic,
cognitive,
?grade,
max_nesting = state.max_nesting,
"compute_complexity_generic"
);
Some(ComplexityMetrics {
cyclomatic,
cognitive,
grade,
smells,
})
}
fn walk_generic(node: Node, src: &[u8], depth: u8, branch_kinds: &[&str], state: &mut WalkState) {
state.max_nesting = state.max_nesting.max(depth);
let kind = node.kind();
let mut nest_inc: u8 = 0;
if branch_kinds.contains(&kind) {
let is_binary = matches!(
kind,
"binary_expression" | "binary" | "infix_expression" | "boolean_operator"
);
if !is_binary || is_logical_op(node, src) {
state.note_branch(depth);
if !is_binary {
nest_inc = 1;
}
}
}
let new_depth = depth.saturating_add(nest_inc);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_generic(child, src, new_depth, branch_kinds, state);
}
}
fn is_logical_op(node: Node, src: &[u8]) -> bool {
if let Some(op) = node.child_by_field_name("operator") {
let txt = op.utf8_text(src).unwrap_or("");
return matches!(txt, "&&" | "||" | "and" | "or");
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let txt = child.utf8_text(src).unwrap_or("");
if matches!(txt, "&&" | "||" | "and" | "or") {
return true;
}
}
false
}
fn detect_smells_generic(src: &[u8], state: &WalkState) -> Vec<CodeSmell> {
let mut smells = Vec::new();
let lines = line_count(src);
if lines > LONG_FUNCTION_THRESHOLD {
smells.push(CodeSmell::LongFunction { lines });
}
if state.max_nesting > DEEP_NESTING_THRESHOLD {
smells.push(CodeSmell::DeepNesting {
max_depth: state.max_nesting,
});
}
if !contains_doc_marker(src) {
smells.push(CodeSmell::MissingDocstring);
}
smells
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_complexity_rust_single_branch() {
let src = "fn foo(a: i32, b: i32) -> i32 { if a > b { a } else { b } }";
let m = compute_complexity_rust(src).expect("parse should succeed");
assert!(
m.cyclomatic >= 2,
"expected cyclomatic >= 2, got {}",
m.cyclomatic
);
assert!(
matches!(m.grade, ComplexityGrade::A | ComplexityGrade::B),
"expected grade A or B, got {:?}",
m.grade
);
}
#[test]
fn compute_complexity_rust_no_branches() {
let src = "fn foo() -> i32 { 42 }";
let m = compute_complexity_rust(src).expect("parse should succeed");
assert_eq!(
m.cyclomatic, 1,
"expected cyclomatic == 1, got {}",
m.cyclomatic
);
assert_eq!(m.grade, ComplexityGrade::A);
}
#[test]
fn compute_complexity_rust_match_arms_count() {
let src = r#"
fn classify(n: i32) -> &'static str {
match n {
0 => "zero",
1 => "one",
2 => "two",
_ => "many",
}
}
"#;
let m = compute_complexity_rust(src).expect("parse should succeed");
assert!(
m.cyclomatic >= 3,
"expected cyclomatic >= 3, got {}",
m.cyclomatic
);
}
#[test]
fn compute_complexity_rust_short_circuit_counts() {
let src = r#"fn f(a: bool, b: bool, c: bool) -> bool { a && b || c }"#;
let m = compute_complexity_rust(src).expect("parse should succeed");
assert!(m.cyclomatic >= 3);
}
#[test]
fn compute_complexity_typescript_single_branch() {
let src = "function foo(a: number, b: number): number { return a > b ? a : b; }";
let m = compute_complexity_typescript(src).expect("parse should succeed");
assert!(m.cyclomatic >= 2);
}
#[test]
fn compute_complexity_typescript_no_branches() {
let src = "function foo(): number { return 42; }";
let m = compute_complexity_typescript(src).expect("parse should succeed");
assert_eq!(m.cyclomatic, 1);
assert_eq!(m.grade, ComplexityGrade::A);
}
#[test]
fn long_function_smell_fires_for_long_fn() {
let mut body = String::from("/// doc\nfn big(a: i32) -> i32 {\n");
for _ in 0..60 {
body.push_str(" let _ = 1;\n");
}
body.push_str(" a\n}\n");
let m = compute_complexity_rust(&body).expect("parse should succeed");
assert!(
m.smells
.iter()
.any(|s| matches!(s, CodeSmell::LongFunction { .. })),
"expected LongFunction smell, got {:?}",
m.smells
);
}
#[test]
fn missing_docstring_smell_for_undocumented_rust_fn() {
let m = compute_complexity_rust("fn f() {}").expect("parse should succeed");
assert!(m
.smells
.iter()
.any(|s| matches!(s, CodeSmell::MissingDocstring)));
}
#[test]
fn doc_comment_suppresses_missing_docstring() {
let m = compute_complexity_rust("/// hi\nfn f() {}").expect("parse should succeed");
assert!(!m
.smells
.iter()
.any(|s| matches!(s, CodeSmell::MissingDocstring)));
}
#[test]
fn generic_complexity_counts_python_branches() {
let src = "def f(a, b):\n if a > b and b > 0:\n return a\n for x in range(b):\n pass\n return b\n";
let m = compute_complexity_generic(src, "python").expect("python should parse");
assert!(
m.cyclomatic >= 4,
"expected cyclomatic >= 4, got {}",
m.cyclomatic
);
}
#[test]
fn generic_complexity_no_branches_is_one() {
let src = "def f():\n return 42\n";
let m = compute_complexity_generic(src, "python").expect("python should parse");
assert_eq!(m.cyclomatic, 1);
assert_eq!(m.grade, ComplexityGrade::A);
}
#[test]
fn generic_complexity_handles_go_and_ruby() {
let go = "func f(a int) int {\n\tif a > 0 {\n\t\treturn a\n\t}\n\treturn 0\n}\n";
let m = compute_complexity_generic(go, "go").expect("go should parse");
assert!(m.cyclomatic >= 2, "go cyclomatic {}", m.cyclomatic);
let ruby = "def f(a)\n if a > 0\n a\n else\n 0\n end\nend\n";
let m = compute_complexity_generic(ruby, "ruby").expect("ruby should parse");
assert!(m.cyclomatic >= 2, "ruby cyclomatic {}", m.cyclomatic);
}
#[test]
fn generic_complexity_unknown_language_is_none() {
assert!(compute_complexity_generic("anything", "klingon").is_none());
}
#[test]
fn generic_complexity_handles_java_branches() {
let src = "class A {\n int f(int a) {\n if (a > 0 && a < 10) { return a; }\n return 0;\n }\n}\n";
let m = compute_complexity_generic(src, "java").expect("java should parse");
assert!(m.cyclomatic >= 3, "java cyclomatic {}", m.cyclomatic);
}
}