Skip to main content

dk_engine/parser/
python_parser.rs

1use super::LanguageParser;
2use dk_core::{CallKind, Import, RawCallEdge, Result, Span, Symbol, SymbolKind, TypeInfo, Visibility};
3use std::path::Path;
4use tree_sitter::{Node, Parser, TreeCursor};
5use uuid::Uuid;
6
7/// Python parser backed by tree-sitter.
8///
9/// Extracts symbols, call edges, imports, and (stub) type information from
10/// Python source files.
11pub struct PythonParser;
12
13impl PythonParser {
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Create a configured tree-sitter parser for Python.
19    fn create_parser() -> Result<Parser> {
20        let mut parser = Parser::new();
21        parser
22            .set_language(&tree_sitter_python::LANGUAGE.into())
23            .map_err(|e| dk_core::Error::ParseError(format!("Failed to load Python grammar: {e}")))?;
24        Ok(parser)
25    }
26
27    /// Parse source bytes into a tree-sitter tree.
28    fn parse_tree(source: &[u8]) -> Result<tree_sitter::Tree> {
29        let mut parser = Self::create_parser()?;
30        parser
31            .parse(source, None)
32            .ok_or_else(|| dk_core::Error::ParseError("tree-sitter parse returned None".into()))
33    }
34
35    /// Get the text of a node as a UTF-8 string.
36    fn node_text<'a>(node: &Node, source: &'a [u8]) -> &'a str {
37        let text = &source[node.start_byte()..node.end_byte()];
38        std::str::from_utf8(text).unwrap_or("")
39    }
40
41    /// Determine visibility based on Python naming conventions.
42    /// Names starting with `_` are considered private; everything else is public.
43    fn name_visibility(name: &str) -> Visibility {
44        if name.starts_with('_') {
45            Visibility::Private
46        } else {
47            Visibility::Public
48        }
49    }
50
51    /// Extract the name from a function_definition or class_definition node.
52    fn node_name(node: &Node, source: &[u8]) -> Option<String> {
53        node.child_by_field_name("name")
54            .map(|n| Self::node_text(&n, source).to_string())
55    }
56
57    /// Extract the first line of the node's source text as the signature.
58    fn node_signature(node: &Node, source: &[u8]) -> Option<String> {
59        let text_str = Self::node_text(node, source);
60        let first_line = text_str.lines().next()?;
61        Some(first_line.trim().to_string())
62    }
63
64    /// Extract docstring from a function or class body.
65    ///
66    /// In Python, a docstring is the first statement in the body if it is an
67    /// `expression_statement` containing a `string` node.
68    fn extract_docstring(node: &Node, source: &[u8]) -> Option<String> {
69        // Look for the "body" field (block node)
70        let body = node.child_by_field_name("body")?;
71
72        // The first child of the block should be the potential docstring
73        let first_stmt = body.child(0)?;
74
75        if first_stmt.kind() == "expression_statement" {
76            let expr = first_stmt.child(0)?;
77            if expr.kind() == "string" {
78                let raw = Self::node_text(&expr, source);
79                // Strip triple-quote delimiters and clean up
80                let content = raw
81                    .strip_prefix("\"\"\"")
82                    .and_then(|s| s.strip_suffix("\"\"\""))
83                    .or_else(|| {
84                        raw.strip_prefix("'''")
85                            .and_then(|s| s.strip_suffix("'''"))
86                    })
87                    .unwrap_or(raw);
88                let trimmed = content.trim().to_string();
89                if !trimmed.is_empty() {
90                    return Some(trimmed);
91                }
92            }
93        }
94
95        None
96    }
97
98    /// Collect preceding `#` comments for a node.
99    fn doc_comments(node: &Node, source: &[u8]) -> Option<String> {
100        let mut comments = Vec::new();
101        let mut sibling = node.prev_sibling();
102
103        while let Some(prev) = sibling {
104            if prev.kind() == "comment" {
105                let text = Self::node_text(&prev, source).trim().to_string();
106                // Strip the `# ` prefix
107                let content = text
108                    .strip_prefix("# ")
109                    .unwrap_or(text.strip_prefix('#').unwrap_or(&text));
110                comments.push(content.to_string());
111                sibling = prev.prev_sibling();
112                continue;
113            }
114            break;
115        }
116
117        if comments.is_empty() {
118            None
119        } else {
120            comments.reverse();
121            Some(comments.join("\n"))
122        }
123    }
124
125    /// Extract a symbol from a function_definition or class_definition node.
126    fn extract_symbol_from_def(
127        node: &Node,
128        source: &[u8],
129        file_path: &Path,
130    ) -> Option<Symbol> {
131        let kind = match node.kind() {
132            "function_definition" => SymbolKind::Function,
133            "class_definition" => SymbolKind::Class,
134            _ => return None,
135        };
136
137        let name = Self::node_name(node, source)?;
138        if name.is_empty() {
139            return None;
140        }
141
142        let visibility = Self::name_visibility(&name);
143        let signature = Self::node_signature(node, source);
144
145        // Try docstring first, fall back to preceding comments
146        let doc_comment = Self::extract_docstring(node, source)
147            .or_else(|| Self::doc_comments(node, source));
148
149        Some(Symbol {
150            id: Uuid::new_v4(),
151            name: name.clone(),
152            qualified_name: name,
153            kind,
154            visibility,
155            file_path: file_path.to_path_buf(),
156            span: Span {
157                start_byte: node.start_byte() as u32,
158                end_byte: node.end_byte() as u32,
159            },
160            signature,
161            doc_comment,
162            parent: None,
163            last_modified_by: None,
164            last_modified_intent: None,
165        })
166    }
167
168    /// Extract the name from a simple assignment at the top level.
169    /// e.g. `MAX_RETRIES = 3` yields "MAX_RETRIES".
170    /// Only handles simple identifier = value assignments (not tuple unpacking, etc.).
171    fn extract_assignment_name(node: &Node, source: &[u8]) -> Option<String> {
172        if node.kind() != "expression_statement" {
173            return None;
174        }
175
176        // The expression_statement should contain an assignment
177        let child = node.child(0)?;
178        if child.kind() != "assignment" {
179            return None;
180        }
181
182        // The left side should be a simple identifier
183        let left = child.child_by_field_name("left")?;
184        if left.kind() != "identifier" {
185            return None;
186        }
187
188        let name = Self::node_text(&left, source).to_string();
189        if name.is_empty() {
190            None
191        } else {
192            Some(name)
193        }
194    }
195
196    /// Find the name of the enclosing function for a given node, if any.
197    fn enclosing_function_name(node: &Node, source: &[u8]) -> String {
198        let mut current = node.parent();
199        while let Some(parent) = current {
200            if parent.kind() == "function_definition" {
201                if let Some(name_node) = parent.child_by_field_name("name") {
202                    let name = Self::node_text(&name_node, source);
203                    if !name.is_empty() {
204                        return name.to_string();
205                    }
206                }
207            }
208            current = parent.parent();
209        }
210        "<module>".to_string()
211    }
212
213    /// Extract the callee name and call kind from a call node's function field.
214    fn extract_callee_info(node: &Node, source: &[u8]) -> (String, CallKind) {
215        match node.kind() {
216            "attribute" => {
217                // e.g. obj.method — the callee is the attribute (method name)
218                if let Some(attr) = node.child_by_field_name("attribute") {
219                    let name = Self::node_text(&attr, source).to_string();
220                    return (name, CallKind::MethodCall);
221                }
222                let text = Self::node_text(node, source).to_string();
223                (text, CallKind::MethodCall)
224            }
225            "identifier" => {
226                let name = Self::node_text(node, source).to_string();
227                (name, CallKind::DirectCall)
228            }
229            _ => {
230                let text = Self::node_text(node, source).to_string();
231                (text, CallKind::DirectCall)
232            }
233        }
234    }
235
236    /// Recursively walk the tree to extract call edges.
237    fn walk_calls(cursor: &mut TreeCursor, source: &[u8], calls: &mut Vec<RawCallEdge>) {
238        let node = cursor.node();
239
240        match node.kind() {
241            "call" => {
242                // Python call node has a "function" field
243                if let Some(func_node) = node.child_by_field_name("function") {
244                    let (callee, kind) = Self::extract_callee_info(&func_node, source);
245                    if !callee.is_empty() {
246                        let caller = Self::enclosing_function_name(&node, source);
247                        calls.push(RawCallEdge {
248                            caller_name: caller,
249                            callee_name: callee,
250                            call_site: Span {
251                                start_byte: node.start_byte() as u32,
252                                end_byte: node.end_byte() as u32,
253                            },
254                            kind,
255                        });
256                    }
257                }
258            }
259            "decorator" => {
260                // A decorator is effectively a call to the decorator function.
261                // The decorator node contains the decorator expression (after @).
262                // It can be a simple identifier like `@login_required`,
263                // a call like `@app.route("/api")`, or an attribute like `@app.middleware`.
264                //
265                // For `@login_required`, the child is an identifier.
266                // For `@app.route("/api")`, the child is a call node (which walk_calls handles).
267                // For `@app.middleware`, the child is an attribute.
268                //
269                // We handle the identifier and attribute cases here; the call case
270                // is handled recursively when we descend into children.
271                let mut inner_cursor = node.walk();
272                for child in node.children(&mut inner_cursor) {
273                    match child.kind() {
274                        "identifier" => {
275                            let name = Self::node_text(&child, source).to_string();
276                            if !name.is_empty() {
277                                let caller = Self::enclosing_function_name(&node, source);
278                                calls.push(RawCallEdge {
279                                    caller_name: caller,
280                                    callee_name: name,
281                                    call_site: Span {
282                                        start_byte: node.start_byte() as u32,
283                                        end_byte: node.end_byte() as u32,
284                                    },
285                                    kind: CallKind::DirectCall,
286                                });
287                            }
288                        }
289                        "attribute" => {
290                            if let Some(attr) = child.child_by_field_name("attribute") {
291                                let name = Self::node_text(&attr, source).to_string();
292                                if !name.is_empty() {
293                                    let caller = Self::enclosing_function_name(&node, source);
294                                    calls.push(RawCallEdge {
295                                        caller_name: caller,
296                                        callee_name: name,
297                                        call_site: Span {
298                                            start_byte: node.start_byte() as u32,
299                                            end_byte: node.end_byte() as u32,
300                                        },
301                                        kind: CallKind::MethodCall,
302                                    });
303                                }
304                            }
305                        }
306                        _ => {}
307                    }
308                }
309            }
310            _ => {}
311        }
312
313        // Recurse into children
314        if cursor.goto_first_child() {
315            loop {
316                Self::walk_calls(cursor, source, calls);
317                if !cursor.goto_next_sibling() {
318                    break;
319                }
320            }
321            cursor.goto_parent();
322        }
323    }
324
325    /// Extract imports from an `import_statement` node.
326    /// e.g. `import os` or `import os, sys`
327    fn extract_import_statement(node: &Node, source: &[u8]) -> Vec<Import> {
328        let mut imports = Vec::new();
329        let mut cursor = node.walk();
330
331        for child in node.children(&mut cursor) {
332            match child.kind() {
333                "dotted_name" => {
334                    let module = Self::node_text(&child, source).to_string();
335                    if !module.is_empty() {
336                        imports.push(Import {
337                            module_path: module.clone(),
338                            imported_name: module,
339                            alias: None,
340                            is_external: true,
341                        });
342                    }
343                }
344                "aliased_import" => {
345                    let name_node = child.child_by_field_name("name");
346                    let alias_node = child.child_by_field_name("alias");
347
348                    if let Some(name_n) = name_node {
349                        let module = Self::node_text(&name_n, source).to_string();
350                        let alias = alias_node
351                            .map(|a| Self::node_text(&a, source).to_string());
352                        imports.push(Import {
353                            module_path: module.clone(),
354                            imported_name: module,
355                            alias,
356                            is_external: true,
357                        });
358                    }
359                }
360                _ => {}
361            }
362        }
363
364        imports
365    }
366
367    /// Extract imports from an `import_from_statement` node.
368    /// e.g. `from os.path import join, exists` or `from .local import helper`
369    fn extract_import_from_statement(node: &Node, source: &[u8]) -> Vec<Import> {
370        let mut imports = Vec::new();
371
372        // Get the module name. In tree-sitter-python the module is in the
373        // "module_name" field. For relative imports it includes the dots.
374        let module_path = Self::extract_from_module_path(node, source);
375        let is_external = !module_path.starts_with('.');
376
377        // Collect imported names
378        let mut cursor = node.walk();
379        for child in node.children(&mut cursor) {
380            match child.kind() {
381                "dotted_name" | "identifier" => {
382                    // Skip the module name itself (already captured)
383                    // The imported names come after the "import" keyword
384                    // In tree-sitter-python, the imported names are in the node's
385                    // named children that are not the module_name field.
386                    // We need to distinguish module from imported names.
387                }
388                "aliased_import" => {
389                    let name_node = child.child_by_field_name("name");
390                    let alias_node = child.child_by_field_name("alias");
391
392                    if let Some(name_n) = name_node {
393                        let imported_name = Self::node_text(&name_n, source).to_string();
394                        let alias = alias_node
395                            .map(|a| Self::node_text(&a, source).to_string());
396                        imports.push(Import {
397                            module_path: module_path.clone(),
398                            imported_name,
399                            alias,
400                            is_external,
401                        });
402                    }
403                }
404                "wildcard_import" => {
405                    imports.push(Import {
406                        module_path: module_path.clone(),
407                        imported_name: "*".to_string(),
408                        alias: None,
409                        is_external,
410                    });
411                }
412                _ => {}
413            }
414        }
415
416        // If we found no imports from the structured children above, parse
417        // the imported names from the node text. The tree-sitter-python grammar
418        // places imported names as direct children of import_from_statement.
419        if imports.is_empty() {
420            Self::extract_from_imported_names(node, source, &module_path, is_external, &mut imports);
421        }
422
423        imports
424    }
425
426    /// Extract the module path from a `from ... import` statement.
427    /// Handles both absolute (`from os.path`) and relative (`from .local`) imports.
428    fn extract_from_module_path(node: &Node, source: &[u8]) -> String {
429        // The module_name field contains the dotted name (may include leading dots for relative).
430        if let Some(module_node) = node.child_by_field_name("module_name") {
431            return Self::node_text(&module_node, source).to_string();
432        }
433
434        // Fallback: reconstruct from the node text between `from` and `import`.
435        let text = Self::node_text(node, source);
436        if let Some(from_idx) = text.find("from") {
437            let after_from = &text[from_idx + 4..];
438            if let Some(import_idx) = after_from.find("import") {
439                let module = after_from[..import_idx].trim();
440                return module.to_string();
441            }
442        }
443
444        String::new()
445    }
446
447    /// Extract imported names from a from-import statement by walking its children.
448    fn extract_from_imported_names(
449        node: &Node,
450        source: &[u8],
451        module_path: &str,
452        is_external: bool,
453        imports: &mut Vec<Import>,
454    ) {
455        // Walk through all children looking for imported names.
456        // In tree-sitter-python, after the module_name and "import" keyword,
457        // the imported identifiers appear as children.
458        let mut found_import_keyword = false;
459        let mut cursor = node.walk();
460
461        for child in node.children(&mut cursor) {
462            let text = Self::node_text(&child, source);
463
464            if text == "import" {
465                found_import_keyword = true;
466                continue;
467            }
468
469            if !found_import_keyword {
470                continue;
471            }
472
473            match child.kind() {
474                "dotted_name" | "identifier" => {
475                    let imported_name = text.to_string();
476                    if !imported_name.is_empty() && imported_name != "," {
477                        imports.push(Import {
478                            module_path: module_path.to_string(),
479                            imported_name,
480                            alias: None,
481                            is_external,
482                        });
483                    }
484                }
485                "aliased_import" => {
486                    let name_node = child.child_by_field_name("name");
487                    let alias_node = child.child_by_field_name("alias");
488
489                    if let Some(name_n) = name_node {
490                        let imported_name = Self::node_text(&name_n, source).to_string();
491                        let alias = alias_node
492                            .map(|a| Self::node_text(&a, source).to_string());
493                        imports.push(Import {
494                            module_path: module_path.to_string(),
495                            imported_name,
496                            alias,
497                            is_external,
498                        });
499                    }
500                }
501                "wildcard_import" => {
502                    imports.push(Import {
503                        module_path: module_path.to_string(),
504                        imported_name: "*".to_string(),
505                        alias: None,
506                        is_external,
507                    });
508                }
509                _ => {}
510            }
511        }
512    }
513}
514
515impl Default for PythonParser {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521impl LanguageParser for PythonParser {
522    fn extensions(&self) -> &[&str] {
523        &["py"]
524    }
525
526    fn extract_symbols(&self, source: &[u8], file_path: &Path) -> Result<Vec<Symbol>> {
527        if source.is_empty() {
528            return Ok(vec![]);
529        }
530
531        let tree = Self::parse_tree(source)?;
532        let root = tree.root_node();
533        let mut symbols = Vec::new();
534        let mut cursor = root.walk();
535
536        for node in root.children(&mut cursor) {
537            match node.kind() {
538                "function_definition" | "class_definition" => {
539                    if let Some(sym) = Self::extract_symbol_from_def(&node, source, file_path) {
540                        symbols.push(sym);
541                    }
542                }
543                "decorated_definition" => {
544                    // Unwrap the decorated_definition to find the inner function or class
545                    if let Some(definition) = node.child_by_field_name("definition") {
546                        match definition.kind() {
547                            "function_definition" | "class_definition" => {
548                                if let Some(mut sym) =
549                                    Self::extract_symbol_from_def(&definition, source, file_path)
550                                {
551                                    // Use the span of the whole decorated definition
552                                    sym.span = Span {
553                                        start_byte: node.start_byte() as u32,
554                                        end_byte: node.end_byte() as u32,
555                                    };
556                                    // Include the decorator in the signature
557                                    sym.signature = Self::node_signature(&node, source);
558                                    symbols.push(sym);
559                                }
560                            }
561                            _ => {}
562                        }
563                    }
564                }
565                "expression_statement" => {
566                    // Module-level assignment
567                    if let Some(name) = Self::extract_assignment_name(&node, source) {
568                        let visibility = Self::name_visibility(&name);
569                        symbols.push(Symbol {
570                            id: Uuid::new_v4(),
571                            name: name.clone(),
572                            qualified_name: name,
573                            kind: SymbolKind::Variable,
574                            visibility,
575                            file_path: file_path.to_path_buf(),
576                            span: Span {
577                                start_byte: node.start_byte() as u32,
578                                end_byte: node.end_byte() as u32,
579                            },
580                            signature: Self::node_signature(&node, source),
581                            doc_comment: Self::doc_comments(&node, source),
582                            parent: None,
583                            last_modified_by: None,
584                            last_modified_intent: None,
585                        });
586                    }
587                }
588                _ => {}
589            }
590        }
591
592        Ok(symbols)
593    }
594
595    fn extract_calls(&self, source: &[u8], _file_path: &Path) -> Result<Vec<RawCallEdge>> {
596        if source.is_empty() {
597            return Ok(vec![]);
598        }
599
600        let tree = Self::parse_tree(source)?;
601        let root = tree.root_node();
602        let mut calls = Vec::new();
603        let mut cursor = root.walk();
604
605        Self::walk_calls(&mut cursor, source, &mut calls);
606
607        Ok(calls)
608    }
609
610    fn extract_types(&self, _source: &[u8], _file_path: &Path) -> Result<Vec<TypeInfo>> {
611        // Stub: will be enhanced later
612        Ok(vec![])
613    }
614
615    fn extract_imports(&self, source: &[u8], _file_path: &Path) -> Result<Vec<Import>> {
616        if source.is_empty() {
617            return Ok(vec![]);
618        }
619
620        let tree = Self::parse_tree(source)?;
621        let root = tree.root_node();
622        let mut imports = Vec::new();
623        let mut cursor = root.walk();
624
625        for node in root.children(&mut cursor) {
626            match node.kind() {
627                "import_statement" => {
628                    imports.extend(Self::extract_import_statement(&node, source));
629                }
630                "import_from_statement" => {
631                    imports.extend(Self::extract_import_from_statement(&node, source));
632                }
633                _ => {}
634            }
635        }
636
637        Ok(imports)
638    }
639}