mcpls_core/lsp/
client.rs

1//! LSP client implementation with async request/response handling.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicI64, Ordering};
6
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use tokio::sync::{Mutex, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, trace, warn};
14
15use crate::config::LspServerConfig;
16use crate::error::{Error, Result};
17use crate::lsp::transport::LspTransport;
18use crate::lsp::types::{InboundMessage, JsonRpcRequest, RequestId};
19
20/// LSP client with async request/response handling.
21///
22/// This client manages communication with an LSP server, handling:
23/// - Concurrent requests with unique ID tracking
24/// - Background message loop for receiving responses
25/// - Timeout support for all requests
26/// - Graceful shutdown
27#[derive(Debug)]
28pub struct LspClient {
29    /// Configuration for this LSP server.
30    config: LspServerConfig,
31
32    /// Current server state.
33    state: Arc<Mutex<super::ServerState>>,
34
35    /// Atomic counter for request IDs.
36    request_counter: Arc<AtomicI64>,
37
38    /// Command sender for outbound messages.
39    command_tx: mpsc::Sender<ClientCommand>,
40
41    /// Background receiver task handle.
42    receiver_task: Option<JoinHandle<Result<()>>>,
43}
44
45impl Clone for LspClient {
46    fn clone(&self) -> Self {
47        Self {
48            config: self.config.clone(),
49            state: Arc::clone(&self.state),
50            request_counter: Arc::clone(&self.request_counter),
51            command_tx: self.command_tx.clone(),
52            receiver_task: None,
53        }
54    }
55}
56
57/// Commands for client control.
58enum ClientCommand {
59    /// Send a request and wait for response.
60    SendRequest {
61        request: JsonRpcRequest,
62        response_tx: oneshot::Sender<Result<Value>>,
63    },
64    /// Send a notification (no response expected).
65    SendNotification {
66        method: String,
67        params: Option<Value>,
68    },
69    /// Shutdown the client.
70    Shutdown,
71}
72
73impl LspClient {
74    /// Create a new LSP client with the given configuration.
75    ///
76    /// The client starts in an uninitialized state. Call `initialize()` to
77    /// start the server and complete the initialization handshake.
78    #[must_use]
79    pub fn new(config: LspServerConfig) -> Self {
80        let (command_tx, _command_rx) = mpsc::channel(100);
81
82        Self {
83            config,
84            state: Arc::new(Mutex::new(super::ServerState::Uninitialized)),
85            request_counter: Arc::new(AtomicI64::new(1)),
86            command_tx,
87            receiver_task: None,
88        }
89    }
90
91    /// Create client from transport (for testing or custom spawning).
92    ///
93    /// This method initializes the background message loop with the provided transport.
94    pub(crate) fn from_transport(config: LspServerConfig, transport: LspTransport) -> Self {
95        let state = Arc::new(Mutex::new(super::ServerState::Initializing));
96        let request_counter = Arc::new(AtomicI64::new(1));
97        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
98
99        let (command_tx, command_rx) = mpsc::channel(100);
100
101        let receiver_task =
102            tokio::spawn(Self::message_loop(transport, command_rx, pending_requests));
103
104        Self {
105            config,
106            state,
107            request_counter,
108            command_tx,
109            receiver_task: Some(receiver_task),
110        }
111    }
112
113    /// Get the language ID for this client.
114    #[must_use]
115    pub fn language_id(&self) -> &str {
116        &self.config.language_id
117    }
118
119    /// Get the current server state.
120    pub async fn state(&self) -> super::ServerState {
121        *self.state.lock().await
122    }
123
124    /// Send request and wait for response with timeout.
125    ///
126    /// # Type Parameters
127    ///
128    /// * `P` - The type of the request parameters (must be serializable)
129    /// * `R` - The type of the response result (must be deserializable)
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if:
134    /// - Server has shut down
135    /// - Request times out
136    /// - Response cannot be deserialized
137    /// - LSP server returns an error
138    pub async fn request<P, R>(
139        &self,
140        method: &str,
141        params: P,
142        timeout_duration: Duration,
143    ) -> Result<R>
144    where
145        P: Serialize,
146        R: DeserializeOwned,
147    {
148        let id = RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst));
149        let params_value = serde_json::to_value(params)?;
150
151        let (response_tx, response_rx) = oneshot::channel();
152
153        let request = JsonRpcRequest {
154            jsonrpc: "2.0".to_string(),
155            id: id.clone(),
156            method: method.to_string(),
157            params: Some(params_value),
158        };
159
160        debug!("Sending request: {} (id={:?})", method, id);
161
162        self.command_tx
163            .send(ClientCommand::SendRequest {
164                request,
165                response_tx,
166            })
167            .await
168            .map_err(|_| Error::ServerTerminated)?;
169
170        let result_value = timeout(timeout_duration, response_rx)
171            .await
172            .map_err(|_| Error::Timeout(timeout_duration.as_secs()))?
173            .map_err(|_| Error::ServerTerminated)??;
174
175        serde_json::from_value(result_value)
176            .map_err(|e| Error::LspProtocolError(format!("Failed to deserialize response: {e}")))
177    }
178
179    /// Send notification (fire-and-forget, no response expected).
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if the server has shut down.
184    pub async fn notify<P>(&self, method: &str, params: P) -> Result<()>
185    where
186        P: Serialize,
187    {
188        let params_value = serde_json::to_value(params)?;
189
190        debug!("Sending notification: {}", method);
191
192        self.command_tx
193            .send(ClientCommand::SendNotification {
194                method: method.to_string(),
195                params: Some(params_value),
196            })
197            .await
198            .map_err(|_| Error::ServerTerminated)?;
199
200        Ok(())
201    }
202
203    /// Shutdown client gracefully.
204    ///
205    /// This sends a shutdown command to the background task and waits for it to complete.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the background task failed.
210    pub async fn shutdown(mut self) -> Result<()> {
211        debug!("Shutting down LSP client");
212
213        let _ = self.command_tx.send(ClientCommand::Shutdown).await;
214
215        if let Some(task) = self.receiver_task.take() {
216            task.await
217                .map_err(|e| Error::Transport(format!("Receiver task failed: {e}")))??;
218        }
219
220        *self.state.lock().await = super::ServerState::Shutdown;
221
222        Ok(())
223    }
224
225    /// Background task: handle message I/O.
226    ///
227    /// This task runs in the background, handling:
228    /// - Outbound requests and notifications
229    /// - Inbound responses and server notifications
230    /// - Matching responses to pending requests
231    async fn message_loop(
232        mut transport: LspTransport,
233        mut command_rx: mpsc::Receiver<ClientCommand>,
234        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<Result<Value>>>>>,
235    ) -> Result<()> {
236        loop {
237            tokio::select! {
238                Some(command) = command_rx.recv() => {
239                    match command {
240                        ClientCommand::SendRequest { request, response_tx } => {
241                            pending_requests.lock().await.insert(
242                                request.id.clone(),
243                                response_tx,
244                            );
245
246                            let value = serde_json::to_value(&request)?;
247                            transport.send(&value).await?;
248                        }
249                        ClientCommand::SendNotification { method, params } => {
250                            let notification = serde_json::json!({
251                                "jsonrpc": "2.0",
252                                "method": method,
253                                "params": params,
254                            });
255                            transport.send(&notification).await?;
256                        }
257                        ClientCommand::Shutdown => {
258                            debug!("Client shutdown requested");
259                            break;
260                        }
261                    }
262                }
263
264                message = transport.receive() => {
265                    let message = message?;
266                    match message {
267                        InboundMessage::Response(response) => {
268                            trace!("Received response: id={:?}", response.id);
269
270                            let sender = pending_requests.lock().await.remove(&response.id);
271
272                            if let Some(sender) = sender {
273                                if let Some(error) = response.error {
274                                    error!("LSP error response: {} (code {})", error.message, error.code);
275                                    let _ = sender.send(Err(Error::LspServerError {
276                                        code: error.code,
277                                        message: error.message,
278                                    }));
279                                } else if let Some(result) = response.result {
280                                    let _ = sender.send(Ok(result));
281                                } else {
282                                    warn!("Response with neither result nor error: {:?}", response.id);
283                                }
284                            } else {
285                                warn!("Received response for unknown request ID: {:?}", response.id);
286                            }
287                        }
288                        InboundMessage::Notification(notification) => {
289                            debug!("Received notification: {}", notification.method);
290                            // TODO: Handle server notifications (diagnostics, etc.)
291                            // For Phase 2, just log and ignore
292                        }
293                    }
294                }
295            }
296        }
297
298        Ok(())
299    }
300}
301
302#[cfg(test)]
303#[allow(clippy::unwrap_used)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_request_id_generation() {
309        let counter = AtomicI64::new(1);
310
311        let id1 = counter.fetch_add(1, Ordering::SeqCst);
312        let id2 = counter.fetch_add(1, Ordering::SeqCst);
313        let id3 = counter.fetch_add(1, Ordering::SeqCst);
314
315        assert_eq!(id1, 1);
316        assert_eq!(id2, 2);
317        assert_eq!(id3, 3);
318    }
319
320    #[test]
321    fn test_client_creation() {
322        let config = LspServerConfig::rust_analyzer();
323
324        let client = LspClient::new(config);
325        assert_eq!(client.language_id(), "rust");
326    }
327}