use crate::io::transport::Transport;
use crate::lsp::framing::LspFraming;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Mutex, mpsc};
use tracing::{debug, error, trace};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: serde_json::Value,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcErrorObject>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcErrorObject {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum JsonRpcErrorCode {
InternalError = -32603,
}
impl JsonRpcErrorCode {}
#[derive(Debug, thiserror::Error)]
pub enum JsonRpcError {
#[error("JSON-RPC server error ({code}): {message}")]
Server {
code: i32,
message: String,
data: Option<serde_json::Value>,
},
#[error("Transport error: {0}")]
Transport(String),
#[error("Serialization error: {0}")]
Serialization(serde_json::Error),
#[error("Deserialization error: {0}")]
Deserialization(serde_json::Error),
#[error("Request timeout")]
Timeout,
#[error("Request was cancelled")]
RequestCancelled,
#[error("Missing result in response")]
MissingResult,
}
type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) + Send + Sync>;
type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> JsonRpcResponse + Send + Sync>;
#[derive(Debug, Clone)]
enum JsonRpcMessage {
Request {
method: String,
id: serde_json::Value,
params: Option<serde_json::Value>,
},
Notification {
method: String,
params: Option<serde_json::Value>,
},
Response {
id: serde_json::Value,
result: Option<serde_json::Value>,
error: Option<JsonRpcErrorObject>,
},
Invalid(String),
}
impl JsonRpcMessage {
fn classify(message: &str) -> Self {
let parsed = match serde_json::from_str::<serde_json::Value>(message) {
Ok(value) => value,
Err(e) => return Self::Invalid(format!("JSON parse error: {e}")),
};
let method = parsed
.get("method")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let id = parsed.get("id").cloned();
let params = parsed.get("params").cloned();
match (method, id) {
(Some(method), Some(id)) if !id.is_null() => Self::Request { method, id, params },
(Some(method), _) => Self::Notification { method, params },
(None, Some(id)) if !id.is_null() => {
let result = parsed.get("result").cloned();
let error = parsed
.get("error")
.and_then(|e| serde_json::from_value::<JsonRpcErrorObject>(e.clone()).ok());
Self::Response { id, result, error }
}
_ => Self::Invalid("Missing required fields or invalid structure".to_string()),
}
}
}
#[derive(Default)]
struct ClientState {
notification_handler: Option<NotificationHandler>,
request_handler: Option<RequestHandler>,
pending_requests: HashMap<u64, mpsc::UnboundedSender<JsonRpcResponse>>,
}
pub struct JsonRpcClient<T: Transport> {
outbound_sender: mpsc::UnboundedSender<String>,
request_id: AtomicU64,
state: Arc<Mutex<ClientState>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Transport + 'static> JsonRpcClient<T> {
pub fn new(transport: T) -> Self {
let framed_transport = LspFraming::new(transport);
let transport_arc = Arc::new(Mutex::new(framed_transport));
let (outbound_sender, mut outbound_receiver) = mpsc::unbounded_channel::<String>();
let state = Arc::new(Mutex::new(ClientState::default()));
let transport_for_task = Arc::clone(&transport_arc);
let state_for_task = Arc::clone(&state);
let outbound_sender_for_task = outbound_sender.clone();
tokio::spawn(async move {
loop {
tokio::select! {
Some(message) = outbound_receiver.recv() => {
let mut transport = transport_for_task.lock().await;
if let Err(e) = transport.send(&message).await {
error!("Failed to send message: {}", e);
break;
}
drop(transport);
}
result = async {
let mut transport = transport_for_task.lock().await;
transport.receive().await
} => {
match result {
Ok(message) => {
Self::process_inbound_message(message, &state_for_task, &outbound_sender_for_task).await;
}
Err(e) => {
error!("Failed to receive message: {}", e);
break;
}
}
}
}
}
trace!("Transport handler task finished");
});
Self {
outbound_sender,
request_id: AtomicU64::new(1),
state,
_phantom: std::marker::PhantomData,
}
}
pub async fn on_notification<F>(&self, handler: F)
where
F: Fn(JsonRpcNotification) + Send + Sync + 'static,
{
let mut state = self.state.lock().await;
state.notification_handler = Some(Arc::new(handler));
}
pub async fn on_request<F>(&self, handler: F)
where
F: Fn(JsonRpcRequest) -> JsonRpcResponse + Send + Sync + 'static,
{
let mut state = self.state.lock().await;
state.request_handler = Some(Arc::new(handler));
}
async fn process_inbound_message(
message: String,
state: &Arc<Mutex<ClientState>>,
outbound_sender: &mpsc::UnboundedSender<String>,
) {
trace!("JsonRpcClient: Received {} bytes", message.len());
let classified_message = JsonRpcMessage::classify(&message);
match classified_message {
JsonRpcMessage::Request { method, id, params } => {
debug!("Received request: {} with id: {:?}", method, id);
let request_handler = {
let state = state.lock().await;
state.request_handler.clone()
};
if let Some(handler) = request_handler {
let request = JsonRpcRequest {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
method,
id: id.clone(),
params,
};
let response = handler(request);
let response_json = match serde_json::to_string(&response) {
Ok(json) => json,
Err(e) => {
debug!("Failed to serialize response: {}", e);
return;
}
};
if outbound_sender.send(response_json).is_err() {
debug!("Failed to send response back to server");
}
} else {
debug!("No request handler registered for method: {}", method);
}
}
JsonRpcMessage::Notification { method, params } => {
debug!("Received notification: {}", method);
let notification_handler = {
let state = state.lock().await;
state.notification_handler.clone()
};
if let Some(handler) = notification_handler {
let notification = JsonRpcNotification {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
method,
params,
};
handler(notification);
}
}
JsonRpcMessage::Response { id, result, error } => {
if let Some(id_u64) = id.as_u64() {
let mut state = state.lock().await;
if let Some(sender) = state.pending_requests.remove(&id_u64) {
let response = JsonRpcResponse {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
id,
result,
error,
};
if sender.send(response).is_err() {
debug!("Response receiver dropped for request {}", id_u64);
}
} else {
debug!("Received response for unknown request {}", id_u64);
}
} else {
debug!("Response has non-numeric ID, cannot match: {:?}", id);
}
}
JsonRpcMessage::Invalid(reason) => {
debug!("Received invalid JSON-RPC message: {}", reason);
}
}
}
pub async fn request<P, R>(
&mut self,
method: &str,
params: Option<P>,
) -> Result<R, JsonRpcError>
where
P: serde::Serialize,
R: for<'de> serde::Deserialize<'de>,
{
self.request_with_timeout(method, params, std::time::Duration::from_secs(30))
.await
}
pub async fn request_with_timeout<P, R>(
&mut self,
method: &str,
params: Option<P>,
timeout: std::time::Duration,
) -> Result<R, JsonRpcError>
where
P: serde::Serialize,
R: for<'de> serde::Deserialize<'de>,
{
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
let (response_sender, mut response_receiver) = mpsc::unbounded_channel();
{
let mut state = self.state.lock().await;
state.pending_requests.insert(id, response_sender);
}
let request = JsonRpcRequest {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
id: Value::Number(serde_json::Number::from(id)),
method: method.to_string(),
params: params
.map(|p| serde_json::to_value(p).map_err(JsonRpcError::Serialization))
.transpose()?,
};
let request_json = serde_json::to_string(&request).map_err(JsonRpcError::Serialization)?;
debug!("JsonRpcClient: Sending request: {}", request_json);
self.outbound_sender
.send(request_json)
.map_err(|_| JsonRpcError::Transport("Outbound channel closed".to_string()))?;
let response_result = tokio::time::timeout(timeout, response_receiver.recv()).await;
let response = match response_result {
Ok(Some(response)) => response,
Ok(None) => {
let mut state = self.state.lock().await;
state.pending_requests.remove(&id);
return Err(JsonRpcError::RequestCancelled);
}
Err(_) => {
let mut state = self.state.lock().await;
state.pending_requests.remove(&id);
return Err(JsonRpcError::Timeout);
}
};
if let Some(error) = response.error {
return Err(JsonRpcError::Server {
code: error.code,
message: error.message,
data: error.data,
});
}
match response.result {
Some(Value::Null) => {
serde_json::from_value(Value::Null).map_err(JsonRpcError::Deserialization)
}
Some(result) => serde_json::from_value(result).map_err(JsonRpcError::Deserialization),
None => Err(JsonRpcError::MissingResult),
}
}
pub async fn notify<P>(&mut self, method: &str, params: Option<P>) -> Result<(), JsonRpcError>
where
P: serde::Serialize,
{
let notification = JsonRpcNotification {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
method: method.to_string(),
params: params
.map(|p| serde_json::to_value(p).map_err(JsonRpcError::Serialization))
.transpose()?,
};
let notification_json =
serde_json::to_string(¬ification).map_err(JsonRpcError::Serialization)?;
debug!("JsonRpcClient: Sending notification: {}", notification_json);
self.outbound_sender
.send(notification_json)
.map_err(|_| JsonRpcError::Transport("Outbound channel closed".to_string()))?;
Ok(())
}
pub async fn cleanup_pending_requests(&mut self) {
let mut state = self.state.lock().await;
for (id, sender) in state.pending_requests.drain() {
debug!("JsonRpcClient: Cleaning up pending request ID {}", id);
let _ = sender.send(JsonRpcResponse {
jsonrpc: crate::lsp::jsonrpc_utils::JSONRPC_VERSION.to_string(),
id: Value::Number(serde_json::Number::from(id)),
result: None,
error: Some(JsonRpcErrorObject {
code: JsonRpcErrorCode::InternalError as i32,
message: "Request cancelled due to connection restart".to_string(),
data: None,
}),
});
}
}
pub async fn close(&mut self) -> Result<(), JsonRpcError> {
self.cleanup_pending_requests().await;
Ok(())
}
}