Skip to main content

aft/lsp/
manager.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::str::FromStr;
4
5use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
6use lsp_types::notification::{DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument};
7use lsp_types::{
8    DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
9    TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem,
10    VersionedTextDocumentIdentifier,
11};
12
13use crate::lsp::client::{LspClient, LspEvent, ServerState};
14use crate::lsp::diagnostics::{from_lsp_diagnostics, DiagnosticsStore, StoredDiagnostic};
15use crate::lsp::document::DocumentStore;
16use crate::lsp::registry::{servers_for_file, ServerDef, ServerKind};
17use crate::lsp::roots::{find_workspace_root, ServerKey};
18use crate::lsp::LspError;
19
20pub struct LspManager {
21    /// Active server instances, keyed by (ServerKind, workspace_root).
22    clients: HashMap<ServerKey, LspClient>,
23    /// Tracks opened documents and versions per active server.
24    documents: HashMap<ServerKey, DocumentStore>,
25    /// Stored publishDiagnostics payloads across all servers.
26    diagnostics: DiagnosticsStore,
27    /// Unified event channel — all server reader threads send here.
28    event_tx: Sender<LspEvent>,
29    event_rx: Receiver<LspEvent>,
30    /// Optional binary path overrides used by integration tests.
31    binary_overrides: HashMap<ServerKind, PathBuf>,
32}
33
34impl LspManager {
35    pub fn new() -> Self {
36        let (event_tx, event_rx) = unbounded();
37        Self {
38            clients: HashMap::new(),
39            documents: HashMap::new(),
40            diagnostics: DiagnosticsStore::new(),
41            event_tx,
42            event_rx,
43            binary_overrides: HashMap::new(),
44        }
45    }
46
47    /// For testing: override the binary for a server kind.
48    pub fn override_binary(&mut self, kind: ServerKind, binary_path: PathBuf) {
49        self.binary_overrides.insert(kind, binary_path);
50    }
51
52    /// Ensure a server is running for the given file. Spawns if needed.
53    /// Returns the active server keys for the file, or an empty vec if none match.
54    pub fn ensure_server_for_file(&mut self, file_path: &Path) -> Vec<ServerKey> {
55        let defs = servers_for_file(file_path);
56        let mut keys = Vec::new();
57
58        for def in defs {
59            let Some(root) = find_workspace_root(file_path, def.root_markers) else {
60                continue;
61            };
62
63            let key = ServerKey {
64                kind: def.kind,
65                root,
66            };
67
68            if !self.clients.contains_key(&key) {
69                match self.spawn_server(def, &key.root) {
70                    Ok(client) => {
71                        self.clients.insert(key.clone(), client);
72                        self.documents.entry(key.clone()).or_default();
73                    }
74                    Err(err) => {
75                        log::error!("failed to spawn {}: {}", def.name, err);
76                        continue;
77                    }
78                }
79            }
80
81            keys.push(key);
82        }
83
84        keys
85    }
86    /// Ensure that servers are running for the file and that the document is open
87    /// in each server's DocumentStore. Reads file content from disk if not already open.
88    /// Returns the server keys for the file.
89    pub fn ensure_file_open(&mut self, file_path: &Path) -> Result<Vec<ServerKey>, LspError> {
90        let canonical_path = canonicalize_for_lsp(file_path)?;
91        let server_keys = self.ensure_server_for_file(&canonical_path);
92        if server_keys.is_empty() {
93            return Ok(server_keys);
94        }
95
96        let uri = uri_for_path(&canonical_path)?;
97        let language_id = language_id_for_extension(
98            canonical_path
99                .extension()
100                .and_then(|ext| ext.to_str())
101                .unwrap_or_default(),
102        )
103        .to_string();
104
105        for key in &server_keys {
106            let already_open = self
107                .documents
108                .get(key)
109                .map_or(false, |store| store.is_open(&canonical_path));
110
111            if !already_open {
112                let content = std::fs::read_to_string(&canonical_path).map_err(LspError::Io)?;
113                if let Some(client) = self.clients.get_mut(key) {
114                    client.send_notification::<DidOpenTextDocument>(DidOpenTextDocumentParams {
115                        text_document: TextDocumentItem::new(
116                            uri.clone(),
117                            language_id.clone(),
118                            0,
119                            content,
120                        ),
121                    })?;
122                }
123                self.documents
124                    .entry(key.clone())
125                    .or_default()
126                    .open(canonical_path.clone());
127            }
128        }
129
130        Ok(server_keys)
131    }
132
133    /// Notify relevant LSP servers that a file has been written/changed.
134    /// This is the main hook called after every file write in AFT.
135    ///
136    /// If the file's server isn't running yet, starts it (lazy spawn).
137    /// If the file isn't open in LSP yet, sends didOpen. Otherwise sends didChange.
138    pub fn notify_file_changed(&mut self, file_path: &Path, content: &str) -> Result<(), LspError> {
139        let canonical_path = canonicalize_for_lsp(file_path)?;
140        let server_keys = self.ensure_server_for_file(&canonical_path);
141        if server_keys.is_empty() {
142            return Ok(());
143        }
144
145        let uri = uri_for_path(&canonical_path)?;
146        let language_id = language_id_for_extension(
147            canonical_path
148                .extension()
149                .and_then(|ext| ext.to_str())
150                .unwrap_or_default(),
151        )
152        .to_string();
153
154        for key in server_keys {
155            let current_version = self
156                .documents
157                .get(&key)
158                .and_then(|store| store.version(&canonical_path));
159
160            if let Some(version) = current_version {
161                let next_version = version + 1;
162                if let Some(client) = self.clients.get_mut(&key) {
163                    client.send_notification::<DidChangeTextDocument>(
164                        DidChangeTextDocumentParams {
165                            text_document: VersionedTextDocumentIdentifier::new(
166                                uri.clone(),
167                                next_version,
168                            ),
169                            content_changes: vec![TextDocumentContentChangeEvent {
170                                range: None,
171                                range_length: None,
172                                text: content.to_string(),
173                            }],
174                        },
175                    )?;
176                }
177                if let Some(store) = self.documents.get_mut(&key) {
178                    store.bump_version(&canonical_path);
179                }
180                continue;
181            }
182
183            if let Some(client) = self.clients.get_mut(&key) {
184                client.send_notification::<DidOpenTextDocument>(DidOpenTextDocumentParams {
185                    text_document: TextDocumentItem::new(
186                        uri.clone(),
187                        language_id.clone(),
188                        0,
189                        content.to_string(),
190                    ),
191                })?;
192            }
193            self.documents
194                .entry(key)
195                .or_default()
196                .open(canonical_path.clone());
197        }
198
199        Ok(())
200    }
201
202    /// Close a document in all servers that have it open.
203    pub fn notify_file_closed(&mut self, file_path: &Path) -> Result<(), LspError> {
204        let canonical_path = canonicalize_for_lsp(file_path)?;
205        let uri = uri_for_path(&canonical_path)?;
206        let keys: Vec<ServerKey> = self.documents.keys().cloned().collect();
207
208        for key in keys {
209            let was_open = self
210                .documents
211                .get(&key)
212                .map(|store| store.is_open(&canonical_path))
213                .unwrap_or(false);
214            if !was_open {
215                continue;
216            }
217
218            if let Some(client) = self.clients.get_mut(&key) {
219                client.send_notification::<DidCloseTextDocument>(DidCloseTextDocumentParams {
220                    text_document: TextDocumentIdentifier::new(uri.clone()),
221                })?;
222            }
223
224            if let Some(store) = self.documents.get_mut(&key) {
225                store.close(&canonical_path);
226            }
227        }
228
229        Ok(())
230    }
231
232    /// Get an active client for a file path, if one exists.
233    pub fn client_for_file(&self, file_path: &Path) -> Option<&LspClient> {
234        let key = self.server_key_for_file(file_path)?;
235        self.clients.get(&key)
236    }
237
238    /// Get a mutable active client for a file path, if one exists.
239    pub fn client_for_file_mut(&mut self, file_path: &Path) -> Option<&mut LspClient> {
240        let key = self.server_key_for_file(file_path)?;
241        self.clients.get_mut(&key)
242    }
243
244    /// Number of tracked server clients.
245    pub fn active_client_count(&self) -> usize {
246        self.clients.len()
247    }
248
249    /// Drain all pending LSP events. Call from the main loop.
250    pub fn drain_events(&mut self) -> Vec<LspEvent> {
251        let mut events = Vec::new();
252        while let Ok(event) = self.event_rx.try_recv() {
253            self.handle_event(&event);
254            events.push(event);
255        }
256        events
257    }
258
259    /// Wait for diagnostics to arrive for a specific file until a timeout expires.
260    pub fn wait_for_diagnostics(
261        &mut self,
262        file_path: &Path,
263        timeout: std::time::Duration,
264    ) -> Vec<StoredDiagnostic> {
265        let deadline = std::time::Instant::now() + timeout;
266        self.wait_for_file_diagnostics(file_path, deadline)
267    }
268
269    /// Wait for diagnostics to arrive for a specific file until a deadline.
270    ///
271    /// Drains already-queued events first, then blocks on the shared event
272    /// channel only until either `publishDiagnostics` arrives for this file or
273    /// the deadline is reached.
274    pub fn wait_for_file_diagnostics(
275        &mut self,
276        file_path: &Path,
277        deadline: std::time::Instant,
278    ) -> Vec<StoredDiagnostic> {
279        let lookup_path = normalize_lookup_path(file_path);
280
281        if self.server_key_for_file(&lookup_path).is_none() {
282            return Vec::new();
283        }
284
285        loop {
286            if self.drain_events_for_file(&lookup_path) {
287                break;
288            }
289
290            let now = std::time::Instant::now();
291            if now >= deadline {
292                break;
293            }
294
295            let timeout = deadline.saturating_duration_since(now);
296            match self.event_rx.recv_timeout(timeout) {
297                Ok(event) => {
298                    if matches!(
299                        self.handle_event(&event),
300                        Some(ref published_file) if published_file.as_path() == lookup_path.as_path()
301                    ) {
302                        break;
303                    }
304                }
305                Err(RecvTimeoutError::Timeout) | Err(RecvTimeoutError::Disconnected) => break,
306            }
307        }
308
309        self.get_diagnostics_for_file(&lookup_path)
310            .into_iter()
311            .cloned()
312            .collect()
313    }
314
315    /// Shutdown all servers gracefully.
316    pub fn shutdown_all(&mut self) {
317        for (key, mut client) in self.clients.drain() {
318            if let Err(err) = client.shutdown() {
319                log::error!("error shutting down {:?}: {}", key, err);
320            }
321        }
322        self.documents.clear();
323        self.diagnostics = DiagnosticsStore::new();
324    }
325
326    /// Check if any server is active.
327    pub fn has_active_servers(&self) -> bool {
328        self.clients
329            .values()
330            .any(|client| client.state() == ServerState::Ready)
331    }
332
333    pub fn get_diagnostics_for_file(&self, file: &Path) -> Vec<&StoredDiagnostic> {
334        let normalized = normalize_lookup_path(file);
335        self.diagnostics.for_file(&normalized)
336    }
337
338    pub fn get_diagnostics_for_directory(&self, dir: &Path) -> Vec<&StoredDiagnostic> {
339        let normalized = normalize_lookup_path(dir);
340        self.diagnostics.for_directory(&normalized)
341    }
342
343    pub fn get_all_diagnostics(&self) -> Vec<&StoredDiagnostic> {
344        self.diagnostics.all()
345    }
346
347    fn drain_events_for_file(&mut self, file_path: &Path) -> bool {
348        let mut saw_file_diagnostics = false;
349        while let Ok(event) = self.event_rx.try_recv() {
350            if matches!(
351                self.handle_event(&event),
352                Some(ref published_file) if published_file.as_path() == file_path
353            ) {
354                saw_file_diagnostics = true;
355            }
356        }
357        saw_file_diagnostics
358    }
359
360    fn handle_event(&mut self, event: &LspEvent) -> Option<PathBuf> {
361        match event {
362            LspEvent::Notification {
363                server_kind,
364                method,
365                params: Some(params),
366                ..
367            } if method == "textDocument/publishDiagnostics" => {
368                self.handle_publish_diagnostics(*server_kind, params)
369            }
370            LspEvent::ServerExited { server_kind, root } => {
371                let key = ServerKey {
372                    kind: *server_kind,
373                    root: root.clone(),
374                };
375                self.clients.remove(&key);
376                self.documents.remove(&key);
377                self.diagnostics.clear_server(*server_kind);
378                None
379            }
380            _ => None,
381        }
382    }
383
384    fn handle_publish_diagnostics(
385        &mut self,
386        server: ServerKind,
387        params: &serde_json::Value,
388    ) -> Option<PathBuf> {
389        if let Ok(publish_params) =
390            serde_json::from_value::<lsp_types::PublishDiagnosticsParams>(params.clone())
391        {
392            let Some(file) = uri_to_path(&publish_params.uri) else {
393                return None;
394            };
395            let stored = from_lsp_diagnostics(file.clone(), publish_params.diagnostics);
396            self.diagnostics.publish(server, file, stored);
397            return Some(uri_to_path(&publish_params.uri)?);
398        }
399        None
400    }
401
402    fn spawn_server(&self, def: &ServerDef, root: &Path) -> Result<LspClient, LspError> {
403        let binary = self.resolve_binary(def)?;
404        let mut client = LspClient::spawn(
405            def.kind,
406            root.to_path_buf(),
407            &binary,
408            def.args,
409            self.event_tx.clone(),
410        )?;
411        client.initialize(root)?;
412        Ok(client)
413    }
414
415    fn resolve_binary(&self, def: &ServerDef) -> Result<PathBuf, LspError> {
416        if let Some(path) = self.binary_overrides.get(&def.kind) {
417            if path.exists() {
418                return Ok(path.clone());
419            }
420            return Err(LspError::NotFound(format!(
421                "override binary for {:?} not found: {}",
422                def.kind,
423                path.display()
424            )));
425        }
426
427        if let Some(path) = env_binary_override(def.kind) {
428            if path.exists() {
429                return Ok(path);
430            }
431            return Err(LspError::NotFound(format!(
432                "environment override binary for {:?} not found: {}",
433                def.kind,
434                path.display()
435            )));
436        }
437
438        which::which(def.binary).map_err(|_| {
439            LspError::NotFound(format!(
440                "language server binary '{}' not found on PATH",
441                def.binary
442            ))
443        })
444    }
445
446    fn server_key_for_file(&self, file_path: &Path) -> Option<ServerKey> {
447        for def in servers_for_file(file_path) {
448            let root = find_workspace_root(file_path, def.root_markers)?;
449            let key = ServerKey {
450                kind: def.kind,
451                root,
452            };
453            if self.clients.contains_key(&key) {
454                return Some(key);
455            }
456        }
457        None
458    }
459}
460
461impl Default for LspManager {
462    fn default() -> Self {
463        Self::new()
464    }
465}
466
467fn canonicalize_for_lsp(file_path: &Path) -> Result<PathBuf, LspError> {
468    std::fs::canonicalize(file_path).map_err(LspError::from)
469}
470
471fn uri_for_path(path: &Path) -> Result<lsp_types::Uri, LspError> {
472    let url = url::Url::from_file_path(path).map_err(|_| {
473        LspError::NotFound(format!(
474            "failed to convert '{}' to file URI",
475            path.display()
476        ))
477    })?;
478    lsp_types::Uri::from_str(url.as_str()).map_err(|_| {
479        LspError::NotFound(format!("failed to parse file URI for '{}'", path.display()))
480    })
481}
482
483fn language_id_for_extension(ext: &str) -> &'static str {
484    match ext {
485        "ts" => "typescript",
486        "tsx" => "typescriptreact",
487        "js" | "mjs" | "cjs" => "javascript",
488        "jsx" => "javascriptreact",
489        "py" | "pyi" => "python",
490        "rs" => "rust",
491        "go" => "go",
492        _ => "plaintext",
493    }
494}
495
496fn normalize_lookup_path(path: &Path) -> PathBuf {
497    std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
498}
499
500fn uri_to_path(uri: &lsp_types::Uri) -> Option<PathBuf> {
501    let url = url::Url::parse(uri.as_str()).ok()?;
502    url.to_file_path()
503        .ok()
504        .map(|path| normalize_lookup_path(&path))
505}
506
507fn env_binary_override(kind: ServerKind) -> Option<PathBuf> {
508    let key = match kind {
509        ServerKind::TypeScript => "AFT_LSP_TYPESCRIPT_BINARY",
510        ServerKind::Python => "AFT_LSP_PYTHON_BINARY",
511        ServerKind::Rust => "AFT_LSP_RUST_BINARY",
512        ServerKind::Go => "AFT_LSP_GO_BINARY",
513    };
514    std::env::var_os(key).map(PathBuf::from)
515}