Skip to main content

dk_engine/parser/
python_parser.rs

1use super::LanguageParser;
2use dk_core::{CallKind, Import, RawCallEdge, Result, Span, Symbol, SymbolKind, TypeInfo, Visibility};
3use std::path::Path;
4use tree_sitter::{Node, Parser, TreeCursor};
5use uuid::Uuid;
6
7/// Python parser backed by tree-sitter.
8///
9/// Extracts symbols, call edges, imports, and (stub) type information from
10/// Python source files.
11pub struct PythonParser;
12
13impl PythonParser {
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Create a configured tree-sitter parser for Python.
19    fn create_parser() -> Result<Parser> {
20        let mut parser = Parser::new();
21        parser
22            .set_language(&tree_sitter_python::LANGUAGE.into())
23            .map_err(|e| dk_core::Error::ParseError(format!("Failed to load Python grammar: {e}")))?;
24        Ok(parser)
25    }
26
27    /// Parse source bytes into a tree-sitter tree.
28    fn parse_tree(source: &[u8]) -> Result<tree_sitter::Tree> {
29        let mut parser = Self::create_parser()?;
30        parser
31            .parse(source, None)
32            .ok_or_else(|| dk_core::Error::ParseError("tree-sitter parse returned None".into()))
33    }
34
35    /// Get the text of a node as a UTF-8 string.
36    fn node_text<'a>(node: &Node, source: &'a [u8]) -> &'a str {
37        let text = &source[node.start_byte()..node.end_byte()];
38        std::str::from_utf8(text).unwrap_or("")
39    }
40
41    /// Determine visibility based on Python naming conventions.
42    /// Names starting with `_` are considered private; everything else is public.
43    fn name_visibility(name: &str) -> Visibility {
44        if name.starts_with('_') {
45            Visibility::Private
46        } else {
47            Visibility::Public
48        }
49    }
50
51    /// Extract the name from a function_definition or class_definition node.
52    fn node_name(node: &Node, source: &[u8]) -> Option<String> {
53        node.child_by_field_name("name")
54            .map(|n| Self::node_text(&n, source).to_string())
55    }
56
57    /// Extract the first line of the node's source text as the signature.
58    fn node_signature(node: &Node, source: &[u8]) -> Option<String> {
59        let text_str = Self::node_text(node, source);
60        let first_line = text_str.lines().next()?;
61        Some(first_line.trim().to_string())
62    }
63
64    /// Extract docstring from a function or class body.
65    ///
66    /// In Python, a docstring is the first statement in the body if it is an
67    /// `expression_statement` containing a `string` node.
68    fn extract_docstring(node: &Node, source: &[u8]) -> Option<String> {
69        // Look for the "body" field (block node)
70        let body = node.child_by_field_name("body")?;
71
72        // The first child of the block should be the potential docstring
73        let first_stmt = body.child(0)?;
74
75        if first_stmt.kind() == "expression_statement" {
76            let expr = first_stmt.child(0)?;
77            if expr.kind() == "string" {
78                let raw = Self::node_text(&expr, source);
79                // Strip triple-quote delimiters and clean up
80                let content = raw
81                    .strip_prefix("\"\"\"")
82                    .and_then(|s| s.strip_suffix("\"\"\""))
83                    .or_else(|| {
84                        raw.strip_prefix("'''")
85                            .and_then(|s| s.strip_suffix("'''"))
86                    })
87                    .unwrap_or(raw);
88                let trimmed = content.trim().to_string();
89                if !trimmed.is_empty() {
90                    return Some(trimmed);
91                }
92            }
93        }
94
95        None
96    }
97
98    /// Collect preceding `#` comments for a node.
99    ///
100    /// Preserves the `#` prefix so that AST merge can reconstruct valid Python.
101    fn doc_comments(node: &Node, source: &[u8]) -> Option<String> {
102        let mut comments = Vec::new();
103        let mut sibling = node.prev_sibling();
104
105        while let Some(prev) = sibling {
106            if prev.kind() == "comment" {
107                // Preserve the full comment text including `#` prefix
108                let text = Self::node_text(&prev, source).trim().to_string();
109                comments.push(text);
110                sibling = prev.prev_sibling();
111                continue;
112            }
113            break;
114        }
115
116        if comments.is_empty() {
117            None
118        } else {
119            comments.reverse();
120            Some(comments.join("\n"))
121        }
122    }
123
124    /// Extract a symbol from a function_definition or class_definition node.
125    fn extract_symbol_from_def(
126        node: &Node,
127        source: &[u8],
128        file_path: &Path,
129    ) -> Option<Symbol> {
130        let kind = match node.kind() {
131            "function_definition" => SymbolKind::Function,
132            "class_definition" => SymbolKind::Class,
133            _ => return None,
134        };
135
136        let name = Self::node_name(node, source)?;
137        if name.is_empty() {
138            return None;
139        }
140
141        let visibility = Self::name_visibility(&name);
142        let signature = Self::node_signature(node, source);
143
144        // Try docstring first, fall back to preceding comments
145        let doc_comment = Self::extract_docstring(node, source)
146            .or_else(|| Self::doc_comments(node, source));
147
148        Some(Symbol {
149            id: Uuid::new_v4(),
150            name: name.clone(),
151            qualified_name: name,
152            kind,
153            visibility,
154            file_path: file_path.to_path_buf(),
155            span: Span {
156                start_byte: node.start_byte() as u32,
157                end_byte: node.end_byte() as u32,
158            },
159            signature,
160            doc_comment,
161            parent: None,
162            last_modified_by: None,
163            last_modified_intent: None,
164        })
165    }
166
167    /// Extract the name from a simple assignment at the top level.
168    /// e.g. `MAX_RETRIES = 3` yields "MAX_RETRIES".
169    /// Only handles simple identifier = value assignments (not tuple unpacking, etc.).
170    fn extract_assignment_name(node: &Node, source: &[u8]) -> Option<String> {
171        if node.kind() != "expression_statement" {
172            return None;
173        }
174
175        // The expression_statement should contain an assignment
176        let child = node.child(0)?;
177        if child.kind() != "assignment" {
178            return None;
179        }
180
181        // The left side should be a simple identifier
182        let left = child.child_by_field_name("left")?;
183        if left.kind() != "identifier" {
184            return None;
185        }
186
187        let name = Self::node_text(&left, source).to_string();
188        if name.is_empty() {
189            None
190        } else {
191            Some(name)
192        }
193    }
194
195    /// Find the name of the enclosing function for a given node, if any.
196    fn enclosing_function_name(node: &Node, source: &[u8]) -> String {
197        let mut current = node.parent();
198        while let Some(parent) = current {
199            if parent.kind() == "function_definition" {
200                if let Some(name_node) = parent.child_by_field_name("name") {
201                    let name = Self::node_text(&name_node, source);
202                    if !name.is_empty() {
203                        return name.to_string();
204                    }
205                }
206            }
207            current = parent.parent();
208        }
209        "<module>".to_string()
210    }
211
212    /// Extract the callee name and call kind from a call node's function field.
213    fn extract_callee_info(node: &Node, source: &[u8]) -> (String, CallKind) {
214        match node.kind() {
215            "attribute" => {
216                // e.g. obj.method — the callee is the attribute (method name)
217                if let Some(attr) = node.child_by_field_name("attribute") {
218                    let name = Self::node_text(&attr, source).to_string();
219                    return (name, CallKind::MethodCall);
220                }
221                let text = Self::node_text(node, source).to_string();
222                (text, CallKind::MethodCall)
223            }
224            "identifier" => {
225                let name = Self::node_text(node, source).to_string();
226                (name, CallKind::DirectCall)
227            }
228            _ => {
229                let text = Self::node_text(node, source).to_string();
230                (text, CallKind::DirectCall)
231            }
232        }
233    }
234
235    /// Recursively walk the tree to extract call edges.
236    fn walk_calls(cursor: &mut TreeCursor, source: &[u8], calls: &mut Vec<RawCallEdge>) {
237        let node = cursor.node();
238
239        match node.kind() {
240            "call" => {
241                // Python call node has a "function" field
242                if let Some(func_node) = node.child_by_field_name("function") {
243                    let (callee, kind) = Self::extract_callee_info(&func_node, source);
244                    if !callee.is_empty() {
245                        let caller = Self::enclosing_function_name(&node, source);
246                        calls.push(RawCallEdge {
247                            caller_name: caller,
248                            callee_name: callee,
249                            call_site: Span {
250                                start_byte: node.start_byte() as u32,
251                                end_byte: node.end_byte() as u32,
252                            },
253                            kind,
254                        });
255                    }
256                }
257            }
258            "decorator" => {
259                // A decorator is effectively a call to the decorator function.
260                // The decorator node contains the decorator expression (after @).
261                // It can be a simple identifier like `@login_required`,
262                // a call like `@app.route("/api")`, or an attribute like `@app.middleware`.
263                //
264                // For `@login_required`, the child is an identifier.
265                // For `@app.route("/api")`, the child is a call node (which walk_calls handles).
266                // For `@app.middleware`, the child is an attribute.
267                //
268                // We handle the identifier and attribute cases here; the call case
269                // is handled recursively when we descend into children.
270                let mut inner_cursor = node.walk();
271                for child in node.children(&mut inner_cursor) {
272                    match child.kind() {
273                        "identifier" => {
274                            let name = Self::node_text(&child, source).to_string();
275                            if !name.is_empty() {
276                                let caller = Self::enclosing_function_name(&node, source);
277                                calls.push(RawCallEdge {
278                                    caller_name: caller,
279                                    callee_name: name,
280                                    call_site: Span {
281                                        start_byte: node.start_byte() as u32,
282                                        end_byte: node.end_byte() as u32,
283                                    },
284                                    kind: CallKind::DirectCall,
285                                });
286                            }
287                        }
288                        "attribute" => {
289                            if let Some(attr) = child.child_by_field_name("attribute") {
290                                let name = Self::node_text(&attr, source).to_string();
291                                if !name.is_empty() {
292                                    let caller = Self::enclosing_function_name(&node, source);
293                                    calls.push(RawCallEdge {
294                                        caller_name: caller,
295                                        callee_name: name,
296                                        call_site: Span {
297                                            start_byte: node.start_byte() as u32,
298                                            end_byte: node.end_byte() as u32,
299                                        },
300                                        kind: CallKind::MethodCall,
301                                    });
302                                }
303                            }
304                        }
305                        _ => {}
306                    }
307                }
308            }
309            _ => {}
310        }
311
312        // Recurse into children
313        if cursor.goto_first_child() {
314            loop {
315                Self::walk_calls(cursor, source, calls);
316                if !cursor.goto_next_sibling() {
317                    break;
318                }
319            }
320            cursor.goto_parent();
321        }
322    }
323
324    /// Extract imports from an `import_statement` node.
325    /// e.g. `import os` or `import os, sys`
326    fn extract_import_statement(node: &Node, source: &[u8]) -> Vec<Import> {
327        let mut imports = Vec::new();
328        let mut cursor = node.walk();
329
330        for child in node.children(&mut cursor) {
331            match child.kind() {
332                "dotted_name" => {
333                    let module = Self::node_text(&child, source).to_string();
334                    if !module.is_empty() {
335                        imports.push(Import {
336                            module_path: module.clone(),
337                            imported_name: module,
338                            alias: None,
339                            is_external: true,
340                        });
341                    }
342                }
343                "aliased_import" => {
344                    let name_node = child.child_by_field_name("name");
345                    let alias_node = child.child_by_field_name("alias");
346
347                    if let Some(name_n) = name_node {
348                        let module = Self::node_text(&name_n, source).to_string();
349                        let alias = alias_node
350                            .map(|a| Self::node_text(&a, source).to_string());
351                        imports.push(Import {
352                            module_path: module.clone(),
353                            imported_name: module,
354                            alias,
355                            is_external: true,
356                        });
357                    }
358                }
359                _ => {}
360            }
361        }
362
363        imports
364    }
365
366    /// Extract imports from an `import_from_statement` node.
367    /// e.g. `from os.path import join, exists` or `from .local import helper`
368    fn extract_import_from_statement(node: &Node, source: &[u8]) -> Vec<Import> {
369        let mut imports = Vec::new();
370
371        // Get the module name. In tree-sitter-python the module is in the
372        // "module_name" field. For relative imports it includes the dots.
373        let module_path = Self::extract_from_module_path(node, source);
374        let is_external = !module_path.starts_with('.');
375
376        // Collect imported names
377        let mut cursor = node.walk();
378        for child in node.children(&mut cursor) {
379            match child.kind() {
380                "dotted_name" | "identifier" => {
381                    // Skip the module name itself (already captured)
382                    // The imported names come after the "import" keyword
383                    // In tree-sitter-python, the imported names are in the node's
384                    // named children that are not the module_name field.
385                    // We need to distinguish module from imported names.
386                }
387                "aliased_import" => {
388                    let name_node = child.child_by_field_name("name");
389                    let alias_node = child.child_by_field_name("alias");
390
391                    if let Some(name_n) = name_node {
392                        let imported_name = Self::node_text(&name_n, source).to_string();
393                        let alias = alias_node
394                            .map(|a| Self::node_text(&a, source).to_string());
395                        imports.push(Import {
396                            module_path: module_path.clone(),
397                            imported_name,
398                            alias,
399                            is_external,
400                        });
401                    }
402                }
403                "wildcard_import" => {
404                    imports.push(Import {
405                        module_path: module_path.clone(),
406                        imported_name: "*".to_string(),
407                        alias: None,
408                        is_external,
409                    });
410                }
411                _ => {}
412            }
413        }
414
415        // If we found no imports from the structured children above, parse
416        // the imported names from the node text. The tree-sitter-python grammar
417        // places imported names as direct children of import_from_statement.
418        if imports.is_empty() {
419            Self::extract_from_imported_names(node, source, &module_path, is_external, &mut imports);
420        }
421
422        imports
423    }
424
425    /// Extract the module path from a `from ... import` statement.
426    /// Handles both absolute (`from os.path`) and relative (`from .local`) imports.
427    fn extract_from_module_path(node: &Node, source: &[u8]) -> String {
428        // The module_name field contains the dotted name (may include leading dots for relative).
429        if let Some(module_node) = node.child_by_field_name("module_name") {
430            return Self::node_text(&module_node, source).to_string();
431        }
432
433        // Fallback: reconstruct from the node text between `from` and `import`.
434        let text = Self::node_text(node, source);
435        if let Some(from_idx) = text.find("from") {
436            let after_from = &text[from_idx + 4..];
437            if let Some(import_idx) = after_from.find("import") {
438                let module = after_from[..import_idx].trim();
439                return module.to_string();
440            }
441        }
442
443        String::new()
444    }
445
446    /// Extract imported names from a from-import statement by walking its children.
447    fn extract_from_imported_names(
448        node: &Node,
449        source: &[u8],
450        module_path: &str,
451        is_external: bool,
452        imports: &mut Vec<Import>,
453    ) {
454        // Walk through all children looking for imported names.
455        // In tree-sitter-python, after the module_name and "import" keyword,
456        // the imported identifiers appear as children.
457        let mut found_import_keyword = false;
458        let mut cursor = node.walk();
459
460        for child in node.children(&mut cursor) {
461            let text = Self::node_text(&child, source);
462
463            if text == "import" {
464                found_import_keyword = true;
465                continue;
466            }
467
468            if !found_import_keyword {
469                continue;
470            }
471
472            match child.kind() {
473                "dotted_name" | "identifier" => {
474                    let imported_name = text.to_string();
475                    if !imported_name.is_empty() && imported_name != "," {
476                        imports.push(Import {
477                            module_path: module_path.to_string(),
478                            imported_name,
479                            alias: None,
480                            is_external,
481                        });
482                    }
483                }
484                "aliased_import" => {
485                    let name_node = child.child_by_field_name("name");
486                    let alias_node = child.child_by_field_name("alias");
487
488                    if let Some(name_n) = name_node {
489                        let imported_name = Self::node_text(&name_n, source).to_string();
490                        let alias = alias_node
491                            .map(|a| Self::node_text(&a, source).to_string());
492                        imports.push(Import {
493                            module_path: module_path.to_string(),
494                            imported_name,
495                            alias,
496                            is_external,
497                        });
498                    }
499                }
500                "wildcard_import" => {
501                    imports.push(Import {
502                        module_path: module_path.to_string(),
503                        imported_name: "*".to_string(),
504                        alias: None,
505                        is_external,
506                    });
507                }
508                _ => {}
509            }
510        }
511    }
512}
513
514impl Default for PythonParser {
515    fn default() -> Self {
516        Self::new()
517    }
518}
519
520impl LanguageParser for PythonParser {
521    fn extensions(&self) -> &[&str] {
522        &["py"]
523    }
524
525    fn extract_symbols(&self, source: &[u8], file_path: &Path) -> Result<Vec<Symbol>> {
526        if source.is_empty() {
527            return Ok(vec![]);
528        }
529
530        let tree = Self::parse_tree(source)?;
531        let root = tree.root_node();
532        let mut symbols = Vec::new();
533        let mut cursor = root.walk();
534
535        for node in root.children(&mut cursor) {
536            match node.kind() {
537                "function_definition" | "class_definition" => {
538                    if let Some(sym) = Self::extract_symbol_from_def(&node, source, file_path) {
539                        symbols.push(sym);
540                    }
541                }
542                "decorated_definition" => {
543                    // Unwrap the decorated_definition to find the inner function or class
544                    if let Some(definition) = node.child_by_field_name("definition") {
545                        match definition.kind() {
546                            "function_definition" | "class_definition" => {
547                                if let Some(mut sym) =
548                                    Self::extract_symbol_from_def(&definition, source, file_path)
549                                {
550                                    // Use the span of the whole decorated definition
551                                    sym.span = Span {
552                                        start_byte: node.start_byte() as u32,
553                                        end_byte: node.end_byte() as u32,
554                                    };
555                                    // Include the decorator in the signature
556                                    sym.signature = Self::node_signature(&node, source);
557                                    symbols.push(sym);
558                                }
559                            }
560                            _ => {}
561                        }
562                    }
563                }
564                "expression_statement" => {
565                    // Module-level assignment
566                    if let Some(name) = Self::extract_assignment_name(&node, source) {
567                        let visibility = Self::name_visibility(&name);
568                        symbols.push(Symbol {
569                            id: Uuid::new_v4(),
570                            name: name.clone(),
571                            qualified_name: name,
572                            kind: SymbolKind::Variable,
573                            visibility,
574                            file_path: file_path.to_path_buf(),
575                            span: Span {
576                                start_byte: node.start_byte() as u32,
577                                end_byte: node.end_byte() as u32,
578                            },
579                            signature: Self::node_signature(&node, source),
580                            doc_comment: Self::doc_comments(&node, source),
581                            parent: None,
582                            last_modified_by: None,
583                            last_modified_intent: None,
584                        });
585                    }
586                }
587                _ => {}
588            }
589        }
590
591        Ok(symbols)
592    }
593
594    fn extract_calls(&self, source: &[u8], _file_path: &Path) -> Result<Vec<RawCallEdge>> {
595        if source.is_empty() {
596            return Ok(vec![]);
597        }
598
599        let tree = Self::parse_tree(source)?;
600        let root = tree.root_node();
601        let mut calls = Vec::new();
602        let mut cursor = root.walk();
603
604        Self::walk_calls(&mut cursor, source, &mut calls);
605
606        Ok(calls)
607    }
608
609    fn extract_types(&self, _source: &[u8], _file_path: &Path) -> Result<Vec<TypeInfo>> {
610        // Stub: will be enhanced later
611        Ok(vec![])
612    }
613
614    fn extract_imports(&self, source: &[u8], _file_path: &Path) -> Result<Vec<Import>> {
615        if source.is_empty() {
616            return Ok(vec![]);
617        }
618
619        let tree = Self::parse_tree(source)?;
620        let root = tree.root_node();
621        let mut imports = Vec::new();
622        let mut cursor = root.walk();
623
624        for node in root.children(&mut cursor) {
625            match node.kind() {
626                "import_statement" => {
627                    imports.extend(Self::extract_import_statement(&node, source));
628                }
629                "import_from_statement" => {
630                    imports.extend(Self::extract_import_from_statement(&node, source));
631                }
632                _ => {}
633            }
634        }
635
636        Ok(imports)
637    }
638}