Skip to main content

cgx_engine/parsers/
py.rs

1use tree_sitter::{Parser, Query, QueryCursor};
2
3use crate::parser::{EdgeDef, EdgeKind, LanguageParser, NodeDef, NodeKind, ParseResult};
4use crate::walker::SourceFile;
5
6pub struct PythonParser {
7    language: tree_sitter::Language,
8}
9
10impl PythonParser {
11    pub fn new() -> Self {
12        Self {
13            language: tree_sitter_python::language(),
14        }
15    }
16}
17
18impl Default for PythonParser {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl LanguageParser for PythonParser {
25    fn extensions(&self) -> &[&str] {
26        &["py"]
27    }
28
29    fn extract(&self, file: &SourceFile) -> anyhow::Result<ParseResult> {
30        let mut parser = Parser::new();
31        parser.set_language(&self.language)?;
32
33        let tree = parser.parse(&file.content, None).ok_or_else(|| {
34            anyhow::anyhow!("failed to parse {}", file.relative_path)
35        })?;
36
37        let source_bytes = file.content.as_bytes();
38        let root = tree.root_node();
39        let mut nodes = Vec::new();
40        let mut edges = Vec::new();
41
42        let fp = format!("file:{}", file.relative_path);
43
44        // Function definitions
45        if let Ok(query) = Query::new(
46            &self.language,
47            "(function_definition name: (identifier) @name) @fn",
48        ) {
49            let mut cursor = QueryCursor::new();
50            for m in cursor.matches(&query, root, source_bytes) {
51                let Some(name_capture) = m
52                    .captures
53                    .iter()
54                    .find(|c| query.capture_names()[c.index as usize] == "name")
55                else {
56                    continue;
57                };
58                let name = node_text(name_capture.node, source_bytes);
59                let start = name_capture.node.start_position();
60                let body_end = m.captures.iter()
61                    .find(|c| query.capture_names()[c.index as usize] == "fn")
62                    .map(|c| c.node.end_position())
63                    .unwrap_or_else(|| name_capture.node.end_position());
64                let id = format!("fn:{}:{}", file.relative_path, name);
65
66                nodes.push(NodeDef {
67                    id: id.clone(),
68                    kind: NodeKind::Function,
69                    name,
70                    path: file.relative_path.clone(),
71                    line_start: start.row as u32 + 1,
72                    line_end: body_end.row as u32 + 1,
73                    ..Default::default()
74                });
75
76                edges.push(EdgeDef {
77                    src: fp.clone(),
78                    dst: id,
79                    kind: EdgeKind::Exports,
80                    ..Default::default()
81                });
82            }
83        }
84
85        // Class definitions
86        if let Ok(query) = Query::new(
87            &self.language,
88            "(class_definition name: (identifier) @name) @cls",
89        ) {
90            let mut cursor = QueryCursor::new();
91            for m in cursor.matches(&query, root, source_bytes) {
92                let Some(name_capture) = m
93                    .captures
94                    .iter()
95                    .find(|c| query.capture_names()[c.index as usize] == "name")
96                else {
97                    continue;
98                };
99                let name = node_text(name_capture.node, source_bytes);
100                let start = name_capture.node.start_position();
101                let body_end = m.captures.iter()
102                    .find(|c| query.capture_names()[c.index as usize] == "cls")
103                    .map(|c| c.node.end_position())
104                    .unwrap_or_else(|| name_capture.node.end_position());
105                let id = format!("cls:{}:{}", file.relative_path, name);
106
107                nodes.push(NodeDef {
108                    id: id.clone(),
109                    kind: NodeKind::Class,
110                    name,
111                    path: file.relative_path.clone(),
112                    line_start: start.row as u32 + 1,
113                    line_end: body_end.row as u32 + 1,
114                    ..Default::default()
115                });
116
117                edges.push(EdgeDef {
118                    src: fp.clone(),
119                    dst: id,
120                    kind: EdgeKind::Exports,
121                    ..Default::default()
122                });
123            }
124        }
125
126        // Import statements: `from X import Y`
127        if let Ok(query) = Query::new(
128            &self.language,
129            r#"(import_from_statement
130                module_name: (dotted_name) @mod
131                name: (dotted_name (identifier) @name))
132            "#,
133        ) {
134            let mut cursor = QueryCursor::new();
135            for m in cursor.matches(&query, root, source_bytes) {
136                let mod_name = m
137                    .captures
138                    .iter()
139                    .find(|c| query.capture_names()[c.index as usize] == "mod")
140                    .map(|c| node_text(c.node, source_bytes));
141                let import_name = m
142                    .captures
143                    .iter()
144                    .find(|c| query.capture_names()[c.index as usize] == "name")
145                    .map(|c| node_text(c.node, source_bytes));
146
147                if let (Some(mod_name), Some(_import_name)) = (mod_name, import_name) {
148                    let import_path = resolve_py_import(&file.relative_path, &mod_name);
149                    edges.push(EdgeDef {
150                        src: fp.clone(),
151                        dst: format!("file:{}", import_path),
152                        kind: EdgeKind::Imports,
153                        ..Default::default()
154                    });
155                }
156            }
157        }
158
159        // Import statements: `import X`
160        if let Ok(query) = Query::new(
161            &self.language,
162            "(import_statement name: (dotted_name (identifier) @name))",
163        ) {
164            let mut cursor = QueryCursor::new();
165            for m in cursor.matches(&query, root, source_bytes) {
166                if let Some(cap) = m
167                    .captures
168                    .iter()
169                    .find(|c| query.capture_names()[c.index as usize] == "name")
170                {
171                    let mod_name = node_text(cap.node, source_bytes);
172                    let import_path = format!("{}.py", mod_name.replace('.', "/"));
173                    edges.push(EdgeDef {
174                        src: fp.clone(),
175                        dst: format!("file:{}", import_path),
176                        kind: EdgeKind::Imports,
177                        ..Default::default()
178                    });
179                }
180            }
181        }
182
183        Ok(ParseResult { nodes, edges })
184    }
185}
186
187fn node_text(node: tree_sitter::Node, source: &[u8]) -> String {
188    node.utf8_text(source).unwrap_or("").to_string()
189}
190
191fn resolve_py_import(current_file: &str, module_name: &str) -> String {
192    let dot_count = module_name.chars().take_while(|c| *c == '.').count();
193
194    if dot_count > 0 {
195        // Relative import: `.models` stays in same dir, `..models` goes up 1, etc.
196        let remainder = &module_name[dot_count..];
197        let mut parts: Vec<&str> = current_file.split('/').collect();
198        parts.pop(); // remove filename
199        // `.` = 0 extra pops, `..` = 1 extra pop, `...` = 2 extra pops
200        let up_count = dot_count.saturating_sub(1);
201        for _ in 0..up_count {
202            parts.pop();
203        }
204        if remainder.is_empty() {
205            // `from . import X` — importing from current package's __init__.py
206            parts.push("__init__");
207        } else {
208            parts.push(remainder);
209        }
210        format!("{}.py", parts.join("/"))
211    } else {
212        // Absolute import
213        format!("{}.py", module_name.replace('.', "/"))
214    }
215}