Skip to main content

krait/lsp/
symbols.rs

1use std::path::Path;
2
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5
6use super::client::LspClient;
7use super::files::FileTracker;
8use crate::commands::find::symbol_kind_name;
9
10/// A resolved symbol location with exact line ranges.
11#[derive(Debug)]
12pub struct SymbolLocation {
13    pub name: String,
14    pub kind: String,
15    /// 0-indexed start line.
16    pub start_line: u32,
17    /// 0-indexed end line (inclusive).
18    pub end_line: u32,
19    pub children: Vec<SymbolLocation>,
20}
21
22/// Resolve a symbol's exact range using `textDocument/documentSymbol`.
23///
24/// Supports nested names like `Config.new` by walking children.
25///
26/// `hint_line` (0-indexed) is used to disambiguate overloads: when multiple
27/// symbols share the same name (TypeScript overload stubs + implementation),
28/// the one whose `start_line` is closest to `hint_line` is selected.
29///
30/// # Errors
31/// Returns an error if the file can't be opened or the symbol isn't found.
32///
33/// # Panics
34/// Panics if `hint_line` is `Some` but `all_matches` is unexpectedly empty after
35/// the emptiness check — this is a logic invariant that should never fire.
36pub async fn resolve_symbol_range(
37    name: &str,
38    file_path: &Path,
39    hint_line: Option<u32>,
40    client: &mut LspClient,
41    file_tracker: &mut FileTracker,
42) -> anyhow::Result<SymbolLocation> {
43    file_tracker
44        .ensure_open(file_path, client.transport_mut())
45        .await
46        .with_context(|| format!("failed to open: {}", file_path.display()))?;
47
48    let uri = super::client::path_to_uri(file_path)?;
49    let params = json!({
50        "textDocument": { "uri": uri.as_str() }
51    });
52
53    let request_id = client
54        .transport_mut()
55        .send_request("textDocument/documentSymbol", params)
56        .await?;
57
58    let response = client
59        .wait_for_response_public(request_id)
60        .await
61        .context("textDocument/documentSymbol request failed")?;
62
63    let tree = parse_symbol_locations(&response);
64
65    // Support nested names: "Config.new" → find "Config", then child "new"
66    let parts: Vec<&str> = name.split('.').collect();
67
68    let mut current_list = &tree;
69    let mut result: Option<&SymbolLocation> = None;
70
71    for (i, part) in parts.iter().enumerate() {
72        // For the last part of the name, collect ALL matches so we can
73        // pick the one closest to hint_line (handles TypeScript overloads).
74        let is_last = i == parts.len() - 1;
75
76        if let (true, Some(hint)) = (is_last, hint_line) {
77            let mut all_matches: Vec<&SymbolLocation> = Vec::new();
78            collect_recursive(current_list, part, &mut all_matches);
79            if all_matches.is_empty() {
80                collect_recursive(&tree, part, &mut all_matches);
81            }
82            if all_matches.is_empty() {
83                bail!("symbol '{name}' not found in document symbols");
84            }
85            let best = all_matches
86                .iter()
87                .min_by_key(|s| (i64::from(s.start_line) - i64::from(hint)).unsigned_abs())
88                .copied()
89                .expect("all_matches is non-empty, checked above");
90            result = Some(best);
91        } else {
92            // First try the current level; if not found, search the full subtree.
93            // This handles methods inside classes (e.g. `createPromotions` inside
94            // `PromotionModuleService`) without requiring dotted notation.
95            let found = current_list
96                .iter()
97                .find(|s| name_matches(&s.name, part))
98                .or_else(|| find_recursive(&tree, part));
99            match found {
100                Some(sym) => {
101                    result = Some(sym);
102                    current_list = &sym.children;
103                }
104                None => bail!("symbol '{name}' not found in document symbols"),
105            }
106        }
107    }
108
109    let sym = result.context("empty symbol name")?;
110
111    Ok(SymbolLocation {
112        name: sym.name.clone(),
113        kind: sym.kind.clone(),
114        start_line: sym.start_line,
115        end_line: sym.end_line,
116        children: Vec::new(), // Don't clone the whole subtree
117    })
118}
119
120/// Check if a symbol's name matches the query.
121///
122/// Exact match first; falls back to prefix match for generic types
123/// (e.g. `IRepository<T, ID>` matches query `IRepository`).
124fn name_matches(symbol_name: &str, query: &str) -> bool {
125    if symbol_name == query {
126        return true;
127    }
128    // Handle generics: "IRepository<T, ID>" matches "IRepository"
129    if symbol_name.starts_with(query) {
130        let next = symbol_name.as_bytes().get(query.len()).copied();
131        return matches!(next, Some(b'<' | b'(' | b' '));
132    }
133    false
134}
135
136/// Depth-first search for a symbol by name through the full document symbol tree.
137fn find_recursive<'a>(nodes: &'a [SymbolLocation], name: &str) -> Option<&'a SymbolLocation> {
138    for node in nodes {
139        if name_matches(&node.name, name) {
140            return Some(node);
141        }
142        if let Some(found) = find_recursive(&node.children, name) {
143            return Some(found);
144        }
145    }
146    None
147}
148
149/// Collect ALL symbols with the given name (depth-first) into `out`.
150fn collect_recursive<'a>(
151    nodes: &'a [SymbolLocation],
152    name: &str,
153    out: &mut Vec<&'a SymbolLocation>,
154) {
155    for node in nodes {
156        if name_matches(&node.name, name) {
157            out.push(node);
158        }
159        collect_recursive(&node.children, name, out);
160    }
161}
162
163/// Parse an LSP `documentSymbol` response into a hierarchical tree.
164pub fn parse_symbol_locations(value: &Value) -> Vec<SymbolLocation> {
165    let Some(items) = value.as_array() else {
166        return Vec::new();
167    };
168
169    items.iter().map(parse_single_symbol).collect()
170}
171
172#[allow(clippy::cast_possible_truncation)]
173fn parse_single_symbol(item: &Value) -> SymbolLocation {
174    let name = item
175        .get("name")
176        .and_then(Value::as_str)
177        .unwrap_or_default()
178        .to_string();
179
180    let kind = symbol_kind_name(item.get("kind").and_then(Value::as_u64).unwrap_or(0)).to_string();
181
182    let start_line = item
183        .pointer("/range/start/line")
184        .or_else(|| item.pointer("/location/range/start/line"))
185        .and_then(Value::as_u64)
186        .unwrap_or(0) as u32;
187
188    let end_line = item
189        .pointer("/range/end/line")
190        .or_else(|| item.pointer("/location/range/end/line"))
191        .and_then(Value::as_u64)
192        .unwrap_or(0) as u32;
193
194    let children = item
195        .get("children")
196        .map(parse_symbol_locations)
197        .unwrap_or_default();
198
199    SymbolLocation {
200        name,
201        kind,
202        start_line,
203        end_line,
204        children,
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn parse_empty_response() {
214        let result = parse_symbol_locations(&json!(null));
215        assert!(result.is_empty());
216    }
217
218    #[test]
219    fn parse_flat_symbols() {
220        let response = json!([
221            {
222                "name": "greet",
223                "kind": 12,
224                "range": {
225                    "start": { "line": 0, "character": 0 },
226                    "end": { "line": 3, "character": 1 }
227                },
228                "selectionRange": {
229                    "start": { "line": 0, "character": 3 },
230                    "end": { "line": 0, "character": 8 }
231                }
232            },
233            {
234                "name": "Config",
235                "kind": 23,
236                "range": {
237                    "start": { "line": 5, "character": 0 },
238                    "end": { "line": 10, "character": 1 }
239                },
240                "selectionRange": {
241                    "start": { "line": 5, "character": 11 },
242                    "end": { "line": 5, "character": 17 }
243                }
244            }
245        ]);
246
247        let symbols = parse_symbol_locations(&response);
248        assert_eq!(symbols.len(), 2);
249        assert_eq!(symbols[0].name, "greet");
250        assert_eq!(symbols[0].kind, "function");
251        assert_eq!(symbols[0].start_line, 0);
252        assert_eq!(symbols[0].end_line, 3);
253        assert_eq!(symbols[1].name, "Config");
254        assert_eq!(symbols[1].kind, "struct");
255    }
256
257    #[test]
258    fn parse_nested_symbols() {
259        let response = json!([
260            {
261                "name": "Config",
262                "kind": 5,
263                "range": {
264                    "start": { "line": 0, "character": 0 },
265                    "end": { "line": 20, "character": 1 }
266                },
267                "selectionRange": {
268                    "start": { "line": 0, "character": 6 },
269                    "end": { "line": 0, "character": 12 }
270                },
271                "children": [
272                    {
273                        "name": "new",
274                        "kind": 6,
275                        "range": {
276                            "start": { "line": 5, "character": 2 },
277                            "end": { "line": 10, "character": 3 }
278                        },
279                        "selectionRange": {
280                            "start": { "line": 5, "character": 4 },
281                            "end": { "line": 5, "character": 7 }
282                        }
283                    },
284                    {
285                        "name": "validate",
286                        "kind": 6,
287                        "range": {
288                            "start": { "line": 12, "character": 2 },
289                            "end": { "line": 18, "character": 3 }
290                        },
291                        "selectionRange": {
292                            "start": { "line": 12, "character": 4 },
293                            "end": { "line": 12, "character": 12 }
294                        }
295                    }
296                ]
297            }
298        ]);
299
300        let symbols = parse_symbol_locations(&response);
301        assert_eq!(symbols.len(), 1);
302        assert_eq!(symbols[0].name, "Config");
303        assert_eq!(symbols[0].children.len(), 2);
304        assert_eq!(symbols[0].children[0].name, "new");
305        assert_eq!(symbols[0].children[0].start_line, 5);
306        assert_eq!(symbols[0].children[0].end_line, 10);
307        assert_eq!(symbols[0].children[1].name, "validate");
308    }
309
310    #[test]
311    fn parse_symbol_with_location_fallback() {
312        let response = json!([
313            {
314                "name": "test",
315                "kind": 12,
316                "location": {
317                    "uri": "file:///tmp/test.rs",
318                    "range": {
319                        "start": { "line": 3, "character": 0 },
320                        "end": { "line": 7, "character": 1 }
321                    }
322                }
323            }
324        ]);
325
326        let symbols = parse_symbol_locations(&response);
327        assert_eq!(symbols.len(), 1);
328        assert_eq!(symbols[0].start_line, 3);
329        assert_eq!(symbols[0].end_line, 7);
330    }
331}