Skip to main content

aft/
lsp_hints.rs

1//! LSP-enhanced symbol disambiguation.
2//!
3//! When the plugin has LSP access, it can attach `lsp_hints` to a request with
4//! file + line information for the symbol(s) in play. This module parses those
5//! hints and uses them to narrow ambiguous tree-sitter matches down to the
6//! single correct candidate.
7
8use crate::protocol::RawRequest;
9use crate::symbols::SymbolMatch;
10use serde::Deserialize;
11
12/// A single LSP-sourced symbol hint: name, file path, line number, and optional kind.
13#[derive(Debug, Clone, Deserialize)]
14pub struct LspSymbolHint {
15    pub name: String,
16    pub file: String,
17    pub line: u32,
18    #[serde(default)]
19    pub kind: Option<String>,
20}
21
22/// Collection of LSP symbol hints attached to a request.
23#[derive(Debug, Clone, Deserialize)]
24pub struct LspHints {
25    pub symbols: Vec<LspSymbolHint>,
26}
27
28/// Strip `file://` URI prefix from a path, returning the bare filesystem path.
29fn strip_file_uri(path: &str) -> &str {
30    path.strip_prefix("file://").unwrap_or(path)
31}
32
33/// Parse `lsp_hints` from a raw request.
34///
35/// Returns `Some(hints)` if `req.lsp_hints` is present and valid JSON matching
36/// the `LspHints` schema. Returns `None` (with a stderr warning) on malformed
37/// data, and `None` silently when the field is absent.
38pub fn parse_lsp_hints(req: &RawRequest) -> Option<LspHints> {
39    let value = req.lsp_hints.as_ref()?;
40    match serde_json::from_value::<LspHints>(value.clone()) {
41        Ok(hints) => {
42            log::debug!(
43                "[aft] lsp_hints: parsed {} symbol hints",
44                hints.symbols.len()
45            );
46            Some(hints)
47        }
48        Err(e) => {
49            log::warn!("lsp_hints: ignoring malformed data: {}", e);
50            None
51        }
52    }
53}
54
55/// Use LSP hints to disambiguate multiple tree-sitter symbol matches.
56///
57/// For each candidate match, checks whether any hint's name + file + line aligns
58/// (the hint line falls within the symbol's start_line..=end_line range). If
59/// exactly one candidate aligns with a hint, returns just that match. Otherwise,
60/// returns all matches unchanged (graceful fallback).
61pub fn apply_lsp_disambiguation(matches: Vec<SymbolMatch>, hints: &LspHints) -> Vec<SymbolMatch> {
62    if matches.len() <= 1 || hints.symbols.is_empty() {
63        return matches;
64    }
65
66    let aligned_indices: Vec<usize> = matches
67        .iter()
68        .enumerate()
69        .filter_map(|(i, m)| {
70            let is_aligned = hints.symbols.iter().any(|hint| {
71                let hint_file = strip_file_uri(&hint.file);
72                hint.name == m.symbol.name
73                    && paths_match(hint_file, &m.file)
74                    && hint.line >= m.symbol.range.start_line
75                    && hint.line <= m.symbol.range.end_line
76            });
77            if is_aligned {
78                Some(i)
79            } else {
80                None
81            }
82        })
83        .collect();
84
85    // Only disambiguate if we narrowed to exactly one match.
86    // If zero or multiple still match, fall back to all original candidates.
87    if aligned_indices.len() == 1 {
88        let idx = aligned_indices[0];
89        matches
90            .into_iter()
91            .nth(idx)
92            .map_or_else(Vec::new, |m| vec![m])
93    } else {
94        matches
95    }
96}
97
98/// Check if two file paths refer to the same file.
99/// Compares by suffix — the hint path may be absolute while the match path is relative.
100fn paths_match(hint_path: &str, match_path: &str) -> bool {
101    // Normalize separators
102    let hint = hint_path.replace('\\', "/");
103    let m = match_path.replace('\\', "/");
104
105    if hint == m {
106        return true;
107    }
108
109    // Suffix match: one path ends with the other
110    hint.ends_with(&m) || m.ends_with(&hint)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::symbols::{Range, Symbol, SymbolKind, SymbolMatch};
117
118    fn make_request(lsp_hints: Option<serde_json::Value>) -> RawRequest {
119        RawRequest {
120            id: "test-1".into(),
121            command: "edit_symbol".into(),
122            lsp_hints,
123            params: serde_json::json!({}),
124        }
125    }
126
127    fn make_match(
128        name: &str,
129        file: &str,
130        start_line: u32,
131        end_line: u32,
132        kind: SymbolKind,
133    ) -> SymbolMatch {
134        SymbolMatch {
135            symbol: Symbol {
136                name: name.into(),
137                kind,
138                range: Range {
139                    start_line,
140                    start_col: 0,
141                    end_line,
142                    end_col: 0,
143                },
144                signature: None,
145                scope_chain: vec![],
146                exported: true,
147                parent: None,
148            },
149            file: file.into(),
150        }
151    }
152
153    // --- Parsing tests ---
154
155    #[test]
156    fn parse_valid_hints() {
157        let req = make_request(Some(serde_json::json!({
158            "symbols": [
159                {"name": "process", "file": "src/app.ts", "line": 10, "kind": "function"},
160                {"name": "process", "file": "src/app.ts", "line": 25}
161            ]
162        })));
163        let hints = parse_lsp_hints(&req).unwrap();
164        assert_eq!(hints.symbols.len(), 2);
165        assert_eq!(hints.symbols[0].name, "process");
166        assert_eq!(hints.symbols[0].kind, Some("function".into()));
167        assert_eq!(hints.symbols[1].kind, None);
168    }
169
170    #[test]
171    fn parse_absent_hints_returns_none() {
172        let req = make_request(None);
173        assert!(parse_lsp_hints(&req).is_none());
174    }
175
176    #[test]
177    fn parse_malformed_json_returns_none() {
178        // Missing required "symbols" field
179        let req = make_request(Some(serde_json::json!({"bad": "data"})));
180        assert!(parse_lsp_hints(&req).is_none());
181    }
182
183    #[test]
184    fn parse_empty_symbols_array() {
185        let req = make_request(Some(serde_json::json!({"symbols": []})));
186        let hints = parse_lsp_hints(&req).unwrap();
187        assert!(hints.symbols.is_empty());
188    }
189
190    #[test]
191    fn parse_missing_required_field_in_hint() {
192        // Each hint requires name, file, line — missing "line" here
193        let req = make_request(Some(serde_json::json!({
194            "symbols": [{"name": "foo", "file": "bar.ts"}]
195        })));
196        assert!(parse_lsp_hints(&req).is_none());
197    }
198
199    // --- Disambiguation tests ---
200
201    #[test]
202    fn disambiguate_single_match_by_line() {
203        let matches = vec![
204            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
205            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
206        ];
207        let hints = LspHints {
208            symbols: vec![LspSymbolHint {
209                name: "process".into(),
210                file: "src/app.ts".into(),
211                line: 3,
212                kind: None,
213            }],
214        };
215        let result = apply_lsp_disambiguation(matches, &hints);
216        assert_eq!(result.len(), 1);
217        assert_eq!(result[0].symbol.range.start_line, 2);
218    }
219
220    #[test]
221    fn disambiguate_no_match_returns_all() {
222        let matches = vec![
223            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
224            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
225        ];
226        let hints = LspHints {
227            symbols: vec![LspSymbolHint {
228                name: "process".into(),
229                file: "other/file.ts".into(),
230                line: 99,
231                kind: None,
232            }],
233        };
234        let result = apply_lsp_disambiguation(matches, &hints);
235        assert_eq!(
236            result.len(),
237            2,
238            "no hint matches → fallback to all candidates"
239        );
240    }
241
242    #[test]
243    fn disambiguate_stale_hint_ignored() {
244        // Hint line doesn't fall in any symbol's range
245        let matches = vec![
246            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
247            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
248        ];
249        let hints = LspHints {
250            symbols: vec![LspSymbolHint {
251                name: "process".into(),
252                file: "src/app.ts".into(),
253                line: 50, // stale — doesn't match either range
254                kind: None,
255            }],
256        };
257        let result = apply_lsp_disambiguation(matches, &hints);
258        assert_eq!(
259            result.len(),
260            2,
261            "stale hint should fall back to all candidates"
262        );
263    }
264
265    #[test]
266    fn disambiguate_file_uri_stripped() {
267        let matches = vec![
268            make_match("handler", "src/api.ts", 10, 20, SymbolKind::Function),
269            make_match("handler", "src/api.ts", 30, 40, SymbolKind::Function),
270        ];
271        let hints = LspHints {
272            symbols: vec![LspSymbolHint {
273                name: "handler".into(),
274                file: "file://src/api.ts".into(),
275                line: 15,
276                kind: None,
277            }],
278        };
279        let result = apply_lsp_disambiguation(matches, &hints);
280        assert_eq!(result.len(), 1);
281        assert_eq!(result[0].symbol.range.start_line, 10);
282    }
283
284    #[test]
285    fn disambiguate_single_input_unchanged() {
286        let matches = vec![make_match("foo", "bar.ts", 1, 5, SymbolKind::Function)];
287        let hints = LspHints {
288            symbols: vec![LspSymbolHint {
289                name: "foo".into(),
290                file: "bar.ts".into(),
291                line: 3,
292                kind: None,
293            }],
294        };
295        let result = apply_lsp_disambiguation(matches, &hints);
296        assert_eq!(result.len(), 1);
297    }
298
299    // --- Path matching tests ---
300
301    #[test]
302    fn paths_match_exact() {
303        assert!(paths_match("src/app.ts", "src/app.ts"));
304    }
305
306    #[test]
307    fn paths_match_suffix() {
308        assert!(paths_match("/home/user/project/src/app.ts", "src/app.ts"));
309    }
310
311    #[test]
312    fn paths_no_match() {
313        assert!(!paths_match("src/other.ts", "src/app.ts"));
314    }
315}