Skip to main content

cha_parser/
python.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use cha_core::{ClassInfo, FunctionInfo, ImportInfo, SourceFile, SourceModel};
5use tree_sitter::{Node, Parser};
6
7use crate::LanguageParser;
8
9pub struct PythonParser;
10
11impl LanguageParser for PythonParser {
12    fn language_name(&self) -> &str {
13        "python"
14    }
15
16    fn parse(&self, file: &SourceFile) -> Option<SourceModel> {
17        let mut parser = Parser::new();
18        parser
19            .set_language(&tree_sitter_python::LANGUAGE.into())
20            .ok()?;
21        let tree = parser.parse(&file.content, None)?;
22        let root = tree.root_node();
23        let src = file.content.as_bytes();
24
25        let mut functions = Vec::new();
26        let mut classes = Vec::new();
27        let mut imports = Vec::new();
28
29        collect_top_level(root, src, &mut functions, &mut classes, &mut imports);
30
31        Some(SourceModel {
32            language: "python".into(),
33            total_lines: file.line_count(),
34            functions,
35            classes,
36            imports,
37            comments: collect_comments(root, src),
38            type_aliases: vec![], // TODO(parser): extract type aliases from 'type X = Y' / 'X = Y' declarations
39        })
40    }
41}
42
43fn push_definition(
44    node: Node,
45    src: &[u8],
46    functions: &mut Vec<FunctionInfo>,
47    classes: &mut Vec<ClassInfo>,
48) {
49    match node.kind() {
50        "function_definition" => {
51            if let Some(f) = extract_function(node, src) {
52                functions.push(f);
53            }
54        }
55        "class_definition" => {
56            if let Some(c) = extract_class(node, src, functions) {
57                classes.push(c);
58            }
59        }
60        _ => {}
61    }
62}
63
64fn collect_top_level(
65    node: Node,
66    src: &[u8],
67    functions: &mut Vec<FunctionInfo>,
68    classes: &mut Vec<ClassInfo>,
69    imports: &mut Vec<ImportInfo>,
70) {
71    let mut cursor = node.walk();
72    for child in node.children(&mut cursor) {
73        match child.kind() {
74            "function_definition" | "class_definition" => {
75                push_definition(child, src, functions, classes);
76            }
77            "import_statement" => collect_import(child, src, imports),
78            "import_from_statement" => collect_import_from(child, src, imports),
79            "decorated_definition" => {
80                let mut inner = child.walk();
81                for c in child.children(&mut inner) {
82                    push_definition(c, src, functions, classes);
83                }
84            }
85            _ => {}
86        }
87    }
88}
89
90fn extract_function(node: Node, src: &[u8]) -> Option<FunctionInfo> {
91    let name = node_text(node.child_by_field_name("name")?, src).to_string();
92    let start_line = node.start_position().row + 1;
93    let end_line = node.end_position().row + 1;
94    let body = node.child_by_field_name("body");
95    let params = node.child_by_field_name("parameters");
96    let (param_count, param_types) = params
97        .map(|p| extract_params(p, src))
98        .unwrap_or((0, vec![]));
99
100    Some(FunctionInfo {
101        name,
102        start_line,
103        end_line,
104        line_count: end_line - start_line + 1,
105        complexity: count_complexity(node),
106        body_hash: body.map(hash_ast_structure),
107        is_exported: true,
108        parameter_count: param_count,
109        parameter_types: param_types,
110        chain_depth: body.map(max_chain_depth).unwrap_or(0),
111        switch_arms: body.map(count_match_arms).unwrap_or(0),
112        external_refs: body
113            .map(|b| collect_external_refs(b, src))
114            .unwrap_or_default(),
115        is_delegating: body.map(|b| check_delegating(b, src)).unwrap_or(false),
116        comment_lines: count_comment_lines(node, src),
117        referenced_fields: body.map(|b| collect_self_refs(b, src)).unwrap_or_default(),
118        null_check_fields: body
119            .map(|b| collect_none_checks(b, src))
120            .unwrap_or_default(),
121        switch_dispatch_target: body.and_then(|b| extract_match_target_py(b, src)),
122        optional_param_count: params.map(count_optional).unwrap_or(0),
123        called_functions: body.map(|b| collect_calls_py(b, src)).unwrap_or_default(),
124        cognitive_complexity: body.map(cognitive_complexity_py).unwrap_or(0),
125    })
126}
127
128fn find_method_def(child: Node) -> Option<Node> {
129    if child.kind() == "function_definition" {
130        return Some(child);
131    }
132    if child.kind() == "decorated_definition" {
133        let mut inner = child.walk();
134        return child
135            .children(&mut inner)
136            .find(|c| c.kind() == "function_definition");
137    }
138    None
139}
140
141fn extract_parent_name(node: Node, src: &[u8]) -> Option<String> {
142    node.child_by_field_name("superclasses").and_then(|sc| {
143        let mut c = sc.walk();
144        sc.children(&mut c)
145            .find(|n| n.kind() != "(" && n.kind() != ")" && n.kind() != ",")
146            .map(|n| node_text(n, src).to_string())
147    })
148}
149
150fn has_listener_name(name: &str) -> bool {
151    name.contains("listener")
152        || name.contains("handler")
153        || name.contains("callback")
154        || name.contains("observer")
155}
156
157fn process_method(
158    func_node: Node,
159    f: &mut FunctionInfo,
160    src: &[u8],
161    field_names: &mut Vec<String>,
162) -> (bool, bool, bool, usize) {
163    let method_name = &f.name;
164    let mut has_behavior = false;
165    let mut is_override = false;
166    let mut is_notify = false;
167    if method_name == "__init__" {
168        collect_init_fields(func_node, src, field_names);
169    } else {
170        has_behavior = true;
171    }
172    let sc = func_node
173        .child_by_field_name("body")
174        .map(|b| count_self_calls(b, src))
175        .unwrap_or(0);
176    if method_name.starts_with("__") && method_name.ends_with("__") && method_name != "__init__" {
177        is_override = true;
178    }
179    if method_name.contains("notify") || method_name.contains("emit") {
180        is_notify = true;
181    }
182    f.is_exported = !method_name.starts_with('_');
183    (has_behavior, is_override, is_notify, sc)
184}
185
186struct ClassScan {
187    methods: Vec<FunctionInfo>,
188    field_names: Vec<String>,
189    delegating_count: usize,
190    has_behavior: bool,
191    override_count: usize,
192    self_call_count: usize,
193    has_notify_method: bool,
194}
195
196fn scan_class_methods(body: Node, src: &[u8]) -> ClassScan {
197    let mut s = ClassScan {
198        methods: Vec::new(),
199        field_names: Vec::new(),
200        delegating_count: 0,
201        has_behavior: false,
202        override_count: 0,
203        self_call_count: 0,
204        has_notify_method: false,
205    };
206    let mut cursor = body.walk();
207    for child in body.children(&mut cursor) {
208        let Some(func_node) = find_method_def(child) else {
209            continue;
210        };
211        let Some(mut f) = extract_function(func_node, src) else {
212            continue;
213        };
214        if f.is_delegating {
215            s.delegating_count += 1;
216        }
217        let (behav, over, notify, sc) = process_method(func_node, &mut f, src, &mut s.field_names);
218        s.has_behavior |= behav;
219        if over {
220            s.override_count += 1;
221        }
222        if notify {
223            s.has_notify_method = true;
224        }
225        s.self_call_count += sc;
226        s.methods.push(f);
227    }
228    s
229}
230
231fn extract_class(
232    node: Node,
233    src: &[u8],
234    top_functions: &mut Vec<FunctionInfo>,
235) -> Option<ClassInfo> {
236    let name = node_text(node.child_by_field_name("name")?, src).to_string();
237    let start_line = node.start_position().row + 1;
238    let end_line = node.end_position().row + 1;
239    let body = node.child_by_field_name("body")?;
240    let s = scan_class_methods(body, src);
241    let method_count = s.methods.len();
242    top_functions.extend(s.methods);
243
244    Some(ClassInfo {
245        name,
246        start_line,
247        end_line,
248        line_count: end_line - start_line + 1,
249        method_count,
250        is_exported: true,
251        delegating_method_count: s.delegating_count,
252        field_count: s.field_names.len(),
253        has_listener_field: s.field_names.iter().any(|n| has_listener_name(n)),
254        field_names: s.field_names,
255        field_types: Vec::new(),
256        has_behavior: s.has_behavior,
257        is_interface: has_only_pass_or_ellipsis(body, src),
258        parent_name: extract_parent_name(node, src),
259        override_count: s.override_count,
260        self_call_count: s.self_call_count,
261        has_notify_method: s.has_notify_method,
262    })
263}
264
265// --- imports ---
266
267fn collect_import(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
268    let line = node.start_position().row + 1;
269    let mut cursor = node.walk();
270    for child in node.children(&mut cursor) {
271        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
272            let text = node_text(child, src);
273            imports.push(ImportInfo {
274                source: text.to_string(),
275                line,
276                ..Default::default()
277            });
278        }
279    }
280}
281
282fn collect_import_from(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
283    let line = node.start_position().row + 1;
284    let module = node
285        .child_by_field_name("module_name")
286        .map(|n| node_text(n, src).to_string())
287        .unwrap_or_default();
288    let mut cursor = node.walk();
289    let mut has_names = false;
290    for child in node.children(&mut cursor) {
291        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
292            let n = node_text(child, src).to_string();
293            if n != module {
294                imports.push(ImportInfo {
295                    source: format!("{module}.{n}"),
296                    line,
297                    ..Default::default()
298                });
299                has_names = true;
300            }
301        }
302    }
303    if !has_names {
304        imports.push(ImportInfo {
305            source: module,
306            line,
307            ..Default::default()
308        });
309    }
310}
311
312// --- helpers ---
313
314fn node_text<'a>(node: Node, src: &'a [u8]) -> &'a str {
315    node.utf8_text(src).unwrap_or("")
316}
317
318fn count_complexity(node: Node) -> usize {
319    let mut complexity = 1usize;
320    let mut cursor = node.walk();
321    visit_all(node, &mut cursor, &mut |n| {
322        match n.kind() {
323            "if_statement"
324            | "elif_clause"
325            | "for_statement"
326            | "while_statement"
327            | "except_clause"
328            | "with_statement"
329            | "assert_statement"
330            | "conditional_expression"
331            | "boolean_operator"
332            | "list_comprehension"
333            | "set_comprehension"
334            | "dictionary_comprehension"
335            | "generator_expression" => {
336                complexity += 1;
337            }
338            "match_statement" => {} // match itself doesn't add, cases do
339            "case_clause" => {
340                complexity += 1;
341            }
342            _ => {}
343        }
344    });
345    complexity
346}
347
348fn hash_ast_structure(node: Node) -> u64 {
349    let mut hasher = DefaultHasher::new();
350    hash_node(node, &mut hasher);
351    hasher.finish()
352}
353
354fn hash_node(node: Node, hasher: &mut DefaultHasher) {
355    node.kind().hash(hasher);
356    let mut cursor = node.walk();
357    for child in node.children(&mut cursor) {
358        hash_node(child, hasher);
359    }
360}
361
362fn max_chain_depth(node: Node) -> usize {
363    let mut max = 0usize;
364    let mut cursor = node.walk();
365    visit_all(node, &mut cursor, &mut |n| {
366        if n.kind() == "attribute" {
367            let depth = chain_len(n);
368            if depth > max {
369                max = depth;
370            }
371        }
372    });
373    max
374}
375
376fn chain_len(node: Node) -> usize {
377    let mut depth = 0usize;
378    let mut current = node;
379    while current.kind() == "attribute" || current.kind() == "call" {
380        if current.kind() == "attribute" {
381            depth += 1;
382        }
383        if let Some(obj) = current.child(0) {
384            current = obj;
385        } else {
386            break;
387        }
388    }
389    depth
390}
391
392fn count_match_arms(node: Node) -> usize {
393    let mut count = 0usize;
394    let mut cursor = node.walk();
395    visit_all(node, &mut cursor, &mut |n| {
396        if n.kind() == "case_clause" {
397            count += 1;
398        }
399    });
400    count
401}
402
403fn collect_external_refs(node: Node, src: &[u8]) -> Vec<String> {
404    let mut refs = Vec::new();
405    let mut cursor = node.walk();
406    visit_all(node, &mut cursor, &mut |n| {
407        if n.kind() != "attribute" {
408            return;
409        }
410        let Some(obj) = n.child(0) else { return };
411        let text = node_text(obj, src);
412        if text != "self"
413            && !text.is_empty()
414            && text.starts_with(|c: char| c.is_lowercase())
415            && !refs.contains(&text.to_string())
416        {
417            refs.push(text.to_string());
418        }
419    });
420    refs
421}
422
423fn unwrap_single_call(body: Node) -> Option<Node> {
424    let mut c = body.walk();
425    let stmts: Vec<Node> = body
426        .children(&mut c)
427        .filter(|n| !n.is_extra() && n.kind() != "pass_statement" && n.kind() != "comment")
428        .collect();
429    if stmts.len() != 1 {
430        return None;
431    }
432    let stmt = stmts[0];
433    match stmt.kind() {
434        "return_statement" => stmt.child(1).filter(|v| v.kind() == "call"),
435        "expression_statement" => stmt.child(0).filter(|v| v.kind() == "call"),
436        _ => None,
437    }
438}
439
440fn check_delegating(body: Node, src: &[u8]) -> bool {
441    let Some(func) = unwrap_single_call(body).and_then(|c| c.child(0)) else {
442        return false;
443    };
444    let text = node_text(func, src);
445    text.contains('.') && !text.starts_with("self.")
446}
447
448fn count_comment_lines(node: Node, src: &[u8]) -> usize {
449    let mut count = 0usize;
450    let mut cursor = node.walk();
451    visit_all(node, &mut cursor, &mut |n| {
452        if n.kind() == "comment" {
453            count += 1;
454        } else if n.kind() == "string" || n.kind() == "expression_statement" {
455            // docstrings
456            let text = node_text(n, src);
457            if text.starts_with("\"\"\"") || text.starts_with("'''") {
458                count += text.lines().count();
459            }
460        }
461    });
462    count
463}
464
465fn collect_self_refs(body: Node, src: &[u8]) -> Vec<String> {
466    let mut refs = Vec::new();
467    let mut cursor = body.walk();
468    visit_all(body, &mut cursor, &mut |n| {
469        if n.kind() != "attribute" {
470            return;
471        }
472        let is_self = n.child(0).is_some_and(|o| node_text(o, src) == "self");
473        if !is_self {
474            return;
475        }
476        if let Some(attr) = n.child_by_field_name("attribute") {
477            let name = node_text(attr, src).to_string();
478            if !refs.contains(&name) {
479                refs.push(name);
480            }
481        }
482    });
483    refs
484}
485
486fn collect_none_checks(body: Node, src: &[u8]) -> Vec<String> {
487    let mut fields = Vec::new();
488    let mut cursor = body.walk();
489    visit_all(body, &mut cursor, &mut |n| {
490        if n.kind() != "comparison_operator" {
491            return;
492        }
493        let text = node_text(n, src);
494        if !text.contains("is None") && !text.contains("is not None") && !text.contains("== None") {
495            return;
496        }
497        if let Some(left) = n.child(0) {
498            let name = node_text(left, src).to_string();
499            if !fields.contains(&name) {
500                fields.push(name);
501            }
502        }
503    });
504    fields
505}
506
507fn is_self_or_cls(name: &str) -> bool {
508    name == "self" || name == "cls"
509}
510
511fn param_name_and_type(child: Node, src: &[u8]) -> Option<(String, String)> {
512    match child.kind() {
513        "identifier" => {
514            let name = node_text(child, src);
515            (!is_self_or_cls(name)).then(|| (name.to_string(), "Any".to_string()))
516        }
517        "typed_parameter" | "default_parameter" | "typed_default_parameter" => {
518            let name = child
519                .child_by_field_name("name")
520                .or_else(|| child.child(0))
521                .map(|n| node_text(n, src))
522                .unwrap_or("");
523            if is_self_or_cls(name) {
524                return None;
525            }
526            let ty = child
527                .child_by_field_name("type")
528                .map(|n| node_text(n, src).to_string())
529                .unwrap_or_else(|| "Any".to_string());
530            Some((name.to_string(), ty))
531        }
532        "list_splat_pattern" | "dictionary_splat_pattern" => {
533            Some(("*".to_string(), "Any".to_string()))
534        }
535        _ => None,
536    }
537}
538
539fn extract_params(params_node: Node, src: &[u8]) -> (usize, Vec<String>) {
540    let mut count = 0usize;
541    let mut types = Vec::new();
542    let mut cursor = params_node.walk();
543    for child in params_node.children(&mut cursor) {
544        if let Some((_name, ty)) = param_name_and_type(child, src) {
545            count += 1;
546            types.push(ty);
547        }
548    }
549    (count, types)
550}
551
552fn count_optional(params_node: Node) -> usize {
553    let mut count = 0usize;
554    let mut cursor = params_node.walk();
555    for child in params_node.children(&mut cursor) {
556        if child.kind() == "default_parameter" || child.kind() == "typed_default_parameter" {
557            count += 1;
558        }
559    }
560    count
561}
562
563fn collect_init_fields(func_node: Node, src: &[u8], fields: &mut Vec<String>) {
564    let Some(body) = func_node.child_by_field_name("body") else {
565        return;
566    };
567    let mut cursor = body.walk();
568    visit_all(body, &mut cursor, &mut |n| {
569        if n.kind() != "assignment" {
570            return;
571        }
572        let Some(left) = n.child_by_field_name("left") else {
573            return;
574        };
575        if left.kind() != "attribute" {
576            return;
577        }
578        let is_self = left.child(0).is_some_and(|o| node_text(o, src) == "self");
579        if !is_self {
580            return;
581        }
582        if let Some(attr) = left.child_by_field_name("attribute") {
583            let name = node_text(attr, src).to_string();
584            if !fields.contains(&name) {
585                fields.push(name);
586            }
587        }
588    });
589}
590
591fn count_self_calls(body: Node, src: &[u8]) -> usize {
592    let mut count = 0;
593    let mut cursor = body.walk();
594    visit_all(body, &mut cursor, &mut |n| {
595        if n.kind() != "call" {
596            return;
597        }
598        let is_self_call = n
599            .child(0)
600            .filter(|f| f.kind() == "attribute")
601            .and_then(|f| f.child(0))
602            .is_some_and(|obj| node_text(obj, src) == "self");
603        if is_self_call {
604            count += 1;
605        }
606    });
607    count
608}
609
610fn is_stub_body(node: Node, src: &[u8]) -> bool {
611    node.child_by_field_name("body")
612        .is_none_or(|b| has_only_pass_or_ellipsis(b, src))
613}
614
615fn has_only_pass_or_ellipsis(body: Node, src: &[u8]) -> bool {
616    let mut cursor = body.walk();
617    for child in body.children(&mut cursor) {
618        let ok = match child.kind() {
619            "pass_statement" | "ellipsis" | "comment" => true,
620            "expression_statement" => child.child(0).is_none_or(|expr| {
621                let text = node_text(expr, src);
622                text == "..." || text.starts_with("\"\"\"") || text.starts_with("'''")
623            }),
624            "function_definition" => is_stub_body(child, src),
625            "decorated_definition" => {
626                let mut inner = child.walk();
627                child
628                    .children(&mut inner)
629                    .filter(|c| c.kind() == "function_definition")
630                    .all(|c| is_stub_body(c, src))
631            }
632            _ => false,
633        };
634        if !ok {
635            return false;
636        }
637    }
638    true
639}
640
641fn cognitive_complexity_py(node: tree_sitter::Node) -> usize {
642    let mut score = 0;
643    cc_walk_py(node, 0, &mut score);
644    score
645}
646
647fn cc_walk_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
648    match node.kind() {
649        "if_statement" => {
650            *score += 1 + nesting;
651            cc_children_py(node, nesting + 1, score);
652            return;
653        }
654        "for_statement" | "while_statement" => {
655            *score += 1 + nesting;
656            cc_children_py(node, nesting + 1, score);
657            return;
658        }
659        "match_statement" => {
660            *score += 1 + nesting;
661            cc_children_py(node, nesting + 1, score);
662            return;
663        }
664        "elif_clause" | "else_clause" => {
665            *score += 1;
666        }
667        "boolean_operator" => {
668            *score += 1;
669        }
670        "except_clause" => {
671            *score += 1 + nesting;
672            cc_children_py(node, nesting + 1, score);
673            return;
674        }
675        "lambda" => {
676            cc_children_py(node, nesting + 1, score);
677            return;
678        }
679        _ => {}
680    }
681    cc_children_py(node, nesting, score);
682}
683
684fn cc_children_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
685    let mut cursor = node.walk();
686    for child in node.children(&mut cursor) {
687        cc_walk_py(child, nesting, score);
688    }
689}
690
691fn extract_match_target_py(body: tree_sitter::Node, src: &[u8]) -> Option<String> {
692    let mut target = None;
693    let mut cursor = body.walk();
694    visit_all(body, &mut cursor, &mut |n| {
695        if n.kind() == "match_statement"
696            && target.is_none()
697            && let Some(subj) = n.child_by_field_name("subject")
698        {
699            target = Some(node_text(subj, src).to_string());
700        }
701    });
702    target
703}
704
705fn collect_calls_py(body: tree_sitter::Node, src: &[u8]) -> Vec<String> {
706    let mut calls = Vec::new();
707    let mut cursor = body.walk();
708    visit_all(body, &mut cursor, &mut |n| {
709        if n.kind() == "call"
710            && let Some(func) = n.child(0)
711        {
712            let name = node_text(func, src).to_string();
713            if !calls.contains(&name) {
714                calls.push(name);
715            }
716        }
717    });
718    calls
719}
720
721fn collect_comments(root: Node, src: &[u8]) -> Vec<cha_core::CommentInfo> {
722    let mut comments = Vec::new();
723    let mut cursor = root.walk();
724    visit_all(root, &mut cursor, &mut |n| {
725        if n.kind().contains("comment") {
726            comments.push(cha_core::CommentInfo {
727                text: node_text(n, src).to_string(),
728                line: n.start_position().row + 1,
729            });
730        }
731    });
732    comments
733}
734
735fn visit_all<F: FnMut(Node)>(node: Node, cursor: &mut tree_sitter::TreeCursor, f: &mut F) {
736    f(node);
737    if cursor.goto_first_child() {
738        loop {
739            let child_node = cursor.node();
740            let mut child_cursor = child_node.walk();
741            visit_all(child_node, &mut child_cursor, f);
742            if !cursor.goto_next_sibling() {
743                break;
744            }
745        }
746        cursor.goto_parent();
747    }
748}