1use anyhow::{anyhow, Result};
20
21pub trait PyrightLspClientTrait {
23 fn open_file(&mut self, file_path: &str, content: &str) -> Result<()>;
24 fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()>;
25 fn query_type(
26 &mut self,
27 file_path: &str,
28 content: &str,
29 line: u32,
30 column: u32,
31 ) -> Result<Option<String>>;
32 fn shutdown(&mut self) -> Result<()>;
33}
34use serde::{Deserialize, Serialize};
35use serde_json::{json, Value};
36use std::io::{BufRead, BufReader, Read, Write};
37use std::process::{Child, Command, Stdio};
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::{Arc, Mutex};
40use std::time::Duration;
41
42#[derive(Debug, Serialize)]
44struct LspRequest {
45 jsonrpc: &'static str,
46 id: u64,
47 method: String,
48 params: Value,
49}
50
51#[derive(Debug, Serialize)]
53struct LspNotification {
54 jsonrpc: &'static str,
55 method: String,
56 params: Value,
57}
58
59#[derive(Debug, Deserialize)]
61struct LspResponse {
62 #[allow(dead_code)]
63 jsonrpc: String,
64 id: Option<u64>,
65 result: Option<Value>,
66 error: Option<LspError>,
67}
68
69#[derive(Debug, Deserialize)]
71struct LspError {
72 #[allow(dead_code)]
73 code: i32,
74 message: String,
75 #[allow(dead_code)]
76 data: Option<Value>,
77}
78
79#[derive(Debug, Serialize)]
81struct Position {
82 line: u32,
83 character: u32,
84}
85
86#[derive(Debug, Serialize)]
88struct TextDocumentIdentifier {
89 uri: String,
90}
91
92#[derive(Debug, Serialize)]
94#[allow(dead_code)]
95struct TextDocumentItem {
96 uri: String,
97 #[serde(rename = "languageId")]
98 language_id: String,
99 version: i32,
100 text: String,
101}
102
103#[derive(Debug, Serialize)]
105struct HoverParams {
106 #[serde(rename = "textDocument")]
107 text_document: TextDocumentIdentifier,
108 position: Position,
109}
110
111#[derive(Debug, Serialize)]
113struct TypeDefinitionParams {
114 #[serde(rename = "textDocument")]
115 text_document: TextDocumentIdentifier,
116 position: Position,
117}
118
119pub struct PyrightLspClient {
121 process: Arc<Mutex<Child>>,
122 request_id: AtomicU64,
123 reader: Arc<Mutex<BufReader<std::process::ChildStdout>>>,
124 is_shutdown: Arc<Mutex<bool>>,
125}
126
127impl PyrightLspClient {
128 pub fn new(workspace_root: Option<&str>) -> Result<Self> {
130 tracing::debug!("Starting PyrightLspClient::new()");
131 let pyright_cmd = if Command::new("pyright-langserver")
133 .arg("--version")
134 .output()
135 .is_ok()
136 {
137 "pyright-langserver"
138 } else if Command::new("pyright").arg("--version").output().is_ok() {
139 "pyright"
141 } else {
142 return Err(anyhow!(
143 "pyright not found. Please install pyright: pip install pyright"
144 ));
145 };
146
147 tracing::debug!("Starting pyright process with command: {}", pyright_cmd);
149 let mut process = Command::new(pyright_cmd)
150 .args(["--stdio"])
151 .stdin(Stdio::piped())
152 .stdout(Stdio::piped())
153 .stderr(Stdio::null())
154 .spawn()
155 .map_err(|e| anyhow!("Failed to start pyright: {}", e))?;
156
157 let stdout = process.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
158 let reader = BufReader::new(stdout);
159
160 let mut client = Self {
161 process: Arc::new(Mutex::new(process)),
162 request_id: AtomicU64::new(0),
163 reader: Arc::new(Mutex::new(reader)),
164 is_shutdown: Arc::new(Mutex::new(false)),
165 };
166
167 client.initialize(workspace_root)?;
169
170 Ok(client)
171 }
172
173 fn initialize(&mut self, workspace_root: Option<&str>) -> Result<()> {
175 let workspace_root = if let Some(root) = workspace_root {
177 std::path::Path::new(root).to_path_buf()
178 } else {
179 std::env::current_dir()?
180 };
181 let workspace_uri = format!("file://{}", workspace_root.display());
182
183 tracing::debug!(
184 "Initializing pyright with workspace: {}",
185 workspace_root.display()
186 );
187
188 let init_params = json!({
189 "processId": std::process::id(),
190 "clientInfo": {
191 "name": "dissolve",
192 "version": "0.1.0"
193 },
194 "locale": "en",
195 "rootPath": workspace_root.to_str(),
196 "rootUri": workspace_uri,
197 "capabilities": {
198 "textDocument": {
199 "hover": {
200 "contentFormat": ["plaintext", "markdown"]
201 },
202 "typeDefinition": {
203 "dynamicRegistration": false
204 }
205 }
206 },
207 "trace": "off",
208 "workspaceFolders": [{
209 "uri": workspace_uri,
210 "name": "test_workspace"
211 }],
212 "initializationOptions": {
213 "autoSearchPaths": true,
214 "useLibraryCodeForTypes": true,
215 "typeCheckingMode": "basic",
216 "python": {
217 "analysis": {
218 "extraPaths": []
219 }
220 }
221 }
222 });
223
224 let _response =
226 self.send_request_with_timeout("initialize", init_params, Duration::from_secs(10))?;
227
228 self.send_notification("initialized", json!({}))?;
230
231 Ok(())
232 }
233
234 fn send_request(&mut self, method: &str, params: Value) -> Result<Value> {
236 self.send_request_with_timeout(method, params, Duration::from_secs(5))
238 }
239
240 fn send_request_with_timeout(
242 &mut self,
243 method: &str,
244 params: Value,
245 timeout: Duration,
246 ) -> Result<Value> {
247 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
248 let request = LspRequest {
249 jsonrpc: "2.0",
250 id,
251 method: method.to_string(),
252 params,
253 };
254
255 self.send_message(&request)?;
256
257 self.read_response_with_timeout(id, timeout)
259 }
260
261 fn send_notification(&mut self, method: &str, params: Value) -> Result<()> {
263 let notification = LspNotification {
264 jsonrpc: "2.0",
265 method: method.to_string(),
266 params,
267 };
268
269 self.send_message(¬ification)
270 }
271
272 fn send_message<T: Serialize>(&mut self, message: &T) -> Result<()> {
274 let content = serde_json::to_string(message)?;
275 let header = format!("Content-Length: {}\r\n\r\n", content.len());
276
277 let mut process = self.process.lock().unwrap();
278 let stdin = process.stdin.as_mut().ok_or_else(|| anyhow!("No stdin"))?;
279 stdin.write_all(header.as_bytes())?;
280 stdin.write_all(content.as_bytes())?;
281 stdin.flush()?;
282
283 Ok(())
284 }
285
286 #[allow(dead_code)]
288 fn read_response(&self, expected_id: u64) -> Result<Value> {
289 let mut reader = self.reader.lock().unwrap();
290
291 loop {
292 let mut headers = Vec::new();
294 loop {
295 let mut line = String::new();
296 reader.read_line(&mut line)?;
297 if line == "\r\n" || line == "\n" {
298 break;
299 }
300 headers.push(line);
301 }
302
303 let content_length = headers
305 .iter()
306 .find(|h| h.starts_with("Content-Length:"))
307 .and_then(|h| h.split(':').nth(1))
308 .and_then(|v| v.trim().parse::<usize>().ok())
309 .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
310
311 let mut content = vec![0u8; content_length];
313 reader.read_exact(&mut content)?;
314
315 let response: LspResponse = serde_json::from_slice(&content)?;
317
318 if response.id.is_none() {
320 continue;
321 }
322
323 if response.id == Some(expected_id) {
325 if let Some(error) = response.error {
326 return Err(anyhow!("LSP error: {}", error.message));
327 }
328 return response
329 .result
330 .ok_or_else(|| anyhow!("No result in response"));
331 }
332 }
333 }
334
335 fn read_response_with_timeout(&self, expected_id: u64, timeout: Duration) -> Result<Value> {
337 use std::time::Instant;
338 let start = Instant::now();
339
340 let mut reader = self.reader.lock().unwrap();
341
342 while start.elapsed() < timeout {
344 std::thread::sleep(Duration::from_millis(10));
346
347 {
349 let mut process = self.process.lock().unwrap();
350 match process.try_wait() {
351 Ok(Some(_)) => return Err(anyhow!("Pyright process has exited")),
352 Ok(None) => {} Err(e) => return Err(anyhow!("Failed to check process status: {}", e)),
354 }
355 }
356
357 loop {
359 let mut headers = Vec::new();
361 loop {
362 let mut line = String::new();
363 match reader.read_line(&mut line) {
364 Ok(0) => return Err(anyhow!("Connection closed")),
365 Ok(_) => {
366 if line == "\r\n" || line == "\n" {
367 break;
368 }
369 headers.push(line);
370 }
371 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
372 break;
374 }
375 Err(e) => return Err(anyhow!("Failed to read line: {}", e)),
376 }
377 }
378
379 if headers.is_empty() {
380 break; }
382
383 let content_length = headers
385 .iter()
386 .find(|h| h.starts_with("Content-Length:"))
387 .and_then(|h| h.split(':').nth(1))
388 .and_then(|v| v.trim().parse::<usize>().ok())
389 .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
390
391 let mut content = vec![0u8; content_length];
393 reader.read_exact(&mut content)?;
394
395 let response: LspResponse = serde_json::from_slice(&content)?;
397
398 if response.id.is_none() {
400 continue;
401 }
402
403 if response.id == Some(expected_id) {
405 if let Some(error) = response.error {
406 return Err(anyhow!("LSP error: {}", error.message));
407 }
408 return response
409 .result
410 .ok_or_else(|| anyhow!("No result in response"));
411 }
412 }
413 }
414
415 Err(anyhow!(
416 "Timeout waiting for LSP response ({}s)",
417 timeout.as_secs()
418 ))
419 }
420
421 pub fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
423 let abs_path = if std::path::Path::new(file_path).is_relative() {
425 std::env::current_dir()?.join(file_path)
426 } else {
427 std::path::PathBuf::from(file_path)
428 };
429 let uri = format!("file://{}", abs_path.display());
430 let params = json!({
431 "textDocument": {
432 "uri": uri,
433 "languageId": "python",
434 "version": 1,
435 "text": content
436 }
437 });
438
439 self.send_notification("textDocument/didOpen", params)?;
440
441 std::thread::sleep(Duration::from_millis(100));
443
444 Ok(())
445 }
446
447 pub fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
449 tracing::debug!(
450 "Updating file in pyright LSP: {} (version {})",
451 file_path,
452 version
453 );
454
455 let abs_path = if std::path::Path::new(file_path).is_relative() {
457 std::env::current_dir()?.join(file_path)
458 } else {
459 std::path::PathBuf::from(file_path)
460 };
461 let uri = format!("file://{}", abs_path.display());
462 let params = json!({
463 "textDocument": {
464 "uri": uri,
465 "version": version
466 },
467 "contentChanges": [{
468 "text": content
469 }]
470 });
471
472 self.send_notification("textDocument/didChange", params)?;
473
474 std::thread::sleep(Duration::from_millis(100));
476
477 Ok(())
478 }
479
480 pub fn get_hover(&mut self, file_path: &str, line: u32, column: u32) -> Result<Option<String>> {
482 let abs_path = if std::path::Path::new(file_path).is_relative() {
484 std::env::current_dir()?.join(file_path)
485 } else {
486 std::path::PathBuf::from(file_path)
487 };
488 let uri = format!("file://{}", abs_path.display());
489 let params = HoverParams {
490 text_document: TextDocumentIdentifier { uri },
491 position: Position {
492 line: line - 1, character: column,
494 },
495 };
496
497 let response = self.send_request("textDocument/hover", serde_json::to_value(params)?)?;
498
499 if let Some(hover) = response.as_object() {
501 if let Some(contents) = hover.get("contents") {
502 let type_info = match contents {
503 Value::String(s) => s.clone(),
504 Value::Object(obj) => {
505 if let Some(Value::String(s)) = obj.get("value") {
506 s.clone()
507 } else {
508 return Ok(None);
509 }
510 }
511 _ => return Ok(None),
512 };
513
514 tracing::debug!("Pyright hover response: {}", type_info);
520
521 if type_info.starts_with("(module) ") {
523 let module_start = "(module) ".len();
525 let module_end = type_info[module_start..]
526 .find('\n')
527 .map(|pos| module_start + pos)
528 .unwrap_or(type_info.len());
529 let module_name = type_info[module_start..module_end].trim();
530 tracing::debug!("Extracted module type: {}", module_name);
531 return Ok(Some(module_name.to_string()));
532 }
533
534 if type_info.starts_with("(class) ") {
536 let class_start = "(class) ".len();
538 let class_end = type_info[class_start..]
539 .find('\n')
540 .map(|pos| class_start + pos)
541 .unwrap_or(type_info.len());
542 let class_name = type_info[class_start..class_end].trim();
543 tracing::debug!("Extracted class type: {}", class_name);
544 return Ok(Some(class_name.to_string()));
545 }
546
547 if let Some(colon_pos) = type_info.find(':') {
549 let type_part = type_info[colon_pos + 1..].trim();
550 tracing::debug!("Extracted type: {}", type_part);
551
552 if type_part == "Unknown" {
554 tracing::warn!(
555 "Pyright returned 'Unknown' type at {}:{}:{}",
556 file_path,
557 line,
558 column
559 );
560 return Ok(None);
561 }
562
563 return Ok(Some(type_part.to_string()));
564 }
565 }
566 }
567
568 Ok(None)
569 }
570
571 pub fn get_type_definition(
573 &mut self,
574 file_path: &str,
575 line: u32,
576 column: u32,
577 ) -> Result<Option<String>> {
578 let abs_path = if std::path::Path::new(file_path).is_relative() {
580 std::env::current_dir()?.join(file_path)
581 } else {
582 std::path::PathBuf::from(file_path)
583 };
584 let uri = format!("file://{}", abs_path.display());
585 let params = TypeDefinitionParams {
586 text_document: TextDocumentIdentifier { uri: uri.clone() },
587 position: Position {
588 line: line - 1, character: column,
590 },
591 };
592
593 let response =
594 self.send_request("textDocument/typeDefinition", serde_json::to_value(params)?)?;
595
596 if let Some(locations) = response.as_array() {
598 if let Some(first_location) = locations.first() {
599 if let Some(target_uri) = first_location.get("uri").and_then(|u| u.as_str()) {
600 if let Some(target_range) = first_location.get("range") {
602 tracing::debug!(
605 "Type definition location: {} at {:?}",
606 target_uri,
607 target_range
608 );
609
610 if let Some(path) = target_uri.strip_prefix("file://") {
612 if let Some(module_name) = path
613 .strip_suffix(".py")
614 .and_then(|p| p.split('/').next_back())
615 {
616 return Ok(Some(module_name.to_string()));
618 }
619 }
620 }
621 }
622 }
623 }
624
625 Ok(None)
626 }
627
628 pub fn query_type(
630 &mut self,
631 file_path: &str,
632 _content: &str,
633 line: u32,
634 column: u32,
635 ) -> Result<Option<String>> {
636 let hover_result = self.get_hover(file_path, line, column);
640
641 match &hover_result {
643 Ok(Some(type_str)) => {
644 tracing::debug!("Pyright hover returned type: {}", type_str);
646
647 if !type_str.contains('.') {
649 if let Ok(Some(type_def_info)) =
650 self.get_type_definition(file_path, line, column)
651 {
652 tracing::debug!("Type definition info: {}", type_def_info);
653 }
656 }
657
658 return Ok(Some(type_str.clone()));
659 }
660 Ok(None) => {
661 tracing::debug!("Pyright returned no type information");
662 }
663 Err(e) => {
664 tracing::debug!("Pyright error: {}", e);
665 }
666 }
667
668 hover_result
669 }
670
671 pub fn shutdown(&mut self) -> Result<()> {
673 {
674 let mut is_shutdown = self.is_shutdown.lock().unwrap();
675 if *is_shutdown {
676 return Ok(());
677 }
678 *is_shutdown = true;
679 }
680
681 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
683 let request = LspRequest {
684 jsonrpc: "2.0",
685 id,
686 method: "shutdown".to_string(),
687 params: json!({}),
688 };
689 self.send_message(&request)?;
690
691 self.read_shutdown_response(id)?;
693
694 self.send_notification("exit", json!({}))?;
695 Ok(())
696 }
697
698 fn read_shutdown_response(&self, expected_id: u64) -> Result<()> {
700 let mut reader = self.reader.lock().unwrap();
701
702 loop {
703 let mut headers = Vec::new();
705 loop {
706 let mut line = String::new();
707 reader.read_line(&mut line)?;
708 if line == "\r\n" || line == "\n" {
709 break;
710 }
711 headers.push(line);
712 }
713
714 let content_length = headers
716 .iter()
717 .find(|h| h.starts_with("Content-Length:"))
718 .and_then(|h| h.split(':').nth(1))
719 .and_then(|v| v.trim().parse::<usize>().ok())
720 .ok_or_else(|| anyhow!("Missing or invalid Content-Length header"))?;
721
722 let mut content = vec![0u8; content_length];
724 reader.read_exact(&mut content)?;
725
726 let response: LspResponse = serde_json::from_slice(&content)?;
728
729 if response.id.is_none() {
731 continue;
732 }
733
734 if response.id == Some(expected_id) {
736 if let Some(error) = response.error {
737 return Err(anyhow!("LSP error: {}", error.message));
738 }
739 return Ok(());
741 }
742 }
743 }
744}
745
746impl Drop for PyrightLspClient {
747 fn drop(&mut self) {
748 let _ = self.shutdown();
750
751 if let Ok(mut process) = self.process.lock() {
753 let _ = process.kill();
754 let _ = process.wait();
755 }
756 }
757}
758
759impl PyrightLspClientTrait for PyrightLspClient {
760 fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
761 self.open_file(file_path, content)
762 }
763
764 fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
765 self.update_file(file_path, content, version)
766 }
767
768 fn query_type(
769 &mut self,
770 file_path: &str,
771 content: &str,
772 line: u32,
773 column: u32,
774 ) -> Result<Option<String>> {
775 self.query_type(file_path, content, line, column)
776 }
777
778 fn shutdown(&mut self) -> Result<()> {
779 self.shutdown()
780 }
781}
782
783pub fn get_type_with_pyright(
785 file_path: &str,
786 content: &str,
787 line: u32,
788 column: u32,
789) -> Result<Option<String>> {
790 let mut client = PyrightLspClient::new(None)?;
791 client.query_type(file_path, content, line, column)
792}
793
794#[cfg(test)]
795pub mod tests {
796 use super::*;
797 use std::collections::HashMap;
798 use std::fs;
799 use std::sync::{Arc, Mutex, OnceLock};
800
801 static CONCURRENT_CLIENT_POOL: OnceLock<
804 Arc<Mutex<HashMap<String, Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient>>>>,
805 > = OnceLock::new();
806
807 pub fn get_workspace_concurrent_client(
810 workspace_root: Option<&str>,
811 ) -> Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient> {
812 let pool = CONCURRENT_CLIENT_POOL.get_or_init(|| Arc::new(Mutex::new(HashMap::new())));
813
814 let workspace_key = workspace_root
815 .map(|s| s.to_string())
816 .unwrap_or_else(|| "default".to_string());
817
818 let mut clients = pool.lock().unwrap();
819
820 if let Some(client) = clients.get(&workspace_key) {
821 client.clone()
822 } else {
823 let client = crate::concurrent_lsp::SyncConcurrentPyrightClient::new(workspace_root)
824 .expect("Failed to create concurrent pyright client for tests");
825 let arc_client = Arc::new(client);
826 clients.insert(workspace_key, arc_client.clone());
827 arc_client
828 }
829 }
830
831 pub fn clear_client_pool() {
833 if let Some(pool) = CONCURRENT_CLIENT_POOL.get() {
834 if let Ok(mut clients) = pool.lock() {
835 for (workspace, client) in clients.iter() {
837 tracing::debug!("Shutting down pyright client for workspace: {}", workspace);
838 let _ = client.shutdown();
839 }
840 clients.clear();
841 tracing::debug!("Cleared pyright client pool for test isolation");
842 }
843 }
844 }
845
846 pub fn cleanup_all_pyright_processes() {
849 clear_client_pool();
850 tracing::info!("Cleaned up all pyright processes");
851 }
852
853 pub struct ConcurrentPyrightClientWrapper {
857 concurrent: Arc<crate::concurrent_lsp::SyncConcurrentPyrightClient>,
858 }
859
860 impl ConcurrentPyrightClientWrapper {
861 pub fn new() -> Self {
862 Self {
863 concurrent: get_workspace_concurrent_client(None),
864 }
865 }
866
867 pub fn new_with_workspace(workspace_root: Option<&str>) -> Self {
868 Self {
869 concurrent: get_workspace_concurrent_client(workspace_root),
870 }
871 }
872
873 pub fn open_file(&self, _file_path: &str, _content: &str) -> Result<()> {
875 Ok(())
877 }
878
879 pub fn update_file(&self, _file_path: &str, _content: &str, _version: i32) -> Result<()> {
881 Ok(())
883 }
884
885 pub fn query_type(
888 &self,
889 file_path: &str,
890 content: &str,
891 line: u32,
892 column: u32,
893 ) -> Result<Option<String>> {
894 self.concurrent
895 .query_type_concurrent(file_path, content, line, column)
896 }
897
898 pub fn shutdown(&mut self) -> Result<()> {
901 Ok(())
904 }
905 }
906
907 impl super::PyrightLspClientTrait for ConcurrentPyrightClientWrapper {
908 fn open_file(&mut self, file_path: &str, content: &str) -> Result<()> {
909 ConcurrentPyrightClientWrapper::open_file(self, file_path, content)
910 }
911
912 fn update_file(&mut self, file_path: &str, content: &str, version: i32) -> Result<()> {
913 ConcurrentPyrightClientWrapper::update_file(self, file_path, content, version)
914 }
915
916 fn query_type(
917 &mut self,
918 file_path: &str,
919 content: &str,
920 line: u32,
921 column: u32,
922 ) -> Result<Option<String>> {
923 ConcurrentPyrightClientWrapper::query_type(self, file_path, content, line, column)
924 }
925
926 fn shutdown(&mut self) -> Result<()> {
927 ConcurrentPyrightClientWrapper::shutdown(self)
928 }
929 }
930 use tempfile::NamedTempFile;
931
932 #[test]
933 #[ignore] fn test_pyright_type_inference() {
935 let code = r#"
936class Repo:
937 @staticmethod
938 def init(path):
939 return Repo()
940
941def test():
942 repo = Repo.init(".")
943"#;
944
945 let temp_file = NamedTempFile::new().unwrap();
946 fs::write(&temp_file, code).unwrap();
947
948 let result = get_type_with_pyright(
949 temp_file.path().to_str().unwrap(),
950 code,
951 8, 4, );
954
955 match result {
956 Ok(Some(type_str)) => {
957 assert!(
958 type_str.contains("Repo"),
959 "Expected Repo type, got: {}",
960 type_str
961 );
962 }
963 Ok(None) => panic!("No type information returned"),
964 Err(e) => panic!("Error: {}", e),
965 }
966 }
967}