1use crate::parser::Language;
10
11pub fn compute_complexity(content: &str, language: Language) -> Option<u32> {
19 let ts_lang = language.tree_sitter_language()?;
20
21 let mut parser = tree_sitter::Parser::new();
22 parser.set_language(&ts_lang).ok()?;
23
24 let tree = parser.parse(content, None)?;
25 let root = tree.root_node();
26
27 let branch_types = branch_node_types(language)?;
28 let logical_ops = logical_operator_types(language);
29
30 let mut count = 0u32;
31 count_branch_nodes(root, content.as_bytes(), &branch_types, &logical_ops, &mut count);
32
33 Some(1 + count)
34}
35
36fn count_branch_nodes(
38 node: tree_sitter::Node,
39 source: &[u8],
40 branch_types: &[&str],
41 logical_ops: &LogicalOps,
42 count: &mut u32,
43) {
44 let kind = node.kind();
45
46 if branch_types.contains(&kind) {
47 *count += 1;
48 } else if is_logical_operator_node(node, source, kind, logical_ops) {
49 *count += 1;
50 }
51
52 let child_count = node.child_count();
53 for i in 0..child_count {
54 if let Some(child) = node.child(i as u32) {
55 count_branch_nodes(child, source, branch_types, logical_ops, count);
56 }
57 }
58}
59
60struct LogicalOps {
62 binary_node_kind: &'static str,
64 operators: &'static [&'static str],
66}
67
68fn is_logical_operator_node(
70 node: tree_sitter::Node,
71 source: &[u8],
72 kind: &str,
73 ops: &LogicalOps,
74) -> bool {
75 if kind != ops.binary_node_kind {
76 return false;
77 }
78
79 let child_count = node.child_count();
81 for i in 0..child_count {
82 if let Some(child) = node.child(i as u32) {
83 let child_kind = child.kind();
85 if ops.operators.contains(&child_kind) {
86 return true;
87 }
88 if let Ok(text) = child.utf8_text(source) {
90 if ops.operators.contains(&text) {
91 return true;
92 }
93 }
94 }
95 }
96 false
97}
98
99#[allow(deprecated)]
104fn branch_node_types(language: Language) -> Option<Vec<&'static str>> {
105 let types: Vec<&str> = match language {
106 Language::Rust => vec![
107 "if_expression",
108 "else_clause",
109 "match_arm",
110 "for_expression",
111 "while_expression",
112 "loop_expression",
113 ],
114 Language::Python => vec![
115 "if_statement",
116 "elif_clause",
117 "for_statement",
118 "while_statement",
119 "except_clause",
120 "conditional_expression",
121 ],
122 Language::JavaScript | Language::TypeScript => vec![
123 "if_statement",
124 "else_clause",
125 "switch_case",
126 "for_statement",
127 "for_in_statement",
128 "while_statement",
129 "do_statement",
130 "ternary_expression",
131 "catch_clause",
132 ],
133 Language::Go => vec!["if_statement", "expression_case", "for_statement"],
134 Language::Java => vec![
135 "if_statement",
136 "switch_block_statement_group",
137 "for_statement",
138 "enhanced_for_statement",
139 "while_statement",
140 "do_statement",
141 "catch_clause",
142 "ternary_expression",
143 ],
144 Language::C | Language::Cpp => vec![
145 "if_statement",
146 "else_clause",
147 "case_statement",
148 "for_statement",
149 "while_statement",
150 "do_statement",
151 "conditional_expression",
152 ],
153 Language::CSharp => vec![
154 "if_statement",
155 "else_clause",
156 "switch_section",
157 "for_statement",
158 "for_each_statement",
159 "while_statement",
160 "do_statement",
161 "catch_clause",
162 "conditional_expression",
163 ],
164 Language::Ruby => {
165 vec!["if", "elsif", "unless", "while", "until", "for", "when", "rescue", "conditional"]
166 },
167 Language::Php => vec![
168 "if_statement",
169 "else_clause",
170 "case_statement",
171 "for_statement",
172 "foreach_statement",
173 "while_statement",
174 "do_statement",
175 "catch_clause",
176 ],
177 Language::Kotlin => vec![
178 "if_expression",
179 "when_entry",
180 "for_statement",
181 "while_statement",
182 "do_while_statement",
183 "catch_block",
184 ],
185 Language::Swift => vec![
186 "if_statement",
187 "guard_statement",
188 "switch_case",
189 "for_in_statement",
190 "while_statement",
191 "repeat_while_statement",
192 "catch_clause",
193 ],
194 Language::Dart => vec![
195 "if_statement",
196 "else_clause",
197 "switch_case",
198 "for_statement",
199 "while_statement",
200 "do_statement",
201 "catch_clause",
202 "conditional_expression",
203 ],
204 _ => return None,
205 };
206 Some(types)
207}
208
209fn logical_operator_types(language: Language) -> LogicalOps {
211 match language {
212 Language::Python => {
213 LogicalOps { binary_node_kind: "boolean_operator", operators: &["and", "or"] }
214 },
215 Language::Ruby => {
216 LogicalOps { binary_node_kind: "binary", operators: &["&&", "||", "and", "or"] }
217 },
218 _ => LogicalOps { binary_node_kind: "binary_expression", operators: &["&&", "||"] },
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_linear_function_rust() {
228 let code = r#"
229fn add(a: i32, b: i32) -> i32 {
230 let result = a + b;
231 result
232}
233"#;
234 let score = compute_complexity(code, Language::Rust).unwrap();
235 assert_eq!(score, 1, "Linear function should have complexity 1");
236 }
237
238 #[test]
239 fn test_single_if_rust() {
240 let code = r#"
241fn check(x: i32) -> bool {
242 if x > 0 {
243 return true;
244 }
245 false
246}
247"#;
248 let score = compute_complexity(code, Language::Rust).unwrap();
249 assert_eq!(score, 2, "Single if should have complexity 2");
250 }
251
252 #[test]
253 fn test_if_else_rust() {
254 let code = r#"
255fn check(x: i32) -> &str {
256 if x > 0 {
257 "positive"
258 } else {
259 "non-positive"
260 }
261}
262"#;
263 let score = compute_complexity(code, Language::Rust).unwrap();
264 assert_eq!(score, 3, "if/else should have complexity 3");
266 }
267
268 #[test]
269 fn test_nested_control_flow_rust() {
270 let code = r#"
271fn complex(items: &[i32]) -> i32 {
272 let mut sum = 0;
273 for item in items {
274 if *item > 0 {
275 sum += item;
276 } else {
277 if *item < -10 {
278 continue;
279 }
280 }
281 }
282 sum
283}
284"#;
285 let score = compute_complexity(code, Language::Rust).unwrap();
286 assert_eq!(score, 5, "Nested control flow should have complexity 5");
289 }
290
291 #[test]
292 fn test_logical_operators_rust() {
293 let code = r#"
294fn check(a: bool, b: bool, c: bool) -> bool {
295 if a && b || c {
296 true
297 } else {
298 false
299 }
300}
301"#;
302 let score = compute_complexity(code, Language::Rust).unwrap();
303 assert_eq!(score, 5, "Logical operators should add to complexity");
305 }
306
307 #[test]
308 fn test_match_rust() {
309 let code = r#"
310fn classify(x: i32) -> &str {
311 match x {
312 0 => "zero",
313 1..=10 => "small",
314 _ => "large",
315 }
316}
317"#;
318 let score = compute_complexity(code, Language::Rust).unwrap();
319 assert_eq!(score, 4, "Match with 3 arms should have complexity 4");
321 }
322
323 #[test]
324 fn test_linear_function_python() {
325 let code = r#"
326def add(a, b):
327 result = a + b
328 return result
329"#;
330 let score = compute_complexity(code, Language::Python).unwrap();
331 assert_eq!(score, 1, "Linear Python function should have complexity 1");
332 }
333
334 #[test]
335 fn test_if_elif_python() {
336 let code = r#"
337def classify(x):
338 if x > 0:
339 return "positive"
340 elif x == 0:
341 return "zero"
342 else:
343 return "negative"
344"#;
345 let score = compute_complexity(code, Language::Python).unwrap();
346 assert_eq!(score, 3, "if/elif/else should have complexity 3");
348 }
349
350 #[test]
351 fn test_linear_function_javascript() {
352 let code = r#"
353function add(a, b) {
354 const result = a + b;
355 return result;
356}
357"#;
358 let score = compute_complexity(code, Language::JavaScript).unwrap();
359 assert_eq!(score, 1, "Linear JS function should have complexity 1");
360 }
361
362 #[test]
363 fn test_if_else_javascript() {
364 let code = r#"
365function check(x) {
366 if (x > 0) {
367 return true;
368 } else {
369 return false;
370 }
371}
372"#;
373 let score = compute_complexity(code, Language::JavaScript).unwrap();
374 assert_eq!(score, 3, "if/else JS should have complexity 3");
376 }
377
378 #[test]
379 fn test_unsupported_language_returns_none() {
380 let code = "some code here";
381 let score = compute_complexity(code, Language::Haskell);
383 assert!(score.is_none(), "Unsupported language should return None");
384 }
385
386 #[test]
387 fn test_go_if_for() {
388 let code = r#"
389func process(items []int) int {
390 sum := 0
391 for _, item := range items {
392 if item > 0 {
393 sum += item
394 }
395 }
396 return sum
397}
398"#;
399 let score = compute_complexity(code, Language::Go).unwrap();
400 assert_eq!(score, 3, "Go for+if should have complexity 3");
402 }
403
404 #[test]
405 fn test_empty_content() {
406 let score = compute_complexity("", Language::Rust).unwrap();
407 assert_eq!(score, 1, "Empty content should have complexity 1");
408 }
409}