Skip to main content

dissolve_python/
pyright_lsp.rs

1// Copyright (C) 2024 Jelmer Vernooij <jelmer@samba.org>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Pyright LSP integration for type inference
16//!
17//! This module provides type querying capabilities using pyright language server.
18
19use anyhow::{anyhow, Result};
20
21/// Trait for pyright-like LSP clients that can be used in TypeIntrospectionContext
22pub trait PyrightLspClientTrait {
23    fn open_file(&mut self, file_path: &str, content: &str) -> Result<()>;
24    fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()>;
25    fn query_type(
26        &mut self,
27        file_path: &str,
28        content: &str,
29        line: u32,
30        column: u32,
31    ) -> Result<Option<String>>;
32    fn shutdown(&mut self) -> Result<()>;
33}
34use serde::{Deserialize, Serialize};
35use serde_json::{json, Value};
36use std::io::{BufRead, BufReader, Read, Write};
37use std::process::{Child, Command, Stdio};
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::{Arc, Mutex};
40use std::time::Duration;
41
42/// LSP request message
43#[derive(Debug, Serialize)]
44struct LspRequest {
45    jsonrpc: &'static str,
46    id: u64,
47    method: String,
48    params: Value,
49}
50
51/// LSP notification message
52#[derive(Debug, Serialize)]
53struct LspNotification {
54    jsonrpc: &'static str,
55    method: String,
56    params: Value,
57}
58
59/// LSP response message
60#[derive(Debug, Deserialize)]
61struct LspResponse {
62    #[allow(dead_code)]
63    jsonrpc: String,
64    id: Option<u64>,
65    result: Option<Value>,
66    error: Option<LspError>,
67}
68
69/// LSP error
70#[derive(Debug, Deserialize)]
71struct LspError {
72    #[allow(dead_code)]
73    code: i32,
74    message: String,
75    #[allow(dead_code)]
76    data: Option<Value>,
77}
78
79/// Position in a text document
80#[derive(Debug, Serialize)]
81struct Position {
82    line: u32,
83    character: u32,
84}
85
86/// Text document identifier
87#[derive(Debug, Serialize)]
88struct TextDocumentIdentifier {
89    uri: String,
90}
91
92/// Text document item
93#[derive(Debug, Serialize)]
94#[allow(dead_code)]
95struct TextDocumentItem {
96    uri: String,
97    #[serde(rename = "languageId")]
98    language_id: String,
99    version: i32,
100    text: String,
101}
102
103/// Hover params
104#[derive(Debug, Serialize)]
105struct HoverParams {
106    #[serde(rename = "textDocument")]
107    text_document: TextDocumentIdentifier,
108    position: Position,
109}
110
111/// Type definition params (same structure as hover params)
112#[derive(Debug, Serialize)]
113struct TypeDefinitionParams {
114    #[serde(rename = "textDocument")]
115    text_document: TextDocumentIdentifier,
116    position: Position,
117}
118
119/// Pyright LSP client
120pub struct PyrightLspClient {
121    process: Arc<Mutex<Child>>,
122    request_id: AtomicU64,
123    reader: Arc<Mutex<BufReader<std::process::ChildStdout>>>,
124    is_shutdown: Arc<Mutex<bool>>,
125}
126
127impl PyrightLspClient {
128    /// Create and start a new pyright LSP client
129    pub fn new(workspace_root: Option<&str>) -> Result<Self> {
130        tracing::debug!("Starting PyrightLspClient::new()");
131        // Try to find pyright executable
132        let pyright_cmd = if Command::new("pyright-langserver")
133            .arg("--version")
134            .output()
135            .is_ok()
136        {
137            "pyright-langserver"
138        } else if Command::new("pyright").arg("--version").output().is_ok() {
139            // Some installations use 'pyright' directly
140            "pyright"
141        } else {
142            return Err(anyhow!(
143                "pyright not found. Please install pyright: pip install pyright"
144            ));
145        };
146
147        // Start pyright in LSP mode
148        tracing::debug!("Starting pyright process with command: {}", pyright_cmd);
149        let mut process = Command::new(pyright_cmd)
150            .args(["--stdio"])
151            .stdin(Stdio::piped())
152            .stdout(Stdio::piped())
153            .stderr(Stdio::null())
154            .spawn()
155            .map_err(|e| anyhow!("Failed to start pyright: {}", e))?;
156
157        let stdout = process.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
158        let reader = BufReader::new(stdout);
159
160        let mut client = Self {
161            process: Arc::new(Mutex::new(process)),
162            request_id: AtomicU64::new(0),
163            reader: Arc::new(Mutex::new(reader)),
164            is_shutdown: Arc::new(Mutex::new(false)),
165        };
166
167        // Initialize the LSP connection
168        client.initialize(workspace_root)?;
169
170        Ok(client)
171    }
172
173    /// Initialize the LSP connection
174    fn initialize(&mut self, workspace_root: Option<&str>) -> Result<()> {
175        // Use provided workspace root or fall back to current directory
176        let workspace_root = if let Some(root) = workspace_root {
177            std::path::Path::new(root).to_path_buf()
178        } else {
179            std::env::current_dir()?
180        };
181        let workspace_uri = format!("file://{}", workspace_root.display());
182
183        tracing::debug!(
184            "Initializing pyright with workspace: {}",
185            workspace_root.display()
186        );
187
188        let init_params = json!({
189            "processId": std::process::id(),
190            "clientInfo": {
191                "name": "dissolve",
192                "version": "0.1.0"
193            },
194            "locale": "en",
195            "rootPath": workspace_root.to_str(),
196            "rootUri": workspace_uri,
197            "capabilities": {
198                "textDocument": {
199                    "hover": {
200                        "contentFormat": ["plaintext", "markdown"]
201                    },
202                    "typeDefinition": {
203                        "dynamicRegistration": false
204                    }
205                }
206            },
207            "trace": "off",
208            "workspaceFolders": [{
209                "uri": workspace_uri,
210                "name": "test_workspace"
211            }],
212            "initializationOptions": {
213                "autoSearchPaths": true,
214                "useLibraryCodeForTypes": true,
215                "typeCheckingMode": "basic",
216                "python": {
217                    "analysis": {
218                        "extraPaths": []
219                    }
220                }
221            }
222        });
223
224        // Use timeout for initialization
225        let _response =
226            self.send_request_with_timeout("initialize", init_params, Duration::from_secs(10))?;
227
228        // Send initialized notification
229        self.send_notification("initialized", json!({}))?;
230
231        Ok(())
232    }
233
234    /// Send a request to the language server
235    fn send_request(&mut self, method: &str, params: Value) -> Result<Value> {
236        // Use timeout for all requests, not just initialization
237        self.send_request_with_timeout(method, params, Duration::from_secs(5))
238    }
239
240    /// Send a request to the language server with timeout
241    fn send_request_with_timeout(
242        &mut self,
243        method: &str,
244        params: Value,
245        timeout: Duration,
246    ) -> Result<Value> {
247        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
248        let request = LspRequest {
249            jsonrpc: "2.0",
250            id,
251            method: method.to_string(),
252            params,
253        };
254
255        self.send_message(&request)?;
256
257        // Read response with timeout
258        self.read_response_with_timeout(id, timeout)
259    }
260
261    /// Send a notification to the language server
262    fn send_notification(&mut self, method: &str, params: Value) -> Result<()> {
263        let notification = LspNotification {
264            jsonrpc: "2.0",
265            method: method.to_string(),
266            params,
267        };
268
269        self.send_message(&notification)
270    }
271
272    /// Send a message to the language server
273    fn send_message<T: Serialize>(&mut self, message: &T) -> Result<()> {
274        let content = serde_json::to_string(message)?;
275        let header = format!("Content-Length: {}\r\n\r\n", content.len());
276
277        let mut process = self.process.lock().unwrap();
278        let stdin = process.stdin.as_mut().ok_or_else(|| anyhow!("No stdin"))?;
279        stdin.write_all(header.as_bytes())?;
280        stdin.write_all(content.as_bytes())?;
281        stdin.flush()?;
282
283        Ok(())
284    }
285
286    /// Read a response from the language server
287    #[allow(dead_code)]
288    fn read_response(&self, expected_id: u64) -> Result<Value> {
289        let mut reader = self.reader.lock().unwrap();
290
291        loop {
292            // Read headers
293            let mut headers = Vec::new();
294            loop {
295                let mut line = String::new();
296                reader.read_line(&mut line)?;
297                if line == "\r\n" || line == "\n" {
298                    break;
299                }
300                headers.push(line);
301            }
302
303            // Parse Content-Length header
304            let content_length = headers
305                .iter()
306                .find(|h| h.starts_with("Content-Length:"))
307                .and_then(|h| h.split(':').nth(1))
308                .and_then(|v| v.trim().parse::<usize>().ok())
309                .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
310
311            // Read content
312            let mut content = vec![0u8; content_length];
313            reader.read_exact(&mut content)?;
314
315            // Parse JSON
316            let response: LspResponse = serde_json::from_slice(&content)?;
317
318            // Skip notifications
319            if response.id.is_none() {
320                continue;
321            }
322
323            // Check if this is our response
324            if response.id == Some(expected_id) {
325                if let Some(error) = response.error {
326                    return Err(anyhow!("LSP error: {}", error.message));
327                }
328                return response
329                    .result
330                    .ok_or_else(|| anyhow!("No result in response"));
331            }
332        }
333    }
334
335    /// Read a response from the language server with timeout
336    fn read_response_with_timeout(&self, expected_id: u64, timeout: Duration) -> Result<Value> {
337        use std::time::Instant;
338        let start = Instant::now();
339
340        let mut reader = self.reader.lock().unwrap();
341
342        // Poll for response with timeout
343        while start.elapsed() < timeout {
344            // Try to read with a small timeout to avoid blocking indefinitely
345            std::thread::sleep(Duration::from_millis(10));
346
347            // Check if the process is still alive
348            {
349                let mut process = self.process.lock().unwrap();
350                match process.try_wait() {
351                    Ok(Some(_)) => return Err(anyhow!("Pyright process has exited")),
352                    Ok(None) => {} // Still running
353                    Err(e) => return Err(anyhow!("Failed to check process status: {}", e)),
354                }
355            }
356
357            // Try to read response
358            loop {
359                // Read headers
360                let mut headers = Vec::new();
361                loop {
362                    let mut line = String::new();
363                    match reader.read_line(&mut line) {
364                        Ok(0) => return Err(anyhow!("Connection closed")),
365                        Ok(_) => {
366                            if line == "\r\n" || line == "\n" {
367                                break;
368                            }
369                            headers.push(line);
370                        }
371                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
372                            // No data available yet, continue outer loop
373                            break;
374                        }
375                        Err(e) => return Err(anyhow!("Failed to read line: {}", e)),
376                    }
377                }
378
379                if headers.is_empty() {
380                    break; // No data available, continue with timeout loop
381                }
382
383                // Parse Content-Length header
384                let content_length = headers
385                    .iter()
386                    .find(|h| h.starts_with("Content-Length:"))
387                    .and_then(|h| h.split(':').nth(1))
388                    .and_then(|v| v.trim().parse::<usize>().ok())
389                    .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
390
391                // Read content
392                let mut content = vec![0u8; content_length];
393                reader.read_exact(&mut content)?;
394
395                // Parse JSON
396                let response: LspResponse = serde_json::from_slice(&content)?;
397
398                // Skip notifications
399                if response.id.is_none() {
400                    continue;
401                }
402
403                // Check if this is our response
404                if response.id == Some(expected_id) {
405                    if let Some(error) = response.error {
406                        return Err(anyhow!("LSP error: {}", error.message));
407                    }
408                    return response
409                        .result
410                        .ok_or_else(|| anyhow!("No result in response"));
411                }
412            }
413        }
414
415        Err(anyhow!(
416            "Timeout waiting for LSP response ({}s)",
417            timeout.as_secs()
418        ))
419    }
420
421    /// Open a file in the language server
422    pub fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
423        // Convert to absolute path if relative
424        let abs_path = if std::path::Path::new(file_path).is_relative() {
425            std::env::current_dir()?.join(file_path)
426        } else {
427            std::path::PathBuf::from(file_path)
428        };
429        let uri = format!("file://{}", abs_path.display());
430        let params = json!({
431            "textDocument": {
432                "uri": uri,
433                "languageId": "python",
434                "version": 1,
435                "text": content
436            }
437        });
438
439        self.send_notification("textDocument/didOpen", params)?;
440
441        // Give pyright time to analyze the file
442        std::thread::sleep(Duration::from_millis(100));
443
444        Ok(())
445    }
446
447    /// Update file content in the language server
448    pub fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
449        tracing::debug!(
450            "Updating file in pyright LSP: {} (version {})",
451            file_path,
452            version
453        );
454
455        // Convert to absolute path if relative
456        let abs_path = if std::path::Path::new(file_path).is_relative() {
457            std::env::current_dir()?.join(file_path)
458        } else {
459            std::path::PathBuf::from(file_path)
460        };
461        let uri = format!("file://{}", abs_path.display());
462        let params = json!({
463            "textDocument": {
464                "uri": uri,
465                "version": version
466            },
467            "contentChanges": [{
468                "text": content
469            }]
470        });
471
472        self.send_notification("textDocument/didChange", params)?;
473
474        // Give pyright time to analyze the changes
475        std::thread::sleep(Duration::from_millis(100));
476
477        Ok(())
478    }
479
480    /// Get hover information (type) at a specific position
481    pub fn get_hover(&mut self, file_path: &str, line: u32, column: u32) -> Result<Option<String>> {
482        // Convert to absolute path if relative
483        let abs_path = if std::path::Path::new(file_path).is_relative() {
484            std::env::current_dir()?.join(file_path)
485        } else {
486            std::path::PathBuf::from(file_path)
487        };
488        let uri = format!("file://{}", abs_path.display());
489        let params = HoverParams {
490            text_document: TextDocumentIdentifier { uri },
491            position: Position {
492                line: line - 1, // Convert to 0-based
493                character: column,
494            },
495        };
496
497        let response = self.send_request("textDocument/hover", serde_json::to_value(params)?)?;
498
499        // Extract type information from hover response
500        if let Some(hover) = response.as_object() {
501            if let Some(contents) = hover.get("contents") {
502                let type_info = match contents {
503                    Value::String(s) => s.clone(),
504                    Value::Object(obj) => {
505                        if let Some(Value::String(s)) = obj.get("value") {
506                            s.clone()
507                        } else {
508                            return Ok(None);
509                        }
510                    }
511                    _ => return Ok(None),
512                };
513
514                // Parse pyright's hover format
515                // Examples:
516                //   "(variable) repo: Repo"
517                //   "(module) porcelain\n..."
518                // eprintln!("DEBUG HOVER: Raw hover response: '{}'", type_info);
519                tracing::debug!("Pyright hover response: {}", type_info);
520
521                // Check for module format first
522                if type_info.starts_with("(module) ") {
523                    // Extract module name - it's between "(module) " and the first newline or end of string
524                    let module_start = "(module) ".len();
525                    let module_end = type_info[module_start..]
526                        .find('\n')
527                        .map(|pos| module_start + pos)
528                        .unwrap_or(type_info.len());
529                    let module_name = type_info[module_start..module_end].trim();
530                    tracing::debug!("Extracted module type: {}", module_name);
531                    return Ok(Some(module_name.to_string()));
532                }
533
534                // Check for class format
535                if type_info.starts_with("(class) ") {
536                    // Extract class name - it's between "(class) " and the first newline or end of string
537                    let class_start = "(class) ".len();
538                    let class_end = type_info[class_start..]
539                        .find('\n')
540                        .map(|pos| class_start + pos)
541                        .unwrap_or(type_info.len());
542                    let class_name = type_info[class_start..class_end].trim();
543                    tracing::debug!("Extracted class type: {}", class_name);
544                    return Ok(Some(class_name.to_string()));
545                }
546
547                // Otherwise look for colon format for variables
548                if let Some(colon_pos) = type_info.find(':') {
549                    let type_part = type_info[colon_pos + 1..].trim();
550                    tracing::debug!("Extracted type: {}", type_part);
551
552                    // Check if pyright returned "Unknown" - treat as no type info
553                    if type_part == "Unknown" {
554                        tracing::warn!(
555                            "Pyright returned 'Unknown' type at {}:{}:{}",
556                            file_path,
557                            line,
558                            column
559                        );
560                        return Ok(None);
561                    }
562
563                    return Ok(Some(type_part.to_string()));
564                }
565            }
566        }
567
568        Ok(None)
569    }
570
571    /// Get type definition location
572    pub fn get_type_definition(
573        &mut self,
574        file_path: &str,
575        line: u32,
576        column: u32,
577    ) -> Result<Option<String>> {
578        // Convert to absolute path if relative
579        let abs_path = if std::path::Path::new(file_path).is_relative() {
580            std::env::current_dir()?.join(file_path)
581        } else {
582            std::path::PathBuf::from(file_path)
583        };
584        let uri = format!("file://{}", abs_path.display());
585        let params = TypeDefinitionParams {
586            text_document: TextDocumentIdentifier { uri: uri.clone() },
587            position: Position {
588                line: line - 1, // Convert to 0-based
589                character: column,
590            },
591        };
592
593        let response =
594            self.send_request("textDocument/typeDefinition", serde_json::to_value(params)?)?;
595
596        // Parse the response to get the location
597        if let Some(locations) = response.as_array() {
598            if let Some(first_location) = locations.first() {
599                if let Some(target_uri) = first_location.get("uri").and_then(|u| u.as_str()) {
600                    // The URI contains the file path which might have the module information
601                    if let Some(target_range) = first_location.get("range") {
602                        // We have the location of the type definition
603                        // Now we need to read that location to get the type name
604                        tracing::debug!(
605                            "Type definition location: {} at {:?}",
606                            target_uri,
607                            target_range
608                        );
609
610                        // For now, just extract the filename which might give us module info
611                        if let Some(path) = target_uri.strip_prefix("file://") {
612                            if let Some(module_name) = path
613                                .strip_suffix(".py")
614                                .and_then(|p| p.split('/').next_back())
615                            {
616                                // This is a simple heuristic - the file name is often the module name
617                                return Ok(Some(module_name.to_string()));
618                            }
619                        }
620                    }
621                }
622            }
623        }
624
625        Ok(None)
626    }
627
628    /// Query type at a specific location
629    pub fn query_type(
630        &mut self,
631        file_path: &str,
632        _content: &str,
633        line: u32,
634        column: u32,
635    ) -> Result<Option<String>> {
636        // Note: we assume the file is already open to avoid redundant open calls
637
638        // First try hover for immediate type info
639        let hover_result = self.get_hover(file_path, line, column);
640
641        // Debug output
642        match &hover_result {
643            Ok(Some(type_str)) => {
644                // eprintln!("DEBUG PYRIGHT: Querying {}:{} returned type: {}", line, column, type_str);
645                tracing::debug!("Pyright hover returned type: {}", type_str);
646
647                // If we get a simple type name, try to get more info from type definition
648                if !type_str.contains('.') {
649                    if let Ok(Some(type_def_info)) =
650                        self.get_type_definition(file_path, line, column)
651                    {
652                        tracing::debug!("Type definition info: {}", type_def_info);
653                        // For now, still return the hover result
654                        // In the future we could combine this info
655                    }
656                }
657
658                return Ok(Some(type_str.clone()));
659            }
660            Ok(None) => {
661                tracing::debug!("Pyright returned no type information");
662            }
663            Err(e) => {
664                tracing::debug!("Pyright error: {}", e);
665            }
666        }
667
668        hover_result
669    }
670
671    /// Shutdown the language server
672    pub fn shutdown(&mut self) -> Result<()> {
673        {
674            let mut is_shutdown = self.is_shutdown.lock().unwrap();
675            if *is_shutdown {
676                return Ok(());
677            }
678            *is_shutdown = true;
679        }
680
681        // For shutdown, we expect a null result, so we need special handling
682        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
683        let request = LspRequest {
684            jsonrpc: "2.0",
685            id,
686            method: "shutdown".to_string(),
687            params: json!({}),
688        };
689        self.send_message(&request)?;
690
691        // Read shutdown response - expect null result
692        self.read_shutdown_response(id)?;
693
694        self.send_notification("exit", json!({}))?;
695        Ok(())
696    }
697
698    /// Read shutdown response that expects null result
699    fn read_shutdown_response(&self, expected_id: u64) -> Result<()> {
700        let mut reader = self.reader.lock().unwrap();
701
702        loop {
703            // Read headers
704            let mut headers = Vec::new();
705            loop {
706                let mut line = String::new();
707                reader.read_line(&mut line)?;
708                if line == "\r\n" || line == "\n" {
709                    break;
710                }
711                headers.push(line);
712            }
713
714            // Parse Content-Length header
715            let content_length = headers
716                .iter()
717                .find(|h| h.starts_with("Content-Length:"))
718                .and_then(|h| h.split(':').nth(1))
719                .and_then(|v| v.trim().parse::<usize>().ok())
720                .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
721
722            // Read content
723            let mut content = vec![0u8; content_length];
724            reader.read_exact(&mut content)?;
725
726            // Parse JSON
727            let response: LspResponse = serde_json::from_slice(&content)?;
728
729            // Skip notifications
730            if response.id.is_none() {
731                continue;
732            }
733
734            // Check if this is our response
735            if response.id == Some(expected_id) {
736                if let Some(error) = response.error {
737                    return Err(anyhow!("LSP error: {}", error.message));
738                }
739                // For shutdown, result is null - this is expected and valid
740                return Ok(());
741            }
742        }
743    }
744}
745
746impl Drop for PyrightLspClient {
747    fn drop(&mut self) {
748        // Try to shutdown gracefully
749        let _ = self.shutdown();
750
751        // Kill the process if it's still running
752        if let Ok(mut process) = self.process.lock() {
753            let _ = process.kill();
754            let _ = process.wait();
755        }
756    }
757}
758
759impl PyrightLspClientTrait for PyrightLspClient {
760    fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
761        self.open_file(file_path, content)
762    }
763
764    fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
765        self.update_file(file_path, content, version)
766    }
767
768    fn query_type(
769        &mut self,
770        file_path: &str,
771        content: &str,
772        line: u32,
773        column: u32,
774    ) -> Result<Option<String>> {
775        self.query_type(file_path, content, line, column)
776    }
777
778    fn shutdown(&mut self) -> Result<()> {
779        self.shutdown()
780    }
781}
782
783/// Get type for a variable at a specific location using pyright
784pub fn get_type_with_pyright(
785    file_path: &str,
786    content: &str,
787    line: u32,
788    column: u32,
789) -> Result<Option<String>> {
790    let mut client = PyrightLspClient::new(None)?;
791    client.query_type(file_path, content, line, column)
792}
793
794#[cfg(test)]
795pub mod tests {
796    use super::*;
797    use std::collections::HashMap;
798    use std::fs;
799    use std::sync::{Arc, Mutex, OnceLock};
800
801    /// Pool of concurrent pyright clients for tests - one per workspace to avoid cross-test pollution
802    /// Uses message-passing with request IDs to distribute responses to correct threads
803    static CONCURRENT_CLIENT_POOL: OnceLock<
804        Arc<Mutex<HashMap<String, Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient>>>>,
805    > = OnceLock::new();
806
807    /// Get or create a concurrent pyright client for a specific workspace
808    /// This enables TRUE PARALLELISM while maintaining test isolation by using separate clients per workspace
809    pub fn get_workspace_concurrent_client(
810        workspace_root: Option<&str>,
811    ) -> Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient> {
812        let pool = CONCURRENT_CLIENT_POOL.get_or_init(|| Arc::new(Mutex::new(HashMap::new())));
813
814        let workspace_key = workspace_root
815            .map(|s| s.to_string())
816            .unwrap_or_else(|| "default".to_string());
817
818        let mut clients = pool.lock().unwrap();
819
820        if let Some(client) = clients.get(&workspace_key) {
821            client.clone()
822        } else {
823            let client = crate::concurrent_lsp::SyncConcurrentPyrightClient::new(workspace_root)
824                .expect("Failed to create concurrent pyright client for tests");
825            let arc_client = Arc::new(client);
826            clients.insert(workspace_key, arc_client.clone());
827            arc_client
828        }
829    }
830
831    /// Clear the client pool to force creation of fresh clients (for test isolation)
832    pub fn clear_client_pool() {
833        if let Some(pool) = CONCURRENT_CLIENT_POOL.get() {
834            if let Ok(mut clients) = pool.lock() {
835                // Shutdown existing clients gracefully
836                for (workspace, client) in clients.iter() {
837                    tracing::debug!("Shutting down pyright client for workspace: {}", workspace);
838                    let _ = client.shutdown();
839                }
840                clients.clear();
841                tracing::debug!("Cleared pyright client pool for test isolation");
842            }
843        }
844    }
845
846    /// Cleanup function to be called at the end of test runs
847    /// This ensures all pyright processes are terminated
848    pub fn cleanup_all_pyright_processes() {
849        clear_client_pool();
850        tracing::info!("Cleaned up all pyright processes");
851    }
852
853    /// A wrapper around the concurrent PyrightLspClient that implements the same interface
854    /// but uses the new message-passing concurrent client instead of mutex-based locking.
855    /// This removes the serialization bottleneck and enables true parallelism in tests!
856    pub struct ConcurrentPyrightClientWrapper {
857        concurrent: Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient>,
858    }
859
860    impl ConcurrentPyrightClientWrapper {
861        pub fn new() -> Self {
862            Self {
863                concurrent: get_workspace_concurrent_client(None),
864            }
865        }
866
867        pub fn new_with_workspace(workspace_root: Option<&str>) -> Self {
868            Self {
869                concurrent: get_workspace_concurrent_client(workspace_root),
870            }
871        }
872
873        /// Open a file - handled automatically by concurrent client on first query
874        pub fn open_file(&self, _file_path: &str, _content: &str) -> Result<()> {
875            // The concurrent client handles file opening automatically when queries are made
876            Ok(())
877        }
878
879        /// Update a file - not needed for concurrent client
880        pub fn update_file(&self, _file_path: &str, _content: &str, _version: i32) -> Result<()> {
881            // The concurrent client doesn't need explicit file updates for our use case
882            Ok(())
883        }
884
885        /// Query type using the concurrent client - THIS IS THE PERFORMANCE WIN!
886        /// Multiple threads can call this simultaneously without blocking each other
887        pub fn query_type(
888            &self,
889            file_path: &str,
890            content: &str,
891            line: u32,
892            column: u32,
893        ) -> Result<Option<String>> {
894            self.concurrent
895                .query_type_concurrent(file_path, content, line, column)
896        }
897
898        /// Shutdown the client - for the concurrent wrapper we don't shutdown the shared client
899        /// but we could trigger cleanup if needed
900        pub fn shutdown(&mut self) -> Result<()> {
901            // Note: We don't shut down the shared client here because it's shared across tests
902            // The shared clients are cleaned up via clear_client_pool() when needed
903            Ok(())
904        }
905    }
906
907    impl super::PyrightLspClientTrait for ConcurrentPyrightClientWrapper {
908        fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
909            ConcurrentPyrightClientWrapper::open_file(self, file_path, content)
910        }
911
912        fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
913            ConcurrentPyrightClientWrapper::update_file(self, file_path, content, version)
914        }
915
916        fn query_type(
917            &mut self,
918            file_path: &str,
919            content: &str,
920            line: u32,
921            column: u32,
922        ) -> Result<Option<String>> {
923            ConcurrentPyrightClientWrapper::query_type(self, file_path, content, line, column)
924        }
925
926        fn shutdown(&mut self) -> Result<()> {
927            ConcurrentPyrightClientWrapper::shutdown(self)
928        }
929    }
930    use tempfile::NamedTempFile;
931
932    #[test]
933    #[ignore] // Ignore by default as it requires pyright to be installed
934    fn test_pyright_type_inference() {
935        let code = r#"
936class Repo:
937    @staticmethod
938    def init(path):
939        return Repo()
940
941def test():
942    repo = Repo.init(".")
943"#;
944
945        let temp_file = NamedTempFile::new().unwrap();
946        fs::write(&temp_file, code).unwrap();
947
948        let result = get_type_with_pyright(
949            temp_file.path().to_str().unwrap(),
950            code,
951            8, // Line with 'repo' variable
952            4, // Column of 'repo'
953        );
954
955        match result {
956            Ok(Some(type_str)) => {
957                assert!(
958                    type_str.contains("Repo"),
959                    "Expected Repo type, got: {}",
960                    type_str
961                );
962            }
963            Ok(None) => panic!("No type information returned"),
964            Err(e) => panic!("Error: {}", e),
965        }
966    }
967}