use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::{Duration, timeout};
use tracing::{debug, error, trace, warn};
use crate::config::LspServerConfig;
use crate::error::{Error, Result};
use crate::lsp::transport::LspTransport;
use crate::lsp::types::{InboundMessage, JsonRpcRequest, LspNotification, RequestId};
const JSONRPC_VERSION: &str = "2.0";
type PendingRequests = HashMap<RequestId, oneshot::Sender<Result<Value>>>;
#[derive(Debug)]
pub struct LspClient {
config: LspServerConfig,
state: Arc<Mutex<super::ServerState>>,
request_counter: Arc<AtomicI64>,
command_tx: mpsc::Sender<ClientCommand>,
receiver_task: Option<JoinHandle<Result<()>>>,
}
impl Clone for LspClient {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
state: Arc::clone(&self.state),
request_counter: Arc::clone(&self.request_counter),
command_tx: self.command_tx.clone(),
receiver_task: None,
}
}
}
enum ClientCommand {
SendRequest {
request: JsonRpcRequest,
response_tx: oneshot::Sender<Result<Value>>,
},
SendNotification {
method: String,
params: Option<Value>,
},
Shutdown,
}
impl LspClient {
#[must_use]
pub fn new(config: LspServerConfig) -> Self {
let (command_tx, _command_rx) = mpsc::channel(1);
Self {
config,
state: Arc::new(Mutex::new(super::ServerState::Uninitialized)),
request_counter: Arc::new(AtomicI64::new(1)),
command_tx,
receiver_task: None,
}
}
pub(crate) fn from_transport(config: LspServerConfig, transport: LspTransport) -> Self {
let state = Arc::new(Mutex::new(super::ServerState::Initializing));
let request_counter = Arc::new(AtomicI64::new(1));
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
let (command_tx, command_rx) = mpsc::channel(100);
let receiver_task = tokio::spawn(Self::message_loop(
transport,
command_rx,
pending_requests,
None,
));
Self {
config,
state,
request_counter,
command_tx,
receiver_task: Some(receiver_task),
}
}
#[allow(dead_code)] pub(crate) fn from_transport_with_notifications(
config: LspServerConfig,
transport: LspTransport,
notification_tx: mpsc::Sender<LspNotification>,
) -> Self {
let state = Arc::new(Mutex::new(super::ServerState::Initializing));
let request_counter = Arc::new(AtomicI64::new(1));
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
let (command_tx, command_rx) = mpsc::channel(100);
let receiver_task = tokio::spawn(Self::message_loop(
transport,
command_rx,
pending_requests,
Some(notification_tx),
));
Self {
config,
state,
request_counter,
command_tx,
receiver_task: Some(receiver_task),
}
}
#[must_use]
pub fn language_id(&self) -> &str {
&self.config.language_id
}
pub async fn state(&self) -> super::ServerState {
*self.state.lock().await
}
pub async fn request<P, R>(
&self,
method: &str,
params: P,
timeout_duration: Duration,
) -> Result<R>
where
P: Serialize,
R: DeserializeOwned,
{
let id = RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst));
let params_value = serde_json::to_value(params)?;
let (response_tx, response_rx) = oneshot::channel();
let request = JsonRpcRequest {
jsonrpc: JSONRPC_VERSION.to_string(),
id: id.clone(),
method: method.to_string(),
params: Some(params_value),
};
debug!("Sending request: {} (id={:?})", method, id);
self.command_tx
.send(ClientCommand::SendRequest {
request,
response_tx,
})
.await
.map_err(|_| Error::ServerTerminated)?;
let result_value = timeout(timeout_duration, response_rx)
.await
.map_err(|_| Error::Timeout(timeout_duration.as_secs()))?
.map_err(|_| Error::ServerTerminated)??;
serde_json::from_value(result_value)
.map_err(|e| Error::LspProtocolError(format!("Failed to deserialize response: {e}")))
}
pub async fn notify<P>(&self, method: &str, params: P) -> Result<()>
where
P: Serialize,
{
let params_value = serde_json::to_value(params)?;
debug!("Sending notification: {}", method);
self.command_tx
.send(ClientCommand::SendNotification {
method: method.to_string(),
params: Some(params_value),
})
.await
.map_err(|_| Error::ServerTerminated)?;
Ok(())
}
pub async fn shutdown(mut self) -> Result<()> {
debug!("Shutting down LSP client");
let _ = self.command_tx.send(ClientCommand::Shutdown).await;
if let Some(task) = self.receiver_task.take() {
task.await
.map_err(|e| Error::Transport(format!("Receiver task failed: {e}")))??;
}
*self.state.lock().await = super::ServerState::Shutdown;
Ok(())
}
async fn message_loop(
mut transport: LspTransport,
mut command_rx: mpsc::Receiver<ClientCommand>,
pending_requests: Arc<Mutex<PendingRequests>>,
notification_tx: Option<mpsc::Sender<LspNotification>>,
) -> Result<()> {
debug!("Message loop started");
let result = Self::message_loop_inner(
&mut transport,
&mut command_rx,
&pending_requests,
notification_tx.as_ref(),
)
.await;
if let Err(ref e) = result {
error!("Message loop exiting with error: {}", e);
} else {
debug!("Message loop exiting normally");
}
result
}
async fn message_loop_inner(
transport: &mut LspTransport,
command_rx: &mut mpsc::Receiver<ClientCommand>,
pending_requests: &Arc<Mutex<PendingRequests>>,
notification_tx: Option<&mpsc::Sender<LspNotification>>,
) -> Result<()> {
loop {
tokio::select! {
Some(command) = command_rx.recv() => {
match command {
ClientCommand::SendRequest { request, response_tx } => {
pending_requests.lock().await.insert(
request.id.clone(),
response_tx,
);
let value = serde_json::to_value(&request)?;
transport.send(&value).await?;
}
ClientCommand::SendNotification { method, params } => {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
transport.send(¬ification).await?;
}
ClientCommand::Shutdown => {
debug!("Client shutdown requested");
break;
}
}
}
message = transport.receive() => {
let message = match message {
Ok(m) => m,
Err(e) => {
error!("Transport receive error: {}", e);
return Err(e);
}
};
match message {
InboundMessage::Response(response) => {
trace!("Received response: id={:?}", response.id);
let sender = pending_requests.lock().await.remove(&response.id);
if let Some(sender) = sender {
if let Some(error) = response.error {
let message = if error.message.len() > 200 {
format!("{}... (truncated)", &error.message[..200])
} else {
error.message.clone()
};
error!("LSP error response: {} (code {})", message, error.code);
let _ = sender.send(Err(Error::LspServerError {
code: error.code,
message: error.message,
}));
} else if let Some(result) = response.result {
let _ = sender.send(Ok(result));
} else {
trace!("Response with null result: {:?}", response.id);
let _ = sender.send(Ok(Value::Null));
}
} else {
warn!("Received response for unknown request ID: {:?}", response.id);
}
}
InboundMessage::Notification(notification) => {
debug!("Received notification: {}", notification.method);
let typed = LspNotification::parse(¬ification.method, notification.params);
if let Some(tx) = notification_tx {
if let LspNotification::PublishDiagnostics(ref params) = typed {
debug!(
"Forwarding diagnostics for {}: {} items",
params.uri.as_str(),
params.diagnostics.len()
);
} else {
trace!("Forwarding notification: {:?}", typed);
}
if tx.try_send(typed).is_err() {
warn!("Notification channel full or closed, dropping notification");
}
}
}
}
}
}
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_request_id_generation() {
let counter = AtomicI64::new(1);
let id1 = counter.fetch_add(1, Ordering::SeqCst);
let id2 = counter.fetch_add(1, Ordering::SeqCst);
let id3 = counter.fetch_add(1, Ordering::SeqCst);
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[test]
fn test_client_creation() {
let config = LspServerConfig::rust_analyzer();
let client = LspClient::new(config);
assert_eq!(client.language_id(), "rust");
}
#[test]
fn test_client_clone() {
let config = LspServerConfig::rust_analyzer();
let client = LspClient::new(config);
#[allow(clippy::redundant_clone)]
let cloned = client.clone();
assert_eq!(cloned.language_id(), "rust");
assert!(
cloned.receiver_task.is_none(),
"Cloned client should not own receiver task"
);
}
#[tokio::test]
async fn test_null_response_handling() {
use crate::lsp::types::{JsonRpcResponse, RequestId};
let pending_requests: Arc<Mutex<PendingRequests>> = Arc::new(Mutex::new(HashMap::new()));
let (response_tx, response_rx) = oneshot::channel::<Result<Value>>();
pending_requests
.lock()
.await
.insert(RequestId::Number(1), response_tx);
let null_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: RequestId::Number(1),
result: None,
error: None,
};
let sender = pending_requests.lock().await.remove(&null_response.id);
if let Some(sender) = sender {
let _ = sender.send(Ok(Value::Null));
}
let timeout_result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), response_rx).await;
assert!(timeout_result.is_ok(), "Should not timeout");
let channel_result = timeout_result.unwrap();
assert!(
channel_result.is_ok(),
"Channel should not be closed: {:?}",
channel_result.err()
);
let response = channel_result.unwrap();
assert!(
response.is_ok(),
"Should receive Ok(Value::Null), not Err: {:?}",
response.err()
);
let value = response.unwrap();
assert_eq!(value, Value::Null, "Should receive Value::Null");
}
#[tokio::test]
async fn test_error_response_handling() {
use crate::lsp::types::{JsonRpcError, JsonRpcResponse, RequestId};
let pending_requests: Arc<Mutex<PendingRequests>> = Arc::new(Mutex::new(HashMap::new()));
let (response_tx, response_rx) = oneshot::channel::<Result<Value>>();
pending_requests
.lock()
.await
.insert(RequestId::Number(1), response_tx);
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: RequestId::Number(1),
result: None,
error: Some(JsonRpcError {
code: -32601,
message: "Method not found".to_string(),
data: None,
}),
};
let sender = pending_requests.lock().await.remove(&error_response.id);
if let Some(sender) = sender {
if let Some(error) = error_response.error {
let _ = sender.send(Err(Error::LspServerError {
code: error.code,
message: error.message,
}));
}
}
let result = response_rx.await.unwrap();
assert!(result.is_err(), "Should receive error");
if let Err(Error::LspServerError { code, message }) = result {
assert_eq!(code, -32601);
assert_eq!(message, "Method not found");
} else {
panic!("Expected LspServerError");
}
}
#[tokio::test]
async fn test_unknown_request_id() {
use crate::lsp::types::{JsonRpcResponse, RequestId};
let pending_requests: Arc<Mutex<PendingRequests>> = Arc::new(Mutex::new(HashMap::new()));
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: RequestId::Number(999),
result: Some(Value::Null),
error: None,
};
let sender = pending_requests.lock().await.remove(&response.id);
assert!(sender.is_none(), "Should not find sender for unknown ID");
}
#[tokio::test]
async fn test_long_error_message_truncation() {
use crate::lsp::types::{JsonRpcError, JsonRpcResponse, RequestId};
let pending_requests: Arc<Mutex<PendingRequests>> = Arc::new(Mutex::new(HashMap::new()));
let (response_tx, response_rx) = oneshot::channel::<Result<Value>>();
pending_requests
.lock()
.await
.insert(RequestId::Number(1), response_tx);
let long_message = "x".repeat(250);
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: RequestId::Number(1),
result: None,
error: Some(JsonRpcError {
code: -32700,
message: long_message.clone(),
data: None,
}),
};
let sender = pending_requests.lock().await.remove(&error_response.id);
if let Some(sender) = sender {
if let Some(error) = error_response.error {
let _ = sender.send(Err(Error::LspServerError {
code: error.code,
message: error.message,
}));
}
}
let result = response_rx.await.unwrap();
assert!(result.is_err());
if let Err(Error::LspServerError { code, message }) = result {
assert_eq!(code, -32700);
assert_eq!(
message.len(),
250,
"Full message should be preserved in Error"
);
} else {
panic!("Expected LspServerError");
}
}
#[tokio::test]
async fn test_concurrent_request_ids() {
let counter = Arc::new(AtomicI64::new(1));
let counter1 = Arc::clone(&counter);
let counter2 = Arc::clone(&counter);
let counter3 = Arc::clone(&counter);
let handles = vec![
tokio::spawn(async move { counter1.fetch_add(1, Ordering::SeqCst) }),
tokio::spawn(async move { counter2.fetch_add(1, Ordering::SeqCst) }),
tokio::spawn(async move { counter3.fetch_add(1, Ordering::SeqCst) }),
];
let mut ids = Vec::new();
for handle in handles {
ids.push(handle.await.unwrap());
}
ids.sort_unstable();
assert_eq!(ids, vec![1, 2, 3], "IDs should be unique and sequential");
}
#[test]
fn test_jsonrpc_version_constant() {
assert_eq!(JSONRPC_VERSION, "2.0");
}
}