mcpls_core/lsp/
client.rs

1//! LSP client implementation with async request/response handling.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicI64, Ordering};
6
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use tokio::sync::{Mutex, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, trace, warn};
14
15use crate::config::LspServerConfig;
16use crate::error::{Error, Result};
17use crate::lsp::transport::LspTransport;
18use crate::lsp::types::{InboundMessage, JsonRpcRequest, LspNotification, RequestId};
19
20/// JSON-RPC protocol version.
21const JSONRPC_VERSION: &str = "2.0";
22
23/// Type alias for pending request tracking map.
24type PendingRequests = HashMap<RequestId, oneshot::Sender<Result<Value>>>;
25
26/// LSP client with async request/response handling.
27///
28/// This client manages communication with an LSP server, handling:
29/// - Concurrent requests with unique ID tracking
30/// - Background message loop for receiving responses
31/// - Timeout support for all requests
32/// - Graceful shutdown
33#[derive(Debug)]
34pub struct LspClient {
35    /// Configuration for this LSP server.
36    config: LspServerConfig,
37
38    /// Current server state.
39    state: Arc<Mutex<super::ServerState>>,
40
41    /// Atomic counter for request IDs.
42    request_counter: Arc<AtomicI64>,
43
44    /// Command sender for outbound messages.
45    command_tx: mpsc::Sender<ClientCommand>,
46
47    /// Background receiver task handle.
48    receiver_task: Option<JoinHandle<Result<()>>>,
49}
50
51impl Clone for LspClient {
52    /// Creates a clone that shares the underlying connection.
53    ///
54    /// The clone does not own the receiver task and cannot perform shutdown.
55    /// All clones share the same command channel for sending requests.
56    fn clone(&self) -> Self {
57        Self {
58            config: self.config.clone(),
59            state: Arc::clone(&self.state),
60            request_counter: Arc::clone(&self.request_counter),
61            command_tx: self.command_tx.clone(),
62            receiver_task: None,
63        }
64    }
65}
66
67/// Commands for client control.
68enum ClientCommand {
69    /// Send a request and wait for response.
70    SendRequest {
71        request: JsonRpcRequest,
72        response_tx: oneshot::Sender<Result<Value>>,
73    },
74    /// Send a notification (no response expected).
75    SendNotification {
76        method: String,
77        params: Option<Value>,
78    },
79    /// Shutdown the client.
80    Shutdown,
81}
82
83impl LspClient {
84    /// Create a new LSP client with the given configuration.
85    ///
86    /// The client starts in an uninitialized state. Call `initialize()` to
87    /// start the server and complete the initialization handshake.
88    #[must_use]
89    pub fn new(config: LspServerConfig) -> Self {
90        // Placeholder channel - the receiver is intentionally dropped since
91        // the client starts uninitialized. A real channel is created when
92        // `from_transport` or `from_transport_with_notifications` is called.
93        let (command_tx, _command_rx) = mpsc::channel(1); // Minimal capacity for placeholder
94
95        Self {
96            config,
97            state: Arc::new(Mutex::new(super::ServerState::Uninitialized)),
98            request_counter: Arc::new(AtomicI64::new(1)),
99            command_tx,
100            receiver_task: None,
101        }
102    }
103
104    /// Create client from transport (for testing or custom spawning).
105    ///
106    /// This method initializes the background message loop with the provided transport.
107    pub(crate) fn from_transport(config: LspServerConfig, transport: LspTransport) -> Self {
108        let state = Arc::new(Mutex::new(super::ServerState::Initializing));
109        let request_counter = Arc::new(AtomicI64::new(1));
110        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
111
112        let (command_tx, command_rx) = mpsc::channel(100);
113
114        let receiver_task = tokio::spawn(Self::message_loop(
115            transport,
116            command_rx,
117            pending_requests,
118            None,
119        ));
120
121        Self {
122            config,
123            state,
124            request_counter,
125            command_tx,
126            receiver_task: Some(receiver_task),
127        }
128    }
129
130    /// Create client from transport with notification forwarding.
131    ///
132    /// Notifications received from the LSP server will be parsed and sent
133    /// through the provided channel.
134    #[allow(dead_code)] // Used in Phase 4
135    pub(crate) fn from_transport_with_notifications(
136        config: LspServerConfig,
137        transport: LspTransport,
138        notification_tx: mpsc::Sender<LspNotification>,
139    ) -> Self {
140        let state = Arc::new(Mutex::new(super::ServerState::Initializing));
141        let request_counter = Arc::new(AtomicI64::new(1));
142        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
143
144        let (command_tx, command_rx) = mpsc::channel(100);
145
146        let receiver_task = tokio::spawn(Self::message_loop(
147            transport,
148            command_rx,
149            pending_requests,
150            Some(notification_tx),
151        ));
152
153        Self {
154            config,
155            state,
156            request_counter,
157            command_tx,
158            receiver_task: Some(receiver_task),
159        }
160    }
161
162    /// Get the language ID for this client.
163    #[must_use]
164    pub fn language_id(&self) -> &str {
165        &self.config.language_id
166    }
167
168    /// Get the current server state.
169    pub async fn state(&self) -> super::ServerState {
170        *self.state.lock().await
171    }
172
173    /// Send request and wait for response with timeout.
174    ///
175    /// # Type Parameters
176    ///
177    /// * `P` - The type of the request parameters (must be serializable)
178    /// * `R` - The type of the response result (must be deserializable)
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if:
183    /// - Server has shut down
184    /// - Request times out
185    /// - Response cannot be deserialized
186    /// - LSP server returns an error
187    pub async fn request<P, R>(
188        &self,
189        method: &str,
190        params: P,
191        timeout_duration: Duration,
192    ) -> Result<R>
193    where
194        P: Serialize,
195        R: DeserializeOwned,
196    {
197        let id = RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst));
198        let params_value = serde_json::to_value(params)?;
199
200        let (response_tx, response_rx) = oneshot::channel();
201
202        let request = JsonRpcRequest {
203            jsonrpc: JSONRPC_VERSION.to_string(),
204            id: id.clone(),
205            method: method.to_string(),
206            params: Some(params_value),
207        };
208
209        debug!("Sending request: {} (id={:?})", method, id);
210
211        self.command_tx
212            .send(ClientCommand::SendRequest {
213                request,
214                response_tx,
215            })
216            .await
217            .map_err(|_| Error::ServerTerminated)?;
218
219        let result_value = timeout(timeout_duration, response_rx)
220            .await
221            .map_err(|_| Error::Timeout(timeout_duration.as_secs()))?
222            .map_err(|_| Error::ServerTerminated)??;
223
224        serde_json::from_value(result_value)
225            .map_err(|e| Error::LspProtocolError(format!("Failed to deserialize response: {e}")))
226    }
227
228    /// Send notification (fire-and-forget, no response expected).
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the server has shut down.
233    pub async fn notify<P>(&self, method: &str, params: P) -> Result<()>
234    where
235        P: Serialize,
236    {
237        let params_value = serde_json::to_value(params)?;
238
239        debug!("Sending notification: {}", method);
240
241        self.command_tx
242            .send(ClientCommand::SendNotification {
243                method: method.to_string(),
244                params: Some(params_value),
245            })
246            .await
247            .map_err(|_| Error::ServerTerminated)?;
248
249        Ok(())
250    }
251
252    /// Shutdown client gracefully.
253    ///
254    /// This sends a shutdown command to the background task and waits for it to complete.
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if the background task failed.
259    pub async fn shutdown(mut self) -> Result<()> {
260        debug!("Shutting down LSP client");
261
262        let _ = self.command_tx.send(ClientCommand::Shutdown).await;
263
264        if let Some(task) = self.receiver_task.take() {
265            task.await
266                .map_err(|e| Error::Transport(format!("Receiver task failed: {e}")))??;
267        }
268
269        *self.state.lock().await = super::ServerState::Shutdown;
270
271        Ok(())
272    }
273
274    /// Background task: handle message I/O.
275    ///
276    /// This task runs in the background, handling:
277    /// - Outbound requests and notifications
278    /// - Inbound responses and server notifications
279    /// - Matching responses to pending requests
280    async fn message_loop(
281        mut transport: LspTransport,
282        mut command_rx: mpsc::Receiver<ClientCommand>,
283        pending_requests: Arc<Mutex<PendingRequests>>,
284        notification_tx: Option<mpsc::Sender<LspNotification>>,
285    ) -> Result<()> {
286        debug!("Message loop started");
287        let result = Self::message_loop_inner(
288            &mut transport,
289            &mut command_rx,
290            &pending_requests,
291            notification_tx.as_ref(),
292        )
293        .await;
294        if let Err(ref e) = result {
295            error!("Message loop exiting with error: {}", e);
296        } else {
297            debug!("Message loop exiting normally");
298        }
299        result
300    }
301
302    async fn message_loop_inner(
303        transport: &mut LspTransport,
304        command_rx: &mut mpsc::Receiver<ClientCommand>,
305        pending_requests: &Arc<Mutex<PendingRequests>>,
306        notification_tx: Option<&mpsc::Sender<LspNotification>>,
307    ) -> Result<()> {
308        loop {
309            tokio::select! {
310                Some(command) = command_rx.recv() => {
311                    match command {
312                        ClientCommand::SendRequest { request, response_tx } => {
313                            pending_requests.lock().await.insert(
314                                request.id.clone(),
315                                response_tx,
316                            );
317
318                            let value = serde_json::to_value(&request)?;
319                            transport.send(&value).await?;
320                        }
321                        ClientCommand::SendNotification { method, params } => {
322                            let notification = serde_json::json!({
323                                "jsonrpc": "2.0",
324                                "method": method,
325                                "params": params,
326                            });
327                            transport.send(&notification).await?;
328                        }
329                        ClientCommand::Shutdown => {
330                            debug!("Client shutdown requested");
331                            break;
332                        }
333                    }
334                }
335
336                message = transport.receive() => {
337                    let message = match message {
338                        Ok(m) => m,
339                        Err(e) => {
340                            error!("Transport receive error: {}", e);
341                            return Err(e);
342                        }
343                    };
344                    match message {
345                        InboundMessage::Response(response) => {
346                            trace!("Received response: id={:?}", response.id);
347
348                            let sender = pending_requests.lock().await.remove(&response.id);
349
350                            if let Some(sender) = sender {
351                                if let Some(error) = response.error {
352                                    let message = if error.message.len() > 200 {
353                                        format!("{}... (truncated)", &error.message[..200])
354                                    } else {
355                                        error.message.clone()
356                                    };
357                                    error!("LSP error response: {} (code {})", message, error.code);
358                                    let _ = sender.send(Err(Error::LspServerError {
359                                        code: error.code,
360                                        message: error.message,
361                                    }));
362                                } else if let Some(result) = response.result {
363                                    let _ = sender.send(Ok(result));
364                                } else {
365                                    // LSP spec allows null result for some requests (e.g., hover with no info).
366                                    // Treat as successful response with null value.
367                                    trace!("Response with null result: {:?}", response.id);
368                                    let _ = sender.send(Ok(Value::Null));
369                                }
370                            } else {
371                                warn!("Received response for unknown request ID: {:?}", response.id);
372                            }
373                        }
374                        InboundMessage::Notification(notification) => {
375                            debug!("Received notification: {}", notification.method);
376
377                            // Parse notification into typed variant
378                            let typed = LspNotification::parse(&notification.method, notification.params);
379
380                            // Forward to notification handler if sender is available
381                            if let Some(tx) = notification_tx {
382                                // Log diagnostics count since it's useful for debugging
383                                if let LspNotification::PublishDiagnostics(ref params) = typed {
384                                    debug!(
385                                        "Forwarding diagnostics for {}: {} items",
386                                        params.uri.as_str(),
387                                        params.diagnostics.len()
388                                    );
389                                } else {
390                                    trace!("Forwarding notification: {:?}", typed);
391                                }
392
393                                // Send the notification with backpressure handling
394                                if tx.try_send(typed).is_err() {
395                                    warn!("Notification channel full or closed, dropping notification");
396                                }
397                            }
398                        }
399                    }
400                }
401            }
402        }
403
404        Ok(())
405    }
406}
407
408#[cfg(test)]
409#[allow(clippy::unwrap_used)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_request_id_generation() {
415        let counter = AtomicI64::new(1);
416
417        let id1 = counter.fetch_add(1, Ordering::SeqCst);
418        let id2 = counter.fetch_add(1, Ordering::SeqCst);
419        let id3 = counter.fetch_add(1, Ordering::SeqCst);
420
421        assert_eq!(id1, 1);
422        assert_eq!(id2, 2);
423        assert_eq!(id3, 3);
424    }
425
426    #[test]
427    fn test_client_creation() {
428        let config = LspServerConfig::rust_analyzer();
429
430        let client = LspClient::new(config);
431        assert_eq!(client.language_id(), "rust");
432    }
433
434    #[tokio::test]
435    async fn test_null_response_handling() {
436        use crate::lsp::types::{JsonRpcResponse, RequestId};
437
438        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
439
440        let (response_tx, response_rx) = oneshot::channel::<Result<Value>>();
441
442        pending_requests
443            .lock()
444            .await
445            .insert(RequestId::Number(1), response_tx);
446
447        let null_response = JsonRpcResponse {
448            jsonrpc: "2.0".to_string(),
449            id: RequestId::Number(1),
450            result: None,
451            error: None,
452        };
453
454        let sender = pending_requests.lock().await.remove(&null_response.id);
455        if let Some(sender) = sender {
456            let _ = sender.send(Ok(Value::Null));
457        }
458
459        let timeout_result =
460            tokio::time::timeout(tokio::time::Duration::from_millis(100), response_rx).await;
461
462        assert!(timeout_result.is_ok(), "Should not timeout");
463
464        let channel_result = timeout_result.unwrap();
465        assert!(
466            channel_result.is_ok(),
467            "Channel should not be closed: {:?}",
468            channel_result.err()
469        );
470
471        let response = channel_result.unwrap();
472        assert!(
473            response.is_ok(),
474            "Should receive Ok(Value::Null), not Err: {:?}",
475            response.err()
476        );
477
478        let value = response.unwrap();
479        assert_eq!(value, Value::Null, "Should receive Value::Null");
480    }
481}