1use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicI64, Ordering};
6
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use tokio::sync::{Mutex, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, trace, warn};
14
15use crate::config::LspServerConfig;
16use crate::error::{Error, Result};
17use crate::lsp::transport::LspTransport;
18use crate::lsp::types::{InboundMessage, JsonRpcRequest, LspNotification, RequestId};
19
20const JSONRPC_VERSION: &str = "2.0";
22
23type PendingRequests = HashMap<RequestId, oneshot::Sender<Result<Value>>>;
25
26#[derive(Debug)]
34pub struct LspClient {
35 config: LspServerConfig,
37
38 state: Arc<Mutex<super::ServerState>>,
40
41 request_counter: Arc<AtomicI64>,
43
44 command_tx: mpsc::Sender<ClientCommand>,
46
47 receiver_task: Option<JoinHandle<Result<()>>>,
49}
50
51impl Clone for LspClient {
52 fn clone(&self) -> Self {
57 Self {
58 config: self.config.clone(),
59 state: Arc::clone(&self.state),
60 request_counter: Arc::clone(&self.request_counter),
61 command_tx: self.command_tx.clone(),
62 receiver_task: None,
63 }
64 }
65}
66
67enum ClientCommand {
69 SendRequest {
71 request: JsonRpcRequest,
72 response_tx: oneshot::Sender<Result<Value>>,
73 },
74 SendNotification {
76 method: String,
77 params: Option<Value>,
78 },
79 Shutdown,
81}
82
83impl LspClient {
84 #[must_use]
89 pub fn new(config: LspServerConfig) -> Self {
90 let (command_tx, _command_rx) = mpsc::channel(1); Self {
96 config,
97 state: Arc::new(Mutex::new(super::ServerState::Uninitialized)),
98 request_counter: Arc::new(AtomicI64::new(1)),
99 command_tx,
100 receiver_task: None,
101 }
102 }
103
104 pub(crate) fn from_transport(config: LspServerConfig, transport: LspTransport) -> Self {
108 let state = Arc::new(Mutex::new(super::ServerState::Initializing));
109 let request_counter = Arc::new(AtomicI64::new(1));
110 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
111
112 let (command_tx, command_rx) = mpsc::channel(100);
113
114 let receiver_task = tokio::spawn(Self::message_loop(
115 transport,
116 command_rx,
117 pending_requests,
118 None,
119 ));
120
121 Self {
122 config,
123 state,
124 request_counter,
125 command_tx,
126 receiver_task: Some(receiver_task),
127 }
128 }
129
130 #[allow(dead_code)] pub(crate) fn from_transport_with_notifications(
136 config: LspServerConfig,
137 transport: LspTransport,
138 notification_tx: mpsc::Sender<LspNotification>,
139 ) -> Self {
140 let state = Arc::new(Mutex::new(super::ServerState::Initializing));
141 let request_counter = Arc::new(AtomicI64::new(1));
142 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
143
144 let (command_tx, command_rx) = mpsc::channel(100);
145
146 let receiver_task = tokio::spawn(Self::message_loop(
147 transport,
148 command_rx,
149 pending_requests,
150 Some(notification_tx),
151 ));
152
153 Self {
154 config,
155 state,
156 request_counter,
157 command_tx,
158 receiver_task: Some(receiver_task),
159 }
160 }
161
162 #[must_use]
164 pub fn language_id(&self) -> &str {
165 &self.config.language_id
166 }
167
168 pub async fn state(&self) -> super::ServerState {
170 *self.state.lock().await
171 }
172
173 pub async fn request<P, R>(
188 &self,
189 method: &str,
190 params: P,
191 timeout_duration: Duration,
192 ) -> Result<R>
193 where
194 P: Serialize,
195 R: DeserializeOwned,
196 {
197 let id = RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst));
198 let params_value = serde_json::to_value(params)?;
199
200 let (response_tx, response_rx) = oneshot::channel();
201
202 let request = JsonRpcRequest {
203 jsonrpc: JSONRPC_VERSION.to_string(),
204 id: id.clone(),
205 method: method.to_string(),
206 params: Some(params_value),
207 };
208
209 debug!("Sending request: {} (id={:?})", method, id);
210
211 self.command_tx
212 .send(ClientCommand::SendRequest {
213 request,
214 response_tx,
215 })
216 .await
217 .map_err(|_| Error::ServerTerminated)?;
218
219 let result_value = timeout(timeout_duration, response_rx)
220 .await
221 .map_err(|_| Error::Timeout(timeout_duration.as_secs()))?
222 .map_err(|_| Error::ServerTerminated)??;
223
224 serde_json::from_value(result_value)
225 .map_err(|e| Error::LspProtocolError(format!("Failed to deserialize response: {e}")))
226 }
227
228 pub async fn notify<P>(&self, method: &str, params: P) -> Result<()>
234 where
235 P: Serialize,
236 {
237 let params_value = serde_json::to_value(params)?;
238
239 debug!("Sending notification: {}", method);
240
241 self.command_tx
242 .send(ClientCommand::SendNotification {
243 method: method.to_string(),
244 params: Some(params_value),
245 })
246 .await
247 .map_err(|_| Error::ServerTerminated)?;
248
249 Ok(())
250 }
251
252 pub async fn shutdown(mut self) -> Result<()> {
260 debug!("Shutting down LSP client");
261
262 let _ = self.command_tx.send(ClientCommand::Shutdown).await;
263
264 if let Some(task) = self.receiver_task.take() {
265 task.await
266 .map_err(|e| Error::Transport(format!("Receiver task failed: {e}")))??;
267 }
268
269 *self.state.lock().await = super::ServerState::Shutdown;
270
271 Ok(())
272 }
273
274 async fn message_loop(
281 mut transport: LspTransport,
282 mut command_rx: mpsc::Receiver<ClientCommand>,
283 pending_requests: Arc<Mutex<PendingRequests>>,
284 notification_tx: Option<mpsc::Sender<LspNotification>>,
285 ) -> Result<()> {
286 debug!("Message loop started");
287 let result = Self::message_loop_inner(
288 &mut transport,
289 &mut command_rx,
290 &pending_requests,
291 notification_tx.as_ref(),
292 )
293 .await;
294 if let Err(ref e) = result {
295 error!("Message loop exiting with error: {}", e);
296 } else {
297 debug!("Message loop exiting normally");
298 }
299 result
300 }
301
302 async fn message_loop_inner(
303 transport: &mut LspTransport,
304 command_rx: &mut mpsc::Receiver<ClientCommand>,
305 pending_requests: &Arc<Mutex<PendingRequests>>,
306 notification_tx: Option<&mpsc::Sender<LspNotification>>,
307 ) -> Result<()> {
308 loop {
309 tokio::select! {
310 Some(command) = command_rx.recv() => {
311 match command {
312 ClientCommand::SendRequest { request, response_tx } => {
313 pending_requests.lock().await.insert(
314 request.id.clone(),
315 response_tx,
316 );
317
318 let value = serde_json::to_value(&request)?;
319 transport.send(&value).await?;
320 }
321 ClientCommand::SendNotification { method, params } => {
322 let notification = serde_json::json!({
323 "jsonrpc": "2.0",
324 "method": method,
325 "params": params,
326 });
327 transport.send(¬ification).await?;
328 }
329 ClientCommand::Shutdown => {
330 debug!("Client shutdown requested");
331 break;
332 }
333 }
334 }
335
336 message = transport.receive() => {
337 let message = match message {
338 Ok(m) => m,
339 Err(e) => {
340 error!("Transport receive error: {}", e);
341 return Err(e);
342 }
343 };
344 match message {
345 InboundMessage::Response(response) => {
346 trace!("Received response: id={:?}", response.id);
347
348 let sender = pending_requests.lock().await.remove(&response.id);
349
350 if let Some(sender) = sender {
351 if let Some(error) = response.error {
352 let message = if error.message.len() > 200 {
353 format!("{}... (truncated)", &error.message[..200])
354 } else {
355 error.message.clone()
356 };
357 error!("LSP error response: {} (code {})", message, error.code);
358 let _ = sender.send(Err(Error::LspServerError {
359 code: error.code,
360 message: error.message,
361 }));
362 } else if let Some(result) = response.result {
363 let _ = sender.send(Ok(result));
364 } else {
365 trace!("Response with null result: {:?}", response.id);
368 let _ = sender.send(Ok(Value::Null));
369 }
370 } else {
371 warn!("Received response for unknown request ID: {:?}", response.id);
372 }
373 }
374 InboundMessage::Notification(notification) => {
375 debug!("Received notification: {}", notification.method);
376
377 let typed = LspNotification::parse(¬ification.method, notification.params);
379
380 if let Some(tx) = notification_tx {
382 if let LspNotification::PublishDiagnostics(ref params) = typed {
384 debug!(
385 "Forwarding diagnostics for {}: {} items",
386 params.uri.as_str(),
387 params.diagnostics.len()
388 );
389 } else {
390 trace!("Forwarding notification: {:?}", typed);
391 }
392
393 if tx.try_send(typed).is_err() {
395 warn!("Notification channel full or closed, dropping notification");
396 }
397 }
398 }
399 }
400 }
401 }
402 }
403
404 Ok(())
405 }
406}
407
408#[cfg(test)]
409#[allow(clippy::unwrap_used)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_request_id_generation() {
415 let counter = AtomicI64::new(1);
416
417 let id1 = counter.fetch_add(1, Ordering::SeqCst);
418 let id2 = counter.fetch_add(1, Ordering::SeqCst);
419 let id3 = counter.fetch_add(1, Ordering::SeqCst);
420
421 assert_eq!(id1, 1);
422 assert_eq!(id2, 2);
423 assert_eq!(id3, 3);
424 }
425
426 #[test]
427 fn test_client_creation() {
428 let config = LspServerConfig::rust_analyzer();
429
430 let client = LspClient::new(config);
431 assert_eq!(client.language_id(), "rust");
432 }
433
434 #[tokio::test]
435 async fn test_null_response_handling() {
436 use crate::lsp::types::{JsonRpcResponse, RequestId};
437
438 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
439
440 let (response_tx, response_rx) = oneshot::channel::<Result<Value>>();
441
442 pending_requests
443 .lock()
444 .await
445 .insert(RequestId::Number(1), response_tx);
446
447 let null_response = JsonRpcResponse {
448 jsonrpc: "2.0".to_string(),
449 id: RequestId::Number(1),
450 result: None,
451 error: None,
452 };
453
454 let sender = pending_requests.lock().await.remove(&null_response.id);
455 if let Some(sender) = sender {
456 let _ = sender.send(Ok(Value::Null));
457 }
458
459 let timeout_result =
460 tokio::time::timeout(tokio::time::Duration::from_millis(100), response_rx).await;
461
462 assert!(timeout_result.is_ok(), "Should not timeout");
463
464 let channel_result = timeout_result.unwrap();
465 assert!(
466 channel_result.is_ok(),
467 "Channel should not be closed: {:?}",
468 channel_result.err()
469 );
470
471 let response = channel_result.unwrap();
472 assert!(
473 response.is_ok(),
474 "Should receive Ok(Value::Null), not Err: {:?}",
475 response.err()
476 );
477
478 let value = response.unwrap();
479 assert_eq!(value, Value::Null, "Should receive Value::Null");
480 }
481}