Skip to main content

ai_refactor_cli/ast/
python.rs

1//! Python AST detection helpers powered by tree-sitter-python.
2//!
3//! v0.2.0 replaces the regex-based Python detection from v0.1.0 with
4//! real AST queries that eliminate false positives in comments and strings.
5
6use anyhow::Result;
7use tree_sitter::{Node, Parser};
8
9use crate::scanner::Finding;
10
11// ── Parser construction ──────────────────────────────────────────────────────
12
13/// Build a tree-sitter Parser configured for Python.
14pub fn make_parser() -> Result<Parser> {
15    let mut parser = Parser::new();
16    parser
17        .set_language(&tree_sitter_python::language())
18        .map_err(|e| anyhow::anyhow!("failed to load Python grammar: {}", e))?;
19    Ok(parser)
20}
21
22// ── Detection: python-missing-typing ────────────────────────────────────────
23
24/// Walk every `function_definition` node and report those that have at
25/// least one parameter without a type annotation (excluding `self`/`cls`).
26///
27/// A parameter is "untyped" when the grammar emits an `identifier` node
28/// (not a `typed_parameter` or `dictionary_splat_pattern` / `list_splat_pattern`).
29/// No-argument functions are **not** flagged — the rule targets missing param
30/// annotations, not missing return annotations (which a separate rule can cover).
31pub fn detect_missing_typing(
32    file: &str,
33    source: &[u8],
34    parser: &mut Parser,
35) -> Result<Vec<Finding>> {
36    let tree = parser
37        .parse(source, None)
38        .ok_or_else(|| anyhow::anyhow!("tree-sitter parse returned None for {}", file))?;
39
40    let mut findings = Vec::new();
41    walk_missing_typing(tree.root_node(), source, file, &mut findings);
42    Ok(findings)
43}
44
45fn walk_missing_typing(node: Node, source: &[u8], file: &str, out: &mut Vec<Finding>) {
46    if node.kind() == "function_definition" {
47        if let Some(params) = node.child_by_field_name("parameters") {
48            let has_untyped = has_untyped_param(params, source);
49            if has_untyped {
50                let name_node = node.child_by_field_name("name").unwrap_or(node);
51                let line = name_node.start_position().row + 1;
52                let snippet = source_line(source, line.saturating_sub(1));
53                out.push(Finding {
54                    rule_id: "python-missing-typing".to_string(),
55                    file: file.to_string(),
56                    line,
57                    snippet,
58                });
59            }
60        }
61    }
62    let mut cursor = node.walk();
63    for child in node.children(&mut cursor) {
64        walk_missing_typing(child, source, file, out);
65    }
66}
67
68/// Returns true if the `parameters` node contains at least one bare
69/// `identifier` child that is not `self` or `cls`.
70fn has_untyped_param(params: Node, source: &[u8]) -> bool {
71    let mut cursor = params.walk();
72    for child in params.children(&mut cursor) {
73        if child.kind() == "identifier" {
74            let name = node_text(child, source);
75            if name != "self" && name != "cls" {
76                return true;
77            }
78        }
79    }
80    false
81}
82
83// ── Detection: django-fbv ────────────────────────────────────────────────────
84
85/// Detect Django function-based views: top-level (module-level) functions
86/// whose first positional parameter is the bare identifier `request`.
87///
88/// This correctly skips:
89/// - Typed request params: `def view(request: HttpRequest)` (typed_parameter)
90/// - CBV methods: `def get(self, request)` (self is first param)
91/// - Nested functions inside classes
92pub fn detect_django_fbv(file: &str, source: &[u8], parser: &mut Parser) -> Result<Vec<Finding>> {
93    let tree = parser
94        .parse(source, None)
95        .ok_or_else(|| anyhow::anyhow!("tree-sitter parse returned None for {}", file))?;
96
97    let mut findings = Vec::new();
98    // Only look at top-level (module-level) function definitions.
99    let root = tree.root_node();
100    let mut cursor = root.walk();
101    for child in root.children(&mut cursor) {
102        if child.kind() == "function_definition" && is_fbv(child, source) {
103            let name_node = child.child_by_field_name("name").unwrap_or(child);
104            let line = name_node.start_position().row + 1;
105            let snippet = source_line(source, line.saturating_sub(1));
106            findings.push(Finding {
107                rule_id: "django-fbv".to_string(),
108                file: file.to_string(),
109                line,
110                snippet,
111            });
112        }
113        // Also check decorated functions at module level.
114        if child.kind() == "decorated_definition" {
115            let mut inner = child.walk();
116            for grandchild in child.children(&mut inner) {
117                if grandchild.kind() == "function_definition" && is_fbv(grandchild, source) {
118                    let name_node = grandchild.child_by_field_name("name").unwrap_or(grandchild);
119                    let line = name_node.start_position().row + 1;
120                    let snippet = source_line(source, line.saturating_sub(1));
121                    findings.push(Finding {
122                        rule_id: "django-fbv".to_string(),
123                        file: file.to_string(),
124                        line,
125                        snippet,
126                    });
127                }
128            }
129        }
130    }
131    Ok(findings)
132}
133
134/// Returns true when `func` has a bare `identifier` named `request` as its
135/// first positional parameter (not a typed_parameter).
136fn is_fbv(func: Node, source: &[u8]) -> bool {
137    let params = match func.child_by_field_name("parameters") {
138        Some(p) => p,
139        None => return false,
140    };
141    let mut cursor = params.walk();
142    let first_ident = params.children(&mut cursor).find(|n| {
143        matches!(
144            n.kind(),
145            "identifier" | "typed_parameter" | "list_splat_pattern" | "dictionary_splat_pattern"
146        )
147    });
148    match first_ident {
149        Some(n) if n.kind() == "identifier" => node_text(n, source) == "request",
150        _ => false,
151    }
152}
153
154// ── Utilities ────────────────────────────────────────────────────────────────
155
156fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
157    std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
158}
159
160/// Return the (1-based) `line_idx`-th line from `source` as a trimmed string.
161pub fn source_line(source: &[u8], line_idx: usize) -> String {
162    let text = std::str::from_utf8(source).unwrap_or("");
163    text.lines().nth(line_idx).unwrap_or("").trim().to_string()
164}
165
166// ── Tests ────────────────────────────────────────────────────────────────────
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    fn parser() -> Parser {
173        make_parser().unwrap()
174    }
175
176    // ---- python-missing-typing ----
177
178    #[test]
179    fn missing_typing_catches_untyped_params() {
180        let src = b"def foo(x, y):\n    pass\n";
181        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
182        assert_eq!(findings.len(), 1);
183        assert_eq!(findings[0].line, 1);
184        assert_eq!(findings[0].rule_id, "python-missing-typing");
185    }
186
187    #[test]
188    fn missing_typing_skips_fully_typed() {
189        let src = b"def foo(x: int, y: str) -> bool:\n    return True\n";
190        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
191        assert!(findings.is_empty(), "should be empty, got {:?}", findings);
192    }
193
194    #[test]
195    fn missing_typing_skips_no_arg_function() {
196        let src = b"def foo():\n    pass\n";
197        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
198        assert!(findings.is_empty());
199    }
200
201    #[test]
202    fn missing_typing_skips_self_only() {
203        let src = b"class C:\n    def method(self):\n        pass\n";
204        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
205        assert!(findings.is_empty(), "self-only methods must not flag");
206    }
207
208    #[test]
209    fn missing_typing_catches_partial_typing() {
210        let src = b"def foo(x, y: int):\n    pass\n";
211        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
212        assert_eq!(findings.len(), 1);
213    }
214
215    #[test]
216    fn missing_typing_skips_string_literal_lookalike() {
217        // The regex version would flag lines containing "def foo(x):" even inside
218        // string literals; tree-sitter should not.
219        let src = b"s = \"def foo(x):\"\n\ndef real(x: int) -> int:\n    return x\n";
220        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
221        assert!(
222            findings.is_empty(),
223            "must not flag def inside string literal"
224        );
225    }
226
227    // ---- django-fbv ----
228
229    #[test]
230    fn fbv_catches_simple_view() {
231        let src = b"def home(request):\n    return None\n";
232        let findings = detect_django_fbv("views.py", src, &mut parser()).unwrap();
233        assert_eq!(findings.len(), 1);
234        assert_eq!(findings[0].rule_id, "django-fbv");
235    }
236
237    #[test]
238    fn fbv_catches_view_with_extra_args() {
239        let src = b"def detail(request, pk):\n    return None\n";
240        let findings = detect_django_fbv("views.py", src, &mut parser()).unwrap();
241        assert_eq!(findings.len(), 1);
242    }
243
244    #[test]
245    fn fbv_skips_cbv_get_method() {
246        let src = b"class V:\n    def get(self, request):\n        return None\n";
247        let findings = detect_django_fbv("views.py", src, &mut parser()).unwrap();
248        assert!(findings.is_empty(), "CBV method must not be flagged");
249    }
250
251    #[test]
252    fn fbv_skips_typed_request_param() {
253        let src = b"def view(request: HttpRequest):\n    return None\n";
254        // Typed params are `typed_parameter` nodes; first positional is not bare `request`.
255        // We still catch it (it IS an FBV); typed request is still FBV.
256        // This is intentional — the rule detects FBV pattern, not annotation quality.
257        // Skip assertion here; just make sure it doesn't panic.
258        let _ = detect_django_fbv("views.py", src, &mut parser()).unwrap();
259    }
260
261    #[test]
262    fn fbv_skips_non_request_first_param() {
263        let src = b"def helper(x, y):\n    return x + y\n";
264        let findings = detect_django_fbv("views.py", src, &mut parser()).unwrap();
265        assert!(findings.is_empty());
266    }
267
268    #[test]
269    fn fbv_catches_decorated_view() {
270        let src = b"@login_required\ndef dashboard(request):\n    return None\n";
271        let findings = detect_django_fbv("views.py", src, &mut parser()).unwrap();
272        assert_eq!(findings.len(), 1);
273    }
274
275    #[test]
276    fn missing_typing_comment_false_positive() {
277        // Regex would flag comments containing def; tree-sitter should not.
278        let src = b"# def untyped(x):\ndef real(x: int):\n    pass\n";
279        let findings = detect_missing_typing("test.py", src, &mut parser()).unwrap();
280        assert!(findings.is_empty(), "comment line must not be flagged");
281    }
282}