llmcc_python/
collect.rs

1use std::mem;
2use std::path::{Path, PathBuf};
3
4use llmcc_core::context::CompileUnit;
5use llmcc_core::ir::{HirIdent, HirNode};
6use llmcc_core::symbol::{Scope, ScopeStack, Symbol, SymbolKind};
7
8use crate::descriptor::class::PythonClassDescriptor;
9use crate::descriptor::function::PythonFunctionDescriptor;
10use crate::descriptor::import::ImportDescriptor;
11use crate::descriptor::variable::VariableDescriptor;
12use crate::token::AstVisitorPython;
13use crate::token::LangPython;
14
15#[derive(Debug)]
16pub struct CollectionResult {
17    pub functions: Vec<PythonFunctionDescriptor>,
18    pub classes: Vec<PythonClassDescriptor>,
19    pub variables: Vec<VariableDescriptor>,
20    pub imports: Vec<ImportDescriptor>,
21}
22
23#[derive(Debug)]
24struct DeclCollector<'tcx> {
25    unit: CompileUnit<'tcx>,
26    scopes: ScopeStack<'tcx>,
27    functions: Vec<PythonFunctionDescriptor>,
28    classes: Vec<PythonClassDescriptor>,
29    variables: Vec<VariableDescriptor>,
30    imports: Vec<ImportDescriptor>,
31}
32
33impl<'tcx> DeclCollector<'tcx> {
34    pub fn new(unit: CompileUnit<'tcx>, globals: &'tcx Scope<'tcx>) -> Self {
35        let mut scopes = ScopeStack::new(&unit.cc.arena, &unit.cc.interner, &unit.cc.symbol_map);
36        scopes.push_with_symbol(globals, None);
37        Self {
38            unit,
39            scopes,
40            functions: Vec::new(),
41            classes: Vec::new(),
42            variables: Vec::new(),
43            imports: Vec::new(),
44        }
45    }
46
47    fn parent_symbol(&self) -> Option<&'tcx Symbol> {
48        self.scopes.scoped_symbol()
49    }
50
51    fn scoped_fqn(&self, _node: &HirNode<'tcx>, name: &str) -> String {
52        if let Some(parent) = self.parent_symbol() {
53            let parent_fqn = parent.fqn_name.borrow();
54            if parent_fqn.is_empty() {
55                name.to_string()
56            } else {
57                format!("{}::{}", parent_fqn.as_str(), name)
58            }
59        } else {
60            name.to_string()
61        }
62    }
63
64    fn take_functions(&mut self) -> Vec<PythonFunctionDescriptor> {
65        mem::take(&mut self.functions)
66    }
67
68    fn take_classes(&mut self) -> Vec<PythonClassDescriptor> {
69        mem::take(&mut self.classes)
70    }
71
72    fn take_variables(&mut self) -> Vec<VariableDescriptor> {
73        mem::take(&mut self.variables)
74    }
75
76    fn take_imports(&mut self) -> Vec<ImportDescriptor> {
77        mem::take(&mut self.imports)
78    }
79
80    fn create_new_symbol(
81        &mut self,
82        node: &HirNode<'tcx>,
83        field_id: u16,
84        global: bool,
85        kind: SymbolKind,
86    ) -> Option<(&'tcx Symbol, &'tcx HirIdent<'tcx>, String)> {
87        let ident_node = node.opt_child_by_field(self.unit, field_id)?;
88        let ident = ident_node.as_ident()?;
89        let fqn = self.scoped_fqn(node, &ident.name);
90        let owner = node.hir_id();
91
92        let symbol = match self.scopes.find_symbol_local(&ident.name) {
93            Some(existing) if existing.kind() != SymbolKind::Unknown && existing.kind() != kind => {
94                self.insert_into_scope(owner, ident, global, &fqn, kind)
95            }
96            Some(existing) => existing,
97            None => self.insert_into_scope(owner, ident, global, &fqn, kind),
98        };
99
100        Some((symbol, ident, fqn))
101    }
102
103    fn insert_into_scope(
104        &mut self,
105        owner: llmcc_core::ir::HirId,
106        ident: &'tcx HirIdent<'tcx>,
107        global: bool,
108        fqn: &str,
109        kind: SymbolKind,
110    ) -> &'tcx Symbol {
111        let interner = self.unit.interner();
112        let unit_index = self.unit.index;
113
114        self.scopes.insert_with(owner, ident, global, |symbol| {
115            symbol.set_owner(owner);
116            symbol.set_fqn(fqn.to_string(), interner);
117            symbol.set_kind(kind);
118            symbol.set_unit_index(unit_index);
119        })
120    }
121
122    fn visit_children_scope(&mut self, node: &HirNode<'tcx>, symbol: Option<&'tcx Symbol>) {
123        let depth = self.scopes.depth();
124        // Allocate scope for this node
125        let scope = self.unit.alloc_scope(node.hir_id());
126        self.scopes.push_with_symbol(scope, symbol);
127        self.visit_children(node);
128        self.scopes.pop_until(depth);
129    }
130
131    fn visit_children(&mut self, node: &HirNode<'tcx>) {
132        for id in node.children() {
133            let child = self.unit.hir_node(*id);
134            self.visit_node(child);
135        }
136    }
137
138    fn module_segments_from_path(path: &Path) -> Vec<String> {
139        if path.extension().and_then(|ext| ext.to_str()) != Some("py") {
140            return Vec::new();
141        }
142
143        let mut segments: Vec<String> = Vec::new();
144
145        if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
146            if stem != "__init__" && !stem.is_empty() {
147                segments.push(stem.to_string());
148            }
149        }
150
151        let mut current = path.parent();
152        while let Some(dir) = current {
153            let dir_name = match dir.file_name().and_then(|n| n.to_str()) {
154                Some(name) if !name.is_empty() => name.to_string(),
155                _ => break,
156            };
157
158            let has_init = dir.join("__init__.py").exists() || dir.join("__init__.pyi").exists();
159            if has_init {
160                segments.push(dir_name);
161                current = dir.parent();
162                continue;
163            }
164
165            if segments.is_empty() {
166                segments.push(dir_name);
167            }
168            break;
169        }
170
171        segments.reverse();
172        segments
173    }
174
175    fn ensure_module_symbol(&mut self, node: &HirNode<'tcx>) -> Option<&'tcx Symbol> {
176        let scope = self.unit.alloc_scope(node.hir_id());
177        if let Some(symbol) = scope.symbol() {
178            return Some(symbol);
179        }
180
181        let raw_path = self.unit.file_path().or_else(|| self.unit.file().path());
182        let path = raw_path
183            .map(PathBuf::from)
184            .and_then(|p| p.canonicalize().ok().or(Some(p)))
185            .unwrap_or_else(|| PathBuf::from("__module__"));
186
187        let segments = Self::module_segments_from_path(&path);
188        let interner = self.unit.interner();
189
190        let (name, fqn) = if segments.is_empty() {
191            let fallback = path
192                .file_stem()
193                .and_then(|s| s.to_str())
194                .unwrap_or("__module__")
195                .to_string();
196            (fallback.clone(), fallback)
197        } else {
198            let name = segments
199                .last()
200                .cloned()
201                .unwrap_or_else(|| "__module__".to_string());
202            let fqn = segments.join("::");
203            (name, fqn)
204        };
205
206        let key = interner.intern(&name);
207        let symbol = Symbol::new(node.hir_id(), name.clone(), key);
208        let symbol = self.unit.cc.arena.alloc(symbol);
209        symbol.set_kind(SymbolKind::Module);
210        symbol.set_unit_index(self.unit.index);
211        symbol.set_fqn(fqn, interner);
212
213        self.unit
214            .cc
215            .symbol_map
216            .borrow_mut()
217            .insert(symbol.id, symbol);
218
219        let _ = self.scopes.insert_symbol(symbol, true);
220        scope.set_symbol(Some(symbol));
221        Some(symbol)
222    }
223
224    fn extract_base_classes(
225        &mut self,
226        arg_list_node: &HirNode<'tcx>,
227        class: &mut PythonClassDescriptor,
228    ) {
229        for child_id in arg_list_node.children() {
230            let child = self.unit.hir_node(*child_id);
231            if child.kind_id() == LangPython::identifier {
232                if let Some(ident) = child.as_ident() {
233                    class.add_base_class(ident.name.clone());
234                }
235            }
236        }
237    }
238
239    fn extract_class_members(
240        &mut self,
241        body_node: &HirNode<'tcx>,
242        class: &mut PythonClassDescriptor,
243    ) {
244        for child_id in body_node.children() {
245            let child = self.unit.hir_node(*child_id);
246            let kind_id = child.kind_id();
247
248            if kind_id == LangPython::function_definition {
249                if let Some(name_node) = child.opt_child_by_field(self.unit, LangPython::field_name)
250                {
251                    if let Some(ident) = name_node.as_ident() {
252                        class.add_method(ident.name.clone());
253                    }
254                }
255                self.extract_instance_fields_from_method(&child, class);
256            } else if kind_id == LangPython::decorated_definition {
257                if let Some(method_name) = self.extract_decorated_method_name(&child) {
258                    class.add_method(method_name);
259                }
260                if let Some(method_node) = self.method_node_from_decorated(&child) {
261                    self.extract_instance_fields_from_method(&method_node, class);
262                }
263            } else if kind_id == LangPython::assignment {
264                if let Some(field) = self.extract_class_field(&child) {
265                    self.upsert_class_field(class, field);
266                }
267            } else if kind_id == LangPython::expression_statement {
268                for stmt_child_id in child.children() {
269                    let stmt_child = self.unit.hir_node(*stmt_child_id);
270                    if stmt_child.kind_id() == LangPython::assignment {
271                        if let Some(field) = self.extract_class_field(&stmt_child) {
272                            self.upsert_class_field(class, field);
273                        }
274                    }
275                }
276            }
277        }
278    }
279
280    fn extract_decorated_method_name(&self, node: &HirNode<'tcx>) -> Option<String> {
281        for child_id in node.children() {
282            let child = self.unit.hir_node(*child_id);
283            if child.kind_id() == LangPython::function_definition {
284                if let Some(name_node) = child.opt_child_by_field(self.unit, LangPython::field_name)
285                {
286                    if let Some(ident) = name_node.as_ident() {
287                        return Some(ident.name.clone());
288                    }
289                }
290            }
291        }
292        None
293    }
294
295    fn method_node_from_decorated(&self, node: &HirNode<'tcx>) -> Option<HirNode<'tcx>> {
296        for child_id in node.children() {
297            let child = self.unit.hir_node(*child_id);
298            if child.kind_id() == LangPython::function_definition {
299                return Some(child);
300            }
301        }
302        None
303    }
304
305    fn extract_class_field(
306        &self,
307        node: &HirNode<'tcx>,
308    ) -> Option<crate::descriptor::class::ClassField> {
309        let left_node = node.opt_child_by_field(self.unit, LangPython::field_left)?;
310        let ident = left_node.as_ident()?;
311
312        let mut field = crate::descriptor::class::ClassField::new(ident.name.clone());
313
314        let type_hint = node
315            .opt_child_by_field(self.unit, LangPython::field_type)
316            .and_then(|type_node| {
317                let text = self.unit.get_text(
318                    type_node.inner_ts_node().start_byte(),
319                    type_node.inner_ts_node().end_byte(),
320                );
321                let trimmed = text.trim();
322                if trimmed.is_empty() {
323                    None
324                } else {
325                    Some(trimmed.to_string())
326                }
327            })
328            .or_else(|| {
329                for child_id in node.children() {
330                    let child = self.unit.hir_node(*child_id);
331                    if child.kind_id() == LangPython::type_node {
332                        let text = self.unit.get_text(
333                            child.inner_ts_node().start_byte(),
334                            child.inner_ts_node().end_byte(),
335                        );
336                        let trimmed = text.trim();
337                        if !trimmed.is_empty() {
338                            return Some(trimmed.to_string());
339                        }
340                    }
341                }
342                None
343            });
344
345        if let Some(type_hint) = type_hint {
346            field = field.with_type_hint(type_hint);
347        }
348
349        Some(field)
350    }
351
352    fn upsert_class_field(
353        &self,
354        class: &mut PythonClassDescriptor,
355        field: crate::descriptor::class::ClassField,
356    ) {
357        if let Some(existing) = class.fields.iter_mut().find(|f| f.name == field.name) {
358            if existing.type_hint.is_none() && field.type_hint.is_some() {
359                existing.type_hint = field.type_hint;
360            }
361        } else {
362            class.add_field(field);
363        }
364    }
365
366    fn extract_instance_fields_from_method(
367        &mut self,
368        method_node: &HirNode<'tcx>,
369        class: &mut PythonClassDescriptor,
370    ) {
371        self.collect_instance_fields_recursive(method_node, class);
372    }
373
374    fn collect_instance_fields_recursive(
375        &mut self,
376        node: &HirNode<'tcx>,
377        class: &mut PythonClassDescriptor,
378    ) {
379        if node.kind_id() == LangPython::assignment {
380            self.extract_instance_field_from_assignment(node, class);
381        }
382
383        for child_id in node.children() {
384            let child = self.unit.hir_node(*child_id);
385            self.collect_instance_fields_recursive(&child, class);
386        }
387    }
388
389    fn extract_instance_field_from_assignment(
390        &mut self,
391        node: &HirNode<'tcx>,
392        class: &mut PythonClassDescriptor,
393    ) {
394        let left_node = match node.opt_child_by_field(self.unit, LangPython::field_left) {
395            Some(node) => node,
396            None => return,
397        };
398
399        if left_node.kind_id() != LangPython::attribute {
400            return;
401        }
402
403        let mut identifier_names = Vec::new();
404        for child_id in left_node.children() {
405            let child = self.unit.hir_node(*child_id);
406            if child.kind_id() == LangPython::identifier {
407                if let Some(ident) = child.as_ident() {
408                    identifier_names.push(ident.name.clone());
409                }
410            }
411        }
412
413        if identifier_names.first().map(String::as_str) != Some("self") {
414            return;
415        }
416
417        let field_name = match identifier_names.last() {
418            Some(name) if name != "self" => name.clone(),
419            _ => return,
420        };
421
422        let field = crate::descriptor::class::ClassField::new(field_name);
423        self.upsert_class_field(class, field);
424    }
425}
426
427impl<'tcx> AstVisitorPython<'tcx> for DeclCollector<'tcx> {
428    fn unit(&self) -> CompileUnit<'tcx> {
429        self.unit
430    }
431
432    fn visit_source_file(&mut self, node: HirNode<'tcx>) {
433        let module_symbol = self.ensure_module_symbol(&node);
434        self.visit_children_scope(&node, module_symbol);
435    }
436
437    fn visit_function_definition(&mut self, node: HirNode<'tcx>) {
438        if let Some((symbol, ident, _fqn)) =
439            self.create_new_symbol(&node, LangPython::field_name, true, SymbolKind::Function)
440        {
441            let mut func = PythonFunctionDescriptor::new(ident.name.clone());
442
443            // Extract parameters and return type using AST walking methods
444            for child_id in node.children() {
445                let child = self.unit.hir_node(*child_id);
446                let kind_id = child.kind_id();
447
448                if kind_id == LangPython::parameters {
449                    func.extract_parameters_from_ast(&child, self.unit);
450                }
451            }
452
453            // Extract return type by walking the AST
454            func.extract_return_type_from_ast(&node, self.unit);
455
456            self.functions.push(func);
457            self.visit_children_scope(&node, Some(symbol));
458        }
459    }
460
461    fn visit_class_definition(&mut self, node: HirNode<'tcx>) {
462        if let Some((symbol, ident, _fqn)) =
463            self.create_new_symbol(&node, LangPython::field_name, true, SymbolKind::Struct)
464        {
465            let mut class = PythonClassDescriptor::new(ident.name.clone());
466
467            // Look for base classes and body
468            for child_id in node.children() {
469                let child = self.unit.hir_node(*child_id);
470                let kind_id = child.kind_id();
471
472                if kind_id == LangPython::argument_list {
473                    // These are base classes
474                    self.extract_base_classes(&child, &mut class);
475                } else if kind_id == LangPython::block {
476                    // This is the class body
477                    self.extract_class_members(&child, &mut class);
478                }
479            }
480
481            self.classes.push(class);
482            self.visit_children_scope(&node, Some(symbol));
483        }
484    }
485
486    fn visit_decorated_definition(&mut self, node: HirNode<'tcx>) {
487        // decorated_definition contains decorators followed by the actual definition (function or class)
488        let mut decorators = Vec::new();
489
490        for child_id in node.children() {
491            let child = self.unit.hir_node(*child_id);
492            let kind_id = child.kind_id();
493
494            if kind_id == LangPython::decorator {
495                // Extract decorator name
496                // A decorator is usually just an identifier or a call expression
497                // For now, extract the text of the decorator
498                let decorator_text = self.unit.get_text(
499                    child.inner_ts_node().start_byte(),
500                    child.inner_ts_node().end_byte(),
501                );
502                if !decorator_text.is_empty() {
503                    decorators.push(decorator_text.trim_start_matches('@').trim().to_string());
504                }
505            }
506        }
507
508        // Visit the decorated definition and apply decorators to the last collected function/class
509        self.visit_children(&node);
510
511        // Apply decorators to the last function or class that was added
512        if !decorators.is_empty() {
513            if let Some(last_func) = self.functions.last_mut() {
514                last_func.decorators = decorators.clone();
515            }
516        }
517    }
518
519    fn visit_import_statement(&mut self, node: HirNode<'tcx>) {
520        // Handle: import os, sys, etc.
521        let mut cursor = node.inner_ts_node().walk();
522
523        for child in node.inner_ts_node().children(&mut cursor) {
524            if child.kind() == "dotted_name" || child.kind() == "identifier" {
525                let text = self.unit.get_text(child.start_byte(), child.end_byte());
526                let _import =
527                    ImportDescriptor::new(text, crate::descriptor::import::ImportKind::Simple);
528                self.imports.push(_import);
529            }
530        }
531    }
532
533    fn visit_import_from(&mut self, _node: HirNode<'tcx>) {
534        // Handle: from x import y
535        // This is more complex - we need to parse module and names
536        // For now, simple implementation
537    }
538
539    fn visit_assignment(&mut self, node: HirNode<'tcx>) {
540        // Handle: x = value
541        // In tree-sitter, the "left" side of assignment is the target
542        if let Some((_symbol, ident, _)) =
543            self.create_new_symbol(&node, LangPython::field_left, false, SymbolKind::Variable)
544        {
545            use crate::descriptor::variable::VariableScope;
546            let var = VariableDescriptor::new(ident.name.clone(), VariableScope::FunctionLocal);
547            self.variables.push(var);
548        }
549    }
550
551    fn visit_unknown(&mut self, node: HirNode<'tcx>) {
552        self.visit_children(&node);
553    }
554}
555
556pub fn collect_symbols<'tcx>(
557    unit: CompileUnit<'tcx>,
558    globals: &'tcx Scope<'tcx>,
559) -> CollectionResult {
560    let root = unit.file_start_hir_id().unwrap();
561    let node = unit.hir_node(root);
562    let mut collector = DeclCollector::new(unit, globals);
563    collector.visit_node(node);
564
565    CollectionResult {
566        functions: collector.take_functions(),
567        classes: collector.take_classes(),
568        variables: collector.take_variables(),
569        imports: collector.take_imports(),
570    }
571}