llmcc_python/
bind.rs

1use llmcc_core::context::CompileUnit;
2use llmcc_core::interner::InternedStr;
3use llmcc_core::ir::HirNode;
4use llmcc_core::symbol::{Scope, ScopeStack, Symbol, SymbolKind};
5
6use crate::token::{AstVisitorPython, LangPython};
7
8#[derive(Debug, Default)]
9pub struct BindingResult {
10    pub calls: Vec<CallBinding>,
11}
12
13#[derive(Debug, Clone)]
14pub struct CallBinding {
15    pub caller: String,
16    pub target: String,
17}
18
19#[derive(Debug)]
20struct SymbolBinder<'tcx> {
21    unit: CompileUnit<'tcx>,
22    scopes: ScopeStack<'tcx>,
23    calls: Vec<CallBinding>,
24}
25
26impl<'tcx> SymbolBinder<'tcx> {
27    pub fn new(unit: CompileUnit<'tcx>, globals: &'tcx Scope<'tcx>) -> Self {
28        let mut scopes = ScopeStack::new(&unit.cc.arena, &unit.cc.interner, &unit.cc.symbol_map);
29        scopes.push(globals);
30        Self {
31            unit,
32            scopes,
33            calls: Vec::new(),
34        }
35    }
36
37    fn interner(&self) -> &llmcc_core::interner::InternPool {
38        self.unit.interner()
39    }
40
41    fn current_symbol(&self) -> Option<&'tcx Symbol> {
42        self.scopes.scoped_symbol()
43    }
44
45    #[allow(dead_code)]
46    fn visit_children_scope(&mut self, node: &HirNode<'tcx>, symbol: Option<&'tcx Symbol>) {
47        let depth = self.scopes.depth();
48        if let Some(symbol) = symbol {
49            if let Some(parent) = self.scopes.scoped_symbol() {
50                parent.add_dependency(symbol);
51            }
52        }
53
54        let scope = self.unit.opt_get_scope(node.hir_id());
55        if let Some(scope) = scope {
56            self.scopes.push_with_symbol(scope, symbol);
57            self.visit_children(node);
58            self.scopes.pop_until(depth);
59        } else {
60            self.visit_children(node);
61        }
62    }
63
64    fn lookup_symbol_suffix(
65        &mut self,
66        suffix: &[InternedStr],
67        kind: Option<SymbolKind>,
68    ) -> Option<&'tcx Symbol> {
69        let file_index = self.unit.index;
70        self.scopes
71            .find_scoped_suffix_with_filters(suffix, kind, Some(file_index))
72            .or_else(|| {
73                self.scopes
74                    .find_scoped_suffix_with_filters(suffix, kind, None)
75            })
76            .or_else(|| {
77                self.scopes
78                    .find_global_suffix_with_filters(suffix, kind, Some(file_index))
79            })
80            .or_else(|| {
81                self.scopes
82                    .find_global_suffix_with_filters(suffix, kind, None)
83            })
84    }
85
86    fn add_symbol_relation(&mut self, symbol: Option<&'tcx Symbol>) {
87        if let (Some(current), Some(target)) = (self.current_symbol(), symbol) {
88            current.add_dependency(target);
89        }
90    }
91
92    fn visit_children(&mut self, node: &HirNode<'tcx>) {
93        // Use HIR children instead of tree-sitter children
94        for child_id in node.children() {
95            let child = self.unit.hir_node(*child_id);
96            self.visit_node(child);
97        }
98    }
99
100    fn visit_decorated_def(&mut self, node: &HirNode<'tcx>) {
101        let mut decorator_symbols = Vec::new();
102        let mut definition_idx = None;
103
104        for (idx, child_id) in node.children().iter().enumerate() {
105            let child = self.unit.hir_node(*child_id);
106            let kind_id = child.kind_id();
107
108            if kind_id == LangPython::decorator {
109                let content = self.unit.file().content();
110                let ts_node = child.inner_ts_node();
111                if let Ok(decorator_text) = ts_node.utf8_text(&content) {
112                    let decorator_name = decorator_text.trim_start_matches('@').trim();
113                    let key = self.interner().intern(decorator_name);
114                    if let Some(decorator_symbol) =
115                        self.lookup_symbol_suffix(&[key], Some(SymbolKind::Function))
116                    {
117                        decorator_symbols.push(decorator_symbol);
118                    }
119                }
120            } else if kind_id == LangPython::function_definition
121                || kind_id == LangPython::class_definition
122            {
123                definition_idx = Some(idx);
124                break;
125            }
126        }
127
128        if let Some(idx) = definition_idx {
129            let definition_id = node.children()[idx];
130            let definition = self.unit.hir_node(definition_id);
131            self.visit_definition_node(&definition, &decorator_symbols);
132        }
133    }
134
135    fn visit_call_impl(&mut self, node: &HirNode<'tcx>) {
136        // Extract function being called
137        let ts_node = node.inner_ts_node();
138
139        // In tree-sitter-python, call has a `function` field
140        if let Some(func_node) = ts_node.child_by_field_name("function") {
141            let content = self.unit.file().content();
142            let record_target = |name: &str, this: &mut SymbolBinder<'tcx>| {
143                let key = this.interner().intern(name);
144
145                // First try to find in current scoped context (for method calls)
146                if let Some(target) = this.lookup_symbol_suffix(&[key], Some(SymbolKind::Function))
147                {
148                    this.add_symbol_relation(Some(target));
149                    let caller_name = this
150                        .current_symbol()
151                        .map(|s| s.fqn_name.borrow().clone())
152                        .unwrap_or_else(|| "<module>".to_string());
153                    let target_name = target.fqn_name.borrow().clone();
154                    this.calls.push(CallBinding {
155                        caller: caller_name,
156                        target: target_name,
157                    });
158                    return true;
159                }
160
161                // Try to find a struct (class) constructor call
162                if let Some(target) = this.lookup_symbol_suffix(&[key], Some(SymbolKind::Struct)) {
163                    this.add_symbol_relation(Some(target));
164                    return true;
165                }
166
167                // For method calls (self.method()), try looking up within parent class
168                // If current symbol is a method, parent is the class
169                if let Some(current) = this.current_symbol() {
170                    if current.kind() == SymbolKind::Function {
171                        let fqn = current.fqn_name.borrow();
172                        // Split "ClassName.method_name" to get class name
173                        if let Some(dot_pos) = fqn.rfind("::") {
174                            let class_name = &fqn[..dot_pos];
175                            // Build the method FQN: "ClassName.method_name"
176                            let method_fqn = format!("{}::{}", class_name, name);
177                            // Look up the method with no kind filter first
178                            if let Some(target) = this.scopes.find_global_suffix_with_filters(
179                                &[this.interner().intern(&method_fqn)],
180                                None,
181                                None,
182                            ) {
183                                if target.kind() == SymbolKind::Function {
184                                    this.add_symbol_relation(Some(target));
185                                    let caller_name = fqn.clone();
186                                    let target_name = target.fqn_name.borrow().clone();
187                                    this.calls.push(CallBinding {
188                                        caller: caller_name,
189                                        target: target_name,
190                                    });
191                                    return true;
192                                }
193                            }
194                        }
195                    }
196                }
197
198                // If not found, try looking with no kind filter (generic lookup)
199                if let Some(target) = this.lookup_symbol_suffix(&[key], None) {
200                    this.add_symbol_relation(Some(target));
201                    if target.kind() == SymbolKind::Function {
202                        let caller_name = this
203                            .current_symbol()
204                            .map(|s| s.fqn_name.borrow().clone())
205                            .unwrap_or_else(|| "<module>".to_string());
206                        let target_name = target.fqn_name.borrow().clone();
207                        this.calls.push(CallBinding {
208                            caller: caller_name,
209                            target: target_name,
210                        });
211                    }
212                    return true;
213                }
214
215                false
216            };
217            let handled = match func_node.kind_id() {
218                id if id == LangPython::identifier => {
219                    if let Ok(name) = func_node.utf8_text(&content) {
220                        record_target(name, self)
221                    } else {
222                        false
223                    }
224                }
225                id if id == LangPython::attribute => {
226                    // For attribute access (e.g., self.method()), extract the method name
227                    if let Some(attr_node) = func_node.child_by_field_name("attribute") {
228                        if let Ok(name) = attr_node.utf8_text(&content) {
229                            record_target(name, self)
230                        } else {
231                            false
232                        }
233                    } else {
234                        false
235                    }
236                }
237                _ => false,
238            };
239
240            if !handled {
241                if let Ok(name) = func_node.utf8_text(&content) {
242                    let _ = record_target(name.trim(), self);
243                }
244            }
245        }
246
247        self.visit_children(node);
248    }
249
250    fn visit_definition_node(&mut self, node: &HirNode<'tcx>, decorator_symbols: &[&'tcx Symbol]) {
251        let kind_id = node.kind_id();
252        let name_node = match node.opt_child_by_field(self.unit, LangPython::field_name) {
253            Some(name) => name,
254            None => {
255                self.visit_children(node);
256                return;
257            }
258        };
259
260        let ident = match name_node.as_ident() {
261            Some(ident) => ident,
262            None => {
263                self.visit_children(node);
264                return;
265            }
266        };
267
268        let key = self.interner().intern(&ident.name);
269        let preferred_kind = if kind_id == LangPython::function_definition {
270            Some(SymbolKind::Function)
271        } else if kind_id == LangPython::class_definition {
272            Some(SymbolKind::Struct)
273        } else {
274            None
275        };
276
277        let mut symbol = preferred_kind
278            .and_then(|kind| self.lookup_symbol_suffix(&[key], Some(kind)))
279            .or_else(|| self.lookup_symbol_suffix(&[key], None));
280
281        let parent_symbol = self.current_symbol();
282
283        if let Some(scope) = self.unit.opt_get_scope(node.hir_id()) {
284            if symbol.is_none() {
285                symbol = scope.symbol();
286            }
287
288            let depth = self.scopes.depth();
289            self.scopes.push_with_symbol(scope, symbol);
290
291            if let Some(current_symbol) = self.current_symbol() {
292                if kind_id == LangPython::function_definition {
293                    if let Some(class_symbol) = parent_symbol {
294                        if class_symbol.kind() == SymbolKind::Struct {
295                            class_symbol.add_dependency(current_symbol);
296                        }
297                    }
298                } else if kind_id == LangPython::class_definition {
299                    self.add_base_class_dependencies(node, current_symbol);
300                }
301
302                for decorator_symbol in decorator_symbols {
303                    current_symbol.add_dependency(decorator_symbol);
304                }
305            }
306
307            self.visit_children(node);
308            self.scopes.pop_until(depth);
309        } else {
310            self.visit_children(node);
311        }
312    }
313
314    fn add_base_class_dependencies(&mut self, node: &HirNode<'tcx>, class_symbol: &Symbol) {
315        for child_id in node.children() {
316            let child = self.unit.hir_node(*child_id);
317            if child.kind_id() == LangPython::argument_list {
318                for base_id in child.children() {
319                    let base_node = self.unit.hir_node(*base_id);
320
321                    if let Some(ident) = base_node.as_ident() {
322                        let key = self.interner().intern(&ident.name);
323                        if let Some(base_symbol) =
324                            self.lookup_symbol_suffix(&[key], Some(SymbolKind::Struct))
325                        {
326                            class_symbol.add_dependency(base_symbol);
327                        }
328                    } else if base_node.kind_id() == LangPython::attribute {
329                        if let Some(attr_node) =
330                            base_node.inner_ts_node().child_by_field_name("attribute")
331                        {
332                            let content = self.unit.file().content();
333                            if let Ok(name) = attr_node.utf8_text(&content) {
334                                let key = self.interner().intern(name);
335                                if let Some(base_symbol) =
336                                    self.lookup_symbol_suffix(&[key], Some(SymbolKind::Struct))
337                                {
338                                    class_symbol.add_dependency(base_symbol);
339                                }
340                            }
341                        }
342                    }
343                }
344            }
345        }
346    }
347}
348
349impl<'tcx> AstVisitorPython<'tcx> for SymbolBinder<'tcx> {
350    fn unit(&self) -> CompileUnit<'tcx> {
351        self.unit
352    }
353
354    fn visit_source_file(&mut self, node: HirNode<'tcx>) {
355        self.visit_children_scope(&node, None);
356    }
357
358    fn visit_function_definition(&mut self, node: HirNode<'tcx>) {
359        let name_node = match node.opt_child_by_field(self.unit, LangPython::field_name) {
360            Some(n) => n,
361            None => {
362                self.visit_children(&node);
363                return;
364            }
365        };
366
367        let ident = match name_node.as_ident() {
368            Some(id) => id,
369            None => {
370                self.visit_children(&node);
371                return;
372            }
373        };
374
375        let key = self.interner().intern(&ident.name);
376        let mut symbol = self.lookup_symbol_suffix(&[key], Some(SymbolKind::Function));
377
378        // Get the parent symbol before pushing a new scope
379        let parent_symbol = self.current_symbol();
380
381        if let Some(scope) = self.unit.opt_get_scope(node.hir_id()) {
382            // If symbol not found by lookup, get it from the scope
383            if symbol.is_none() {
384                symbol = scope.symbol();
385            }
386
387            let depth = self.scopes.depth();
388            self.scopes.push_with_symbol(scope, symbol);
389
390            if let Some(current_symbol) = self.current_symbol() {
391                // If parent is a class, class depends on method
392                if let Some(parent) = parent_symbol {
393                    if parent.kind() == SymbolKind::Struct {
394                        parent.add_dependency(current_symbol);
395                    }
396                }
397            }
398
399            self.visit_children(&node);
400            self.scopes.pop_until(depth);
401        } else {
402            self.visit_children(&node);
403        }
404    }
405
406    fn visit_class_definition(&mut self, node: HirNode<'tcx>) {
407        let name_node = match node.opt_child_by_field(self.unit, LangPython::field_name) {
408            Some(n) => n,
409            None => {
410                self.visit_children(&node);
411                return;
412            }
413        };
414
415        let ident = match name_node.as_ident() {
416            Some(id) => id,
417            None => {
418                self.visit_children(&node);
419                return;
420            }
421        };
422
423        let key = self.interner().intern(&ident.name);
424        let mut symbol = self.lookup_symbol_suffix(&[key], Some(SymbolKind::Struct));
425
426        if let Some(scope) = self.unit.opt_get_scope(node.hir_id()) {
427            // If symbol not found by lookup, get it from the scope
428            if symbol.is_none() {
429                symbol = scope.symbol();
430            }
431
432            let depth = self.scopes.depth();
433            self.scopes.push_with_symbol(scope, symbol);
434
435            if let Some(current_symbol) = self.current_symbol() {
436                self.add_base_class_dependencies(&node, current_symbol);
437            }
438
439            self.visit_children(&node);
440            self.scopes.pop_until(depth);
441        } else {
442            self.visit_children(&node);
443        }
444    }
445
446    fn visit_decorated_definition(&mut self, node: HirNode<'tcx>) {
447        self.visit_decorated_def(&node);
448    }
449
450    fn visit_block(&mut self, node: HirNode<'tcx>) {
451        self.visit_children_scope(&node, None);
452    }
453
454    fn visit_call(&mut self, node: HirNode<'tcx>) {
455        // Delegate to the existing visit_call method
456        self.visit_call_impl(&node);
457    }
458
459    fn visit_unknown(&mut self, node: HirNode<'tcx>) {
460        self.visit_children(&node);
461    }
462}
463
464pub fn bind_symbols<'tcx>(unit: CompileUnit<'tcx>, globals: &'tcx Scope<'tcx>) -> BindingResult {
465    let mut binder = SymbolBinder::new(unit, globals);
466
467    if let Some(file_start_id) = unit.file_start_hir_id() {
468        if let Some(root) = unit.opt_hir_node(file_start_id) {
469            binder.visit_children(&root);
470        }
471    }
472
473    BindingResult {
474        calls: binder.calls,
475    }
476}