Skip to main content

cgx_engine/parsers/
py.rs

1use tree_sitter::{Parser, Query, QueryCursor};
2
3use crate::parser::{EdgeDef, EdgeKind, LanguageParser, NodeDef, NodeKind, ParseResult};
4use crate::walker::SourceFile;
5
6pub struct PythonParser {
7    language: tree_sitter::Language,
8}
9
10impl PythonParser {
11    pub fn new() -> Self {
12        Self {
13            language: tree_sitter_python::language(),
14        }
15    }
16}
17
18impl Default for PythonParser {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl LanguageParser for PythonParser {
25    fn extensions(&self) -> &[&str] {
26        &["py"]
27    }
28
29    fn extract(&self, file: &SourceFile) -> anyhow::Result<ParseResult> {
30        let mut parser = Parser::new();
31        parser.set_language(&self.language)?;
32
33        let tree = parser
34            .parse(&file.content, None)
35            .ok_or_else(|| anyhow::anyhow!("failed to parse {}", file.relative_path))?;
36
37        let source_bytes = file.content.as_bytes();
38        let root = tree.root_node();
39        let mut nodes = Vec::new();
40        let mut edges = Vec::new();
41
42        let fp = format!("file:{}", file.relative_path);
43
44        // Function definitions
45        if let Ok(query) = Query::new(
46            &self.language,
47            "(function_definition name: (identifier) @name) @fn",
48        ) {
49            let mut cursor = QueryCursor::new();
50            for m in cursor.matches(&query, root, source_bytes) {
51                let Some(name_capture) = m
52                    .captures
53                    .iter()
54                    .find(|c| query.capture_names()[c.index as usize] == "name")
55                else {
56                    continue;
57                };
58                let name = node_text(name_capture.node, source_bytes);
59                let start = name_capture.node.start_position();
60                let body_end = m
61                    .captures
62                    .iter()
63                    .find(|c| query.capture_names()[c.index as usize] == "fn")
64                    .map(|c| c.node.end_position())
65                    .unwrap_or_else(|| name_capture.node.end_position());
66                let id = format!("fn:{}:{}", file.relative_path, name);
67
68                nodes.push(NodeDef {
69                    id: id.clone(),
70                    kind: NodeKind::Function,
71                    name,
72                    path: file.relative_path.clone(),
73                    line_start: start.row as u32 + 1,
74                    line_end: body_end.row as u32 + 1,
75                    ..Default::default()
76                });
77
78                edges.push(EdgeDef {
79                    src: fp.clone(),
80                    dst: id,
81                    kind: EdgeKind::Exports,
82                    ..Default::default()
83                });
84            }
85        }
86
87        // Class definitions
88        if let Ok(query) = Query::new(
89            &self.language,
90            "(class_definition name: (identifier) @name) @cls",
91        ) {
92            let mut cursor = QueryCursor::new();
93            for m in cursor.matches(&query, root, source_bytes) {
94                let Some(name_capture) = m
95                    .captures
96                    .iter()
97                    .find(|c| query.capture_names()[c.index as usize] == "name")
98                else {
99                    continue;
100                };
101                let name = node_text(name_capture.node, source_bytes);
102                let start = name_capture.node.start_position();
103                let body_end = m
104                    .captures
105                    .iter()
106                    .find(|c| query.capture_names()[c.index as usize] == "cls")
107                    .map(|c| c.node.end_position())
108                    .unwrap_or_else(|| name_capture.node.end_position());
109                let id = format!("cls:{}:{}", file.relative_path, name);
110
111                nodes.push(NodeDef {
112                    id: id.clone(),
113                    kind: NodeKind::Class,
114                    name,
115                    path: file.relative_path.clone(),
116                    line_start: start.row as u32 + 1,
117                    line_end: body_end.row as u32 + 1,
118                    ..Default::default()
119                });
120
121                edges.push(EdgeDef {
122                    src: fp.clone(),
123                    dst: id,
124                    kind: EdgeKind::Exports,
125                    ..Default::default()
126                });
127            }
128        }
129
130        // Import statements: `from X import Y`
131        if let Ok(query) = Query::new(
132            &self.language,
133            r#"(import_from_statement
134                module_name: (dotted_name) @mod
135                name: (dotted_name (identifier) @name))
136            "#,
137        ) {
138            let mut cursor = QueryCursor::new();
139            for m in cursor.matches(&query, root, source_bytes) {
140                let mod_name = m
141                    .captures
142                    .iter()
143                    .find(|c| query.capture_names()[c.index as usize] == "mod")
144                    .map(|c| node_text(c.node, source_bytes));
145                let import_name = m
146                    .captures
147                    .iter()
148                    .find(|c| query.capture_names()[c.index as usize] == "name")
149                    .map(|c| node_text(c.node, source_bytes));
150
151                if let (Some(mod_name), Some(_import_name)) = (mod_name, import_name) {
152                    let import_path = resolve_py_import(&file.relative_path, &mod_name);
153                    edges.push(EdgeDef {
154                        src: fp.clone(),
155                        dst: format!("file:{}", import_path),
156                        kind: EdgeKind::Imports,
157                        ..Default::default()
158                    });
159                }
160            }
161        }
162
163        // Import statements: `import X`
164        if let Ok(query) = Query::new(
165            &self.language,
166            "(import_statement name: (dotted_name (identifier) @name))",
167        ) {
168            let mut cursor = QueryCursor::new();
169            for m in cursor.matches(&query, root, source_bytes) {
170                if let Some(cap) = m
171                    .captures
172                    .iter()
173                    .find(|c| query.capture_names()[c.index as usize] == "name")
174                {
175                    let mod_name = node_text(cap.node, source_bytes);
176                    let import_path = format!("{}.py", mod_name.replace('.', "/"));
177                    edges.push(EdgeDef {
178                        src: fp.clone(),
179                        dst: format!("file:{}", import_path),
180                        kind: EdgeKind::Imports,
181                        ..Default::default()
182                    });
183                }
184            }
185        }
186
187        // In Python, top-level functions/classes not starting with _ are considered exported
188        for node_def in &mut nodes {
189            if !node_def.name.starts_with('_') {
190                node_def.metadata = serde_json::json!({"exported": true});
191            }
192        }
193
194        Ok(ParseResult {
195            nodes,
196            edges,
197            ..Default::default()
198        })
199    }
200}
201
202fn node_text(node: tree_sitter::Node, source: &[u8]) -> String {
203    node.utf8_text(source).unwrap_or("").to_string()
204}
205
206fn resolve_py_import(current_file: &str, module_name: &str) -> String {
207    let dot_count = module_name.chars().take_while(|c| *c == '.').count();
208
209    if dot_count > 0 {
210        // Relative import: `.models` stays in same dir, `..models` goes up 1, etc.
211        let remainder = &module_name[dot_count..];
212        let mut parts: Vec<&str> = current_file.split('/').collect();
213        parts.pop(); // remove filename
214                     // `.` = 0 extra pops, `..` = 1 extra pop, `...` = 2 extra pops
215        let up_count = dot_count.saturating_sub(1);
216        for _ in 0..up_count {
217            parts.pop();
218        }
219        if remainder.is_empty() {
220            // `from . import X` — importing from current package's __init__.py
221            parts.push("__init__");
222        } else {
223            parts.push(remainder);
224        }
225        format!("{}.py", parts.join("/"))
226    } else {
227        // Absolute import
228        format!("{}.py", module_name.replace('.', "/"))
229    }
230}