Skip to main content

deagle_parse/
python_parser.rs

1//! Python language parser using tree-sitter-python.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5
6use crate::ParseResult;
7
8/// Parse a Python source file and extract definitions.
9pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
10    parse_with_edges(path, content).map(|r| r.nodes)
11}
12
13/// Parse with edge extraction — returns nodes and relationship tuples.
14pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
15    let mut parser = tree_sitter::Parser::new();
16    let language = tree_sitter_python::LANGUAGE;
17    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
18        file: path.display().to_string(),
19        message: format!("Failed to set language: {}", e),
20    })?;
21
22    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
23        file: path.display().to_string(),
24        message: "Failed to parse file".into(),
25    })?;
26
27    let mut nodes = Vec::new();
28    let file_path = path.to_string_lossy().to_string();
29
30    // Insert file node as index 0
31    nodes.push(Node {
32        id: 0,
33        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
34        kind: NodeKind::File,
35        language: Language::Python,
36        file_path: file_path.clone(),
37        line_start: 1,
38        line_end: content.lines().count() as u32,
39        content: None,
40    });
41
42    extract_definitions(tree.root_node(), content, &file_path, &mut nodes, false);
43
44    // Build CONTAINS edges: file (idx 0) → each top-level entity
45    let mut edges = Vec::new();
46    for i in 1..nodes.len() {
47        edges.push((0, i, EdgeKind::Contains));
48    }
49
50    Ok(ParseResult { nodes, edges })
51}
52
53fn extract_definitions(
54    node: tree_sitter::Node,
55    source: &str,
56    file_path: &str,
57    results: &mut Vec<Node>,
58    inside_class: bool,
59) {
60    let kind = match node.kind() {
61        "function_definition" => {
62            if inside_class {
63                Some(NodeKind::Method)
64            } else {
65                Some(NodeKind::Function)
66            }
67        }
68        "class_definition" => Some(NodeKind::Class),
69        "import_statement" | "import_from_statement" => Some(NodeKind::Import),
70        "global_statement" => None, // skip
71        "expression_statement" => {
72            // Check for top-level assignments (module-level constants)
73            if !inside_class {
74                if let Some(child) = node.child(0) {
75                    if child.kind() == "assignment" {
76                        // Only capture UPPER_CASE assignments as constants
77                        if let Some(name) = extract_assignment_name(child, source) {
78                            if name.chars().all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit()) && !name.is_empty() {
79                                let start = node.start_position();
80                                let end = node.end_position();
81                                let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
82                                    if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
83                                });
84                                results.push(Node {
85                                    id: 0,
86                                    name,
87                                    kind: NodeKind::Constant,
88                                    language: Language::Python,
89                                    file_path: file_path.to_string(),
90                                    line_start: (start.row + 1) as u32,
91                                    line_end: (end.row + 1) as u32,
92                                    content,
93                                });
94                            }
95                        }
96                    }
97                }
98            }
99            None
100        }
101        _ => None,
102    };
103
104    if let Some(kind) = kind {
105        if let Some(name) = extract_name(node, source, kind) {
106            let start = node.start_position();
107            let end = node.end_position();
108            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
109                if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
110            });
111
112            results.push(Node {
113                id: 0,
114                name,
115                kind,
116                language: Language::Python,
117                file_path: file_path.to_string(),
118                line_start: (start.row + 1) as u32,
119                line_end: (end.row + 1) as u32,
120                content,
121            });
122        }
123
124        // If this is a class, recurse into its body to find methods
125        if kind == NodeKind::Class {
126            if let Some(body) = node.child_by_field_name("body") {
127                let mut cursor = body.walk();
128                for child in body.children(&mut cursor) {
129                    extract_definitions(child, source, file_path, results, true);
130                }
131            }
132            return; // Don't double-recurse into class children
133        }
134    }
135
136    // Recurse into children (but not into class bodies — handled above)
137    if node.kind() != "class_definition" {
138        let mut cursor = node.walk();
139        for child in node.children(&mut cursor) {
140            extract_definitions(child, source, file_path, results, inside_class);
141        }
142    }
143}
144
145fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
146    match kind {
147        NodeKind::Import => {
148            // Full import text
149            node.utf8_text(source.as_bytes())
150                .ok()
151                .map(|s| s.trim().to_string())
152        }
153        _ => {
154            // Find the 'name' field (tree-sitter-python uses field names)
155            node.child_by_field_name("name")
156                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
157                .map(|s| s.to_string())
158        }
159    }
160}
161
162fn extract_assignment_name(node: tree_sitter::Node, source: &str) -> Option<String> {
163    // Assignment left side — could be identifier or pattern
164    node.child_by_field_name("left")
165        .and_then(|n| {
166            if n.kind() == "identifier" {
167                n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string())
168            } else {
169                None
170            }
171        })
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::path::PathBuf;
178
179    const SAMPLE_PYTHON: &str = r#"
180import os
181from pathlib import Path
182
183MAX_SIZE = 1024
184DEBUG = True
185
186class Config:
187    """Configuration holder."""
188
189    def __init__(self, name: str):
190        self.name = name
191        self.values = {}
192
193    def get(self, key: str) -> str:
194        return self.values.get(key, "")
195
196    @staticmethod
197    def default() -> "Config":
198        return Config("default")
199
200class Status:
201    ACTIVE = "active"
202    INACTIVE = "inactive"
203
204def process(data: list) -> dict:
205    result = {}
206    for item in data:
207        result[item] = True
208    return result
209
210def main():
211    config = Config("test")
212    print(config.get("key"))
213"#;
214
215    #[test]
216    fn test_parse_python_finds_all_definitions() {
217        let path = PathBuf::from("test.py");
218        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
219
220        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
221        assert!(kinds.contains(&NodeKind::Import), "should find import");
222        assert!(kinds.contains(&NodeKind::Constant), "should find constant");
223        assert!(kinds.contains(&NodeKind::Class), "should find class");
224        assert!(kinds.contains(&NodeKind::Function), "should find function");
225    }
226
227    #[test]
228    fn test_parse_python_finds_methods() {
229        let path = PathBuf::from("test.py");
230        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
231
232        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
233        assert!(methods.len() >= 3, "should find methods (__init__, get, default), got {}", methods.len());
234        assert!(methods.iter().any(|m| m.name == "__init__"));
235        assert!(methods.iter().any(|m| m.name == "get"));
236        assert!(methods.iter().any(|m| m.name == "default"));
237    }
238
239    #[test]
240    fn test_parse_python_class_name() {
241        let path = PathBuf::from("test.py");
242        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
243
244        let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
245        assert_eq!(classes.len(), 2);
246        assert!(classes.iter().any(|c| c.name == "Config"));
247        assert!(classes.iter().any(|c| c.name == "Status"));
248        assert_eq!(classes[0].language, Language::Python);
249    }
250
251    #[test]
252    fn test_parse_python_constants() {
253        let path = PathBuf::from("test.py");
254        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
255
256        let constants: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Constant).collect();
257        assert!(constants.iter().any(|c| c.name == "MAX_SIZE"), "should find MAX_SIZE");
258        assert!(constants.iter().any(|c| c.name == "DEBUG"), "should find DEBUG");
259    }
260
261    #[test]
262    fn test_parse_python_line_numbers() {
263        let path = PathBuf::from("test.py");
264        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
265
266        let main_fn = nodes.iter().find(|n| n.name == "main" && n.kind == NodeKind::Function);
267        assert!(main_fn.is_some(), "should find main function");
268        assert!(main_fn.unwrap().line_start > 0, "line numbers should be 1-indexed");
269    }
270
271    #[test]
272    fn test_parse_python_imports() {
273        let path = PathBuf::from("test.py");
274        let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
275
276        let imports: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Import).collect();
277        assert_eq!(imports.len(), 2, "should find 2 import statements");
278        assert!(imports.iter().any(|i| i.name.contains("os")));
279        assert!(imports.iter().any(|i| i.name.contains("pathlib")));
280    }
281
282    #[test]
283    fn test_parse_python_edges() {
284        let path = PathBuf::from("test.py");
285        let result = parse_with_edges(&path, SAMPLE_PYTHON).unwrap();
286
287        assert!(!result.edges.is_empty(), "should have CONTAINS edges");
288        // All edges should be from file node (idx 0)
289        for &(from_idx, _, ref kind) in &result.edges {
290            assert_eq!(from_idx, 0);
291            assert_eq!(*kind, EdgeKind::Contains);
292        }
293    }
294
295    #[test]
296    fn test_parse_empty_python_file() {
297        let path = PathBuf::from("empty.py");
298        let nodes = parse(&path, "").unwrap();
299        assert!(nodes.len() <= 1);
300    }
301
302    #[test]
303    fn test_parse_python_decorated_function() {
304        let source = r#"
305import functools
306
307def decorator(f):
308    return f
309
310@decorator
311def decorated():
312    pass
313
314class MyClass:
315    @staticmethod
316    def static_method():
317        pass
318
319    @classmethod
320    def class_method(cls):
321        pass
322"#;
323        let path = PathBuf::from("deco.py");
324        let nodes = parse(&path, source).unwrap();
325
326        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
327        assert!(fns.iter().any(|f| f.name == "decorator"));
328        assert!(fns.iter().any(|f| f.name == "decorated"));
329
330        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
331        assert!(methods.iter().any(|m| m.name == "static_method"));
332        assert!(methods.iter().any(|m| m.name == "class_method"));
333    }
334
335    #[test]
336    fn test_parse_python_nested_class() {
337        let source = r#"
338class Outer:
339    class Inner:
340        def inner_method(self):
341            pass
342
343    def outer_method(self):
344        pass
345"#;
346        let path = PathBuf::from("nested.py");
347        let nodes = parse(&path, source).unwrap();
348
349        let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
350        assert!(classes.iter().any(|c| c.name == "Outer"));
351    }
352
353    #[test]
354    fn test_parse_python_async_function() {
355        let source = r#"
356import asyncio
357
358async def fetch_data(url: str) -> dict:
359    return {}
360
361class Client:
362    async def connect(self):
363        pass
364"#;
365        let path = PathBuf::from("async.py");
366        let nodes = parse(&path, source).unwrap();
367
368        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
369        assert!(fns.iter().any(|f| f.name == "fetch_data"), "should find async function");
370
371        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
372        assert!(methods.iter().any(|m| m.name == "connect"), "should find async method");
373    }
374
375    #[test]
376    fn test_parse_python_lowercase_not_constant() {
377        let source = r#"
378MAX_SIZE = 100
379lowercase_var = "not a constant"
380_private = True
381"#;
382        let path = PathBuf::from("vars.py");
383        let nodes = parse(&path, source).unwrap();
384
385        let constants: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Constant).collect();
386        assert!(constants.iter().any(|c| c.name == "MAX_SIZE"));
387        // lowercase should NOT be captured as constant
388        assert!(!constants.iter().any(|c| c.name == "lowercase_var"));
389        assert!(!constants.iter().any(|c| c.name == "_private"));
390    }
391
392    #[test]
393    fn test_parse_python_multiple_imports() {
394        let source = r#"
395import os
396import sys
397from typing import Dict, List, Optional
398from pathlib import Path
399from collections import defaultdict
400"#;
401        let path = PathBuf::from("imports.py");
402        let nodes = parse(&path, source).unwrap();
403
404        let imports: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Import).collect();
405        assert_eq!(imports.len(), 5, "should find all 5 import statements");
406    }
407}