Skip to main content

infiniloom_engine/embedding/
complexity.rs

1//! Cyclomatic complexity scoring for embedding chunks
2//!
3//! Computes cyclomatic complexity by counting branch-point AST node types
4//! using Tree-sitter. The formula is: `complexity = 1 + count(branch_nodes)`.
5//!
6//! Supports Rust, Python, JavaScript, TypeScript, Go, Java, C, C++, C#,
7//! Ruby, Kotlin, Swift, PHP, and Dart. Returns `None` for unsupported languages.
8
9use crate::parser::Language;
10
11/// Compute the cyclomatic complexity score for a chunk of source code.
12///
13/// Returns `Some(score)` where score >= 1 for supported languages,
14/// or `None` if the language is unsupported or parsing fails.
15///
16/// The score is `1 + count(branch_nodes)` where branch nodes are
17/// language-specific AST nodes representing control flow decisions.
18pub fn compute_complexity(content: &str, language: Language) -> Option<u32> {
19    let ts_lang = language.tree_sitter_language()?;
20
21    let mut parser = tree_sitter::Parser::new();
22    parser.set_language(&ts_lang).ok()?;
23
24    let tree = parser.parse(content, None)?;
25    let root = tree.root_node();
26
27    let branch_types = branch_node_types(language)?;
28    let logical_ops = logical_operator_types(language);
29
30    let mut count = 0u32;
31    count_branch_nodes(root, content.as_bytes(), &branch_types, &logical_ops, &mut count);
32
33    Some(1 + count)
34}
35
36/// Recursively walk the AST and count branch-point nodes.
37fn count_branch_nodes(
38    node: tree_sitter::Node,
39    source: &[u8],
40    branch_types: &[&str],
41    logical_ops: &LogicalOps,
42    count: &mut u32,
43) {
44    let kind = node.kind();
45
46    if branch_types.contains(&kind) {
47        *count += 1;
48    } else if is_logical_operator_node(node, source, kind, logical_ops) {
49        *count += 1;
50    }
51
52    let child_count = node.child_count();
53    for i in 0..child_count {
54        if let Some(child) = node.child(i as u32) {
55            count_branch_nodes(child, source, branch_types, logical_ops, count);
56        }
57    }
58}
59
60/// Language-specific configuration for logical operator detection.
61struct LogicalOps {
62    /// The AST node kind that may contain a logical operator (e.g., "binary_expression")
63    binary_node_kind: &'static str,
64    /// The operators that count as branch points
65    operators: &'static [&'static str],
66}
67
68/// Check if a node represents a logical operator (&&, ||, `and`, `or`).
69fn is_logical_operator_node(
70    node: tree_sitter::Node,
71    source: &[u8],
72    kind: &str,
73    ops: &LogicalOps,
74) -> bool {
75    if kind != ops.binary_node_kind {
76        return false;
77    }
78
79    // Look for the operator child node
80    let child_count = node.child_count();
81    for i in 0..child_count {
82        if let Some(child) = node.child(i as u32) {
83            // Check if this child is an operator token
84            let child_kind = child.kind();
85            if ops.operators.contains(&child_kind) {
86                return true;
87            }
88            // Check the text content of the operator node
89            if let Ok(text) = child.utf8_text(source) {
90                if ops.operators.contains(&text) {
91                    return true;
92                }
93            }
94        }
95    }
96    false
97}
98
99/// Return the list of branch-point AST node types for a given language.
100///
101/// These are control flow constructs that increase cyclomatic complexity.
102/// Returns `None` for unsupported languages.
103#[allow(deprecated)]
104fn branch_node_types(language: Language) -> Option<Vec<&'static str>> {
105    let types: Vec<&str> = match language {
106        Language::Rust => vec![
107            "if_expression",
108            "else_clause",
109            "match_arm",
110            "for_expression",
111            "while_expression",
112            "loop_expression",
113        ],
114        Language::Python => vec![
115            "if_statement",
116            "elif_clause",
117            "for_statement",
118            "while_statement",
119            "except_clause",
120            "conditional_expression",
121        ],
122        Language::JavaScript | Language::TypeScript => vec![
123            "if_statement",
124            "else_clause",
125            "switch_case",
126            "for_statement",
127            "for_in_statement",
128            "while_statement",
129            "do_statement",
130            "ternary_expression",
131            "catch_clause",
132        ],
133        Language::Go => vec!["if_statement", "expression_case", "for_statement"],
134        Language::Java => vec![
135            "if_statement",
136            "switch_block_statement_group",
137            "for_statement",
138            "enhanced_for_statement",
139            "while_statement",
140            "do_statement",
141            "catch_clause",
142            "ternary_expression",
143        ],
144        Language::C | Language::Cpp => vec![
145            "if_statement",
146            "else_clause",
147            "case_statement",
148            "for_statement",
149            "while_statement",
150            "do_statement",
151            "conditional_expression",
152        ],
153        Language::CSharp => vec![
154            "if_statement",
155            "else_clause",
156            "switch_section",
157            "for_statement",
158            "for_each_statement",
159            "while_statement",
160            "do_statement",
161            "catch_clause",
162            "conditional_expression",
163        ],
164        Language::Ruby => {
165            vec!["if", "elsif", "unless", "while", "until", "for", "when", "rescue", "conditional"]
166        },
167        Language::Php => vec![
168            "if_statement",
169            "else_clause",
170            "case_statement",
171            "for_statement",
172            "foreach_statement",
173            "while_statement",
174            "do_statement",
175            "catch_clause",
176        ],
177        Language::Kotlin => vec![
178            "if_expression",
179            "when_entry",
180            "for_statement",
181            "while_statement",
182            "do_while_statement",
183            "catch_block",
184        ],
185        Language::Swift => vec![
186            "if_statement",
187            "guard_statement",
188            "switch_case",
189            "for_in_statement",
190            "while_statement",
191            "repeat_while_statement",
192            "catch_clause",
193        ],
194        Language::Dart => vec![
195            "if_statement",
196            "else_clause",
197            "switch_case",
198            "for_statement",
199            "while_statement",
200            "do_statement",
201            "catch_clause",
202            "conditional_expression",
203        ],
204        _ => return None,
205    };
206    Some(types)
207}
208
209/// Return logical operator detection configuration for a given language.
210fn logical_operator_types(language: Language) -> LogicalOps {
211    match language {
212        Language::Python => {
213            LogicalOps { binary_node_kind: "boolean_operator", operators: &["and", "or"] }
214        },
215        Language::Ruby => {
216            LogicalOps { binary_node_kind: "binary", operators: &["&&", "||", "and", "or"] }
217        },
218        _ => LogicalOps { binary_node_kind: "binary_expression", operators: &["&&", "||"] },
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_linear_function_rust() {
228        let code = r#"
229fn add(a: i32, b: i32) -> i32 {
230    let result = a + b;
231    result
232}
233"#;
234        let score = compute_complexity(code, Language::Rust).unwrap();
235        assert_eq!(score, 1, "Linear function should have complexity 1");
236    }
237
238    #[test]
239    fn test_single_if_rust() {
240        let code = r#"
241fn check(x: i32) -> bool {
242    if x > 0 {
243        return true;
244    }
245    false
246}
247"#;
248        let score = compute_complexity(code, Language::Rust).unwrap();
249        assert_eq!(score, 2, "Single if should have complexity 2");
250    }
251
252    #[test]
253    fn test_if_else_rust() {
254        let code = r#"
255fn check(x: i32) -> &str {
256    if x > 0 {
257        "positive"
258    } else {
259        "non-positive"
260    }
261}
262"#;
263        let score = compute_complexity(code, Language::Rust).unwrap();
264        // if_expression + else_clause = 2 branch points, so complexity = 3
265        assert_eq!(score, 3, "if/else should have complexity 3");
266    }
267
268    #[test]
269    fn test_nested_control_flow_rust() {
270        let code = r#"
271fn complex(items: &[i32]) -> i32 {
272    let mut sum = 0;
273    for item in items {
274        if *item > 0 {
275            sum += item;
276        } else {
277            if *item < -10 {
278                continue;
279            }
280        }
281    }
282    sum
283}
284"#;
285        let score = compute_complexity(code, Language::Rust).unwrap();
286        // for_expression + if_expression + else_clause + if_expression = 4
287        // complexity = 5
288        assert_eq!(score, 5, "Nested control flow should have complexity 5");
289    }
290
291    #[test]
292    fn test_logical_operators_rust() {
293        let code = r#"
294fn check(a: bool, b: bool, c: bool) -> bool {
295    if a && b || c {
296        true
297    } else {
298        false
299    }
300}
301"#;
302        let score = compute_complexity(code, Language::Rust).unwrap();
303        // if_expression + else_clause + && + || = 4 branch points
304        assert_eq!(score, 5, "Logical operators should add to complexity");
305    }
306
307    #[test]
308    fn test_match_rust() {
309        let code = r#"
310fn classify(x: i32) -> &str {
311    match x {
312        0 => "zero",
313        1..=10 => "small",
314        _ => "large",
315    }
316}
317"#;
318        let score = compute_complexity(code, Language::Rust).unwrap();
319        // 3 match_arms = 3 branch points, complexity = 4
320        assert_eq!(score, 4, "Match with 3 arms should have complexity 4");
321    }
322
323    #[test]
324    fn test_linear_function_python() {
325        let code = r#"
326def add(a, b):
327    result = a + b
328    return result
329"#;
330        let score = compute_complexity(code, Language::Python).unwrap();
331        assert_eq!(score, 1, "Linear Python function should have complexity 1");
332    }
333
334    #[test]
335    fn test_if_elif_python() {
336        let code = r#"
337def classify(x):
338    if x > 0:
339        return "positive"
340    elif x == 0:
341        return "zero"
342    else:
343        return "negative"
344"#;
345        let score = compute_complexity(code, Language::Python).unwrap();
346        // if_statement + elif_clause = 2 branch points
347        assert_eq!(score, 3, "if/elif/else should have complexity 3");
348    }
349
350    #[test]
351    fn test_linear_function_javascript() {
352        let code = r#"
353function add(a, b) {
354    const result = a + b;
355    return result;
356}
357"#;
358        let score = compute_complexity(code, Language::JavaScript).unwrap();
359        assert_eq!(score, 1, "Linear JS function should have complexity 1");
360    }
361
362    #[test]
363    fn test_if_else_javascript() {
364        let code = r#"
365function check(x) {
366    if (x > 0) {
367        return true;
368    } else {
369        return false;
370    }
371}
372"#;
373        let score = compute_complexity(code, Language::JavaScript).unwrap();
374        // if_statement + else_clause = 2
375        assert_eq!(score, 3, "if/else JS should have complexity 3");
376    }
377
378    #[test]
379    fn test_unsupported_language_returns_none() {
380        let code = "some code here";
381        // Haskell is not in our branch_node_types list
382        let score = compute_complexity(code, Language::Haskell);
383        assert!(score.is_none(), "Unsupported language should return None");
384    }
385
386    #[test]
387    fn test_go_if_for() {
388        let code = r#"
389func process(items []int) int {
390    sum := 0
391    for _, item := range items {
392        if item > 0 {
393            sum += item
394        }
395    }
396    return sum
397}
398"#;
399        let score = compute_complexity(code, Language::Go).unwrap();
400        // for_statement + if_statement = 2
401        assert_eq!(score, 3, "Go for+if should have complexity 3");
402    }
403
404    #[test]
405    fn test_empty_content() {
406        let score = compute_complexity("", Language::Rust).unwrap();
407        assert_eq!(score, 1, "Empty content should have complexity 1");
408    }
409}