ricecoder_external_lsp/client/
connection.rs

1//! LSP client connection management
2
3use super::protocol::{JsonRpcHandler, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, RequestId};
4use crate::error::{ExternalLspError, Result};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{broadcast, RwLock};
10
11/// A pending request awaiting a response
12pub struct PendingRequest {
13    /// Request ID
14    pub id: RequestId,
15    /// Request method name
16    pub method: String,
17    /// Time when request was sent
18    pub sent_at: Instant,
19    /// Request timeout
20    pub timeout: Duration,
21    /// Response channel sender
22    pub response_tx: tokio::sync::oneshot::Sender<Result<Value>>,
23}
24
25/// Notification handler callback
26pub type NotificationHandler = Box<dyn Fn(&str, Option<Value>) + Send + Sync>;
27
28/// Manages connection to an external LSP server
29pub struct LspConnection {
30    /// JSON-RPC protocol handler
31    handler: JsonRpcHandler,
32    /// Pending requests awaiting responses
33    pending_requests: Arc<RwLock<HashMap<RequestId, PendingRequest>>>,
34    /// Notification broadcast channel
35    notification_tx: broadcast::Sender<(String, Option<Value>)>,
36}
37
38impl LspConnection {
39    /// Create a new LSP connection
40    pub fn new() -> Self {
41        let (notification_tx, _) = broadcast::channel(100);
42        Self {
43            handler: JsonRpcHandler::new(),
44            pending_requests: Arc::new(RwLock::new(HashMap::new())),
45            notification_tx,
46        }
47    }
48
49    /// Get the JSON-RPC handler
50    pub fn handler(&self) -> &JsonRpcHandler {
51        &self.handler
52    }
53
54    /// Create a new request and track it
55    pub async fn create_tracked_request(
56        &self,
57        method: impl Into<String>,
58        params: Option<Value>,
59        timeout: Duration,
60    ) -> Result<(JsonRpcRequest, tokio::sync::oneshot::Receiver<Result<Value>>)> {
61        let request = self.handler.create_request(method.into(), params);
62        let request_id = request.id.ok_or_else(|| {
63            ExternalLspError::ProtocolError("Request ID not set".to_string())
64        })?;
65
66        let (tx, rx) = tokio::sync::oneshot::channel();
67
68        let pending = PendingRequest {
69            id: request_id,
70            method: request.method.clone(),
71            sent_at: Instant::now(),
72            timeout,
73            response_tx: tx,
74        };
75
76        self.pending_requests.write().await.insert(request_id, pending);
77
78        Ok((request, rx))
79    }
80
81    /// Handle a response and correlate it to a pending request
82    pub async fn handle_response(&self, response: JsonRpcResponse) -> Result<()> {
83        let mut pending = self.pending_requests.write().await;
84
85        if let Some(pending_req) = pending.remove(&response.id) {
86            // Check if request timed out
87            if pending_req.sent_at.elapsed() > pending_req.timeout {
88                return Err(ExternalLspError::Timeout {
89                    timeout_ms: pending_req.timeout.as_millis() as u64,
90                });
91            }
92
93            // Send response to waiting task
94            let result = if let Some(error) = response.error {
95                Err(ExternalLspError::ProtocolError(format!(
96                    "{}: {}",
97                    error.code, error.message
98                )))
99            } else {
100                Ok(response.result.unwrap_or(Value::Null))
101            };
102
103            // Ignore send error if receiver was dropped
104            let _ = pending_req.response_tx.send(result);
105
106            Ok(())
107        } else {
108            Err(ExternalLspError::ProtocolError(format!(
109                "Received response for unknown request ID: {}",
110                response.id
111            )))
112        }
113    }
114
115    /// Get pending request count
116    pub async fn pending_request_count(&self) -> usize {
117        self.pending_requests.read().await.len()
118    }
119
120    /// Check for timed out requests and clean them up
121    pub async fn cleanup_timed_out_requests(&self) -> Vec<RequestId> {
122        let mut pending = self.pending_requests.write().await;
123        let mut timed_out = Vec::new();
124        let mut to_remove = Vec::new();
125
126        for (id, req) in pending.iter() {
127            if req.sent_at.elapsed() > req.timeout {
128                timed_out.push(*id);
129                to_remove.push(*id);
130            }
131        }
132
133        for id in to_remove {
134            if let Some(pending_req) = pending.remove(&id) {
135                // Send timeout error to waiting task
136                let _ = pending_req.response_tx.send(Err(ExternalLspError::Timeout {
137                    timeout_ms: pending_req.timeout.as_millis() as u64,
138                }));
139            }
140        }
141
142        timed_out
143    }
144
145    /// Clear all pending requests
146    pub async fn clear_pending_requests(&self) {
147        self.pending_requests.write().await.clear();
148    }
149
150    /// Get list of pending request IDs
151    pub async fn get_pending_request_ids(&self) -> Vec<RequestId> {
152        self.pending_requests.read().await.keys().copied().collect()
153    }
154
155    /// Handle a notification from the server
156    pub async fn handle_notification(&self, notification: JsonRpcNotification) -> Result<()> {
157        // Broadcast notification to all subscribers
158        let _ = self.notification_tx.send((notification.method, notification.params));
159        Ok(())
160    }
161
162    /// Subscribe to notifications
163    pub fn subscribe_notifications(&self) -> broadcast::Receiver<(String, Option<Value>)> {
164        self.notification_tx.subscribe()
165    }
166
167    /// Handle textDocument/publishDiagnostics notification
168    pub async fn handle_publish_diagnostics(
169        &self,
170        params: Option<Value>,
171    ) -> Result<()> {
172        self.handle_notification(JsonRpcNotification {
173            jsonrpc: "2.0".to_string(),
174            method: "textDocument/publishDiagnostics".to_string(),
175            params,
176        })
177        .await
178    }
179
180    /// Handle window/logMessage notification
181    pub async fn handle_log_message(&self, params: Option<Value>) -> Result<()> {
182        self.handle_notification(JsonRpcNotification {
183            jsonrpc: "2.0".to_string(),
184            method: "window/logMessage".to_string(),
185            params,
186        })
187        .await
188    }
189
190    /// Handle window/showMessage notification
191    pub async fn handle_show_message(&self, params: Option<Value>) -> Result<()> {
192        self.handle_notification(JsonRpcNotification {
193            jsonrpc: "2.0".to_string(),
194            method: "window/showMessage".to_string(),
195            params,
196        })
197        .await
198    }
199
200    /// Send textDocument/didOpen notification
201    pub async fn send_did_open(
202        &self,
203        uri: String,
204        language_id: String,
205        version: i32,
206        text: String,
207    ) -> Result<()> {
208        let params = serde_json::json!({
209            "textDocument": {
210                "uri": uri,
211                "languageId": language_id,
212                "version": version,
213                "text": text
214            }
215        });
216
217        self.handle_notification(JsonRpcNotification {
218            jsonrpc: "2.0".to_string(),
219            method: "textDocument/didOpen".to_string(),
220            params: Some(params),
221        })
222        .await
223    }
224
225    /// Send textDocument/didChange notification
226    pub async fn send_did_change(
227        &self,
228        uri: String,
229        version: i32,
230        content_changes: Vec<Value>,
231    ) -> Result<()> {
232        let params = serde_json::json!({
233            "textDocument": {
234                "uri": uri,
235                "version": version
236            },
237            "contentChanges": content_changes
238        });
239
240        self.handle_notification(JsonRpcNotification {
241            jsonrpc: "2.0".to_string(),
242            method: "textDocument/didChange".to_string(),
243            params: Some(params),
244        })
245        .await
246    }
247
248    /// Send textDocument/didClose notification
249    pub async fn send_did_close(&self, uri: String) -> Result<()> {
250        let params = serde_json::json!({
251            "textDocument": {
252                "uri": uri
253            }
254        });
255
256        self.handle_notification(JsonRpcNotification {
257            jsonrpc: "2.0".to_string(),
258            method: "textDocument/didClose".to_string(),
259            params: Some(params),
260        })
261        .await
262    }
263
264    /// Send textDocument/didSave notification
265    pub async fn send_did_save(&self, uri: String, text: Option<String>) -> Result<()> {
266        let mut params = serde_json::json!({
267            "textDocument": {
268                "uri": uri
269            }
270        });
271
272        if let Some(text) = text {
273            params["text"] = serde_json::json!(text);
274        }
275
276        self.handle_notification(JsonRpcNotification {
277            jsonrpc: "2.0".to_string(),
278            method: "textDocument/didSave".to_string(),
279            params: Some(params),
280        })
281        .await
282    }
283}
284
285impl Default for LspConnection {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[tokio::test]
296    async fn test_create_tracked_request() {
297        let conn = LspConnection::new();
298        let (request, _rx) = conn
299            .create_tracked_request("test", None, Duration::from_secs(5))
300            .await
301            .unwrap();
302
303        assert_eq!(request.method, "test");
304        assert!(request.id.is_some());
305        assert_eq!(conn.pending_request_count().await, 1);
306    }
307
308    #[tokio::test]
309    async fn test_handle_response() {
310        let conn = LspConnection::new();
311        let (request, rx) = conn
312            .create_tracked_request("test", None, Duration::from_secs(5))
313            .await
314            .unwrap();
315
316        let request_id = request.id.unwrap();
317
318        let response = JsonRpcResponse {
319            jsonrpc: "2.0".to_string(),
320            result: Some(Value::String("success".to_string())),
321            error: None,
322            id: request_id,
323        };
324
325        conn.handle_response(response).await.unwrap();
326
327        let result = rx.await.unwrap().unwrap();
328        assert_eq!(result, Value::String("success".to_string()));
329        assert_eq!(conn.pending_request_count().await, 0);
330    }
331
332    #[tokio::test]
333    async fn test_handle_error_response() {
334        let conn = LspConnection::new();
335        let (request, rx) = conn
336            .create_tracked_request("test", None, Duration::from_secs(5))
337            .await
338            .unwrap();
339
340        let request_id = request.id.unwrap();
341
342        let response = JsonRpcResponse {
343            jsonrpc: "2.0".to_string(),
344            result: None,
345            error: Some(crate::client::protocol::JsonRpcError {
346                code: -32600,
347                message: "Invalid Request".to_string(),
348                data: None,
349            }),
350            id: request_id,
351        };
352
353        conn.handle_response(response).await.unwrap();
354
355        let result = rx.await.unwrap();
356        assert!(result.is_err());
357    }
358
359    #[tokio::test]
360    async fn test_cleanup_timed_out_requests() {
361        let conn = LspConnection::new();
362        let (request, _rx) = conn
363            .create_tracked_request("test", None, Duration::from_millis(1))
364            .await
365            .unwrap();
366
367        // Wait for timeout
368        tokio::time::sleep(Duration::from_millis(10)).await;
369
370        let timed_out = conn.cleanup_timed_out_requests().await;
371        assert_eq!(timed_out.len(), 1);
372        assert_eq!(timed_out[0], request.id.unwrap());
373        assert_eq!(conn.pending_request_count().await, 0);
374    }
375
376    #[tokio::test]
377    async fn test_unknown_response_id() {
378        let conn = LspConnection::new();
379
380        let response = JsonRpcResponse {
381            jsonrpc: "2.0".to_string(),
382            result: Some(Value::String("success".to_string())),
383            error: None,
384            id: 999,
385        };
386
387        let result = conn.handle_response(response).await;
388        assert!(result.is_err());
389    }
390
391    #[tokio::test]
392    async fn test_handle_notification() {
393        let conn = LspConnection::new();
394        let mut rx = conn.subscribe_notifications();
395
396        let notification = JsonRpcNotification {
397            jsonrpc: "2.0".to_string(),
398            method: "test/notification".to_string(),
399            params: Some(Value::String("test".to_string())),
400        };
401
402        conn.handle_notification(notification).await.unwrap();
403
404        let (method, params) = rx.recv().await.unwrap();
405        assert_eq!(method, "test/notification");
406        assert_eq!(params, Some(Value::String("test".to_string())));
407    }
408
409    #[tokio::test]
410    async fn test_handle_publish_diagnostics() {
411        let conn = LspConnection::new();
412        let mut rx = conn.subscribe_notifications();
413
414        let params = Some(serde_json::json!({
415            "uri": "file:///test.rs",
416            "diagnostics": []
417        }));
418
419        conn.handle_publish_diagnostics(params.clone())
420            .await
421            .unwrap();
422
423        let (method, received_params) = rx.recv().await.unwrap();
424        assert_eq!(method, "textDocument/publishDiagnostics");
425        assert_eq!(received_params, params);
426    }
427
428    #[tokio::test]
429    async fn test_handle_log_message() {
430        let conn = LspConnection::new();
431        let mut rx = conn.subscribe_notifications();
432
433        let params = Some(serde_json::json!({
434            "type": 1,
435            "message": "Test log message"
436        }));
437
438        conn.handle_log_message(params.clone()).await.unwrap();
439
440        let (method, received_params) = rx.recv().await.unwrap();
441        assert_eq!(method, "window/logMessage");
442        assert_eq!(received_params, params);
443    }
444
445    #[tokio::test]
446    async fn test_handle_show_message() {
447        let conn = LspConnection::new();
448        let mut rx = conn.subscribe_notifications();
449
450        let params = Some(serde_json::json!({
451            "type": 1,
452            "message": "Test show message"
453        }));
454
455        conn.handle_show_message(params.clone()).await.unwrap();
456
457        let (method, received_params) = rx.recv().await.unwrap();
458        assert_eq!(method, "window/showMessage");
459        assert_eq!(received_params, params);
460    }
461
462    #[tokio::test]
463    async fn test_multiple_notification_subscribers() {
464        let conn = LspConnection::new();
465        let mut rx1 = conn.subscribe_notifications();
466        let mut rx2 = conn.subscribe_notifications();
467
468        let notification = JsonRpcNotification {
469            jsonrpc: "2.0".to_string(),
470            method: "test".to_string(),
471            params: None,
472        };
473
474        conn.handle_notification(notification).await.unwrap();
475
476        let (method1, _) = rx1.recv().await.unwrap();
477        let (method2, _) = rx2.recv().await.unwrap();
478
479        assert_eq!(method1, "test");
480        assert_eq!(method2, "test");
481    }
482
483    #[tokio::test]
484    async fn test_send_did_open() {
485        let conn = LspConnection::new();
486        let mut rx = conn.subscribe_notifications(); // mut needed for recv()
487
488        conn.send_did_open(
489            "file:///test.rs".to_string(),
490            "rust".to_string(),
491            1,
492            "fn main() {}".to_string(),
493        )
494        .await
495        .unwrap();
496
497        let (method, params) = rx.recv().await.unwrap();
498        assert_eq!(method, "textDocument/didOpen");
499        assert!(params.is_some());
500
501        let params = params.unwrap();
502        assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
503        assert_eq!(params["textDocument"]["languageId"], "rust");
504        assert_eq!(params["textDocument"]["version"], 1);
505        assert_eq!(params["textDocument"]["text"], "fn main() {}");
506    }
507
508    #[tokio::test]
509    async fn test_send_did_change() {
510        let conn = LspConnection::new();
511        let mut rx = conn.subscribe_notifications();
512
513        let changes = vec![serde_json::json!({
514            "range": {
515                "start": {"line": 0, "character": 0},
516                "end": {"line": 0, "character": 0}
517            },
518            "text": "// comment\n"
519        })];
520
521        conn.send_did_change("file:///test.rs".to_string(), 2, changes.clone())
522            .await
523            .unwrap();
524
525        let (method, params) = rx.recv().await.unwrap();
526        assert_eq!(method, "textDocument/didChange");
527        assert!(params.is_some());
528
529        let params = params.unwrap();
530        assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
531        assert_eq!(params["textDocument"]["version"], 2);
532        assert_eq!(params["contentChanges"], serde_json::json!(changes));
533    }
534
535    #[tokio::test]
536    async fn test_send_did_close() {
537        let conn = LspConnection::new();
538        let mut rx = conn.subscribe_notifications();
539
540        conn.send_did_close("file:///test.rs".to_string())
541            .await
542            .unwrap();
543
544        let (method, params) = rx.recv().await.unwrap();
545        assert_eq!(method, "textDocument/didClose");
546        assert!(params.is_some());
547
548        let params = params.unwrap();
549        assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
550    }
551
552    #[tokio::test]
553    async fn test_send_did_save() {
554        let conn = LspConnection::new();
555        let mut rx = conn.subscribe_notifications();
556
557        conn.send_did_save("file:///test.rs".to_string(), Some("fn main() {}".to_string()))
558            .await
559            .unwrap();
560
561        let (method, params) = rx.recv().await.unwrap();
562        assert_eq!(method, "textDocument/didSave");
563        assert!(params.is_some());
564
565        let params = params.unwrap();
566        assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
567        assert_eq!(params["text"], "fn main() {}");
568    }
569
570    #[tokio::test]
571    async fn test_send_did_save_without_text() {
572        let conn = LspConnection::new();
573        let mut rx = conn.subscribe_notifications();
574
575        conn.send_did_save("file:///test.rs".to_string(), None)
576            .await
577            .unwrap();
578
579        let (method, params) = rx.recv().await.unwrap();
580        assert_eq!(method, "textDocument/didSave");
581        assert!(params.is_some());
582
583        let params = params.unwrap();
584        assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
585        assert!(params.get("text").is_none());
586    }
587}