Skip to main content

dissolve_python/
concurrent_lsp.rs

1//! Concurrent LSP client implementation using message passing
2//!
3//! This module provides high-performance, concurrent access to LSP servers
4//! like pyright-langserver using sync threads and channels - no async needed!
5
6use anyhow::{anyhow, Result};
7use serde_json::{json, Value};
8use std::collections::HashMap;
9use std::io::{BufRead, BufReader, Read, Write};
10use std::process::{Command, Stdio};
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{mpsc, Arc, Mutex};
13use std::thread;
14use tracing::{info, warn};
15use url::Url;
16
17/// LSP request with response channel
18struct LspRequest {
19    id: u64,
20    method: String,
21    params: Value,
22    response_tx: mpsc::Sender<Result<Value>>,
23}
24
25/// LSP notification (no response expected)  
26struct LspNotification {
27    method: String,
28    params: Value,
29}
30
31enum LspMessage {
32    Request(LspRequest),
33    Notification(LspNotification),
34    Shutdown,
35}
36
37/// Sync concurrent Pyright LSP client - uses message passing for true parallelism
38/// Multiple threads can make requests simultaneously without blocking each other!
39pub struct SyncConcurrentPyrightClient {
40    message_tx: mpsc::Sender<LspMessage>,
41    request_id: Arc<AtomicU64>,
42    workspace_files: Arc<Mutex<HashMap<String, i32>>>, // file_path -> version
43    _workspace_root: Option<String>,
44    _process_handle: Arc<thread::JoinHandle<()>>,
45}
46
47impl SyncConcurrentPyrightClient {
48    /// Create and initialize a new sync concurrent pyright client
49    pub fn new(workspace_root: Option<&str>) -> Result<Self> {
50        info!("Starting sync concurrent pyright-langserver...");
51
52        // Start pyright-langserver process
53        let mut process = Command::new("pyright-langserver")
54            .arg("--stdio")
55            .stdin(Stdio::piped())
56            .stdout(Stdio::piped())
57            .stderr(Stdio::piped())
58            .spawn()
59            .map_err(|e| anyhow!("Failed to start pyright-langserver: {}", e))?;
60
61        let mut stdin = process
62            .stdin
63            .take()
64            .ok_or_else(|| anyhow!("Failed to get stdin"))?;
65        let stdout = process
66            .stdout
67            .take()
68            .ok_or_else(|| anyhow!("Failed to get stdout"))?;
69
70        let (message_tx, message_rx) = mpsc::channel();
71        let request_id = Arc::new(AtomicU64::new(1));
72
73        // Shared pending requests map for coordinating responses
74        let pending_requests_arc = Arc::new(Mutex::new(
75            HashMap::<u64, mpsc::Sender<Result<Value>>>::new(),
76        ));
77        let pending_clone = pending_requests_arc.clone();
78
79        // Thread to read server responses and forward to pending requests
80        let _response_reader = thread::spawn(move || {
81            let mut stdout_reader = BufReader::new(stdout);
82            while let Ok(response) = Self::read_lsp_message(&mut stdout_reader) {
83                if let Some(id) = response.get("id").and_then(|v| v.as_u64()) {
84                    let mut pending = pending_clone.lock().unwrap();
85                    if let Some(tx) = pending.remove(&id) {
86                        let result = if let Some(error) = response.get("error") {
87                            Err(anyhow!("LSP error: {}", error))
88                        } else {
89                            Ok(response.get("result").cloned().unwrap_or(Value::Null))
90                        };
91                        let _ = tx.send(result);
92                    }
93                }
94            }
95        });
96
97        // Main message processing thread
98        let process_handle = thread::spawn(move || {
99            // Process client messages and send to LSP server
100            while let Ok(msg) = message_rx.recv() {
101                match msg {
102                    LspMessage::Request(req) => {
103                        {
104                            let mut pending = pending_requests_arc.lock().unwrap();
105                            pending.insert(req.id, req.response_tx);
106                        }
107
108                        let lsp_request = json!({
109                            "jsonrpc": "2.0",
110                            "id": req.id,
111                            "method": req.method,
112                            "params": req.params
113                        });
114
115                        let request_str = format!(
116                            "Content-Length: {}\r\n\r\n{}",
117                            lsp_request.to_string().len(),
118                            lsp_request
119                        );
120
121                        if let Err(e) = stdin.write_all(request_str.as_bytes()) {
122                            warn!("Failed to write LSP request: {}", e);
123                            let mut pending = pending_requests_arc.lock().unwrap();
124                            if let Some(tx) = pending.remove(&req.id) {
125                                let _ = tx.send(Err(anyhow!("Failed to send request: {}", e)));
126                            }
127                        } else if let Err(e) = stdin.flush() {
128                            warn!("Failed to flush LSP request: {}", e);
129                            let mut pending = pending_requests_arc.lock().unwrap();
130                            if let Some(tx) = pending.remove(&req.id) {
131                                let _ = tx.send(Err(anyhow!("Failed to flush request: {}", e)));
132                            }
133                        }
134                    }
135                    LspMessage::Notification(notif) => {
136                        let lsp_notif = json!({
137                            "jsonrpc": "2.0",
138                            "method": notif.method,
139                            "params": notif.params
140                        });
141
142                        let notif_str = format!(
143                            "Content-Length: {}\r\n\r\n{}",
144                            lsp_notif.to_string().len(),
145                            lsp_notif
146                        );
147
148                        if let Err(e) = stdin.write_all(notif_str.as_bytes()) {
149                            warn!("Failed to write LSP notification: {}", e);
150                        } else {
151                            let _ = stdin.flush();
152                        }
153                    }
154                    LspMessage::Shutdown => break,
155                }
156            }
157        });
158
159        // Initialize LSP server
160        let client = Self {
161            message_tx,
162            request_id,
163            workspace_files: Arc::new(Mutex::new(HashMap::new())),
164            _workspace_root: workspace_root.map(|s| s.to_string()),
165            _process_handle: Arc::new(process_handle),
166        };
167
168        client.initialize(workspace_root)?;
169        Ok(client)
170    }
171
172    fn read_lsp_message(reader: &mut BufReader<std::process::ChildStdout>) -> Result<Value> {
173        // Read headers
174        let mut headers = Vec::new();
175        loop {
176            let mut line = String::new();
177            reader.read_line(&mut line)?;
178            let line = line.trim();
179            if line.is_empty() {
180                break; // End of headers
181            }
182            headers.push(line.to_string());
183        }
184
185        // Parse Content-Length header
186        let content_length = headers
187            .iter()
188            .find(|h| h.starts_with("Content-Length:"))
189            .and_then(|h| h.split(':').nth(1))
190            .and_then(|v| v.trim().parse::<usize>().ok())
191            .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
192
193        // Read content
194        let mut content = vec![0u8; content_length];
195        reader.read_exact(&mut content)?;
196
197        // Parse JSON
198        let content_str = String::from_utf8(content)?;
199        serde_json::from_str(&content_str).map_err(|e| anyhow!("Failed to parse JSON: {}", e))
200    }
201
202    fn initialize(&self, workspace_root: Option<&str>) -> Result<()> {
203        let root_uri = workspace_root
204            .and_then(|root| Url::from_file_path(root).ok())
205            .map(|url| url.to_string());
206
207        let params = json!({
208            "processId": std::process::id(),
209            "rootUri": root_uri,
210            "capabilities": {
211                "textDocument": {
212                    "hover": {
213                        "contentFormat": ["plaintext", "markdown"]
214                    }
215                }
216            }
217        });
218
219        let (response_tx, response_rx) = mpsc::channel();
220        let req_id = self.request_id.fetch_add(1, Ordering::SeqCst);
221
222        let request = LspRequest {
223            id: req_id,
224            method: "initialize".to_string(),
225            params,
226            response_tx,
227        };
228
229        self.message_tx
230            .send(LspMessage::Request(request))
231            .map_err(|e| anyhow!("Failed to send initialize request: {}", e))?;
232
233        // Wait for response with timeout
234        match response_rx.recv_timeout(std::time::Duration::from_secs(30)) {
235            Ok(Ok(_)) => {
236                // Send initialized notification
237                let notif = LspNotification {
238                    method: "initialized".to_string(),
239                    params: json!({}),
240                };
241
242                self.message_tx
243                    .send(LspMessage::Notification(notif))
244                    .map_err(|e| anyhow!("Failed to send initialized notification: {}", e))?;
245
246                info!("Sync concurrent Pyright LSP client initialized successfully");
247                Ok(())
248            }
249            Ok(Err(e)) => Err(anyhow!("Initialize failed: {}", e)),
250            Err(_) => Err(anyhow!("Initialize timed out")),
251        }
252    }
253
254    /// Query type - truly concurrent across multiple threads!
255    /// This is the key method that enables parallel test execution
256    pub fn query_type_concurrent(
257        &self,
258        file_path: &str,
259        content: &str,
260        line: u32,
261        column: u32,
262    ) -> Result<Option<String>> {
263        // Convert to absolute path if relative
264        let abs_path = if std::path::Path::new(file_path).is_relative() {
265            std::env::current_dir()?.join(file_path)
266        } else {
267            std::path::PathBuf::from(file_path)
268        };
269
270        // Open document first if needed
271        self.open_document(&abs_path.to_string_lossy(), content)?;
272
273        let uri = format!("file://{}", abs_path.display());
274        let params = json!({
275            "textDocument": {
276                "uri": uri
277            },
278            "position": {
279                "line": line - 1, // Convert to 0-based line numbering for LSP
280                "character": column
281            }
282        });
283
284        let (response_tx, response_rx) = mpsc::channel();
285        let req_id = self.request_id.fetch_add(1, Ordering::SeqCst);
286
287        let request = LspRequest {
288            id: req_id,
289            method: "textDocument/hover".to_string(),
290            params,
291            response_tx,
292        };
293
294        self.message_tx
295            .send(LspMessage::Request(request))
296            .map_err(|e| anyhow!("Failed to send hover request: {}", e))?;
297
298        // Wait for response with timeout - return proper errors, not None!
299        match response_rx.recv_timeout(std::time::Duration::from_secs(5)) {
300            Ok(Ok(response)) => {
301                // Parse hover response - match pyright's actual response format
302                if let Some(hover) = response.as_object() {
303                    if let Some(contents) = hover.get("contents") {
304                        let type_info = match contents {
305                            Value::String(s) => s.clone(),
306                            Value::Object(obj) => {
307                                if let Some(Value::String(s)) = obj.get("value") {
308                                    s.clone()
309                                } else {
310                                    return Ok(None);
311                                }
312                            }
313                            _ => return Ok(None),
314                        };
315
316                        // Parse pyright's hover format like the original client
317                        tracing::debug!("Pyright hover response: {}", type_info);
318
319                        // Check for module format first
320                        if type_info.starts_with("(module) ") {
321                            let module_start = "(module) ".len();
322                            let module_end = type_info[module_start..]
323                                .find('\n')
324                                .map(|pos| module_start + pos)
325                                .unwrap_or(type_info.len());
326                            let module_name = type_info[module_start..module_end].trim();
327                            return Ok(Some(module_name.to_string()));
328                        }
329
330                        // Check for class format
331                        if type_info.starts_with("(class) ") {
332                            let class_start = "(class) ".len();
333                            let class_end = type_info[class_start..]
334                                .find('\n')
335                                .map(|pos| class_start + pos)
336                                .unwrap_or(type_info.len());
337                            let class_name = type_info[class_start..class_end].trim();
338                            return Ok(Some(class_name.to_string()));
339                        }
340
341                        // Look for colon format for variables
342                        if let Some(colon_pos) = type_info.find(':') {
343                            let type_part = type_info[colon_pos + 1..].trim();
344
345                            // Check if pyright returned "Unknown" - treat as no type info
346                            if type_part == "Unknown" {
347                                tracing::warn!(
348                                    "Pyright returned 'Unknown' type at {}:{}:{}",
349                                    file_path,
350                                    line,
351                                    column
352                                );
353                                return Ok(None);
354                            }
355
356                            return Ok(Some(type_part.to_string()));
357                        }
358                    }
359                }
360                Ok(None)
361            }
362            Ok(Err(e)) => Err(anyhow!("Hover request failed: {}", e)),
363            Err(_) => Err(anyhow!(
364                "Hover request timed out after 5 seconds for {}:{}",
365                line,
366                column
367            )),
368        }
369    }
370
371    fn open_document(&self, file_path: &str, content: &str) -> Result<()> {
372        let mut files = self.workspace_files.lock().unwrap();
373        if !files.contains_key(file_path) {
374            files.insert(file_path.to_string(), 1);
375            drop(files);
376
377            let params = json!({
378                "textDocument": {
379                    "uri": format!("file://{}", file_path),
380                    "languageId": "python",
381                    "version": 1,
382                    "text": content
383                }
384            });
385
386            let notif = LspNotification {
387                method: "textDocument/didOpen".to_string(),
388                params,
389            };
390
391            self.message_tx
392                .send(LspMessage::Notification(notif))
393                .map_err(|e| anyhow!("Failed to send didOpen notification: {}", e))?;
394
395            // Wait for pyright to be ready by polling for analysis completion
396            self.wait_for_analysis_ready(file_path)?;
397        }
398        Ok(())
399    }
400
401    /// Wait for pyright to complete analysis by polling for diagnostics or other readiness indicators
402    fn wait_for_analysis_ready(&self, file_path: &str) -> Result<()> {
403        const MAX_WAIT_MS: u64 = 2000; // Maximum 2 seconds
404        const POLL_INTERVAL_MS: u64 = 50; // Check every 50ms
405
406        let start = std::time::Instant::now();
407
408        // Strategy: Try a simple hover request on line 1 and see if we get any response
409        // Once pyright starts responding to hover requests, it's likely ready
410        while start.elapsed().as_millis() < MAX_WAIT_MS as u128 {
411            // Send a simple hover request to line 1, column 1 to test readiness
412            let params = json!({
413                "textDocument": {
414                    "uri": format!("file://{}", file_path)
415                },
416                "position": {
417                    "line": 0,
418                    "character": 0
419                }
420            });
421
422            let (response_tx, response_rx) = mpsc::channel();
423            let req_id = self.request_id.fetch_add(1, Ordering::SeqCst);
424
425            let request = LspRequest {
426                id: req_id,
427                method: "textDocument/hover".to_string(),
428                params,
429                response_tx,
430            };
431
432            if self.message_tx.send(LspMessage::Request(request)).is_ok() {
433                // Wait for response with a short timeout
434                match response_rx.recv_timeout(std::time::Duration::from_millis(200)) {
435                    Ok(Ok(_)) => {
436                        // Got a response (even if it's null) - pyright is ready
437                        tracing::debug!("Pyright analysis ready for {}", file_path);
438                        return Ok(());
439                    }
440                    Ok(Err(_)) => {
441                        // Got an error response - pyright is responding, so it's ready
442                        tracing::debug!(
443                            "Pyright analysis ready for {} (error response)",
444                            file_path
445                        );
446                        return Ok(());
447                    }
448                    Err(_) => {
449                        // Timeout - pyright might still be analyzing, continue polling
450                    }
451                }
452            }
453
454            std::thread::sleep(std::time::Duration::from_millis(POLL_INTERVAL_MS));
455        }
456
457        // If we've waited too long, log a warning but continue
458        tracing::warn!(
459            "Timeout waiting for pyright analysis of {}, continuing anyway",
460            file_path
461        );
462        Ok(())
463    }
464
465    /// Shutdown the client gracefully  
466    pub fn shutdown(&self) -> Result<()> {
467        tracing::debug!("Shutting down sync concurrent pyright client");
468
469        // Send shutdown message using the message channel
470        if let Err(e) = self.message_tx.send(LspMessage::Shutdown) {
471            tracing::warn!("Failed to send shutdown message: {}", e);
472        }
473
474        // Give the process a moment to shutdown gracefully
475        std::thread::sleep(std::time::Duration::from_millis(100));
476
477        Ok(())
478    }
479}