llmcc_python/
collect.rs

1use std::mem;
2
3use llmcc_core::context::CompileUnit;
4use llmcc_core::ir::{HirIdent, HirNode};
5use llmcc_core::symbol::{Scope, ScopeStack, Symbol, SymbolKind};
6
7use crate::descriptor::class::PythonClassDescriptor;
8use crate::descriptor::function::PythonFunctionDescriptor;
9use crate::descriptor::import::ImportDescriptor;
10use crate::descriptor::variable::VariableDescriptor;
11use crate::token::AstVisitorPython;
12use crate::token::LangPython;
13
14#[derive(Debug)]
15pub struct CollectionResult {
16    pub functions: Vec<PythonFunctionDescriptor>,
17    pub classes: Vec<PythonClassDescriptor>,
18    pub variables: Vec<VariableDescriptor>,
19    pub imports: Vec<ImportDescriptor>,
20}
21
22#[derive(Debug)]
23struct DeclCollector<'tcx> {
24    unit: CompileUnit<'tcx>,
25    scopes: ScopeStack<'tcx>,
26    functions: Vec<PythonFunctionDescriptor>,
27    classes: Vec<PythonClassDescriptor>,
28    variables: Vec<VariableDescriptor>,
29    imports: Vec<ImportDescriptor>,
30}
31
32impl<'tcx> DeclCollector<'tcx> {
33    pub fn new(unit: CompileUnit<'tcx>, globals: &'tcx Scope<'tcx>) -> Self {
34        let mut scopes = ScopeStack::new(&unit.cc.arena, &unit.cc.interner, &unit.cc.symbol_map);
35        scopes.push_with_symbol(globals, None);
36        Self {
37            unit,
38            scopes,
39            functions: Vec::new(),
40            classes: Vec::new(),
41            variables: Vec::new(),
42            imports: Vec::new(),
43        }
44    }
45
46    fn parent_symbol(&self) -> Option<&'tcx Symbol> {
47        self.scopes.scoped_symbol()
48    }
49
50    fn scoped_fqn(&self, _node: &HirNode<'tcx>, name: &str) -> String {
51        if let Some(parent) = self.parent_symbol() {
52            let parent_fqn = parent.fqn_name.borrow();
53            if parent_fqn.is_empty() {
54                name.to_string()
55            } else {
56                format!("{}::{}", parent_fqn.as_str(), name)
57            }
58        } else {
59            name.to_string()
60        }
61    }
62
63    fn take_functions(&mut self) -> Vec<PythonFunctionDescriptor> {
64        mem::take(&mut self.functions)
65    }
66
67    fn take_classes(&mut self) -> Vec<PythonClassDescriptor> {
68        mem::take(&mut self.classes)
69    }
70
71    fn take_variables(&mut self) -> Vec<VariableDescriptor> {
72        mem::take(&mut self.variables)
73    }
74
75    fn take_imports(&mut self) -> Vec<ImportDescriptor> {
76        mem::take(&mut self.imports)
77    }
78
79    fn create_new_symbol(
80        &mut self,
81        node: &HirNode<'tcx>,
82        field_id: u16,
83        global: bool,
84        kind: SymbolKind,
85    ) -> Option<(&'tcx Symbol, &'tcx HirIdent<'tcx>, String)> {
86        let ident_node = node.opt_child_by_field(self.unit, field_id)?;
87        let ident = ident_node.as_ident()?;
88        let fqn = self.scoped_fqn(node, &ident.name);
89        let owner = node.hir_id();
90
91        let symbol = match self.scopes.find_symbol_local(&ident.name) {
92            Some(existing) if existing.kind() != SymbolKind::Unknown && existing.kind() != kind => {
93                self.insert_into_scope(owner, ident, global, &fqn, kind)
94            }
95            Some(existing) => existing,
96            None => self.insert_into_scope(owner, ident, global, &fqn, kind),
97        };
98
99        Some((symbol, ident, fqn))
100    }
101
102    fn insert_into_scope(
103        &mut self,
104        owner: llmcc_core::ir::HirId,
105        ident: &'tcx HirIdent<'tcx>,
106        global: bool,
107        fqn: &str,
108        kind: SymbolKind,
109    ) -> &'tcx Symbol {
110        let interner = self.unit.interner();
111        let unit_index = self.unit.index;
112
113        self.scopes.insert_with(owner, ident, global, |symbol| {
114            symbol.set_owner(owner);
115            symbol.set_fqn(fqn.to_string(), interner);
116            symbol.set_kind(kind);
117            symbol.set_unit_index(unit_index);
118        })
119    }
120
121    fn visit_children_scope(&mut self, node: &HirNode<'tcx>, symbol: Option<&'tcx Symbol>) {
122        let depth = self.scopes.depth();
123        // Allocate scope for this node
124        let scope = self.unit.alloc_scope(node.hir_id());
125        self.scopes.push_with_symbol(scope, symbol);
126        self.visit_children(node);
127        self.scopes.pop_until(depth);
128    }
129
130    fn visit_children(&mut self, node: &HirNode<'tcx>) {
131        for id in node.children() {
132            let child = self.unit.hir_node(*id);
133            self.visit_node(child);
134        }
135    }
136
137    fn extract_base_classes(
138        &mut self,
139        arg_list_node: &HirNode<'tcx>,
140        class: &mut PythonClassDescriptor,
141    ) {
142        for child_id in arg_list_node.children() {
143            let child = self.unit.hir_node(*child_id);
144            if child.kind_id() == LangPython::identifier {
145                if let Some(ident) = child.as_ident() {
146                    class.add_base_class(ident.name.clone());
147                }
148            }
149        }
150    }
151
152    fn extract_class_members(
153        &mut self,
154        body_node: &HirNode<'tcx>,
155        class: &mut PythonClassDescriptor,
156    ) {
157        for child_id in body_node.children() {
158            let child = self.unit.hir_node(*child_id);
159            let kind_id = child.kind_id();
160
161            if kind_id == LangPython::function_definition {
162                if let Some(name_node) = child.opt_child_by_field(self.unit, LangPython::field_name)
163                {
164                    if let Some(ident) = name_node.as_ident() {
165                        class.add_method(ident.name.clone());
166                    }
167                }
168                self.extract_instance_fields_from_method(&child, class);
169            } else if kind_id == LangPython::decorated_definition {
170                if let Some(method_name) = self.extract_decorated_method_name(&child) {
171                    class.add_method(method_name);
172                }
173                if let Some(method_node) = self.method_node_from_decorated(&child) {
174                    self.extract_instance_fields_from_method(&method_node, class);
175                }
176            } else if kind_id == LangPython::assignment {
177                if let Some(field) = self.extract_class_field(&child) {
178                    self.upsert_class_field(class, field);
179                }
180            } else if kind_id == LangPython::expression_statement {
181                for stmt_child_id in child.children() {
182                    let stmt_child = self.unit.hir_node(*stmt_child_id);
183                    if stmt_child.kind_id() == LangPython::assignment {
184                        if let Some(field) = self.extract_class_field(&stmt_child) {
185                            self.upsert_class_field(class, field);
186                        }
187                    }
188                }
189            }
190        }
191    }
192
193    fn extract_decorated_method_name(&self, node: &HirNode<'tcx>) -> Option<String> {
194        for child_id in node.children() {
195            let child = self.unit.hir_node(*child_id);
196            if child.kind_id() == LangPython::function_definition {
197                if let Some(name_node) = child.opt_child_by_field(self.unit, LangPython::field_name)
198                {
199                    if let Some(ident) = name_node.as_ident() {
200                        return Some(ident.name.clone());
201                    }
202                }
203            }
204        }
205        None
206    }
207
208    fn method_node_from_decorated(&self, node: &HirNode<'tcx>) -> Option<HirNode<'tcx>> {
209        for child_id in node.children() {
210            let child = self.unit.hir_node(*child_id);
211            if child.kind_id() == LangPython::function_definition {
212                return Some(child);
213            }
214        }
215        None
216    }
217
218    fn extract_class_field(
219        &self,
220        node: &HirNode<'tcx>,
221    ) -> Option<crate::descriptor::class::ClassField> {
222        let left_node = node.opt_child_by_field(self.unit, LangPython::field_left)?;
223        let ident = left_node.as_ident()?;
224
225        let mut field = crate::descriptor::class::ClassField::new(ident.name.clone());
226
227        let type_hint = node
228            .opt_child_by_field(self.unit, LangPython::field_type)
229            .and_then(|type_node| {
230                let text = self.unit.get_text(
231                    type_node.inner_ts_node().start_byte(),
232                    type_node.inner_ts_node().end_byte(),
233                );
234                let trimmed = text.trim();
235                if trimmed.is_empty() {
236                    None
237                } else {
238                    Some(trimmed.to_string())
239                }
240            })
241            .or_else(|| {
242                for child_id in node.children() {
243                    let child = self.unit.hir_node(*child_id);
244                    if child.kind_id() == LangPython::type_node {
245                        let text = self.unit.get_text(
246                            child.inner_ts_node().start_byte(),
247                            child.inner_ts_node().end_byte(),
248                        );
249                        let trimmed = text.trim();
250                        if !trimmed.is_empty() {
251                            return Some(trimmed.to_string());
252                        }
253                    }
254                }
255                None
256            });
257
258        if let Some(type_hint) = type_hint {
259            field = field.with_type_hint(type_hint);
260        }
261
262        Some(field)
263    }
264
265    fn upsert_class_field(
266        &self,
267        class: &mut PythonClassDescriptor,
268        field: crate::descriptor::class::ClassField,
269    ) {
270        if let Some(existing) = class.fields.iter_mut().find(|f| f.name == field.name) {
271            if existing.type_hint.is_none() && field.type_hint.is_some() {
272                existing.type_hint = field.type_hint;
273            }
274        } else {
275            class.add_field(field);
276        }
277    }
278
279    fn extract_instance_fields_from_method(
280        &mut self,
281        method_node: &HirNode<'tcx>,
282        class: &mut PythonClassDescriptor,
283    ) {
284        self.collect_instance_fields_recursive(method_node, class);
285    }
286
287    fn collect_instance_fields_recursive(
288        &mut self,
289        node: &HirNode<'tcx>,
290        class: &mut PythonClassDescriptor,
291    ) {
292        if node.kind_id() == LangPython::assignment {
293            self.extract_instance_field_from_assignment(node, class);
294        }
295
296        for child_id in node.children() {
297            let child = self.unit.hir_node(*child_id);
298            self.collect_instance_fields_recursive(&child, class);
299        }
300    }
301
302    fn extract_instance_field_from_assignment(
303        &mut self,
304        node: &HirNode<'tcx>,
305        class: &mut PythonClassDescriptor,
306    ) {
307        let left_node = match node.opt_child_by_field(self.unit, LangPython::field_left) {
308            Some(node) => node,
309            None => return,
310        };
311
312        if left_node.kind_id() != LangPython::attribute {
313            return;
314        }
315
316        let mut identifier_names = Vec::new();
317        for child_id in left_node.children() {
318            let child = self.unit.hir_node(*child_id);
319            if child.kind_id() == LangPython::identifier {
320                if let Some(ident) = child.as_ident() {
321                    identifier_names.push(ident.name.clone());
322                }
323            }
324        }
325
326        if identifier_names.first().map(String::as_str) != Some("self") {
327            return;
328        }
329
330        let field_name = match identifier_names.last() {
331            Some(name) if name != "self" => name.clone(),
332            _ => return,
333        };
334
335        let field = crate::descriptor::class::ClassField::new(field_name);
336        self.upsert_class_field(class, field);
337    }
338}
339
340impl<'tcx> AstVisitorPython<'tcx> for DeclCollector<'tcx> {
341    fn unit(&self) -> CompileUnit<'tcx> {
342        self.unit
343    }
344
345    fn visit_source_file(&mut self, node: HirNode<'tcx>) {
346        self.visit_children_scope(&node, None);
347    }
348
349    fn visit_function_definition(&mut self, node: HirNode<'tcx>) {
350        if let Some((symbol, ident, _fqn)) =
351            self.create_new_symbol(&node, LangPython::field_name, true, SymbolKind::Function)
352        {
353            let mut func = PythonFunctionDescriptor::new(ident.name.clone());
354
355            // Extract parameters and return type using AST walking methods
356            for child_id in node.children() {
357                let child = self.unit.hir_node(*child_id);
358                let kind_id = child.kind_id();
359
360                if kind_id == LangPython::parameters {
361                    func.extract_parameters_from_ast(&child, self.unit);
362                }
363            }
364
365            // Extract return type by walking the AST
366            func.extract_return_type_from_ast(&node, self.unit);
367
368            self.functions.push(func);
369            self.visit_children_scope(&node, Some(symbol));
370        }
371    }
372
373    fn visit_class_definition(&mut self, node: HirNode<'tcx>) {
374        if let Some((symbol, ident, _fqn)) =
375            self.create_new_symbol(&node, LangPython::field_name, true, SymbolKind::Struct)
376        {
377            let mut class = PythonClassDescriptor::new(ident.name.clone());
378
379            // Look for base classes and body
380            for child_id in node.children() {
381                let child = self.unit.hir_node(*child_id);
382                let kind_id = child.kind_id();
383
384                if kind_id == LangPython::argument_list {
385                    // These are base classes
386                    self.extract_base_classes(&child, &mut class);
387                } else if kind_id == LangPython::block {
388                    // This is the class body
389                    self.extract_class_members(&child, &mut class);
390                }
391            }
392
393            self.classes.push(class);
394            self.visit_children_scope(&node, Some(symbol));
395        }
396    }
397
398    fn visit_decorated_definition(&mut self, node: HirNode<'tcx>) {
399        // decorated_definition contains decorators followed by the actual definition (function or class)
400        let mut decorators = Vec::new();
401
402        for child_id in node.children() {
403            let child = self.unit.hir_node(*child_id);
404            let kind_id = child.kind_id();
405
406            if kind_id == LangPython::decorator {
407                // Extract decorator name
408                // A decorator is usually just an identifier or a call expression
409                // For now, extract the text of the decorator
410                let decorator_text = self.unit.get_text(
411                    child.inner_ts_node().start_byte(),
412                    child.inner_ts_node().end_byte(),
413                );
414                if !decorator_text.is_empty() {
415                    decorators.push(decorator_text.trim_start_matches('@').trim().to_string());
416                }
417            }
418        }
419
420        // Visit the decorated definition and apply decorators to the last collected function/class
421        self.visit_children(&node);
422
423        // Apply decorators to the last function or class that was added
424        if !decorators.is_empty() {
425            if let Some(last_func) = self.functions.last_mut() {
426                last_func.decorators = decorators.clone();
427            }
428        }
429    }
430
431    fn visit_import_statement(&mut self, node: HirNode<'tcx>) {
432        // Handle: import os, sys, etc.
433        let mut cursor = node.inner_ts_node().walk();
434
435        for child in node.inner_ts_node().children(&mut cursor) {
436            if child.kind() == "dotted_name" || child.kind() == "identifier" {
437                let text = self.unit.get_text(child.start_byte(), child.end_byte());
438                let _import =
439                    ImportDescriptor::new(text, crate::descriptor::import::ImportKind::Simple);
440                self.imports.push(_import);
441            }
442        }
443    }
444
445    fn visit_import_from(&mut self, _node: HirNode<'tcx>) {
446        // Handle: from x import y
447        // This is more complex - we need to parse module and names
448        // For now, simple implementation
449    }
450
451    fn visit_assignment(&mut self, node: HirNode<'tcx>) {
452        // Handle: x = value
453        // In tree-sitter, the "left" side of assignment is the target
454        if let Some((_symbol, ident, _)) =
455            self.create_new_symbol(&node, LangPython::field_left, false, SymbolKind::Variable)
456        {
457            use crate::descriptor::variable::VariableScope;
458            let var = VariableDescriptor::new(ident.name.clone(), VariableScope::FunctionLocal);
459            self.variables.push(var);
460        }
461    }
462
463    fn visit_unknown(&mut self, node: HirNode<'tcx>) {
464        self.visit_children(&node);
465    }
466}
467
468pub fn collect_symbols<'tcx>(
469    unit: CompileUnit<'tcx>,
470    globals: &'tcx Scope<'tcx>,
471) -> CollectionResult {
472    let root = unit.file_start_hir_id().unwrap();
473    let node = unit.hir_node(root);
474    let mut collector = DeclCollector::new(unit, globals);
475    collector.visit_node(node);
476
477    CollectionResult {
478        functions: collector.take_functions(),
479        classes: collector.take_classes(),
480        variables: collector.take_variables(),
481        imports: collector.take_imports(),
482    }
483}