Skip to main content

cgx_engine/parsers/
py.rs

1use tree_sitter::{Parser, Query, QueryCursor};
2
3use crate::parser::{meta_set, EdgeDef, EdgeKind, LanguageParser, NodeDef, NodeKind, ParseResult};
4use crate::walker::SourceFile;
5
6pub struct PythonParser {
7    language: tree_sitter::Language,
8}
9
10impl PythonParser {
11    pub fn new() -> Self {
12        Self {
13            language: tree_sitter_python::language(),
14        }
15    }
16}
17
18impl Default for PythonParser {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl LanguageParser for PythonParser {
25    fn extensions(&self) -> &[&str] {
26        &["py"]
27    }
28
29    fn extract(&self, file: &SourceFile) -> anyhow::Result<ParseResult> {
30        let mut parser = Parser::new();
31        parser.set_language(&self.language)?;
32
33        let tree = parser
34            .parse(&file.content, None)
35            .ok_or_else(|| anyhow::anyhow!("failed to parse {}", file.relative_path))?;
36
37        let source_bytes = file.content.as_bytes();
38        let root = tree.root_node();
39        let mut nodes = Vec::new();
40        let mut edges = Vec::new();
41
42        let fp = format!("file:{}", file.relative_path);
43
44        // Function definitions
45        if let Ok(query) = Query::new(
46            &self.language,
47            "(function_definition name: (identifier) @name) @fn",
48        ) {
49            let mut cursor = QueryCursor::new();
50            for m in cursor.matches(&query, root, source_bytes) {
51                let Some(name_capture) = m
52                    .captures
53                    .iter()
54                    .find(|c| query.capture_names()[c.index as usize] == "name")
55                else {
56                    continue;
57                };
58                let fn_node = m
59                    .captures
60                    .iter()
61                    .find(|c| query.capture_names()[c.index as usize] == "fn")
62                    .map(|c| c.node);
63                let name = node_text(name_capture.node, source_bytes);
64                let start = name_capture.node.start_position();
65                let body_end = fn_node
66                    .map(|n| n.end_position())
67                    .unwrap_or_else(|| name_capture.node.end_position());
68                let id = format!("fn:{}:{}", file.relative_path, name);
69
70                let doc_comment = fn_node.and_then(|n| extract_py_docstring(n, source_bytes));
71
72                let mut def = NodeDef {
73                    id: id.clone(),
74                    kind: NodeKind::Function,
75                    name,
76                    path: file.relative_path.clone(),
77                    line_start: start.row as u32 + 1,
78                    line_end: body_end.row as u32 + 1,
79                    ..Default::default()
80                };
81                if let Some(doc) = doc_comment {
82                    meta_set(&mut def, "doc_comment", serde_json::Value::String(doc));
83                }
84                nodes.push(def);
85
86                edges.push(EdgeDef {
87                    src: fp.clone(),
88                    dst: id,
89                    kind: EdgeKind::Exports,
90                    ..Default::default()
91                });
92            }
93        }
94
95        // Class definitions
96        if let Ok(query) = Query::new(
97            &self.language,
98            "(class_definition name: (identifier) @name) @cls",
99        ) {
100            let mut cursor = QueryCursor::new();
101            for m in cursor.matches(&query, root, source_bytes) {
102                let Some(name_capture) = m
103                    .captures
104                    .iter()
105                    .find(|c| query.capture_names()[c.index as usize] == "name")
106                else {
107                    continue;
108                };
109                let cls_node = m
110                    .captures
111                    .iter()
112                    .find(|c| query.capture_names()[c.index as usize] == "cls")
113                    .map(|c| c.node);
114                let name = node_text(name_capture.node, source_bytes);
115                let start = name_capture.node.start_position();
116                let body_end = cls_node
117                    .map(|n| n.end_position())
118                    .unwrap_or_else(|| name_capture.node.end_position());
119                let id = format!("cls:{}:{}", file.relative_path, name);
120
121                let doc_comment = cls_node.and_then(|n| extract_py_docstring(n, source_bytes));
122
123                let mut def = NodeDef {
124                    id: id.clone(),
125                    kind: NodeKind::Class,
126                    name,
127                    path: file.relative_path.clone(),
128                    line_start: start.row as u32 + 1,
129                    line_end: body_end.row as u32 + 1,
130                    ..Default::default()
131                };
132                if let Some(doc) = doc_comment {
133                    meta_set(&mut def, "doc_comment", serde_json::Value::String(doc));
134                }
135                nodes.push(def);
136
137                edges.push(EdgeDef {
138                    src: fp.clone(),
139                    dst: id,
140                    kind: EdgeKind::Exports,
141                    ..Default::default()
142                });
143            }
144        }
145
146        // Import statements: `from X import Y`
147        if let Ok(query) = Query::new(
148            &self.language,
149            r#"(import_from_statement
150                module_name: (dotted_name) @mod
151                name: (dotted_name (identifier) @name))
152            "#,
153        ) {
154            let mut cursor = QueryCursor::new();
155            for m in cursor.matches(&query, root, source_bytes) {
156                let mod_name = m
157                    .captures
158                    .iter()
159                    .find(|c| query.capture_names()[c.index as usize] == "mod")
160                    .map(|c| node_text(c.node, source_bytes));
161                let import_name = m
162                    .captures
163                    .iter()
164                    .find(|c| query.capture_names()[c.index as usize] == "name")
165                    .map(|c| node_text(c.node, source_bytes));
166
167                if let (Some(mod_name), Some(_import_name)) = (mod_name, import_name) {
168                    let import_path = resolve_py_import(&file.relative_path, &mod_name);
169                    edges.push(EdgeDef {
170                        src: fp.clone(),
171                        dst: format!("file:{}", import_path),
172                        kind: EdgeKind::Imports,
173                        ..Default::default()
174                    });
175                }
176            }
177        }
178
179        // Import statements: `import X`
180        if let Ok(query) = Query::new(
181            &self.language,
182            "(import_statement name: (dotted_name (identifier) @name))",
183        ) {
184            let mut cursor = QueryCursor::new();
185            for m in cursor.matches(&query, root, source_bytes) {
186                if let Some(cap) = m
187                    .captures
188                    .iter()
189                    .find(|c| query.capture_names()[c.index as usize] == "name")
190                {
191                    let mod_name = node_text(cap.node, source_bytes);
192                    let import_path = format!("{}.py", mod_name.replace('.', "/"));
193                    edges.push(EdgeDef {
194                        src: fp.clone(),
195                        dst: format!("file:{}", import_path),
196                        kind: EdgeKind::Imports,
197                        ..Default::default()
198                    });
199                }
200            }
201        }
202
203        // In Python, top-level functions/classes not starting with _ are considered exported.
204        // Preserve any existing metadata (e.g. doc_comment) by merging rather than overwriting.
205        for node_def in &mut nodes {
206            if !node_def.name.starts_with('_') {
207                meta_set(node_def, "exported", serde_json::Value::Bool(true));
208            }
209        }
210
211        Ok(ParseResult {
212            nodes,
213            edges,
214            ..Default::default()
215        })
216    }
217}
218
219fn node_text(node: tree_sitter::Node, source: &[u8]) -> String {
220    node.utf8_text(source).unwrap_or("").to_string()
221}
222
223/// Python docstring = the first `string` statement inside the function/class body block.
224fn extract_py_docstring(def_node: tree_sitter::Node, source: &[u8]) -> Option<String> {
225    let body = def_node.child_by_field_name("body")?;
226    let mut cursor = body.walk();
227    if !cursor.goto_first_child() {
228        return None;
229    }
230    loop {
231        let stmt = cursor.node();
232        // Tree-sitter-python wraps the docstring as: block > expression_statement > string
233        if stmt.kind() == "expression_statement" {
234            let mut inner = stmt.walk();
235            if inner.goto_first_child() && inner.node().kind() == "string" {
236                let raw = inner.node().utf8_text(source).unwrap_or("").trim();
237                return Some(strip_py_string_quotes(raw));
238            }
239        }
240        // Stop at first non-string statement — docstring must be the very first thing.
241        if stmt.kind() != "comment" {
242            return None;
243        }
244        if !cursor.goto_next_sibling() {
245            return None;
246        }
247    }
248}
249
250fn strip_py_string_quotes(raw: &str) -> String {
251    let s = raw.trim();
252    // Strip optional u/b/r/f prefix
253    let s = s.trim_start_matches(['u', 'b', 'r', 'f', 'U', 'B', 'R', 'F']);
254    let inner = if let Some(rest) = s.strip_prefix("\"\"\"") {
255        rest.strip_suffix("\"\"\"").unwrap_or(rest)
256    } else if let Some(rest) = s.strip_prefix("'''") {
257        rest.strip_suffix("'''").unwrap_or(rest)
258    } else if let Some(rest) = s.strip_prefix('"') {
259        rest.strip_suffix('"').unwrap_or(rest)
260    } else if let Some(rest) = s.strip_prefix('\'') {
261        rest.strip_suffix('\'').unwrap_or(rest)
262    } else {
263        s
264    };
265    inner.trim().to_string()
266}
267
268fn resolve_py_import(current_file: &str, module_name: &str) -> String {
269    let dot_count = module_name.chars().take_while(|c| *c == '.').count();
270
271    if dot_count > 0 {
272        // Relative import: `.models` stays in same dir, `..models` goes up 1, etc.
273        let remainder = &module_name[dot_count..];
274        let mut parts: Vec<&str> = current_file.split('/').collect();
275        parts.pop(); // remove filename
276                     // `.` = 0 extra pops, `..` = 1 extra pop, `...` = 2 extra pops
277        let up_count = dot_count.saturating_sub(1);
278        for _ in 0..up_count {
279            parts.pop();
280        }
281        if remainder.is_empty() {
282            // `from . import X` — importing from current package's __init__.py
283            parts.push("__init__");
284        } else {
285            parts.push(remainder);
286        }
287        format!("{}.py", parts.join("/"))
288    } else {
289        // Absolute import
290        format!("{}.py", module_name.replace('.', "/"))
291    }
292}