Skip to main content

aft/
lsp_hints.rs

1//! LSP-enhanced symbol disambiguation.
2//!
3//! When the plugin has LSP access, it can attach `lsp_hints` to a request with
4//! file + line information for the symbol(s) in play. This module parses those
5//! hints and uses them to narrow ambiguous tree-sitter matches down to the
6//! single correct candidate.
7
8use crate::protocol::RawRequest;
9use crate::symbols::SymbolMatch;
10use serde::Deserialize;
11
12/// A single LSP-sourced symbol hint: name, file path, line number, and optional kind.
13#[derive(Debug, Clone, Deserialize)]
14pub struct LspSymbolHint {
15    pub name: String,
16    pub file: String,
17    pub line: u32,
18    #[serde(default)]
19    pub kind: Option<String>,
20}
21
22/// Collection of LSP symbol hints attached to a request.
23#[derive(Debug, Clone, Deserialize)]
24pub struct LspHints {
25    pub symbols: Vec<LspSymbolHint>,
26}
27
28/// Strip `file://` URI prefix from a path, returning the bare filesystem path.
29fn strip_file_uri(path: &str) -> &str {
30    path.strip_prefix("file://").unwrap_or(path)
31}
32
33/// Parse `lsp_hints` from a raw request.
34///
35/// Returns `Some(hints)` if `req.lsp_hints` is present and valid JSON matching
36/// the `LspHints` schema. Returns `None` (with a stderr warning) on malformed
37/// data, and `None` silently when the field is absent.
38pub fn parse_lsp_hints(req: &RawRequest) -> Option<LspHints> {
39    let value = req.lsp_hints.as_ref()?;
40    match serde_json::from_value::<LspHints>(value.clone()) {
41        Ok(hints) => {
42            log::debug!(
43                "[aft] lsp_hints: parsed {} symbol hints",
44                hints.symbols.len()
45            );
46            Some(hints)
47        }
48        Err(e) => {
49            log::warn!("lsp_hints: ignoring malformed data: {}", e);
50            None
51        }
52    }
53}
54
55/// Use LSP hints to disambiguate multiple tree-sitter symbol matches.
56///
57/// For each candidate match, checks whether any hint's name + file + line aligns
58/// (the hint line falls within the symbol's start_line..=end_line range). If
59/// exactly one candidate aligns with a hint, returns just that match. Otherwise,
60/// returns all matches unchanged (graceful fallback).
61pub fn apply_lsp_disambiguation(matches: Vec<SymbolMatch>, hints: &LspHints) -> Vec<SymbolMatch> {
62    if matches.len() <= 1 || hints.symbols.is_empty() {
63        return matches;
64    }
65
66    let aligned_indices: Vec<usize> = matches
67        .iter()
68        .enumerate()
69        .filter_map(|(i, m)| {
70            let is_aligned = hints.symbols.iter().any(|hint| {
71                let hint_file = strip_file_uri(&hint.file);
72                hint.name == m.symbol.name
73                    && paths_match(hint_file, &m.file)
74                    && hint.line >= m.symbol.range.start_line
75                    && hint.line <= m.symbol.range.end_line
76            });
77            if is_aligned {
78                Some(i)
79            } else {
80                None
81            }
82        })
83        .collect();
84
85    // Only disambiguate if we narrowed to exactly one match.
86    // If zero or multiple still match, fall back to all original candidates.
87    if aligned_indices.len() == 1 {
88        let idx = aligned_indices[0];
89        vec![matches.into_iter().nth(idx).unwrap()]
90    } else {
91        matches
92    }
93}
94
95/// Check if two file paths refer to the same file.
96/// Compares by suffix — the hint path may be absolute while the match path is relative.
97fn paths_match(hint_path: &str, match_path: &str) -> bool {
98    // Normalize separators
99    let hint = hint_path.replace('\\', "/");
100    let m = match_path.replace('\\', "/");
101
102    if hint == m {
103        return true;
104    }
105
106    // Suffix match: one path ends with the other
107    hint.ends_with(&m) || m.ends_with(&hint)
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::symbols::{Range, Symbol, SymbolKind, SymbolMatch};
114
115    fn make_request(lsp_hints: Option<serde_json::Value>) -> RawRequest {
116        RawRequest {
117            id: "test-1".into(),
118            command: "edit_symbol".into(),
119            lsp_hints,
120            params: serde_json::json!({}),
121        }
122    }
123
124    fn make_match(
125        name: &str,
126        file: &str,
127        start_line: u32,
128        end_line: u32,
129        kind: SymbolKind,
130    ) -> SymbolMatch {
131        SymbolMatch {
132            symbol: Symbol {
133                name: name.into(),
134                kind,
135                range: Range {
136                    start_line,
137                    start_col: 0,
138                    end_line,
139                    end_col: 0,
140                },
141                signature: None,
142                scope_chain: vec![],
143                exported: true,
144                parent: None,
145            },
146            file: file.into(),
147        }
148    }
149
150    // --- Parsing tests ---
151
152    #[test]
153    fn parse_valid_hints() {
154        let req = make_request(Some(serde_json::json!({
155            "symbols": [
156                {"name": "process", "file": "src/app.ts", "line": 10, "kind": "function"},
157                {"name": "process", "file": "src/app.ts", "line": 25}
158            ]
159        })));
160        let hints = parse_lsp_hints(&req).unwrap();
161        assert_eq!(hints.symbols.len(), 2);
162        assert_eq!(hints.symbols[0].name, "process");
163        assert_eq!(hints.symbols[0].kind, Some("function".into()));
164        assert_eq!(hints.symbols[1].kind, None);
165    }
166
167    #[test]
168    fn parse_absent_hints_returns_none() {
169        let req = make_request(None);
170        assert!(parse_lsp_hints(&req).is_none());
171    }
172
173    #[test]
174    fn parse_malformed_json_returns_none() {
175        // Missing required "symbols" field
176        let req = make_request(Some(serde_json::json!({"bad": "data"})));
177        assert!(parse_lsp_hints(&req).is_none());
178    }
179
180    #[test]
181    fn parse_empty_symbols_array() {
182        let req = make_request(Some(serde_json::json!({"symbols": []})));
183        let hints = parse_lsp_hints(&req).unwrap();
184        assert!(hints.symbols.is_empty());
185    }
186
187    #[test]
188    fn parse_missing_required_field_in_hint() {
189        // Each hint requires name, file, line — missing "line" here
190        let req = make_request(Some(serde_json::json!({
191            "symbols": [{"name": "foo", "file": "bar.ts"}]
192        })));
193        assert!(parse_lsp_hints(&req).is_none());
194    }
195
196    // --- Disambiguation tests ---
197
198    #[test]
199    fn disambiguate_single_match_by_line() {
200        let matches = vec![
201            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
202            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
203        ];
204        let hints = LspHints {
205            symbols: vec![LspSymbolHint {
206                name: "process".into(),
207                file: "src/app.ts".into(),
208                line: 3,
209                kind: None,
210            }],
211        };
212        let result = apply_lsp_disambiguation(matches, &hints);
213        assert_eq!(result.len(), 1);
214        assert_eq!(result[0].symbol.range.start_line, 2);
215    }
216
217    #[test]
218    fn disambiguate_no_match_returns_all() {
219        let matches = vec![
220            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
221            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
222        ];
223        let hints = LspHints {
224            symbols: vec![LspSymbolHint {
225                name: "process".into(),
226                file: "other/file.ts".into(),
227                line: 99,
228                kind: None,
229            }],
230        };
231        let result = apply_lsp_disambiguation(matches, &hints);
232        assert_eq!(
233            result.len(),
234            2,
235            "no hint matches → fallback to all candidates"
236        );
237    }
238
239    #[test]
240    fn disambiguate_stale_hint_ignored() {
241        // Hint line doesn't fall in any symbol's range
242        let matches = vec![
243            make_match("process", "src/app.ts", 2, 4, SymbolKind::Function),
244            make_match("process", "src/app.ts", 7, 10, SymbolKind::Method),
245        ];
246        let hints = LspHints {
247            symbols: vec![LspSymbolHint {
248                name: "process".into(),
249                file: "src/app.ts".into(),
250                line: 50, // stale — doesn't match either range
251                kind: None,
252            }],
253        };
254        let result = apply_lsp_disambiguation(matches, &hints);
255        assert_eq!(
256            result.len(),
257            2,
258            "stale hint should fall back to all candidates"
259        );
260    }
261
262    #[test]
263    fn disambiguate_file_uri_stripped() {
264        let matches = vec![
265            make_match("handler", "src/api.ts", 10, 20, SymbolKind::Function),
266            make_match("handler", "src/api.ts", 30, 40, SymbolKind::Function),
267        ];
268        let hints = LspHints {
269            symbols: vec![LspSymbolHint {
270                name: "handler".into(),
271                file: "file://src/api.ts".into(),
272                line: 15,
273                kind: None,
274            }],
275        };
276        let result = apply_lsp_disambiguation(matches, &hints);
277        assert_eq!(result.len(), 1);
278        assert_eq!(result[0].symbol.range.start_line, 10);
279    }
280
281    #[test]
282    fn disambiguate_single_input_unchanged() {
283        let matches = vec![make_match("foo", "bar.ts", 1, 5, SymbolKind::Function)];
284        let hints = LspHints {
285            symbols: vec![LspSymbolHint {
286                name: "foo".into(),
287                file: "bar.ts".into(),
288                line: 3,
289                kind: None,
290            }],
291        };
292        let result = apply_lsp_disambiguation(matches, &hints);
293        assert_eq!(result.len(), 1);
294    }
295
296    // --- Path matching tests ---
297
298    #[test]
299    fn paths_match_exact() {
300        assert!(paths_match("src/app.ts", "src/app.ts"));
301    }
302
303    #[test]
304    fn paths_match_suffix() {
305        assert!(paths_match("/home/user/project/src/app.ts", "src/app.ts"));
306    }
307
308    #[test]
309    fn paths_no_match() {
310        assert!(!paths_match("src/other.ts", "src/app.ts"));
311    }
312}