ricecoder_lsp/hover/
symbol_resolver.rs

1//! Symbol Position Resolution
2//!
3//! This module provides symbol resolution at specific positions in code,
4//! handling nested scopes and symbol shadowing.
5
6use crate::types::{Position, Symbol};
7use std::collections::HashMap;
8
9/// Symbol scope information
10#[derive(Debug, Clone)]
11pub struct SymbolScope {
12    /// Symbols in this scope
13    pub symbols: Vec<Symbol>,
14    /// Parent scope (for nested scopes)
15    pub parent: Option<Box<SymbolScope>>,
16    /// Scope range (start and end positions)
17    pub start_line: u32,
18    pub end_line: u32,
19}
20
21impl SymbolScope {
22    /// Create a new symbol scope
23    pub fn new(start_line: u32, end_line: u32) -> Self {
24        Self {
25            symbols: Vec::new(),
26            parent: None,
27            start_line,
28            end_line,
29        }
30    }
31
32    /// Add a symbol to this scope
33    pub fn add_symbol(&mut self, symbol: Symbol) {
34        self.symbols.push(symbol);
35    }
36
37    /// Find symbol at position in this scope
38    pub fn find_symbol_at_position(&self, position: Position) -> Option<Symbol> {
39        // Check if position is within this scope
40        if position.line < self.start_line || position.line > self.end_line {
41            return None;
42        }
43
44        // Look for symbol in this scope
45        for symbol in &self.symbols {
46            if symbol.range.start.line <= position.line && position.line <= symbol.range.end.line {
47                // Check if position is within the symbol's range
48                if position.line == symbol.range.start.line
49                    && position.character < symbol.range.start.character
50                {
51                    continue;
52                }
53                if position.line == symbol.range.end.line
54                    && position.character > symbol.range.end.character
55                {
56                    continue;
57                }
58                return Some(symbol.clone());
59            }
60        }
61
62        // Check parent scope
63        if let Some(parent) = &self.parent {
64            parent.find_symbol_at_position(position)
65        } else {
66            None
67        }
68    }
69
70    /// Get all symbols visible at position (including parent scopes)
71    pub fn get_visible_symbols(&self, position: Position) -> Vec<Symbol> {
72        let mut symbols = Vec::new();
73
74        // Add symbols from parent scopes
75        if let Some(parent) = &self.parent {
76            symbols.extend(parent.get_visible_symbols(position));
77        }
78
79        // Add symbols from this scope
80        for symbol in &self.symbols {
81            if symbol.range.start.line <= position.line && position.line <= symbol.range.end.line {
82                symbols.push(symbol.clone());
83            }
84        }
85
86        symbols
87    }
88}
89
90/// Symbol resolver for position-based lookups
91pub struct SymbolResolver {
92    /// Global symbol index
93    global_symbols: HashMap<String, Symbol>,
94    /// Scoped symbols
95    scopes: Vec<SymbolScope>,
96}
97
98impl SymbolResolver {
99    /// Create a new symbol resolver
100    pub fn new() -> Self {
101        Self {
102            global_symbols: HashMap::new(),
103            scopes: Vec::new(),
104        }
105    }
106
107    /// Index symbols for resolution
108    pub fn index_symbols(&mut self, symbols: Vec<Symbol>) {
109        self.global_symbols.clear();
110        for symbol in symbols {
111            self.global_symbols.insert(symbol.name.clone(), symbol);
112        }
113    }
114
115    /// Add a scope with symbols
116    pub fn add_scope(&mut self, scope: SymbolScope) {
117        self.scopes.push(scope);
118    }
119
120    /// Resolve symbol at position
121    pub fn resolve_at_position(&self, position: Position) -> Option<Symbol> {
122        // First check scoped symbols (more specific)
123        for scope in &self.scopes {
124            if let Some(symbol) = scope.find_symbol_at_position(position) {
125                return Some(symbol);
126            }
127        }
128
129        // Fall back to global symbols
130        // Find symbol by name at position (simple approach)
131        for symbol in self.global_symbols.values() {
132            if symbol.range.start.line <= position.line && position.line <= symbol.range.end.line {
133                if position.line == symbol.range.start.line
134                    && position.character >= symbol.range.start.character
135                    && position.character <= symbol.range.end.character
136                {
137                    return Some(symbol.clone());
138                }
139                if position.line == symbol.range.end.line
140                    && position.character >= symbol.range.start.character
141                    && position.character <= symbol.range.end.character
142                {
143                    return Some(symbol.clone());
144                }
145                if position.line > symbol.range.start.line && position.line < symbol.range.end.line
146                {
147                    return Some(symbol.clone());
148                }
149            }
150        }
151
152        None
153    }
154
155    /// Get symbol by name
156    pub fn get_symbol(&self, name: &str) -> Option<Symbol> {
157        self.global_symbols.get(name).cloned()
158    }
159
160    /// Get all symbols
161    pub fn get_all_symbols(&self) -> Vec<Symbol> {
162        self.global_symbols.values().cloned().collect()
163    }
164
165    /// Handle symbol shadowing - get the most specific symbol at position
166    pub fn resolve_with_shadowing(&self, position: Position) -> Option<Symbol> {
167        // Check scoped symbols first (they shadow global symbols)
168        for scope in &self.scopes {
169            if position.line >= scope.start_line && position.line <= scope.end_line {
170                if let Some(symbol) = scope.find_symbol_at_position(position) {
171                    return Some(symbol);
172                }
173            }
174        }
175
176        // Fall back to global symbols
177        self.resolve_at_position(position)
178    }
179
180    /// Get all visible symbols at position (for autocomplete, etc.)
181    pub fn get_visible_at_position(&self, position: Position) -> Vec<Symbol> {
182        let mut visible = Vec::new();
183
184        // Add global symbols
185        visible.extend(self.global_symbols.values().cloned());
186
187        // Add scoped symbols
188        for scope in &self.scopes {
189            if position.line >= scope.start_line && position.line <= scope.end_line {
190                visible.extend(scope.get_visible_symbols(position));
191            }
192        }
193
194        // Remove duplicates (keep the most specific one)
195        let mut seen = std::collections::HashSet::new();
196        visible.retain(|s| seen.insert(s.name.clone()));
197
198        visible
199    }
200}
201
202impl Default for SymbolResolver {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::types::{Range, SymbolKind};
212
213    #[test]
214    fn test_symbol_scope_creation() {
215        let scope = SymbolScope::new(0, 10);
216        assert_eq!(scope.start_line, 0);
217        assert_eq!(scope.end_line, 10);
218        assert!(scope.symbols.is_empty());
219    }
220
221    #[test]
222    fn test_add_symbol_to_scope() {
223        let mut scope = SymbolScope::new(0, 10);
224        let symbol = Symbol {
225            name: "test_fn".to_string(),
226            kind: SymbolKind::Function,
227            range: Range::new(Position::new(2, 0), Position::new(5, 0)),
228            definition: None,
229            references: vec![],
230            documentation: None,
231        };
232
233        scope.add_symbol(symbol);
234        assert_eq!(scope.symbols.len(), 1);
235    }
236
237    #[test]
238    fn test_find_symbol_in_scope() {
239        let mut scope = SymbolScope::new(0, 10);
240        let symbol = Symbol {
241            name: "my_fn".to_string(),
242            kind: SymbolKind::Function,
243            range: Range::new(Position::new(2, 0), Position::new(5, 0)),
244            definition: None,
245            references: vec![],
246            documentation: None,
247        };
248
249        scope.add_symbol(symbol);
250
251        // Position within symbol range
252        let found = scope.find_symbol_at_position(Position::new(3, 5));
253        assert!(found.is_some());
254        assert_eq!(found.unwrap().name, "my_fn");
255
256        // Position outside symbol range
257        let found = scope.find_symbol_at_position(Position::new(6, 0));
258        assert!(found.is_none());
259    }
260
261    #[test]
262    fn test_symbol_resolver_creation() {
263        let resolver = SymbolResolver::new();
264        assert_eq!(resolver.global_symbols.len(), 0);
265        assert_eq!(resolver.scopes.len(), 0);
266    }
267
268    #[test]
269    fn test_index_symbols() {
270        let mut resolver = SymbolResolver::new();
271        let symbol = Symbol {
272            name: "global_fn".to_string(),
273            kind: SymbolKind::Function,
274            range: Range::new(Position::new(0, 0), Position::new(10, 0)),
275            definition: None,
276            references: vec![],
277            documentation: None,
278        };
279
280        resolver.index_symbols(vec![symbol]);
281        assert_eq!(resolver.global_symbols.len(), 1);
282    }
283
284    #[test]
285    fn test_get_symbol_by_name() {
286        let mut resolver = SymbolResolver::new();
287        let symbol = Symbol {
288            name: "my_function".to_string(),
289            kind: SymbolKind::Function,
290            range: Range::new(Position::new(0, 0), Position::new(5, 0)),
291            definition: None,
292            references: vec![],
293            documentation: None,
294        };
295
296        resolver.index_symbols(vec![symbol]);
297        let found = resolver.get_symbol("my_function");
298        assert!(found.is_some());
299        assert_eq!(found.unwrap().name, "my_function");
300    }
301
302    #[test]
303    fn test_resolve_at_position() {
304        let mut resolver = SymbolResolver::new();
305        let symbol = Symbol {
306            name: "test_var".to_string(),
307            kind: SymbolKind::Variable,
308            range: Range::new(Position::new(5, 0), Position::new(5, 8)),
309            definition: None,
310            references: vec![],
311            documentation: None,
312        };
313
314        resolver.index_symbols(vec![symbol]);
315        let found = resolver.resolve_at_position(Position::new(5, 4));
316        assert!(found.is_some());
317        assert_eq!(found.unwrap().name, "test_var");
318    }
319
320    #[test]
321    fn test_symbol_shadowing() {
322        let mut resolver = SymbolResolver::new();
323
324        // Global symbol
325        let global_symbol = Symbol {
326            name: "x".to_string(),
327            kind: SymbolKind::Variable,
328            range: Range::new(Position::new(0, 0), Position::new(20, 0)),
329            definition: None,
330            references: vec![],
331            documentation: None,
332        };
333
334        resolver.index_symbols(vec![global_symbol]);
335
336        // Local symbol (shadows global)
337        let mut scope = SymbolScope::new(5, 15);
338        let local_symbol = Symbol {
339            name: "x".to_string(),
340            kind: SymbolKind::Variable,
341            range: Range::new(Position::new(6, 0), Position::new(10, 0)),
342            definition: None,
343            references: vec![],
344            documentation: None,
345        };
346
347        scope.add_symbol(local_symbol);
348        resolver.add_scope(scope);
349
350        // At position within local scope, should get local symbol
351        let found = resolver.resolve_with_shadowing(Position::new(7, 0));
352        assert!(found.is_some());
353        let symbol = found.unwrap();
354        assert_eq!(symbol.name, "x");
355        assert_eq!(symbol.range.start.line, 6);
356    }
357
358    #[test]
359    fn test_get_visible_symbols() {
360        let mut resolver = SymbolResolver::new();
361
362        // Global symbols
363        let global1 = Symbol {
364            name: "global_fn".to_string(),
365            kind: SymbolKind::Function,
366            range: Range::new(Position::new(0, 0), Position::new(20, 0)),
367            definition: None,
368            references: vec![],
369            documentation: None,
370        };
371
372        let global2 = Symbol {
373            name: "global_var".to_string(),
374            kind: SymbolKind::Variable,
375            range: Range::new(Position::new(0, 0), Position::new(20, 0)),
376            definition: None,
377            references: vec![],
378            documentation: None,
379        };
380
381        resolver.index_symbols(vec![global1, global2]);
382
383        // Local scope
384        let mut scope = SymbolScope::new(5, 15);
385        let local_symbol = Symbol {
386            name: "local_var".to_string(),
387            kind: SymbolKind::Variable,
388            range: Range::new(Position::new(6, 0), Position::new(10, 0)),
389            definition: None,
390            references: vec![],
391            documentation: None,
392        };
393
394        scope.add_symbol(local_symbol);
395        resolver.add_scope(scope);
396
397        // Get visible symbols at position within local scope
398        let visible = resolver.get_visible_at_position(Position::new(7, 0));
399        assert!(visible.len() >= 3); // At least global_fn, global_var, local_var
400    }
401
402    #[test]
403    fn test_nested_scopes() {
404        let mut outer_scope = SymbolScope::new(0, 20);
405        let outer_symbol = Symbol {
406            name: "outer_var".to_string(),
407            kind: SymbolKind::Variable,
408            range: Range::new(Position::new(1, 0), Position::new(19, 0)),
409            definition: None,
410            references: vec![],
411            documentation: None,
412        };
413
414        outer_scope.add_symbol(outer_symbol);
415
416        // Inner scope
417        let mut inner_scope = SymbolScope::new(5, 15);
418        let inner_symbol = Symbol {
419            name: "inner_var".to_string(),
420            kind: SymbolKind::Variable,
421            range: Range::new(Position::new(6, 0), Position::new(14, 0)),
422            definition: None,
423            references: vec![],
424            documentation: None,
425        };
426
427        inner_scope.add_symbol(inner_symbol);
428        inner_scope.parent = Some(Box::new(outer_scope));
429
430        // Find symbol in inner scope should also check parent
431        let found = inner_scope.find_symbol_at_position(Position::new(10, 0));
432        assert!(found.is_some());
433    }
434}