Skip to main content

cha_parser/
python.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use cha_core::{ClassInfo, FunctionInfo, ImportInfo, SourceFile, SourceModel};
5use tree_sitter::{Node, Parser};
6
7use crate::LanguageParser;
8
9pub struct PythonParser;
10
11impl LanguageParser for PythonParser {
12    fn language_name(&self) -> &str {
13        "python"
14    }
15
16    fn parse(&self, file: &SourceFile) -> Option<SourceModel> {
17        let mut parser = Parser::new();
18        parser
19            .set_language(&tree_sitter_python::LANGUAGE.into())
20            .ok()?;
21        let tree = parser.parse(&file.content, None)?;
22        let root = tree.root_node();
23        let src = file.content.as_bytes();
24
25        let mut functions = Vec::new();
26        let mut classes = Vec::new();
27        let mut imports = Vec::new();
28
29        collect_top_level(root, src, &mut functions, &mut classes, &mut imports);
30
31        Some(SourceModel {
32            language: "python".into(),
33            total_lines: file.line_count(),
34            functions,
35            classes,
36            imports,
37            comments: collect_comments(root, src),
38            type_aliases: vec![], // TODO(parser): extract type aliases from 'type X = Y' / 'X = Y' declarations
39        })
40    }
41}
42
43fn push_definition(
44    node: Node,
45    src: &[u8],
46    functions: &mut Vec<FunctionInfo>,
47    classes: &mut Vec<ClassInfo>,
48) {
49    match node.kind() {
50        "function_definition" => {
51            if let Some(f) = extract_function(node, src) {
52                functions.push(f);
53            }
54        }
55        "class_definition" => {
56            if let Some(c) = extract_class(node, src, functions) {
57                classes.push(c);
58            }
59        }
60        _ => {}
61    }
62}
63
64fn collect_top_level(
65    node: Node,
66    src: &[u8],
67    functions: &mut Vec<FunctionInfo>,
68    classes: &mut Vec<ClassInfo>,
69    imports: &mut Vec<ImportInfo>,
70) {
71    let mut cursor = node.walk();
72    for child in node.children(&mut cursor) {
73        match child.kind() {
74            "function_definition" | "class_definition" => {
75                push_definition(child, src, functions, classes);
76            }
77            "import_statement" => collect_import(child, src, imports),
78            "import_from_statement" => collect_import_from(child, src, imports),
79            "decorated_definition" => {
80                let mut inner = child.walk();
81                for c in child.children(&mut inner) {
82                    push_definition(c, src, functions, classes);
83                }
84            }
85            _ => {}
86        }
87    }
88}
89
90fn extract_function(node: Node, src: &[u8]) -> Option<FunctionInfo> {
91    let name = node_text(node.child_by_field_name("name")?, src).to_string();
92    let start_line = node.start_position().row + 1;
93    let end_line = node.end_position().row + 1;
94    let body = node.child_by_field_name("body");
95    let params = node.child_by_field_name("parameters");
96    let (param_count, param_types) = params
97        .map(|p| extract_params(p, src))
98        .unwrap_or((0, vec![]));
99
100    Some(FunctionInfo {
101        name,
102        start_line,
103        end_line,
104        line_count: end_line - start_line + 1,
105        complexity: count_complexity(node),
106        body_hash: body.map(hash_ast_structure),
107        is_exported: true,
108        parameter_count: param_count,
109        parameter_types: param_types,
110        chain_depth: body.map(max_chain_depth).unwrap_or(0),
111        switch_arms: body.map(count_match_arms).unwrap_or(0),
112        external_refs: body
113            .map(|b| collect_external_refs(b, src))
114            .unwrap_or_default(),
115        is_delegating: body.map(|b| check_delegating(b, src)).unwrap_or(false),
116        comment_lines: count_comment_lines(node, src),
117        referenced_fields: body.map(|b| collect_self_refs(b, src)).unwrap_or_default(),
118        null_check_fields: body
119            .map(|b| collect_none_checks(b, src))
120            .unwrap_or_default(),
121        switch_dispatch_target: body.and_then(|b| extract_match_target_py(b, src)),
122        optional_param_count: params.map(count_optional).unwrap_or(0),
123        called_functions: body.map(|b| collect_calls_py(b, src)).unwrap_or_default(),
124        cognitive_complexity: body.map(cognitive_complexity_py).unwrap_or(0),
125    })
126}
127
128fn find_method_def(child: Node) -> Option<Node> {
129    if child.kind() == "function_definition" {
130        return Some(child);
131    }
132    if child.kind() == "decorated_definition" {
133        let mut inner = child.walk();
134        return child
135            .children(&mut inner)
136            .find(|c| c.kind() == "function_definition");
137    }
138    None
139}
140
141fn extract_parent_name(node: Node, src: &[u8]) -> Option<String> {
142    node.child_by_field_name("superclasses").and_then(|sc| {
143        let mut c = sc.walk();
144        sc.children(&mut c)
145            .find(|n| n.kind() != "(" && n.kind() != ")" && n.kind() != ",")
146            .map(|n| node_text(n, src).to_string())
147    })
148}
149
150fn has_listener_name(name: &str) -> bool {
151    name.contains("listener")
152        || name.contains("handler")
153        || name.contains("callback")
154        || name.contains("observer")
155}
156
157fn process_method(
158    func_node: Node,
159    f: &mut FunctionInfo,
160    src: &[u8],
161    field_names: &mut Vec<String>,
162) -> (bool, bool, bool, usize) {
163    let method_name = &f.name;
164    let mut has_behavior = false;
165    let mut is_override = false;
166    let mut is_notify = false;
167    if method_name == "__init__" {
168        collect_init_fields(func_node, src, field_names);
169    } else {
170        has_behavior = true;
171    }
172    let sc = func_node
173        .child_by_field_name("body")
174        .map(|b| count_self_calls(b, src))
175        .unwrap_or(0);
176    if method_name.starts_with("__") && method_name.ends_with("__") && method_name != "__init__" {
177        is_override = true;
178    }
179    if method_name.contains("notify") || method_name.contains("emit") {
180        is_notify = true;
181    }
182    f.is_exported = !method_name.starts_with('_');
183    (has_behavior, is_override, is_notify, sc)
184}
185
186struct ClassScan {
187    methods: Vec<FunctionInfo>,
188    field_names: Vec<String>,
189    delegating_count: usize,
190    has_behavior: bool,
191    override_count: usize,
192    self_call_count: usize,
193    has_notify_method: bool,
194}
195
196fn scan_class_methods(body: Node, src: &[u8]) -> ClassScan {
197    let mut s = ClassScan {
198        methods: Vec::new(),
199        field_names: Vec::new(),
200        delegating_count: 0,
201        has_behavior: false,
202        override_count: 0,
203        self_call_count: 0,
204        has_notify_method: false,
205    };
206    let mut cursor = body.walk();
207    for child in body.children(&mut cursor) {
208        let Some(func_node) = find_method_def(child) else {
209            continue;
210        };
211        let Some(mut f) = extract_function(func_node, src) else {
212            continue;
213        };
214        if f.is_delegating {
215            s.delegating_count += 1;
216        }
217        let (behav, over, notify, sc) = process_method(func_node, &mut f, src, &mut s.field_names);
218        s.has_behavior |= behav;
219        if over {
220            s.override_count += 1;
221        }
222        if notify {
223            s.has_notify_method = true;
224        }
225        s.self_call_count += sc;
226        s.methods.push(f);
227    }
228    s
229}
230
231fn extract_class(
232    node: Node,
233    src: &[u8],
234    top_functions: &mut Vec<FunctionInfo>,
235) -> Option<ClassInfo> {
236    let name = node_text(node.child_by_field_name("name")?, src).to_string();
237    let start_line = node.start_position().row + 1;
238    let end_line = node.end_position().row + 1;
239    let body = node.child_by_field_name("body")?;
240    let s = scan_class_methods(body, src);
241    let method_count = s.methods.len();
242    top_functions.extend(s.methods);
243
244    Some(ClassInfo {
245        name,
246        start_line,
247        end_line,
248        line_count: end_line - start_line + 1,
249        method_count,
250        is_exported: true,
251        delegating_method_count: s.delegating_count,
252        field_count: s.field_names.len(),
253        has_listener_field: s.field_names.iter().any(|n| has_listener_name(n)),
254        field_names: s.field_names,
255        field_types: Vec::new(),
256        has_behavior: s.has_behavior,
257        is_interface: has_only_pass_or_ellipsis(body, src),
258        parent_name: extract_parent_name(node, src),
259        override_count: s.override_count,
260        self_call_count: s.self_call_count,
261        has_notify_method: s.has_notify_method,
262    })
263}
264
265// --- imports ---
266
267fn collect_import(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
268    let line = node.start_position().row + 1;
269    let mut cursor = node.walk();
270    for child in node.children(&mut cursor) {
271        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
272            let text = node_text(child, src);
273            imports.push(ImportInfo {
274                source: text.to_string(),
275                line,
276            });
277        }
278    }
279}
280
281fn collect_import_from(node: Node, src: &[u8], imports: &mut Vec<ImportInfo>) {
282    let line = node.start_position().row + 1;
283    let module = node
284        .child_by_field_name("module_name")
285        .map(|n| node_text(n, src).to_string())
286        .unwrap_or_default();
287    let mut cursor = node.walk();
288    let mut has_names = false;
289    for child in node.children(&mut cursor) {
290        if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
291            let n = node_text(child, src).to_string();
292            if n != module {
293                imports.push(ImportInfo {
294                    source: format!("{module}.{n}"),
295                    line,
296                });
297                has_names = true;
298            }
299        }
300    }
301    if !has_names {
302        imports.push(ImportInfo {
303            source: module,
304            line,
305        });
306    }
307}
308
309// --- helpers ---
310
311fn node_text<'a>(node: Node, src: &'a [u8]) -> &'a str {
312    node.utf8_text(src).unwrap_or("")
313}
314
315fn count_complexity(node: Node) -> usize {
316    let mut complexity = 1usize;
317    let mut cursor = node.walk();
318    visit_all(node, &mut cursor, &mut |n| {
319        match n.kind() {
320            "if_statement"
321            | "elif_clause"
322            | "for_statement"
323            | "while_statement"
324            | "except_clause"
325            | "with_statement"
326            | "assert_statement"
327            | "conditional_expression"
328            | "boolean_operator"
329            | "list_comprehension"
330            | "set_comprehension"
331            | "dictionary_comprehension"
332            | "generator_expression" => {
333                complexity += 1;
334            }
335            "match_statement" => {} // match itself doesn't add, cases do
336            "case_clause" => {
337                complexity += 1;
338            }
339            _ => {}
340        }
341    });
342    complexity
343}
344
345fn hash_ast_structure(node: Node) -> u64 {
346    let mut hasher = DefaultHasher::new();
347    hash_node(node, &mut hasher);
348    hasher.finish()
349}
350
351fn hash_node(node: Node, hasher: &mut DefaultHasher) {
352    node.kind().hash(hasher);
353    let mut cursor = node.walk();
354    for child in node.children(&mut cursor) {
355        hash_node(child, hasher);
356    }
357}
358
359fn max_chain_depth(node: Node) -> usize {
360    let mut max = 0usize;
361    let mut cursor = node.walk();
362    visit_all(node, &mut cursor, &mut |n| {
363        if n.kind() == "attribute" {
364            let depth = chain_len(n);
365            if depth > max {
366                max = depth;
367            }
368        }
369    });
370    max
371}
372
373fn chain_len(node: Node) -> usize {
374    let mut depth = 0usize;
375    let mut current = node;
376    while current.kind() == "attribute" || current.kind() == "call" {
377        if current.kind() == "attribute" {
378            depth += 1;
379        }
380        if let Some(obj) = current.child(0) {
381            current = obj;
382        } else {
383            break;
384        }
385    }
386    depth
387}
388
389fn count_match_arms(node: Node) -> usize {
390    let mut count = 0usize;
391    let mut cursor = node.walk();
392    visit_all(node, &mut cursor, &mut |n| {
393        if n.kind() == "case_clause" {
394            count += 1;
395        }
396    });
397    count
398}
399
400fn collect_external_refs(node: Node, src: &[u8]) -> Vec<String> {
401    let mut refs = Vec::new();
402    let mut cursor = node.walk();
403    visit_all(node, &mut cursor, &mut |n| {
404        if n.kind() != "attribute" {
405            return;
406        }
407        let Some(obj) = n.child(0) else { return };
408        let text = node_text(obj, src);
409        if text != "self"
410            && !text.is_empty()
411            && text.starts_with(|c: char| c.is_lowercase())
412            && !refs.contains(&text.to_string())
413        {
414            refs.push(text.to_string());
415        }
416    });
417    refs
418}
419
420fn unwrap_single_call(body: Node) -> Option<Node> {
421    let mut c = body.walk();
422    let stmts: Vec<Node> = body
423        .children(&mut c)
424        .filter(|n| !n.is_extra() && n.kind() != "pass_statement" && n.kind() != "comment")
425        .collect();
426    if stmts.len() != 1 {
427        return None;
428    }
429    let stmt = stmts[0];
430    match stmt.kind() {
431        "return_statement" => stmt.child(1).filter(|v| v.kind() == "call"),
432        "expression_statement" => stmt.child(0).filter(|v| v.kind() == "call"),
433        _ => None,
434    }
435}
436
437fn check_delegating(body: Node, src: &[u8]) -> bool {
438    let Some(func) = unwrap_single_call(body).and_then(|c| c.child(0)) else {
439        return false;
440    };
441    let text = node_text(func, src);
442    text.contains('.') && !text.starts_with("self.")
443}
444
445fn count_comment_lines(node: Node, src: &[u8]) -> usize {
446    let mut count = 0usize;
447    let mut cursor = node.walk();
448    visit_all(node, &mut cursor, &mut |n| {
449        if n.kind() == "comment" {
450            count += 1;
451        } else if n.kind() == "string" || n.kind() == "expression_statement" {
452            // docstrings
453            let text = node_text(n, src);
454            if text.starts_with("\"\"\"") || text.starts_with("'''") {
455                count += text.lines().count();
456            }
457        }
458    });
459    count
460}
461
462fn collect_self_refs(body: Node, src: &[u8]) -> Vec<String> {
463    let mut refs = Vec::new();
464    let mut cursor = body.walk();
465    visit_all(body, &mut cursor, &mut |n| {
466        if n.kind() != "attribute" {
467            return;
468        }
469        let is_self = n.child(0).is_some_and(|o| node_text(o, src) == "self");
470        if !is_self {
471            return;
472        }
473        if let Some(attr) = n.child_by_field_name("attribute") {
474            let name = node_text(attr, src).to_string();
475            if !refs.contains(&name) {
476                refs.push(name);
477            }
478        }
479    });
480    refs
481}
482
483fn collect_none_checks(body: Node, src: &[u8]) -> Vec<String> {
484    let mut fields = Vec::new();
485    let mut cursor = body.walk();
486    visit_all(body, &mut cursor, &mut |n| {
487        if n.kind() != "comparison_operator" {
488            return;
489        }
490        let text = node_text(n, src);
491        if !text.contains("is None") && !text.contains("is not None") && !text.contains("== None") {
492            return;
493        }
494        if let Some(left) = n.child(0) {
495            let name = node_text(left, src).to_string();
496            if !fields.contains(&name) {
497                fields.push(name);
498            }
499        }
500    });
501    fields
502}
503
504fn is_self_or_cls(name: &str) -> bool {
505    name == "self" || name == "cls"
506}
507
508fn param_name_and_type(child: Node, src: &[u8]) -> Option<(String, String)> {
509    match child.kind() {
510        "identifier" => {
511            let name = node_text(child, src);
512            (!is_self_or_cls(name)).then(|| (name.to_string(), "Any".to_string()))
513        }
514        "typed_parameter" | "default_parameter" | "typed_default_parameter" => {
515            let name = child
516                .child_by_field_name("name")
517                .or_else(|| child.child(0))
518                .map(|n| node_text(n, src))
519                .unwrap_or("");
520            if is_self_or_cls(name) {
521                return None;
522            }
523            let ty = child
524                .child_by_field_name("type")
525                .map(|n| node_text(n, src).to_string())
526                .unwrap_or_else(|| "Any".to_string());
527            Some((name.to_string(), ty))
528        }
529        "list_splat_pattern" | "dictionary_splat_pattern" => {
530            Some(("*".to_string(), "Any".to_string()))
531        }
532        _ => None,
533    }
534}
535
536fn extract_params(params_node: Node, src: &[u8]) -> (usize, Vec<String>) {
537    let mut count = 0usize;
538    let mut types = Vec::new();
539    let mut cursor = params_node.walk();
540    for child in params_node.children(&mut cursor) {
541        if let Some((_name, ty)) = param_name_and_type(child, src) {
542            count += 1;
543            types.push(ty);
544        }
545    }
546    (count, types)
547}
548
549fn count_optional(params_node: Node) -> usize {
550    let mut count = 0usize;
551    let mut cursor = params_node.walk();
552    for child in params_node.children(&mut cursor) {
553        if child.kind() == "default_parameter" || child.kind() == "typed_default_parameter" {
554            count += 1;
555        }
556    }
557    count
558}
559
560fn collect_init_fields(func_node: Node, src: &[u8], fields: &mut Vec<String>) {
561    let Some(body) = func_node.child_by_field_name("body") else {
562        return;
563    };
564    let mut cursor = body.walk();
565    visit_all(body, &mut cursor, &mut |n| {
566        if n.kind() != "assignment" {
567            return;
568        }
569        let Some(left) = n.child_by_field_name("left") else {
570            return;
571        };
572        if left.kind() != "attribute" {
573            return;
574        }
575        let is_self = left.child(0).is_some_and(|o| node_text(o, src) == "self");
576        if !is_self {
577            return;
578        }
579        if let Some(attr) = left.child_by_field_name("attribute") {
580            let name = node_text(attr, src).to_string();
581            if !fields.contains(&name) {
582                fields.push(name);
583            }
584        }
585    });
586}
587
588fn count_self_calls(body: Node, src: &[u8]) -> usize {
589    let mut count = 0;
590    let mut cursor = body.walk();
591    visit_all(body, &mut cursor, &mut |n| {
592        if n.kind() != "call" {
593            return;
594        }
595        let is_self_call = n
596            .child(0)
597            .filter(|f| f.kind() == "attribute")
598            .and_then(|f| f.child(0))
599            .is_some_and(|obj| node_text(obj, src) == "self");
600        if is_self_call {
601            count += 1;
602        }
603    });
604    count
605}
606
607fn is_stub_body(node: Node, src: &[u8]) -> bool {
608    node.child_by_field_name("body")
609        .is_none_or(|b| has_only_pass_or_ellipsis(b, src))
610}
611
612fn has_only_pass_or_ellipsis(body: Node, src: &[u8]) -> bool {
613    let mut cursor = body.walk();
614    for child in body.children(&mut cursor) {
615        let ok = match child.kind() {
616            "pass_statement" | "ellipsis" | "comment" => true,
617            "expression_statement" => child.child(0).is_none_or(|expr| {
618                let text = node_text(expr, src);
619                text == "..." || text.starts_with("\"\"\"") || text.starts_with("'''")
620            }),
621            "function_definition" => is_stub_body(child, src),
622            "decorated_definition" => {
623                let mut inner = child.walk();
624                child
625                    .children(&mut inner)
626                    .filter(|c| c.kind() == "function_definition")
627                    .all(|c| is_stub_body(c, src))
628            }
629            _ => false,
630        };
631        if !ok {
632            return false;
633        }
634    }
635    true
636}
637
638fn cognitive_complexity_py(node: tree_sitter::Node) -> usize {
639    let mut score = 0;
640    cc_walk_py(node, 0, &mut score);
641    score
642}
643
644fn cc_walk_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
645    match node.kind() {
646        "if_statement" => {
647            *score += 1 + nesting;
648            cc_children_py(node, nesting + 1, score);
649            return;
650        }
651        "for_statement" | "while_statement" => {
652            *score += 1 + nesting;
653            cc_children_py(node, nesting + 1, score);
654            return;
655        }
656        "match_statement" => {
657            *score += 1 + nesting;
658            cc_children_py(node, nesting + 1, score);
659            return;
660        }
661        "elif_clause" | "else_clause" => {
662            *score += 1;
663        }
664        "boolean_operator" => {
665            *score += 1;
666        }
667        "except_clause" => {
668            *score += 1 + nesting;
669            cc_children_py(node, nesting + 1, score);
670            return;
671        }
672        "lambda" => {
673            cc_children_py(node, nesting + 1, score);
674            return;
675        }
676        _ => {}
677    }
678    cc_children_py(node, nesting, score);
679}
680
681fn cc_children_py(node: tree_sitter::Node, nesting: usize, score: &mut usize) {
682    let mut cursor = node.walk();
683    for child in node.children(&mut cursor) {
684        cc_walk_py(child, nesting, score);
685    }
686}
687
688fn extract_match_target_py(body: tree_sitter::Node, src: &[u8]) -> Option<String> {
689    let mut target = None;
690    let mut cursor = body.walk();
691    visit_all(body, &mut cursor, &mut |n| {
692        if n.kind() == "match_statement"
693            && target.is_none()
694            && let Some(subj) = n.child_by_field_name("subject")
695        {
696            target = Some(node_text(subj, src).to_string());
697        }
698    });
699    target
700}
701
702fn collect_calls_py(body: tree_sitter::Node, src: &[u8]) -> Vec<String> {
703    let mut calls = Vec::new();
704    let mut cursor = body.walk();
705    visit_all(body, &mut cursor, &mut |n| {
706        if n.kind() == "call"
707            && let Some(func) = n.child(0)
708        {
709            let name = node_text(func, src).to_string();
710            if !calls.contains(&name) {
711                calls.push(name);
712            }
713        }
714    });
715    calls
716}
717
718fn collect_comments(root: Node, src: &[u8]) -> Vec<cha_core::CommentInfo> {
719    let mut comments = Vec::new();
720    let mut cursor = root.walk();
721    visit_all(root, &mut cursor, &mut |n| {
722        if n.kind().contains("comment") {
723            comments.push(cha_core::CommentInfo {
724                text: node_text(n, src).to_string(),
725                line: n.start_position().row + 1,
726            });
727        }
728    });
729    comments
730}
731
732fn visit_all<F: FnMut(Node)>(node: Node, cursor: &mut tree_sitter::TreeCursor, f: &mut F) {
733    f(node);
734    if cursor.goto_first_child() {
735        loop {
736            let child_node = cursor.node();
737            let mut child_cursor = child_node.walk();
738            visit_all(child_node, &mut child_cursor, f);
739            if !cursor.goto_next_sibling() {
740                break;
741            }
742        }
743        cursor.goto_parent();
744    }
745}