1use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicI64, Ordering};
6
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use tokio::sync::{Mutex, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, trace, warn};
14
15use crate::config::LspServerConfig;
16use crate::error::{Error, Result};
17use crate::lsp::transport::LspTransport;
18use crate::lsp::types::{InboundMessage, JsonRpcRequest, RequestId};
19
20#[derive(Debug)]
28pub struct LspClient {
29 config: LspServerConfig,
31
32 state: Arc<Mutex<super::ServerState>>,
34
35 request_counter: Arc<AtomicI64>,
37
38 command_tx: mpsc::Sender<ClientCommand>,
40
41 receiver_task: Option<JoinHandle<Result<()>>>,
43}
44
45impl Clone for LspClient {
46 fn clone(&self) -> Self {
47 Self {
48 config: self.config.clone(),
49 state: Arc::clone(&self.state),
50 request_counter: Arc::clone(&self.request_counter),
51 command_tx: self.command_tx.clone(),
52 receiver_task: None,
53 }
54 }
55}
56
57enum ClientCommand {
59 SendRequest {
61 request: JsonRpcRequest,
62 response_tx: oneshot::Sender<Result<Value>>,
63 },
64 SendNotification {
66 method: String,
67 params: Option<Value>,
68 },
69 Shutdown,
71}
72
73impl LspClient {
74 #[must_use]
79 pub fn new(config: LspServerConfig) -> Self {
80 let (command_tx, _command_rx) = mpsc::channel(100);
81
82 Self {
83 config,
84 state: Arc::new(Mutex::new(super::ServerState::Uninitialized)),
85 request_counter: Arc::new(AtomicI64::new(1)),
86 command_tx,
87 receiver_task: None,
88 }
89 }
90
91 pub(crate) fn from_transport(config: LspServerConfig, transport: LspTransport) -> Self {
95 let state = Arc::new(Mutex::new(super::ServerState::Initializing));
96 let request_counter = Arc::new(AtomicI64::new(1));
97 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
98
99 let (command_tx, command_rx) = mpsc::channel(100);
100
101 let receiver_task =
102 tokio::spawn(Self::message_loop(transport, command_rx, pending_requests));
103
104 Self {
105 config,
106 state,
107 request_counter,
108 command_tx,
109 receiver_task: Some(receiver_task),
110 }
111 }
112
113 #[must_use]
115 pub fn language_id(&self) -> &str {
116 &self.config.language_id
117 }
118
119 pub async fn state(&self) -> super::ServerState {
121 *self.state.lock().await
122 }
123
124 pub async fn request<P, R>(
139 &self,
140 method: &str,
141 params: P,
142 timeout_duration: Duration,
143 ) -> Result<R>
144 where
145 P: Serialize,
146 R: DeserializeOwned,
147 {
148 let id = RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst));
149 let params_value = serde_json::to_value(params)?;
150
151 let (response_tx, response_rx) = oneshot::channel();
152
153 let request = JsonRpcRequest {
154 jsonrpc: "2.0".to_string(),
155 id: id.clone(),
156 method: method.to_string(),
157 params: Some(params_value),
158 };
159
160 debug!("Sending request: {} (id={:?})", method, id);
161
162 self.command_tx
163 .send(ClientCommand::SendRequest {
164 request,
165 response_tx,
166 })
167 .await
168 .map_err(|_| Error::ServerTerminated)?;
169
170 let result_value = timeout(timeout_duration, response_rx)
171 .await
172 .map_err(|_| Error::Timeout(timeout_duration.as_secs()))?
173 .map_err(|_| Error::ServerTerminated)??;
174
175 serde_json::from_value(result_value)
176 .map_err(|e| Error::LspProtocolError(format!("Failed to deserialize response: {e}")))
177 }
178
179 pub async fn notify<P>(&self, method: &str, params: P) -> Result<()>
185 where
186 P: Serialize,
187 {
188 let params_value = serde_json::to_value(params)?;
189
190 debug!("Sending notification: {}", method);
191
192 self.command_tx
193 .send(ClientCommand::SendNotification {
194 method: method.to_string(),
195 params: Some(params_value),
196 })
197 .await
198 .map_err(|_| Error::ServerTerminated)?;
199
200 Ok(())
201 }
202
203 pub async fn shutdown(mut self) -> Result<()> {
211 debug!("Shutting down LSP client");
212
213 let _ = self.command_tx.send(ClientCommand::Shutdown).await;
214
215 if let Some(task) = self.receiver_task.take() {
216 task.await
217 .map_err(|e| Error::Transport(format!("Receiver task failed: {e}")))??;
218 }
219
220 *self.state.lock().await = super::ServerState::Shutdown;
221
222 Ok(())
223 }
224
225 async fn message_loop(
232 mut transport: LspTransport,
233 mut command_rx: mpsc::Receiver<ClientCommand>,
234 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<Result<Value>>>>>,
235 ) -> Result<()> {
236 loop {
237 tokio::select! {
238 Some(command) = command_rx.recv() => {
239 match command {
240 ClientCommand::SendRequest { request, response_tx } => {
241 pending_requests.lock().await.insert(
242 request.id.clone(),
243 response_tx,
244 );
245
246 let value = serde_json::to_value(&request)?;
247 transport.send(&value).await?;
248 }
249 ClientCommand::SendNotification { method, params } => {
250 let notification = serde_json::json!({
251 "jsonrpc": "2.0",
252 "method": method,
253 "params": params,
254 });
255 transport.send(¬ification).await?;
256 }
257 ClientCommand::Shutdown => {
258 debug!("Client shutdown requested");
259 break;
260 }
261 }
262 }
263
264 message = transport.receive() => {
265 let message = message?;
266 match message {
267 InboundMessage::Response(response) => {
268 trace!("Received response: id={:?}", response.id);
269
270 let sender = pending_requests.lock().await.remove(&response.id);
271
272 if let Some(sender) = sender {
273 if let Some(error) = response.error {
274 error!("LSP error response: {} (code {})", error.message, error.code);
275 let _ = sender.send(Err(Error::LspServerError {
276 code: error.code,
277 message: error.message,
278 }));
279 } else if let Some(result) = response.result {
280 let _ = sender.send(Ok(result));
281 } else {
282 warn!("Response with neither result nor error: {:?}", response.id);
283 }
284 } else {
285 warn!("Received response for unknown request ID: {:?}", response.id);
286 }
287 }
288 InboundMessage::Notification(notification) => {
289 debug!("Received notification: {}", notification.method);
290 }
293 }
294 }
295 }
296 }
297
298 Ok(())
299 }
300}
301
302#[cfg(test)]
303#[allow(clippy::unwrap_used)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_request_id_generation() {
309 let counter = AtomicI64::new(1);
310
311 let id1 = counter.fetch_add(1, Ordering::SeqCst);
312 let id2 = counter.fetch_add(1, Ordering::SeqCst);
313 let id3 = counter.fetch_add(1, Ordering::SeqCst);
314
315 assert_eq!(id1, 1);
316 assert_eq!(id2, 2);
317 assert_eq!(id3, 3);
318 }
319
320 #[test]
321 fn test_client_creation() {
322 let config = LspServerConfig::rust_analyzer();
323
324 let client = LspClient::new(config);
325 assert_eq!(client.language_id(), "rust");
326 }
327}