Skip to main content

codelens_engine/lsp/
session.rs

1use crate::project::ProjectRoot;
2use anyhow::{Context, Result, bail};
3use serde_json::{Value, json};
4use std::collections::HashMap;
5use std::io::{BufRead, BufReader};
6use std::path::Path;
7use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
8use std::thread;
9use std::time::{Duration, Instant};
10use url::Url;
11
12use super::parsers::{
13    diagnostics_from_response, method_suffix_to_hierarchy, references_from_response,
14    rename_plan_from_response, type_hierarchy_node_from_item, type_hierarchy_to_map,
15    workspace_symbols_from_response,
16};
17use super::protocol::{language_id_for_path, poll_readable, read_message, send_message};
18use super::types::{
19    LspDiagnostic, LspDiagnosticRequest, LspReference, LspRenamePlan, LspRenamePlanRequest,
20    LspRequest, LspTypeHierarchyNode, LspTypeHierarchyRequest, LspWorkspaceSymbol,
21    LspWorkspaceSymbolRequest,
22};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25struct SessionKey {
26    command: String,
27    args: Vec<String>,
28}
29
30#[derive(Debug, Clone)]
31struct OpenDocumentState {
32    version: i32,
33    text: String,
34}
35
36pub struct LspSessionPool {
37    project: ProjectRoot,
38    sessions: std::sync::Mutex<HashMap<SessionKey, LspSession>>,
39}
40
41struct LspSession {
42    project: ProjectRoot,
43    child: Child,
44    stdin: ChildStdin,
45    reader: BufReader<ChildStdout>,
46    next_request_id: u64,
47    documents: HashMap<String, OpenDocumentState>,
48    #[allow(dead_code)] // retained for future stderr diagnostics
49    stderr_buffer: std::sync::Arc<std::sync::Mutex<String>>,
50}
51
52/// Known-safe LSP server binaries. Commands not in this list are rejected.
53pub(super) fn is_allowed_lsp_command(command: &str) -> bool {
54    // Extract the binary name from the command path (e.g., "/usr/bin/pyright-langserver" → "pyright-langserver")
55    let binary = std::path::Path::new(command)
56        .file_name()
57        .and_then(|n| n.to_str())
58        .unwrap_or(command);
59
60    ALLOWED_COMMANDS.contains(&binary)
61}
62
63pub(super) const ALLOWED_COMMANDS: &[&str] = &[
64    // From LSP_RECIPES
65    "pyright-langserver",
66    "typescript-language-server",
67    "rust-analyzer",
68    "gopls",
69    "jdtls",
70    "kotlin-language-server",
71    "clangd",
72    "solargraph",
73    "intelephense",
74    "sourcekit-lsp",
75    "csharp-ls",
76    "dart",
77    // Additional well-known LSP servers
78    "metals",
79    "lua-language-server",
80    "terraform-ls",
81    "yaml-language-server",
82    // Test support: allow python3/python for mock LSP in tests
83    "python3",
84    "python",
85];
86
87fn ensure_session<'a>(
88    sessions: &'a mut HashMap<SessionKey, LspSession>,
89    project: &ProjectRoot,
90    command: &str,
91    args: &[String],
92) -> Result<&'a mut LspSession> {
93    if !is_allowed_lsp_command(command) {
94        bail!(
95            "Blocked: '{command}' is not a known LSP server. Only whitelisted LSP binaries are allowed."
96        );
97    }
98
99    let key = SessionKey {
100        command: command.to_owned(),
101        args: args.to_owned(),
102    };
103
104    // Check for dead sessions: if the child process has exited, remove the stale entry.
105    if let Some(session) = sessions.get_mut(&key) {
106        match session.child.try_wait() {
107            Ok(Some(_status)) => {
108                // Process exited — remove stale session so we start fresh below.
109                sessions.remove(&key);
110            }
111            Ok(None) => {} // Still running — will return it via Occupied below.
112            Err(_) => {
113                sessions.remove(&key);
114            }
115        }
116    }
117
118    match sessions.entry(key) {
119        std::collections::hash_map::Entry::Occupied(e) => Ok(e.into_mut()),
120        std::collections::hash_map::Entry::Vacant(e) => {
121            let session = LspSession::start(project, command, args)?;
122            Ok(e.insert(session))
123        }
124    }
125}
126
127impl LspSessionPool {
128    pub fn new(project: ProjectRoot) -> Self {
129        Self {
130            project,
131            sessions: std::sync::Mutex::new(HashMap::new()),
132        }
133    }
134
135    /// Replace the project root and close all existing sessions.
136    pub fn reset(&self, project: ProjectRoot) -> Self {
137        // Drop existing sessions so LSP processes are killed.
138        self.sessions
139            .lock()
140            .unwrap_or_else(|p| p.into_inner())
141            .clear();
142        Self::new(project)
143    }
144
145    pub fn session_count(&self) -> usize {
146        self.sessions
147            .lock()
148            .unwrap_or_else(|p| p.into_inner())
149            .len()
150    }
151
152    pub fn find_referencing_symbols(&self, request: &LspRequest) -> Result<Vec<LspReference>> {
153        let mut sessions = self.sessions.lock().unwrap_or_else(|p| p.into_inner());
154        let session = ensure_session(
155            &mut sessions,
156            &self.project,
157            &request.command,
158            &request.args,
159        )?;
160        session.find_references(request)
161    }
162
163    pub fn get_diagnostics(&self, request: &LspDiagnosticRequest) -> Result<Vec<LspDiagnostic>> {
164        let mut sessions = self.sessions.lock().unwrap_or_else(|p| p.into_inner());
165        let session = ensure_session(
166            &mut sessions,
167            &self.project,
168            &request.command,
169            &request.args,
170        )?;
171        session.get_diagnostics(request)
172    }
173
174    pub fn search_workspace_symbols(
175        &self,
176        request: &LspWorkspaceSymbolRequest,
177    ) -> Result<Vec<LspWorkspaceSymbol>> {
178        let mut sessions = self.sessions.lock().unwrap_or_else(|p| p.into_inner());
179        let session = ensure_session(
180            &mut sessions,
181            &self.project,
182            &request.command,
183            &request.args,
184        )?;
185        session.search_workspace_symbols(request)
186    }
187
188    pub fn get_type_hierarchy(
189        &self,
190        request: &LspTypeHierarchyRequest,
191    ) -> Result<HashMap<String, Value>> {
192        let mut sessions = self.sessions.lock().unwrap_or_else(|p| p.into_inner());
193        let session = ensure_session(
194            &mut sessions,
195            &self.project,
196            &request.command,
197            &request.args,
198        )?;
199        session.get_type_hierarchy(request)
200    }
201
202    pub fn get_rename_plan(&self, request: &LspRenamePlanRequest) -> Result<LspRenamePlan> {
203        let mut sessions = self.sessions.lock().unwrap_or_else(|p| p.into_inner());
204        let session = ensure_session(
205            &mut sessions,
206            &self.project,
207            &request.command,
208            &request.args,
209        )?;
210        session.get_rename_plan(request)
211    }
212}
213
214impl LspSession {
215    fn start(project: &ProjectRoot, command: &str, args: &[String]) -> Result<Self> {
216        let mut child = Command::new(command)
217            .args(args)
218            .stdin(Stdio::piped())
219            .stdout(Stdio::piped())
220            .stderr(Stdio::piped())
221            .spawn()
222            .with_context(|| format!("failed to spawn LSP server {}", command))?;
223
224        let stdin = child.stdin.take().context("failed to open LSP stdin")?;
225        let stdout = child.stdout.take().context("failed to open LSP stdout")?;
226
227        // Capture stderr in a background thread (bounded 4KB ring buffer).
228        let stderr_buffer = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
229        if let Some(stderr) = child.stderr.take() {
230            let buf = std::sync::Arc::clone(&stderr_buffer);
231            thread::spawn(move || {
232                let mut reader = BufReader::new(stderr);
233                let mut line = String::new();
234                while reader.read_line(&mut line).unwrap_or(0) > 0 {
235                    if let Ok(mut b) = buf.lock() {
236                        if b.len() > 4096 {
237                            let drain_to = b.len() - 2048;
238                            b.drain(..drain_to);
239                        }
240                        b.push_str(&line);
241                    }
242                    line.clear();
243                }
244            });
245        }
246
247        let mut session = Self {
248            project: project.clone(),
249            child,
250            stdin,
251            reader: BufReader::new(stdout),
252            next_request_id: 1,
253            documents: HashMap::new(),
254            stderr_buffer,
255        };
256        session.initialize()?;
257        Ok(session)
258    }
259
260    fn initialize(&mut self) -> Result<()> {
261        let id = self.next_id();
262        let root_uri = Url::from_directory_path(self.project.as_path())
263            .ok()
264            .map(|url| url.to_string());
265        self.send_request(
266            id,
267            "initialize",
268            json!({
269                "processId":null,
270                "rootUri": root_uri,
271                "capabilities":{},
272                "workspaceFolders":[
273                    {
274                        "uri": Url::from_directory_path(self.project.as_path()).ok().map(|url| url.to_string()),
275                        "name": self.project.as_path().file_name().and_then(|n| n.to_str()).unwrap_or("workspace")
276                    }
277                ]
278            }),
279        )?;
280        let _ = self.read_response_for_id(id)?;
281        self.send_notification("initialized", json!({}))?;
282        Ok(())
283    }
284
285    fn find_references(&mut self, request: &LspRequest) -> Result<Vec<LspReference>> {
286        let absolute_path = self.project.resolve(&request.file_path)?;
287        let (uri_string, _source) = self.prepare_document(&absolute_path)?;
288
289        let id = self.next_id();
290        self.send_request(
291            id,
292            "textDocument/references",
293            json!({
294                "textDocument":{"uri":uri_string},
295                "position":{"line":request.line.saturating_sub(1),"character":request.column.saturating_sub(1)},
296                "context":{"includeDeclaration":true}
297            }),
298        )?;
299        let response = self.read_response_for_id(id)?;
300        references_from_response(&self.project, response, request.max_results)
301    }
302
303    fn get_diagnostics(&mut self, request: &LspDiagnosticRequest) -> Result<Vec<LspDiagnostic>> {
304        let absolute_path = self.project.resolve(&request.file_path)?;
305        let (uri_string, _source) = self.prepare_document(&absolute_path)?;
306
307        let id = self.next_id();
308        self.send_request(
309            id,
310            "textDocument/diagnostic",
311            json!({
312                "textDocument":{"uri":uri_string}
313            }),
314        )?;
315        let response = self.read_response_for_id(id)?;
316        diagnostics_from_response(&self.project, response, request.max_results)
317    }
318
319    fn search_workspace_symbols(
320        &mut self,
321        request: &LspWorkspaceSymbolRequest,
322    ) -> Result<Vec<LspWorkspaceSymbol>> {
323        let id = self.next_id();
324        self.send_request(
325            id,
326            "workspace/symbol",
327            json!({
328                "query": request.query
329            }),
330        )?;
331        let response = self.read_response_for_id(id)?;
332        workspace_symbols_from_response(&self.project, response, request.max_results)
333    }
334
335    fn get_type_hierarchy(
336        &mut self,
337        request: &LspTypeHierarchyRequest,
338    ) -> Result<HashMap<String, Value>> {
339        let workspace_symbols = self.search_workspace_symbols(&LspWorkspaceSymbolRequest {
340            command: request.command.clone(),
341            args: request.args.clone(),
342            query: request.query.clone(),
343            max_results: 20,
344        })?;
345        let seed = workspace_symbols
346            .into_iter()
347            .find(|symbol| match &request.relative_path {
348                Some(path) => &symbol.file_path == path,
349                None => true,
350            })
351            .with_context(|| format!("No workspace symbol found for '{}'", request.query))?;
352
353        let absolute_path = self.project.resolve(&seed.file_path)?;
354        let (uri_string, _source) = self.prepare_document(&absolute_path)?;
355
356        let id = self.next_id();
357        self.send_request(
358            id,
359            "textDocument/prepareTypeHierarchy",
360            json!({
361                "textDocument":{"uri":uri_string},
362                "position":{"line":seed.line.saturating_sub(1),"character":seed.column.saturating_sub(1)}
363            }),
364        )?;
365        let response = self.read_response_for_id(id)?;
366        let items = response
367            .get("result")
368            .and_then(Value::as_array)
369            .cloned()
370            .unwrap_or_default();
371        let root_item = items
372            .into_iter()
373            .next()
374            .context("LSP prepareTypeHierarchy returned no items")?;
375
376        let root = self.build_type_hierarchy_node(
377            &root_item,
378            request.depth,
379            request.hierarchy_type.as_str(),
380        )?;
381        Ok(type_hierarchy_to_map(&root))
382    }
383
384    fn get_rename_plan(&mut self, request: &LspRenamePlanRequest) -> Result<LspRenamePlan> {
385        let absolute_path = self.project.resolve(&request.file_path)?;
386        let (uri_string, source) = self.prepare_document(&absolute_path)?;
387
388        let id = self.next_id();
389        self.send_request(
390            id,
391            "textDocument/prepareRename",
392            json!({
393                "textDocument":{"uri":uri_string},
394                "position":{"line":request.line.saturating_sub(1),"character":request.column.saturating_sub(1)}
395            }),
396        )?;
397        let response = self.read_response_for_id(id)?;
398        rename_plan_from_response(
399            &self.project,
400            &request.file_path,
401            &source,
402            response,
403            request.new_name.clone(),
404        )
405    }
406
407    fn build_type_hierarchy_node(
408        &mut self,
409        item: &Value,
410        depth: usize,
411        hierarchy_type: &str,
412    ) -> Result<LspTypeHierarchyNode> {
413        let mut node = type_hierarchy_node_from_item(item)?;
414
415        if depth == 0 {
416            return Ok(node);
417        }
418
419        let next_depth = depth.saturating_sub(1);
420        if hierarchy_type == "super" || hierarchy_type == "both" {
421            node.supertypes = self.fetch_type_hierarchy_branch(item, "supertypes", next_depth)?;
422        }
423        if hierarchy_type == "sub" || hierarchy_type == "both" {
424            node.subtypes = self.fetch_type_hierarchy_branch(item, "subtypes", next_depth)?;
425        }
426        Ok(node)
427    }
428
429    fn fetch_type_hierarchy_branch(
430        &mut self,
431        item: &Value,
432        method_suffix: &str,
433        depth: usize,
434    ) -> Result<Vec<LspTypeHierarchyNode>> {
435        let id = self.next_id();
436        self.send_request(
437            id,
438            &format!("typeHierarchy/{method_suffix}"),
439            json!({
440                "item": item
441            }),
442        )?;
443        let response = self.read_response_for_id(id)?;
444        let Some(items) = response.get("result").and_then(Value::as_array) else {
445            return Ok(Vec::new());
446        };
447
448        let mut nodes = Vec::new();
449        for child in items {
450            nodes.push(self.build_type_hierarchy_node(
451                child,
452                depth,
453                method_suffix_to_hierarchy(method_suffix),
454            )?);
455        }
456        Ok(nodes)
457    }
458
459    fn prepare_document(&mut self, absolute_path: &Path) -> Result<(String, String)> {
460        let uri = Url::from_file_path(absolute_path).map_err(|_| {
461            anyhow::anyhow!("failed to build file uri for {}", absolute_path.display())
462        })?;
463        let uri_string = uri.to_string();
464        let source = std::fs::read_to_string(absolute_path)
465            .with_context(|| format!("failed to read {}", absolute_path.display()))?;
466        let language_id = language_id_for_path(absolute_path)?;
467        self.sync_document(&uri_string, language_id, &source)?;
468        Ok((uri_string, source))
469    }
470
471    fn sync_document(&mut self, uri: &str, language_id: &str, source: &str) -> Result<()> {
472        if let Some(state) = self.documents.get(uri)
473            && state.text == source
474        {
475            return Ok(());
476        }
477
478        if let Some(state) = self.documents.get_mut(uri) {
479            state.version += 1;
480            state.text = source.to_owned();
481            let version = state.version;
482            return self.send_notification(
483                "textDocument/didChange",
484                json!({
485                    "textDocument":{"uri":uri,"version":version},
486                    "contentChanges":[{"text":source}]
487                }),
488            );
489        }
490
491        self.documents.insert(
492            uri.to_owned(),
493            OpenDocumentState {
494                version: 1,
495                text: source.to_owned(),
496            },
497        );
498        self.send_notification(
499            "textDocument/didOpen",
500            json!({
501                "textDocument":{
502                    "uri":uri,
503                    "languageId":language_id,
504                    "version":1,
505                    "text":source
506                }
507            }),
508        )
509    }
510
511    fn next_id(&mut self) -> u64 {
512        let id = self.next_request_id;
513        self.next_request_id += 1;
514        id
515    }
516
517    fn send_request(&mut self, id: u64, method: &str, params: Value) -> Result<()> {
518        send_message(
519            &mut self.stdin,
520            &json!({
521                "jsonrpc":"2.0",
522                "id":id,
523                "method":method,
524                "params":params
525            }),
526        )
527    }
528
529    fn send_notification(&mut self, method: &str, params: Value) -> Result<()> {
530        send_message(
531            &mut self.stdin,
532            &json!({
533                "jsonrpc":"2.0",
534                "method":method,
535                "params":params
536            }),
537        )
538    }
539
540    fn read_response_for_id(&mut self, expected_id: u64) -> Result<Value> {
541        let deadline = Instant::now() + Duration::from_secs(30);
542        let mut discarded = 0u32;
543        const MAX_DISCARDED: u32 = 500;
544
545        loop {
546            let remaining = deadline.saturating_duration_since(Instant::now());
547            if remaining.is_zero() {
548                bail!(
549                    "LSP response timeout: no response for request id {expected_id} within 30s \
550                     ({discarded} unrelated messages discarded)"
551                );
552            }
553            if discarded >= MAX_DISCARDED {
554                bail!(
555                    "LSP response loop: discarded {MAX_DISCARDED} messages without finding id {expected_id}"
556                );
557            }
558
559            // Poll the pipe before blocking read — prevents infinite hang
560            if !poll_readable(self.reader.get_ref(), remaining.min(Duration::from_secs(5))) {
561                continue; // no data yet, re-check deadline
562            }
563
564            let message = read_message(&mut self.reader)?;
565            let matches_id = message
566                .get("id")
567                .and_then(Value::as_u64)
568                .map(|id| id == expected_id)
569                .unwrap_or(false);
570            if matches_id {
571                if let Some(error) = message.get("error") {
572                    let code = error.get("code").and_then(Value::as_i64).unwrap_or(-1);
573                    let error_message = error
574                        .get("message")
575                        .and_then(Value::as_str)
576                        .unwrap_or("unknown LSP error");
577                    bail!("LSP request failed ({code}): {error_message}");
578                }
579                return Ok(message);
580            }
581            discarded += 1;
582        }
583    }
584
585    fn shutdown(&mut self) -> Result<()> {
586        let id = self.next_id();
587        self.send_request(id, "shutdown", Value::Null)?;
588        let _ = self.read_response_for_id(id)?;
589        self.send_notification("exit", Value::Null)
590    }
591}
592
593impl Drop for LspSession {
594    fn drop(&mut self) {
595        let _ = self.shutdown();
596        let deadline = Instant::now() + Duration::from_millis(250);
597        while Instant::now() < deadline {
598            match self.child.try_wait() {
599                Ok(Some(_status)) => return,
600                Ok(None) => thread::sleep(Duration::from_millis(10)),
601                Err(_) => break,
602            }
603        }
604        let _ = self.child.kill();
605        let _ = self.child.wait();
606    }
607}