Skip to main content

aft/
lsp_hints.rs

1//! LSP-enhanced symbol disambiguation.
2//!
3//! When the plugin has LSP access, it can attach `lsp_hints` to a request with
4//! file + line information for the symbol(s) in play. This module parses those
5//! hints and uses them to narrow ambiguous tree-sitter matches down to the
6//! single correct candidate.
7
8use crate::protocol::RawRequest;
9use crate::symbols::SymbolMatch;
10use serde::Deserialize;
11use std::path::Path;
12
13/// A single LSP-sourced symbol hint: name, file path, line number, and optional kind.
14#[derive(Debug, Clone, Deserialize)]
15pub struct LspSymbolHint {
16    pub name: String,
17    pub file: String,
18    pub line: u32,
19    #[serde(default)]
20    pub kind: Option<String>,
21}
22
23/// Collection of LSP symbol hints attached to a request.
24#[derive(Debug, Clone, Deserialize)]
25pub struct LspHints {
26    pub symbols: Vec<LspSymbolHint>,
27}
28
29/// Strip `file://` URI prefix from a path, returning the bare filesystem path.
30fn strip_file_uri(path: &str) -> &str {
31    path.strip_prefix("file://").unwrap_or(path)
32}
33
34/// Parse `lsp_hints` from a raw request.
35///
36/// Returns `Some(hints)` if `req.lsp_hints` is present and valid JSON matching
37/// the `LspHints` schema. Returns `None` (with a stderr warning) on malformed
38/// data, and `None` silently when the field is absent.
39pub fn parse_lsp_hints(req: &RawRequest) -> Option<LspHints> {
40    let value = req.lsp_hints.as_ref()?;
41    match serde_json::from_value::<LspHints>(value.clone()) {
42        Ok(hints) => {
43            log::debug!(
44                "lsp_hints: parsed {} symbol hints",
45                hints.symbols.len()
46            );
47            Some(hints)
48        }
49        Err(e) => {
50            log::warn!("lsp_hints: ignoring malformed data: {}", e);
51            None
52        }
53    }
54}
55
56/// Use LSP hints to disambiguate multiple tree-sitter symbol matches.
57///
58/// For each candidate match, checks whether any hint's name + file + line aligns
59/// (the hint line falls within the symbol's start_line..=end_line range). If
60/// exactly one candidate aligns with a hint, returns just that match. Otherwise,
61/// returns all matches unchanged (graceful fallback).
62pub fn apply_lsp_disambiguation(matches: Vec<SymbolMatch>, hints: &LspHints) -> Vec<SymbolMatch> {
63    if matches.len() <= 1 || hints.symbols.is_empty() {
64        return matches;
65    }
66
67    let aligned_indices: Vec<usize> = matches
68        .iter()
69        .enumerate()
70        .filter_map(|(i, m)| {
71            let is_aligned = hints.symbols.iter().any(|hint| {
72                let hint_file = strip_file_uri(&hint.file);
73                hint.name == m.symbol.name
74                    && paths_match(hint_file, &m.file)
75                    && hint.line >= m.symbol.range.start_line
76                    && hint.line <= m.symbol.range.end_line
77            });
78            if is_aligned {
79                Some(i)
80            } else {
81                None
82            }
83        })
84        .collect();
85
86    // Only disambiguate if we narrowed to exactly one match.
87    // If zero or multiple still match, fall back to all original candidates.
88    if aligned_indices.len() == 1 {
89        let idx = aligned_indices[0];
90        matches
91            .into_iter()
92            .nth(idx)
93            .map_or_else(Vec::new, |m| vec![m])
94    } else {
95        matches
96    }
97}
98
99/// Check if two file paths refer to the same file.
100/// Compares canonical paths when possible, then falls back to separator-bounded suffix matching.
101fn paths_match(hint_path: &str, match_path: &str) -> bool {
102    if let (Ok(hint), Ok(m)) = (
103        std::fs::canonicalize(Path::new(hint_path)),
104        std::fs::canonicalize(Path::new(match_path)),
105    ) {
106        return hint == m;
107    }
108
109    let hint = hint_path.replace('\\', "/");
110    let m = match_path.replace('\\', "/");
111
112    if hint == m {
113        return true;
114    }
115
116    if hint.len() >= m.len() {
117        suffix_at_separator_boundary(&hint, &m)
118    } else {
119        suffix_at_separator_boundary(&m, &hint)
120    }
121}
122
123fn suffix_at_separator_boundary(longer: &str, shorter: &str) -> bool {
124    if shorter.is_empty() || longer.len() <= shorter.len() || !longer.ends_with(shorter) {
125        return false;
126    }
127
128    longer.as_bytes()[longer.len() - shorter.len() - 1] == b'/'
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::symbols::{Range, Symbol, SymbolKind, SymbolMatch};
135
136    fn make_request(lsp_hints: Option<serde_json::Value>) -> RawRequest {
137        RawRequest {
138            id: "test-1".into(),
139            command: "edit_symbol".into(),
140            lsp_hints,
141            session_id: None,
142            params: serde_json::json!({}),
143        }
144    }
145
146    fn make_match(
147        name: &str,
148        file: &str,
149        start_line: u32,
150        end_line: u32,
151        kind: SymbolKind,
152    ) -> SymbolMatch {
153        SymbolMatch {
154            symbol: Symbol {
155                name: name.into(),
156                kind,
157                range: Range {
158                    start_line,
159                    start_col: 0,
160                    end_line,
161                    end_col: 0,
162                },
163                signature: None,
164                scope_chain: vec![],
165                exported: true,
166                parent: None,
167            },
168            file: file.into(),
169        }
170    }
171
172    // --- Parsing tests ---
173
174    #[test]
175    fn parse_valid_hints() {
176        let req = make_request(Some(serde_json::json!({
177            "symbols": [
178                {"name": "process", "file": "src/app.ts", "line": 10, "kind": "function"},
179                {"name": "process", "file": "src/app.ts", "line": 25}
180            ]
181        })));
182        let hints = parse_lsp_hints(&req).unwrap();
183        assert_eq!(hints.symbols.len(), 2);
184        assert_eq!(hints.symbols[0].name, "process");
185        assert_eq!(hints.symbols[0].kind, Some("function".into()));
186        assert_eq!(hints.symbols[1].kind, None);
187    }
188
189    #[test]
190    fn parse_absent_hints_returns_none() {
191        let req = make_request(None);
192        assert!(parse_lsp_hints(&req).is_none());
193    }
194
195    #[test]
196    fn parse_malformed_json_returns_none() {
197        // Missing required "symbols" field
198        let req = make_request(Some(serde_json::json!({"bad": "data"})));
199        assert!(parse_lsp_hints(&req).is_none());
200    }
201
202    #[test]
203    fn parse_empty_symbols_array() {
204        let req = make_request(Some(serde_json::json!({"symbols": []})));
205        let hints = parse_lsp_hints(&req).unwrap();
206        assert!(hints.symbols.is_empty());
207    }
208
209    #[test]
210    fn parse_missing_required_field_in_hint() {
211        // Each hint requires name, file, line — missing "line" here
212        let req = make_request(Some(serde_json::json!({
213            "symbols": [{"name": "foo", "file": "bar.ts"}]
214        })));
215        assert!(parse_lsp_hints(&req).is_none());
216    }
217
218    // --- Disambiguation tests ---
219
220    #[test]
221    fn disambiguate_single_match_by_line() {
222        let matches = vec![
223            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
224            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
225        ];
226        let hints = LspHints {
227            symbols: vec![LspSymbolHint {
228                name: "process".into(),
229                file: "src/app.ts".into(),
230                line: 3,
231                kind: None,
232            }],
233        };
234        let result = apply_lsp_disambiguation(matches, &hints);
235        assert_eq!(result.len(), 1);
236        assert_eq!(result[0].symbol.range.start_line, 2);
237    }
238
239    #[test]
240    fn disambiguate_no_match_returns_all() {
241        let matches = vec![
242            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
243            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
244        ];
245        let hints = LspHints {
246            symbols: vec![LspSymbolHint {
247                name: "process".into(),
248                file: "other/file.ts".into(),
249                line: 99,
250                kind: None,
251            }],
252        };
253        let result = apply_lsp_disambiguation(matches, &hints);
254        assert_eq!(
255            result.len(),
256            2,
257            "no hint matches → fallback to all candidates"
258        );
259    }
260
261    #[test]
262    fn disambiguate_stale_hint_ignored() {
263        // Hint line doesn't fall in any symbol's range
264        let matches = vec![
265            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
266            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
267        ];
268        let hints = LspHints {
269            symbols: vec![LspSymbolHint {
270                name: "process".into(),
271                file: "src/app.ts".into(),
272                line: 50, // stale — doesn't match either range
273                kind: None,
274            }],
275        };
276        let result = apply_lsp_disambiguation(matches, &hints);
277        assert_eq!(
278            result.len(),
279            2,
280            "stale hint should fall back to all candidates"
281        );
282    }
283
284    #[test]
285    fn disambiguate_file_uri_stripped() {
286        let matches = vec![
287            make_match("handler", "src/api.ts", 10, 20, SymbolKind::Function),
288            make_match("handler", "src/api.ts", 30, 40, SymbolKind::Function),
289        ];
290        let hints = LspHints {
291            symbols: vec![LspSymbolHint {
292                name: "handler".into(),
293                file: "file://src/api.ts".into(),
294                line: 15,
295                kind: None,
296            }],
297        };
298        let result = apply_lsp_disambiguation(matches, &hints);
299        assert_eq!(result.len(), 1);
300        assert_eq!(result[0].symbol.range.start_line, 10);
301    }
302
303    #[test]
304    fn disambiguate_single_input_unchanged() {
305        let matches = vec![make_match("foo", "bar.ts", 1, 5, SymbolKind::Function)];
306        let hints = LspHints {
307            symbols: vec![LspSymbolHint {
308                name: "foo".into(),
309                file: "bar.ts".into(),
310                line: 3,
311                kind: None,
312            }],
313        };
314        let result = apply_lsp_disambiguation(matches, &hints);
315        assert_eq!(result.len(), 1);
316    }
317
318    // --- Path matching tests ---
319
320    #[test]
321    fn paths_match_exact() {
322        assert!(paths_match("src/app.ts", "src/app.ts"));
323    }
324
325    #[test]
326    fn paths_match_suffix() {
327        assert!(paths_match("/home/user/project/src/app.ts", "src/app.ts"));
328    }
329
330    #[test]
331    fn paths_match_filename_suffix_at_separator_boundary() {
332        assert!(paths_match("/home/user/project/src/app.ts", "app.ts"));
333    }
334
335    #[test]
336    fn paths_do_not_match_partial_filename_suffixes() {
337        assert!(!paths_match("foo/bar/baz.ts", "z.ts"));
338        assert!(!paths_match("foo/bar.ts", "ar.ts"));
339    }
340
341    #[test]
342    fn paths_no_match() {
343        assert!(!paths_match("src/other.ts", "src/app.ts"));
344    }
345}