1use super::protocol::{JsonRpcHandler, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, RequestId};
4use crate::error::{ExternalLspError, Result};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{broadcast, RwLock};
10
11pub struct PendingRequest {
13 pub id: RequestId,
15 pub method: String,
17 pub sent_at: Instant,
19 pub timeout: Duration,
21 pub response_tx: tokio::sync::oneshot::Sender<Result<Value>>,
23}
24
25pub type NotificationHandler = Box<dyn Fn(&str, Option<Value>) + Send + Sync>;
27
28pub struct LspConnection {
30 handler: JsonRpcHandler,
32 pending_requests: Arc<RwLock<HashMap<RequestId, PendingRequest>>>,
34 notification_tx: broadcast::Sender<(String, Option<Value>)>,
36}
37
38impl LspConnection {
39 pub fn new() -> Self {
41 let (notification_tx, _) = broadcast::channel(100);
42 Self {
43 handler: JsonRpcHandler::new(),
44 pending_requests: Arc::new(RwLock::new(HashMap::new())),
45 notification_tx,
46 }
47 }
48
49 pub fn handler(&self) -> &JsonRpcHandler {
51 &self.handler
52 }
53
54 pub async fn create_tracked_request(
56 &self,
57 method: impl Into<String>,
58 params: Option<Value>,
59 timeout: Duration,
60 ) -> Result<(JsonRpcRequest, tokio::sync::oneshot::Receiver<Result<Value>>)> {
61 let request = self.handler.create_request(method.into(), params);
62 let request_id = request.id.ok_or_else(|| {
63 ExternalLspError::ProtocolError("Request ID not set".to_string())
64 })?;
65
66 let (tx, rx) = tokio::sync::oneshot::channel();
67
68 let pending = PendingRequest {
69 id: request_id,
70 method: request.method.clone(),
71 sent_at: Instant::now(),
72 timeout,
73 response_tx: tx,
74 };
75
76 self.pending_requests.write().await.insert(request_id, pending);
77
78 Ok((request, rx))
79 }
80
81 pub async fn handle_response(&self, response: JsonRpcResponse) -> Result<()> {
83 let mut pending = self.pending_requests.write().await;
84
85 if let Some(pending_req) = pending.remove(&response.id) {
86 if pending_req.sent_at.elapsed() > pending_req.timeout {
88 return Err(ExternalLspError::Timeout {
89 timeout_ms: pending_req.timeout.as_millis() as u64,
90 });
91 }
92
93 let result = if let Some(error) = response.error {
95 Err(ExternalLspError::ProtocolError(format!(
96 "{}: {}",
97 error.code, error.message
98 )))
99 } else {
100 Ok(response.result.unwrap_or(Value::Null))
101 };
102
103 let _ = pending_req.response_tx.send(result);
105
106 Ok(())
107 } else {
108 Err(ExternalLspError::ProtocolError(format!(
109 "Received response for unknown request ID: {}",
110 response.id
111 )))
112 }
113 }
114
115 pub async fn pending_request_count(&self) -> usize {
117 self.pending_requests.read().await.len()
118 }
119
120 pub async fn cleanup_timed_out_requests(&self) -> Vec<RequestId> {
122 let mut pending = self.pending_requests.write().await;
123 let mut timed_out = Vec::new();
124 let mut to_remove = Vec::new();
125
126 for (id, req) in pending.iter() {
127 if req.sent_at.elapsed() > req.timeout {
128 timed_out.push(*id);
129 to_remove.push(*id);
130 }
131 }
132
133 for id in to_remove {
134 if let Some(pending_req) = pending.remove(&id) {
135 let _ = pending_req.response_tx.send(Err(ExternalLspError::Timeout {
137 timeout_ms: pending_req.timeout.as_millis() as u64,
138 }));
139 }
140 }
141
142 timed_out
143 }
144
145 pub async fn clear_pending_requests(&self) {
147 self.pending_requests.write().await.clear();
148 }
149
150 pub async fn get_pending_request_ids(&self) -> Vec<RequestId> {
152 self.pending_requests.read().await.keys().copied().collect()
153 }
154
155 pub async fn handle_notification(&self, notification: JsonRpcNotification) -> Result<()> {
157 let _ = self.notification_tx.send((notification.method, notification.params));
159 Ok(())
160 }
161
162 pub fn subscribe_notifications(&self) -> broadcast::Receiver<(String, Option<Value>)> {
164 self.notification_tx.subscribe()
165 }
166
167 pub async fn handle_publish_diagnostics(
169 &self,
170 params: Option<Value>,
171 ) -> Result<()> {
172 self.handle_notification(JsonRpcNotification {
173 jsonrpc: "2.0".to_string(),
174 method: "textDocument/publishDiagnostics".to_string(),
175 params,
176 })
177 .await
178 }
179
180 pub async fn handle_log_message(&self, params: Option<Value>) -> Result<()> {
182 self.handle_notification(JsonRpcNotification {
183 jsonrpc: "2.0".to_string(),
184 method: "window/logMessage".to_string(),
185 params,
186 })
187 .await
188 }
189
190 pub async fn handle_show_message(&self, params: Option<Value>) -> Result<()> {
192 self.handle_notification(JsonRpcNotification {
193 jsonrpc: "2.0".to_string(),
194 method: "window/showMessage".to_string(),
195 params,
196 })
197 .await
198 }
199
200 pub async fn send_did_open(
202 &self,
203 uri: String,
204 language_id: String,
205 version: i32,
206 text: String,
207 ) -> Result<()> {
208 let params = serde_json::json!({
209 "textDocument": {
210 "uri": uri,
211 "languageId": language_id,
212 "version": version,
213 "text": text
214 }
215 });
216
217 self.handle_notification(JsonRpcNotification {
218 jsonrpc: "2.0".to_string(),
219 method: "textDocument/didOpen".to_string(),
220 params: Some(params),
221 })
222 .await
223 }
224
225 pub async fn send_did_change(
227 &self,
228 uri: String,
229 version: i32,
230 content_changes: Vec<Value>,
231 ) -> Result<()> {
232 let params = serde_json::json!({
233 "textDocument": {
234 "uri": uri,
235 "version": version
236 },
237 "contentChanges": content_changes
238 });
239
240 self.handle_notification(JsonRpcNotification {
241 jsonrpc: "2.0".to_string(),
242 method: "textDocument/didChange".to_string(),
243 params: Some(params),
244 })
245 .await
246 }
247
248 pub async fn send_did_close(&self, uri: String) -> Result<()> {
250 let params = serde_json::json!({
251 "textDocument": {
252 "uri": uri
253 }
254 });
255
256 self.handle_notification(JsonRpcNotification {
257 jsonrpc: "2.0".to_string(),
258 method: "textDocument/didClose".to_string(),
259 params: Some(params),
260 })
261 .await
262 }
263
264 pub async fn send_did_save(&self, uri: String, text: Option<String>) -> Result<()> {
266 let mut params = serde_json::json!({
267 "textDocument": {
268 "uri": uri
269 }
270 });
271
272 if let Some(text) = text {
273 params["text"] = serde_json::json!(text);
274 }
275
276 self.handle_notification(JsonRpcNotification {
277 jsonrpc: "2.0".to_string(),
278 method: "textDocument/didSave".to_string(),
279 params: Some(params),
280 })
281 .await
282 }
283}
284
285impl Default for LspConnection {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[tokio::test]
296 async fn test_create_tracked_request() {
297 let conn = LspConnection::new();
298 let (request, _rx) = conn
299 .create_tracked_request("test", None, Duration::from_secs(5))
300 .await
301 .unwrap();
302
303 assert_eq!(request.method, "test");
304 assert!(request.id.is_some());
305 assert_eq!(conn.pending_request_count().await, 1);
306 }
307
308 #[tokio::test]
309 async fn test_handle_response() {
310 let conn = LspConnection::new();
311 let (request, rx) = conn
312 .create_tracked_request("test", None, Duration::from_secs(5))
313 .await
314 .unwrap();
315
316 let request_id = request.id.unwrap();
317
318 let response = JsonRpcResponse {
319 jsonrpc: "2.0".to_string(),
320 result: Some(Value::String("success".to_string())),
321 error: None,
322 id: request_id,
323 };
324
325 conn.handle_response(response).await.unwrap();
326
327 let result = rx.await.unwrap().unwrap();
328 assert_eq!(result, Value::String("success".to_string()));
329 assert_eq!(conn.pending_request_count().await, 0);
330 }
331
332 #[tokio::test]
333 async fn test_handle_error_response() {
334 let conn = LspConnection::new();
335 let (request, rx) = conn
336 .create_tracked_request("test", None, Duration::from_secs(5))
337 .await
338 .unwrap();
339
340 let request_id = request.id.unwrap();
341
342 let response = JsonRpcResponse {
343 jsonrpc: "2.0".to_string(),
344 result: None,
345 error: Some(crate::client::protocol::JsonRpcError {
346 code: -32600,
347 message: "Invalid Request".to_string(),
348 data: None,
349 }),
350 id: request_id,
351 };
352
353 conn.handle_response(response).await.unwrap();
354
355 let result = rx.await.unwrap();
356 assert!(result.is_err());
357 }
358
359 #[tokio::test]
360 async fn test_cleanup_timed_out_requests() {
361 let conn = LspConnection::new();
362 let (request, _rx) = conn
363 .create_tracked_request("test", None, Duration::from_millis(1))
364 .await
365 .unwrap();
366
367 tokio::time::sleep(Duration::from_millis(10)).await;
369
370 let timed_out = conn.cleanup_timed_out_requests().await;
371 assert_eq!(timed_out.len(), 1);
372 assert_eq!(timed_out[0], request.id.unwrap());
373 assert_eq!(conn.pending_request_count().await, 0);
374 }
375
376 #[tokio::test]
377 async fn test_unknown_response_id() {
378 let conn = LspConnection::new();
379
380 let response = JsonRpcResponse {
381 jsonrpc: "2.0".to_string(),
382 result: Some(Value::String("success".to_string())),
383 error: None,
384 id: 999,
385 };
386
387 let result = conn.handle_response(response).await;
388 assert!(result.is_err());
389 }
390
391 #[tokio::test]
392 async fn test_handle_notification() {
393 let conn = LspConnection::new();
394 let mut rx = conn.subscribe_notifications();
395
396 let notification = JsonRpcNotification {
397 jsonrpc: "2.0".to_string(),
398 method: "test/notification".to_string(),
399 params: Some(Value::String("test".to_string())),
400 };
401
402 conn.handle_notification(notification).await.unwrap();
403
404 let (method, params) = rx.recv().await.unwrap();
405 assert_eq!(method, "test/notification");
406 assert_eq!(params, Some(Value::String("test".to_string())));
407 }
408
409 #[tokio::test]
410 async fn test_handle_publish_diagnostics() {
411 let conn = LspConnection::new();
412 let mut rx = conn.subscribe_notifications();
413
414 let params = Some(serde_json::json!({
415 "uri": "file:///test.rs",
416 "diagnostics": []
417 }));
418
419 conn.handle_publish_diagnostics(params.clone())
420 .await
421 .unwrap();
422
423 let (method, received_params) = rx.recv().await.unwrap();
424 assert_eq!(method, "textDocument/publishDiagnostics");
425 assert_eq!(received_params, params);
426 }
427
428 #[tokio::test]
429 async fn test_handle_log_message() {
430 let conn = LspConnection::new();
431 let mut rx = conn.subscribe_notifications();
432
433 let params = Some(serde_json::json!({
434 "type": 1,
435 "message": "Test log message"
436 }));
437
438 conn.handle_log_message(params.clone()).await.unwrap();
439
440 let (method, received_params) = rx.recv().await.unwrap();
441 assert_eq!(method, "window/logMessage");
442 assert_eq!(received_params, params);
443 }
444
445 #[tokio::test]
446 async fn test_handle_show_message() {
447 let conn = LspConnection::new();
448 let mut rx = conn.subscribe_notifications();
449
450 let params = Some(serde_json::json!({
451 "type": 1,
452 "message": "Test show message"
453 }));
454
455 conn.handle_show_message(params.clone()).await.unwrap();
456
457 let (method, received_params) = rx.recv().await.unwrap();
458 assert_eq!(method, "window/showMessage");
459 assert_eq!(received_params, params);
460 }
461
462 #[tokio::test]
463 async fn test_multiple_notification_subscribers() {
464 let conn = LspConnection::new();
465 let mut rx1 = conn.subscribe_notifications();
466 let mut rx2 = conn.subscribe_notifications();
467
468 let notification = JsonRpcNotification {
469 jsonrpc: "2.0".to_string(),
470 method: "test".to_string(),
471 params: None,
472 };
473
474 conn.handle_notification(notification).await.unwrap();
475
476 let (method1, _) = rx1.recv().await.unwrap();
477 let (method2, _) = rx2.recv().await.unwrap();
478
479 assert_eq!(method1, "test");
480 assert_eq!(method2, "test");
481 }
482
483 #[tokio::test]
484 async fn test_send_did_open() {
485 let conn = LspConnection::new();
486 let mut rx = conn.subscribe_notifications(); conn.send_did_open(
489 "file:///test.rs".to_string(),
490 "rust".to_string(),
491 1,
492 "fn main() {}".to_string(),
493 )
494 .await
495 .unwrap();
496
497 let (method, params) = rx.recv().await.unwrap();
498 assert_eq!(method, "textDocument/didOpen");
499 assert!(params.is_some());
500
501 let params = params.unwrap();
502 assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
503 assert_eq!(params["textDocument"]["languageId"], "rust");
504 assert_eq!(params["textDocument"]["version"], 1);
505 assert_eq!(params["textDocument"]["text"], "fn main() {}");
506 }
507
508 #[tokio::test]
509 async fn test_send_did_change() {
510 let conn = LspConnection::new();
511 let mut rx = conn.subscribe_notifications();
512
513 let changes = vec![serde_json::json!({
514 "range": {
515 "start": {"line": 0, "character": 0},
516 "end": {"line": 0, "character": 0}
517 },
518 "text": "// comment\n"
519 })];
520
521 conn.send_did_change("file:///test.rs".to_string(), 2, changes.clone())
522 .await
523 .unwrap();
524
525 let (method, params) = rx.recv().await.unwrap();
526 assert_eq!(method, "textDocument/didChange");
527 assert!(params.is_some());
528
529 let params = params.unwrap();
530 assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
531 assert_eq!(params["textDocument"]["version"], 2);
532 assert_eq!(params["contentChanges"], serde_json::json!(changes));
533 }
534
535 #[tokio::test]
536 async fn test_send_did_close() {
537 let conn = LspConnection::new();
538 let mut rx = conn.subscribe_notifications();
539
540 conn.send_did_close("file:///test.rs".to_string())
541 .await
542 .unwrap();
543
544 let (method, params) = rx.recv().await.unwrap();
545 assert_eq!(method, "textDocument/didClose");
546 assert!(params.is_some());
547
548 let params = params.unwrap();
549 assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
550 }
551
552 #[tokio::test]
553 async fn test_send_did_save() {
554 let conn = LspConnection::new();
555 let mut rx = conn.subscribe_notifications();
556
557 conn.send_did_save("file:///test.rs".to_string(), Some("fn main() {}".to_string()))
558 .await
559 .unwrap();
560
561 let (method, params) = rx.recv().await.unwrap();
562 assert_eq!(method, "textDocument/didSave");
563 assert!(params.is_some());
564
565 let params = params.unwrap();
566 assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
567 assert_eq!(params["text"], "fn main() {}");
568 }
569
570 #[tokio::test]
571 async fn test_send_did_save_without_text() {
572 let conn = LspConnection::new();
573 let mut rx = conn.subscribe_notifications();
574
575 conn.send_did_save("file:///test.rs".to_string(), None)
576 .await
577 .unwrap();
578
579 let (method, params) = rx.recv().await.unwrap();
580 assert_eq!(method, "textDocument/didSave");
581 assert!(params.is_some());
582
583 let params = params.unwrap();
584 assert_eq!(params["textDocument"]["uri"], "file:///test.rs");
585 assert!(params.get("text").is_none());
586 }
587}