Skip to main content

cha_parser/
python.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use cha_core::{ClassInfo, FunctionInfo, ImportInfo, SourceFile, SourceModel};
5use tree_sitter::{Node, Parser};
6
7use crate::LanguageParser;
8
9pub struct PythonParser;
10
11impl LanguageParser for PythonParser {
12    fn language_name(&self) -> &str {
13        "python"
14    }
15
16    fn parse(&self, file: &SourceFile) -> Option<SourceModel> {
17        let mut parser = Parser::new();
18        parser
19            .set_language(&tree_sitter_python::LANGUAGE.into())
20            .ok()?;
21        let tree = parser.parse(&file.content, None)?;
22        let root = tree.root_node();
23        let src = file.content.as_bytes();
24
25        let mut functions = Vec::new();
26        let mut classes = Vec::new();
27        let mut imports = Vec::new();
28
29        collect_top_level(root, src, &mut functions, &mut classes, &mut imports);
30
31        Some(SourceModel {
32            language: "python".into(),
33            total_lines: file.line_count(),
34            functions,
35            classes,
36            imports,
37            comments: collect_comments(root, src),
38            type_aliases: vec![], // TODO(parser): extract type aliases from 'type X = Y' / 'X = Y' declarations
39        })
40    }
41}
42
43fn push_definition(
44    node: Node,
45    src: &[u8],
46    functions: &mut Vec<FunctionInfo>,
47    classes: &mut Vec<ClassInfo>,
48) {
49    match node.kind() {
50        "function_definition" => {
51            if let Some(f) = extract_function(node, src) {
52                functions.push(f);
53            }
54        }
55        "class_definition" => {
56            if let Some(c) = extract_class(node, src, functions) {
57                classes.push(c);
58            }
59        }
60        _ => {}
61    }
62}
63
64fn collect_top_level(
65    node: Node,
66    src: &[u8],
67    functions: &mut Vec<FunctionInfo>,
68    classes: &mut Vec<ClassInfo>,
69    imports: &mut Vec<ImportInfo>,
70) {
71    let mut cursor = node.walk();
72    for child in node.children(&mut cursor) {
73        match child.kind() {
74            "function_definition" | "class_definition" => {
75                push_definition(child, src, functions, classes);
76            }
77            "import_statement" => collect_import(child, src, imports),
78            "import_from_statement" => collect_import_from(child, src, imports),
79            "decorated_definition" => {
80                let mut inner = child.walk();
81                for c in child.children(&mut inner) {
82                    push_definition(c, src, functions, classes);
83                }
84            }
85            _ => {}
86        }
87    }
88}
89
90fn extract_function(node: Node, src: &[u8]) -> Option<FunctionInfo> {
91    let name_node = node.child_by_field_name("name")?;
92    let name = node_text(name_node, src).to_string();
93    let name_col = name_node.start_position().column;
94    let name_end_col = name_node.end_position().column;
95    let start_line = node.start_position().row + 1;
96    let end_line = node.end_position().row + 1;
97    let body = node.child_by_field_name("body");
98    let params = node.child_by_field_name("parameters");
99    let (param_count, param_types) = params
100        .map(|p| extract_params(p, src))
101        .unwrap_or((0, vec![]));
102
103    Some(FunctionInfo {
104        name,
105        start_line,
106        end_line,
107        name_col,
108        name_end_col,
109        line_count: end_line - start_line + 1,
110        complexity: count_complexity(node),
111        body_hash: body.map(hash_ast_structure),
112        is_exported: true,
113        parameter_count: param_count,
114        parameter_types: param_types,
115        chain_depth: body.map(max_chain_depth).unwrap_or(0),
116        switch_arms: body.map(count_match_arms).unwrap_or(0),
117        external_refs: body
118            .map(|b| collect_external_refs(b, src))
119            .unwrap_or_default(),
120        is_delegating: body.map(|b| check_delegating(b, src)).unwrap_or(false),
121        comment_lines: count_comment_lines(node, src),
122        referenced_fields: body.map(|b| collect_self_refs(b, src)).unwrap_or_default(),
123        null_check_fields: body
124            .map(|b| collect_none_checks(b, src))
125            .unwrap_or_default(),
126        switch_dispatch_target: body.and_then(|b| extract_match_target_py(b, src)),
127        optional_param_count: params.map(count_optional).unwrap_or(0),
128        called_functions: body.map(|b| collect_calls_py(b, src)).unwrap_or_default(),
129        cognitive_complexity: body.map(cognitive_complexity_py).unwrap_or(0),
130    })
131}
132
133fn find_method_def(child: Node) -> Option<Node> {
134    if child.kind() == "function_definition" {
135        return Some(child);
136    }
137    if child.kind() == "decorated_definition" {
138        let mut inner = child.walk();
139        return child
140            .children(&mut inner)
141            .find(|c| c.kind() == "function_definition");
142    }
143    None
144}
145
146fn extract_parent_name(node: Node, src: &[u8]) -> Option<String> {
147    node.child_by_field_name("superclasses").and_then(|sc| {
148        let mut c = sc.walk();
149        sc.children(&mut c)
150            .find(|n| n.kind() != "(" && n.kind() != ")" && n.kind() != ",")
151            .map(|n| node_text(n, src).to_string())
152    })
153}
154
155fn has_listener_name(name: &str) -> bool {
156    name.contains("listener")
157        || name.contains("handler")
158        || name.contains("callback")
159        || name.contains("observer")
160}
161
162fn process_method(
163    func_node: Node,
164    f: &mut FunctionInfo,
165    src: &[u8],
166    field_names: &mut Vec<String>,
167) -> (bool, bool, bool, usize) {
168    let method_name = &f.name;
169    let mut has_behavior = false;
170    let mut is_override = false;
171    let mut is_notify = false;
172    if method_name == "__init__" {
173        collect_init_fields(func_node, src, field_names);
174    } else {
175        has_behavior = true;
176    }
177    let sc = func_node
178        .child_by_field_name("body")
179        .map(|b| count_self_calls(b, src))
180        .unwrap_or(0);
181    if method_name.starts_with("__") && method_name.ends_with("__") && method_name != "__init__" {
182        is_override = true;
183    }
184    if method_name.contains("notify") || method_name.contains("emit") {
185        is_notify = true;
186    }
187    f.is_exported = !method_name.starts_with('_');
188    (has_behavior, is_override, is_notify, sc)
189}
190
191struct ClassScan {
192    methods: Vec<FunctionInfo>,
193    field_names: Vec<String>,
194    delegating_count: usize,
195    has_behavior: bool,
196    override_count: usize,
197    self_call_count: usize,
198    has_notify_method: bool,
199}
200
201fn scan_class_methods(body: Node, src: &[u8]) -> ClassScan {
202    let mut s = ClassScan {
203        methods: Vec::new(),
204        field_names: Vec::new(),
205        delegating_count: 0,
206        has_behavior: false,
207        override_count: 0,
208        self_call_count: 0,
209        has_notify_method: false,
210    };
211    let mut cursor = body.walk();
212    for child in body.children(&mut cursor) {
213        let Some(func_node) = find_method_def(child) else {
214            continue;
215        };
216        let Some(mut f) = extract_function(func_node, src) else {
217            continue;
218        };
219        if f.is_delegating {
220            s.delegating_count += 1;
221        }
222        let (behav, over, notify, sc) = process_method(func_node, &mut f, src, &mut s.field_names);
223        s.has_behavior |= behav;
224        if over {
225            s.override_count += 1;
226        }
227        if notify {
228            s.has_notify_method = true;
229        }
230        s.self_call_count += sc;
231        s.methods.push(f);
232    }
233    s
234}
235
236fn extract_class(
237    node: Node,
238    src: &[u8],
239    top_functions: &mut Vec<FunctionInfo>,
240) -> Option<ClassInfo> {
241    let name_node = node.child_by_field_name("name")?;
242    let name = node_text(name_node, src).to_string();
243    let name_col = name_node.start_position().column;
244    let name_end_col = name_node.end_position().column;
245    let start_line = node.start_position().row + 1;
246    let end_line = node.end_position().row + 1;
247    let body = node.child_by_field_name("body")?;
248    let s = scan_class_methods(body, src);
249    let method_count = s.methods.len();
250    top_functions.extend(s.methods);
251
252    Some(ClassInfo {
253        name,
254        start_line,
255        end_line,
256        name_col,
257        name_end_col,
258        line_count: end_line - start_line + 1,
259        method_count,
260        is_exported: true,
261        delegating_method_count: s.delegating_count,
262        field_count: s.field_names.len(),
263        has_listener_field: s.field_names.iter().any(|n| has_listener_name(n)),
264        field_names: s.field_names,
265        field_types: Vec::new(),
266        has_behavior: s.has_behavior,
267        is_interface: has_only_pass_or_ellipsis(body, src),
268        parent_name: extract_parent_name(node, src),
269        override_count: s.override_count,
270        self_call_count: s.self_call_count,
271        has_notify_method: s.has_notify_method,
272    })
273}
274
275// --- imports ---
276
277fn collect_import(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
278    let line = node.start_position().row + 1;
279    let col = node.start_position().column;
280    let mut cursor = node.walk();
281    for child in node.children(&mut cursor) {
282        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
283            let text = node_text(child, src);
284            imports.push(ImportInfo {
285                source: text.to_string(),
286                line,
287                col,
288                ..Default::default()
289            });
290        }
291    }
292}
293
294fn collect_import_from(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
295    let line = node.start_position().row + 1;
296    let col = node.start_position().column;
297    let module = node
298        .child_by_field_name("module_name")
299        .map(|n| node_text(n, src).to_string())
300        .unwrap_or_default();
301    let mut cursor = node.walk();
302    let mut has_names = false;
303    for child in node.children(&mut cursor) {
304        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
305            let n = node_text(child, src).to_string();
306            if n != module {
307                imports.push(ImportInfo {
308                    source: format!("{module}.{n}"),
309                    line,
310                    col,
311                    ..Default::default()
312                });
313                has_names = true;
314            }
315        }
316    }
317    if !has_names {
318        imports.push(ImportInfo {
319            source: module,
320            line,
321            col,
322            ..Default::default()
323        });
324    }
325}
326
327// --- helpers ---
328
329fn node_text<'a>(node: Node, src: &'a [u8]) -> &'a str {
330    node.utf8_text(src).unwrap_or("")
331}
332
333fn count_complexity(node: Node) -> usize {
334    let mut complexity = 1usize;
335    let mut cursor = node.walk();
336    visit_all(node, &mut cursor, &mut |n| {
337        match n.kind() {
338            "if_statement"
339            | "elif_clause"
340            | "for_statement"
341            | "while_statement"
342            | "except_clause"
343            | "with_statement"
344            | "assert_statement"
345            | "conditional_expression"
346            | "boolean_operator"
347            | "list_comprehension"
348            | "set_comprehension"
349            | "dictionary_comprehension"
350            | "generator_expression" => {
351                complexity += 1;
352            }
353            "match_statement" => {} // match itself doesn't add, cases do
354            "case_clause" => {
355                complexity += 1;
356            }
357            _ => {}
358        }
359    });
360    complexity
361}
362
363fn hash_ast_structure(node: Node) -> u64 {
364    let mut hasher = DefaultHasher::new();
365    hash_node(node, &mut hasher);
366    hasher.finish()
367}
368
369fn hash_node(node: Node, hasher: &mut DefaultHasher) {
370    node.kind().hash(hasher);
371    let mut cursor = node.walk();
372    for child in node.children(&mut cursor) {
373        hash_node(child, hasher);
374    }
375}
376
377fn max_chain_depth(node: Node) -> usize {
378    let mut max = 0usize;
379    let mut cursor = node.walk();
380    visit_all(node, &mut cursor, &mut |n| {
381        if n.kind() == "attribute" {
382            let depth = chain_len(n);
383            if depth > max {
384                max = depth;
385            }
386        }
387    });
388    max
389}
390
391fn chain_len(node: Node) -> usize {
392    let mut depth = 0usize;
393    let mut current = node;
394    while current.kind() == "attribute" || current.kind() == "call" {
395        if current.kind() == "attribute" {
396            depth += 1;
397        }
398        if let Some(obj) = current.child(0) {
399            current = obj;
400        } else {
401            break;
402        }
403    }
404    depth
405}
406
407fn count_match_arms(node: Node) -> usize {
408    let mut count = 0usize;
409    let mut cursor = node.walk();
410    visit_all(node, &mut cursor, &mut |n| {
411        if n.kind() == "case_clause" {
412            count += 1;
413        }
414    });
415    count
416}
417
418fn collect_external_refs(node: Node, src: &[u8]) -> Vec<String> {
419    let mut refs = Vec::new();
420    let mut cursor = node.walk();
421    visit_all(node, &mut cursor, &mut |n| {
422        if n.kind() != "attribute" {
423            return;
424        }
425        let Some(obj) = n.child(0) else { return };
426        let text = node_text(obj, src);
427        if text != "self"
428            && !text.is_empty()
429            && text.starts_with(|c: char| c.is_lowercase())
430            && !refs.contains(&text.to_string())
431        {
432            refs.push(text.to_string());
433        }
434    });
435    refs
436}
437
438fn unwrap_single_call(body: Node) -> Option<Node> {
439    let mut c = body.walk();
440    let stmts: Vec<Node> = body
441        .children(&mut c)
442        .filter(|n| !n.is_extra() && n.kind() != "pass_statement" && n.kind() != "comment")
443        .collect();
444    if stmts.len() != 1 {
445        return None;
446    }
447    let stmt = stmts[0];
448    match stmt.kind() {
449        "return_statement" => stmt.child(1).filter(|v| v.kind() == "call"),
450        "expression_statement" => stmt.child(0).filter(|v| v.kind() == "call"),
451        _ => None,
452    }
453}
454
455fn check_delegating(body: Node, src: &[u8]) -> bool {
456    let Some(func) = unwrap_single_call(body).and_then(|c| c.child(0)) else {
457        return false;
458    };
459    let text = node_text(func, src);
460    text.contains('.') && !text.starts_with("self.")
461}
462
463fn count_comment_lines(node: Node, src: &[u8]) -> usize {
464    let mut count = 0usize;
465    let mut cursor = node.walk();
466    visit_all(node, &mut cursor, &mut |n| {
467        if n.kind() == "comment" {
468            count += 1;
469        } else if n.kind() == "string" || n.kind() == "expression_statement" {
470            // docstrings
471            let text = node_text(n, src);
472            if text.starts_with("\"\"\"") || text.starts_with("'''") {
473                count += text.lines().count();
474            }
475        }
476    });
477    count
478}
479
480fn collect_self_refs(body: Node, src: &[u8]) -> Vec<String> {
481    let mut refs = Vec::new();
482    let mut cursor = body.walk();
483    visit_all(body, &mut cursor, &mut |n| {
484        if n.kind() != "attribute" {
485            return;
486        }
487        let is_self = n.child(0).is_some_and(|o| node_text(o, src) == "self");
488        if !is_self {
489            return;
490        }
491        if let Some(attr) = n.child_by_field_name("attribute") {
492            let name = node_text(attr, src).to_string();
493            if !refs.contains(&name) {
494                refs.push(name);
495            }
496        }
497    });
498    refs
499}
500
501fn collect_none_checks(body: Node, src: &[u8]) -> Vec<String> {
502    let mut fields = Vec::new();
503    let mut cursor = body.walk();
504    visit_all(body, &mut cursor, &mut |n| {
505        if n.kind() != "comparison_operator" {
506            return;
507        }
508        let text = node_text(n, src);
509        if !text.contains("is None") && !text.contains("is not None") && !text.contains("== None") {
510            return;
511        }
512        if let Some(left) = n.child(0) {
513            let name = node_text(left, src).to_string();
514            if !fields.contains(&name) {
515                fields.push(name);
516            }
517        }
518    });
519    fields
520}
521
522fn is_self_or_cls(name: &str) -> bool {
523    name == "self" || name == "cls"
524}
525
526fn param_name_and_type(child: Node, src: &[u8]) -> Option<(String, String)> {
527    match child.kind() {
528        "identifier" => {
529            let name = node_text(child, src);
530            (!is_self_or_cls(name)).then(|| (name.to_string(), "Any".to_string()))
531        }
532        "typed_parameter" | "default_parameter" | "typed_default_parameter" => {
533            let name = child
534                .child_by_field_name("name")
535                .or_else(|| child.child(0))
536                .map(|n| node_text(n, src))
537                .unwrap_or("");
538            if is_self_or_cls(name) {
539                return None;
540            }
541            let ty = child
542                .child_by_field_name("type")
543                .map(|n| node_text(n, src).to_string())
544                .unwrap_or_else(|| "Any".to_string());
545            Some((name.to_string(), ty))
546        }
547        "list_splat_pattern" | "dictionary_splat_pattern" => {
548            Some(("*".to_string(), "Any".to_string()))
549        }
550        _ => None,
551    }
552}
553
554fn extract_params(params_node: Node, src: &[u8]) -> (usize, Vec<String>) {
555    let mut count = 0usize;
556    let mut types = Vec::new();
557    let mut cursor = params_node.walk();
558    for child in params_node.children(&mut cursor) {
559        if let Some((_name, ty)) = param_name_and_type(child, src) {
560            count += 1;
561            types.push(ty);
562        }
563    }
564    (count, types)
565}
566
567fn count_optional(params_node: Node) -> usize {
568    let mut count = 0usize;
569    let mut cursor = params_node.walk();
570    for child in params_node.children(&mut cursor) {
571        if child.kind() == "default_parameter" || child.kind() == "typed_default_parameter" {
572            count += 1;
573        }
574    }
575    count
576}
577
578fn collect_init_fields(func_node: Node, src: &[u8], fields: &mut Vec<String>) {
579    let Some(body) = func_node.child_by_field_name("body") else {
580        return;
581    };
582    let mut cursor = body.walk();
583    visit_all(body, &mut cursor, &mut |n| {
584        if n.kind() != "assignment" {
585            return;
586        }
587        let Some(left) = n.child_by_field_name("left") else {
588            return;
589        };
590        if left.kind() != "attribute" {
591            return;
592        }
593        let is_self = left.child(0).is_some_and(|o| node_text(o, src) == "self");
594        if !is_self {
595            return;
596        }
597        if let Some(attr) = left.child_by_field_name("attribute") {
598            let name = node_text(attr, src).to_string();
599            if !fields.contains(&name) {
600                fields.push(name);
601            }
602        }
603    });
604}
605
606fn count_self_calls(body: Node, src: &[u8]) -> usize {
607    let mut count = 0;
608    let mut cursor = body.walk();
609    visit_all(body, &mut cursor, &mut |n| {
610        if n.kind() != "call" {
611            return;
612        }
613        let is_self_call = n
614            .child(0)
615            .filter(|f| f.kind() == "attribute")
616            .and_then(|f| f.child(0))
617            .is_some_and(|obj| node_text(obj, src) == "self");
618        if is_self_call {
619            count += 1;
620        }
621    });
622    count
623}
624
625fn is_stub_body(node: Node, src: &[u8]) -> bool {
626    node.child_by_field_name("body")
627        .is_none_or(|b| has_only_pass_or_ellipsis(b, src))
628}
629
630fn has_only_pass_or_ellipsis(body: Node, src: &[u8]) -> bool {
631    let mut cursor = body.walk();
632    for child in body.children(&mut cursor) {
633        let ok = match child.kind() {
634            "pass_statement" | "ellipsis" | "comment" => true,
635            "expression_statement" => child.child(0).is_none_or(|expr| {
636                let text = node_text(expr, src);
637                text == "..." || text.starts_with("\"\"\"") || text.starts_with("'''")
638            }),
639            "function_definition" => is_stub_body(child, src),
640            "decorated_definition" => {
641                let mut inner = child.walk();
642                child
643                    .children(&mut inner)
644                    .filter(|c| c.kind() == "function_definition")
645                    .all(|c| is_stub_body(c, src))
646            }
647            _ => false,
648        };
649        if !ok {
650            return false;
651        }
652    }
653    true
654}
655
656fn cognitive_complexity_py(node: tree_sitter::Node) -> usize {
657    let mut score = 0;
658    cc_walk_py(node, 0, &mut score);
659    score
660}
661
662fn cc_walk_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
663    match node.kind() {
664        "if_statement" => {
665            *score += 1 + nesting;
666            cc_children_py(node, nesting + 1, score);
667            return;
668        }
669        "for_statement" | "while_statement" => {
670            *score += 1 + nesting;
671            cc_children_py(node, nesting + 1, score);
672            return;
673        }
674        "match_statement" => {
675            *score += 1 + nesting;
676            cc_children_py(node, nesting + 1, score);
677            return;
678        }
679        "elif_clause" | "else_clause" => {
680            *score += 1;
681        }
682        "boolean_operator" => {
683            *score += 1;
684        }
685        "except_clause" => {
686            *score += 1 + nesting;
687            cc_children_py(node, nesting + 1, score);
688            return;
689        }
690        "lambda" => {
691            cc_children_py(node, nesting + 1, score);
692            return;
693        }
694        _ => {}
695    }
696    cc_children_py(node, nesting, score);
697}
698
699fn cc_children_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
700    let mut cursor = node.walk();
701    for child in node.children(&mut cursor) {
702        cc_walk_py(child, nesting, score);
703    }
704}
705
706fn extract_match_target_py(body: tree_sitter::Node, src: &[u8]) -> Option<String> {
707    let mut target = None;
708    let mut cursor = body.walk();
709    visit_all(body, &mut cursor, &mut |n| {
710        if n.kind() == "match_statement"
711            && target.is_none()
712            && let Some(subj) = n.child_by_field_name("subject")
713        {
714            target = Some(node_text(subj, src).to_string());
715        }
716    });
717    target
718}
719
720fn collect_calls_py(body: tree_sitter::Node, src: &[u8]) -> Vec<String> {
721    let mut calls = Vec::new();
722    let mut cursor = body.walk();
723    visit_all(body, &mut cursor, &mut |n| {
724        if n.kind() == "call"
725            && let Some(func) = n.child(0)
726        {
727            let name = node_text(func, src).to_string();
728            if !calls.contains(&name) {
729                calls.push(name);
730            }
731        }
732    });
733    calls
734}
735
736fn collect_comments(root: Node, src: &[u8]) -> Vec<cha_core::CommentInfo> {
737    let mut comments = Vec::new();
738    let mut cursor = root.walk();
739    visit_all(root, &mut cursor, &mut |n| {
740        if n.kind().contains("comment") {
741            comments.push(cha_core::CommentInfo {
742                text: node_text(n, src).to_string(),
743                line: n.start_position().row + 1,
744            });
745        }
746    });
747    comments
748}
749
750fn visit_all<F: FnMut(Node)>(node: Node, cursor: &mut tree_sitter::TreeCursor, f: &mut F) {
751    f(node);
752    if cursor.goto_first_child() {
753        loop {
754            let child_node = cursor.node();
755            let mut child_cursor = child_node.walk();
756            visit_all(child_node, &mut child_cursor, f);
757            if !cursor.goto_next_sibling() {
758                break;
759            }
760        }
761        cursor.goto_parent();
762    }
763}