Skip to main content

aft/
lsp_hints.rs

1//! LSP-enhanced symbol disambiguation.
2//!
3//! When the plugin has LSP access, it can attach `lsp_hints` to a request with
4//! file + line information for the symbol(s) in play. This module parses those
5//! hints and uses them to narrow ambiguous tree-sitter matches down to the
6//! single correct candidate.
7
8use crate::protocol::RawRequest;
9use crate::symbols::SymbolMatch;
10use serde::Deserialize;
11
12/// A single LSP-sourced symbol hint: name, file path, line number, and optional kind.
13#[derive(Debug, Clone, Deserialize)]
14pub struct LspSymbolHint {
15    pub name: String,
16    pub file: String,
17    pub line: u32,
18    #[serde(default)]
19    pub kind: Option<String>,
20}
21
22/// Collection of LSP symbol hints attached to a request.
23#[derive(Debug, Clone, Deserialize)]
24pub struct LspHints {
25    pub symbols: Vec<LspSymbolHint>,
26}
27
28/// Strip `file://` URI prefix from a path, returning the bare filesystem path.
29fn strip_file_uri(path: &str) -> &str {
30    path.strip_prefix("file://").unwrap_or(path)
31}
32
33/// Parse `lsp_hints` from a raw request.
34///
35/// Returns `Some(hints)` if `req.lsp_hints` is present and valid JSON matching
36/// the `LspHints` schema. Returns `None` (with a stderr warning) on malformed
37/// data, and `None` silently when the field is absent.
38pub fn parse_lsp_hints(req: &RawRequest) -> Option<LspHints> {
39    let value = req.lsp_hints.as_ref()?;
40    match serde_json::from_value::<LspHints>(value.clone()) {
41        Ok(hints) => {
42            log::debug!(
43                "[aft] lsp_hints: parsed {} symbol hints",
44                hints.symbols.len()
45            );
46            Some(hints)
47        }
48        Err(e) => {
49            log::warn!("lsp_hints: ignoring malformed data: {}", e);
50            None
51        }
52    }
53}
54
55/// Use LSP hints to disambiguate multiple tree-sitter symbol matches.
56///
57/// For each candidate match, checks whether any hint's name + file + line aligns
58/// (the hint line falls within the symbol's start_line..=end_line range). If
59/// exactly one candidate aligns with a hint, returns just that match. Otherwise,
60/// returns all matches unchanged (graceful fallback).
61pub fn apply_lsp_disambiguation(matches: Vec<SymbolMatch>, hints: &LspHints) -> Vec<SymbolMatch> {
62    if matches.len() <= 1 || hints.symbols.is_empty() {
63        return matches;
64    }
65
66    let aligned_indices: Vec<usize> = matches
67        .iter()
68        .enumerate()
69        .filter_map(|(i, m)| {
70            let is_aligned = hints.symbols.iter().any(|hint| {
71                let hint_file = strip_file_uri(&hint.file);
72                hint.name == m.symbol.name
73                    && paths_match(hint_file, &m.file)
74                    && hint.line >= m.symbol.range.start_line
75                    && hint.line <= m.symbol.range.end_line
76            });
77            if is_aligned {
78                Some(i)
79            } else {
80                None
81            }
82        })
83        .collect();
84
85    // Only disambiguate if we narrowed to exactly one match.
86    // If zero or multiple still match, fall back to all original candidates.
87    if aligned_indices.len() == 1 {
88        let idx = aligned_indices[0];
89        matches
90            .into_iter()
91            .nth(idx)
92            .map_or_else(Vec::new, |m| vec![m])
93    } else {
94        matches
95    }
96}
97
98/// Check if two file paths refer to the same file.
99/// Compares by suffix — the hint path may be absolute while the match path is relative.
100fn paths_match(hint_path: &str, match_path: &str) -> bool {
101    // Normalize separators
102    let hint = hint_path.replace('\\', "/");
103    let m = match_path.replace('\\', "/");
104
105    if hint == m {
106        return true;
107    }
108
109    // Suffix match: one path ends with the other
110    hint.ends_with(&m) || m.ends_with(&hint)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::symbols::{Range, Symbol, SymbolKind, SymbolMatch};
117
118    fn make_request(lsp_hints: Option<serde_json::Value>) -> RawRequest {
119        RawRequest {
120            id: "test-1".into(),
121            command: "edit_symbol".into(),
122            lsp_hints,
123            session_id: None,
124            params: serde_json::json!({}),
125        }
126    }
127
128    fn make_match(
129        name: &str,
130        file: &str,
131        start_line: u32,
132        end_line: u32,
133        kind: SymbolKind,
134    ) -> SymbolMatch {
135        SymbolMatch {
136            symbol: Symbol {
137                name: name.into(),
138                kind,
139                range: Range {
140                    start_line,
141                    start_col: 0,
142                    end_line,
143                    end_col: 0,
144                },
145                signature: None,
146                scope_chain: vec![],
147                exported: true,
148                parent: None,
149            },
150            file: file.into(),
151        }
152    }
153
154    // --- Parsing tests ---
155
156    #[test]
157    fn parse_valid_hints() {
158        let req = make_request(Some(serde_json::json!({
159            "symbols": [
160                {"name": "process", "file": "src/app.ts", "line": 10, "kind": "function"},
161                {"name": "process", "file": "src/app.ts", "line": 25}
162            ]
163        })));
164        let hints = parse_lsp_hints(&req).unwrap();
165        assert_eq!(hints.symbols.len(), 2);
166        assert_eq!(hints.symbols[0].name, "process");
167        assert_eq!(hints.symbols[0].kind, Some("function".into()));
168        assert_eq!(hints.symbols[1].kind, None);
169    }
170
171    #[test]
172    fn parse_absent_hints_returns_none() {
173        let req = make_request(None);
174        assert!(parse_lsp_hints(&req).is_none());
175    }
176
177    #[test]
178    fn parse_malformed_json_returns_none() {
179        // Missing required "symbols" field
180        let req = make_request(Some(serde_json::json!({"bad": "data"})));
181        assert!(parse_lsp_hints(&req).is_none());
182    }
183
184    #[test]
185    fn parse_empty_symbols_array() {
186        let req = make_request(Some(serde_json::json!({"symbols": []})));
187        let hints = parse_lsp_hints(&req).unwrap();
188        assert!(hints.symbols.is_empty());
189    }
190
191    #[test]
192    fn parse_missing_required_field_in_hint() {
193        // Each hint requires name, file, line — missing "line" here
194        let req = make_request(Some(serde_json::json!({
195            "symbols": [{"name": "foo", "file": "bar.ts"}]
196        })));
197        assert!(parse_lsp_hints(&req).is_none());
198    }
199
200    // --- Disambiguation tests ---
201
202    #[test]
203    fn disambiguate_single_match_by_line() {
204        let matches = vec![
205            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
206            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
207        ];
208        let hints = LspHints {
209            symbols: vec![LspSymbolHint {
210                name: "process".into(),
211                file: "src/app.ts".into(),
212                line: 3,
213                kind: None,
214            }],
215        };
216        let result = apply_lsp_disambiguation(matches, &hints);
217        assert_eq!(result.len(), 1);
218        assert_eq!(result[0].symbol.range.start_line, 2);
219    }
220
221    #[test]
222    fn disambiguate_no_match_returns_all() {
223        let matches = vec![
224            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
225            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
226        ];
227        let hints = LspHints {
228            symbols: vec![LspSymbolHint {
229                name: "process".into(),
230                file: "other/file.ts".into(),
231                line: 99,
232                kind: None,
233            }],
234        };
235        let result = apply_lsp_disambiguation(matches, &hints);
236        assert_eq!(
237            result.len(),
238            2,
239            "no hint matches → fallback to all candidates"
240        );
241    }
242
243    #[test]
244    fn disambiguate_stale_hint_ignored() {
245        // Hint line doesn't fall in any symbol's range
246        let matches = vec![
247            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
248            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
249        ];
250        let hints = LspHints {
251            symbols: vec![LspSymbolHint {
252                name: "process".into(),
253                file: "src/app.ts".into(),
254                line: 50, // stale — doesn't match either range
255                kind: None,
256            }],
257        };
258        let result = apply_lsp_disambiguation(matches, &hints);
259        assert_eq!(
260            result.len(),
261            2,
262            "stale hint should fall back to all candidates"
263        );
264    }
265
266    #[test]
267    fn disambiguate_file_uri_stripped() {
268        let matches = vec![
269            make_match("handler", "src/api.ts", 10, 20, SymbolKind::Function),
270            make_match("handler", "src/api.ts", 30, 40, SymbolKind::Function),
271        ];
272        let hints = LspHints {
273            symbols: vec![LspSymbolHint {
274                name: "handler".into(),
275                file: "file://src/api.ts".into(),
276                line: 15,
277                kind: None,
278            }],
279        };
280        let result = apply_lsp_disambiguation(matches, &hints);
281        assert_eq!(result.len(), 1);
282        assert_eq!(result[0].symbol.range.start_line, 10);
283    }
284
285    #[test]
286    fn disambiguate_single_input_unchanged() {
287        let matches = vec![make_match("foo", "bar.ts", 1, 5, SymbolKind::Function)];
288        let hints = LspHints {
289            symbols: vec![LspSymbolHint {
290                name: "foo".into(),
291                file: "bar.ts".into(),
292                line: 3,
293                kind: None,
294            }],
295        };
296        let result = apply_lsp_disambiguation(matches, &hints);
297        assert_eq!(result.len(), 1);
298    }
299
300    // --- Path matching tests ---
301
302    #[test]
303    fn paths_match_exact() {
304        assert!(paths_match("src/app.ts", "src/app.ts"));
305    }
306
307    #[test]
308    fn paths_match_suffix() {
309        assert!(paths_match("/home/user/project/src/app.ts", "src/app.ts"));
310    }
311
312    #[test]
313    fn paths_no_match() {
314        assert!(!paths_match("src/other.ts", "src/app.ts"));
315    }
316}