use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::sync::{broadcast, mpsc, oneshot};
use tracing::{Instrument, error, warn};
use crate::{Error, ProtocolError};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: u64,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub mod error_codes {
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
#[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
pub const INTERNAL_ERROR: i32 = -32603;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize)]
pub enum JsonRpcMessage {
Request(JsonRpcRequest),
Response(JsonRpcResponse),
Notification(JsonRpcNotification),
}
impl<'de> Deserialize<'de> for JsonRpcMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
let obj = value
.as_object()
.ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
let has_id = obj.contains_key("id");
let has_method = obj.contains_key("method");
if has_id && has_method {
JsonRpcRequest::deserialize(value)
.map(JsonRpcMessage::Request)
.map_err(serde::de::Error::custom)
} else if has_id {
JsonRpcResponse::deserialize(value)
.map(JsonRpcMessage::Response)
.map_err(serde::de::Error::custom)
} else {
JsonRpcNotification::deserialize(value)
.map(JsonRpcMessage::Notification)
.map_err(serde::de::Error::custom)
}
}
}
impl JsonRpcRequest {
pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
method: method.to_string(),
params,
}
}
}
impl JsonRpcResponse {
#[allow(dead_code)]
pub fn is_error(&self) -> bool {
self.error.is_some()
}
}
const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
struct WriteCommand {
frame: Vec<u8>,
ack: oneshot::Sender<Result<(), std::io::Error>>,
}
pub struct JsonRpcClient {
request_id: AtomicU64,
write_tx: mpsc::UnboundedSender<WriteCommand>,
pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
}
impl JsonRpcClient {
pub fn new(
writer: impl AsyncWrite + Unpin + Send + 'static,
reader: impl AsyncRead + Unpin + Send + 'static,
notification_tx: broadcast::Sender<JsonRpcNotification>,
request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
) -> Self {
let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
let writer_span = tracing::error_span!("jsonrpc_write_loop");
tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
let client = Self {
request_id: AtomicU64::new(1),
write_tx,
pending_requests: Arc::new(RwLock::new(HashMap::new())),
notification_tx,
request_tx,
};
let pending_requests = client.pending_requests.clone();
let notification_tx_clone = client.notification_tx.clone();
let request_tx_clone = client.request_tx.clone();
let reader_span = tracing::error_span!("jsonrpc_read_loop");
tokio::spawn(
async move {
Self::read_loop(
reader,
pending_requests,
notification_tx_clone,
request_tx_clone,
)
.await;
}
.instrument(reader_span),
);
client
}
async fn write_loop(
mut writer: impl AsyncWrite + Unpin + Send + 'static,
mut rx: mpsc::UnboundedReceiver<WriteCommand>,
) {
while let Some(WriteCommand { frame, ack }) = rx.recv().await {
let result = async {
writer.write_all(&frame).await?;
writer.flush().await?;
Ok::<_, std::io::Error>(())
}
.await;
let _ = ack.send(result);
}
}
async fn read_loop(
reader: impl AsyncRead + Unpin + Send,
pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
) {
let mut reader = BufReader::new(reader);
loop {
match Self::read_message(&mut reader).await {
Ok(Some(message)) => match message {
JsonRpcMessage::Response(response) => {
let id = response.id;
let tx = pending_requests.write().remove(&id);
if let Some(tx) = tx {
if tx.send(response).is_err() {
warn!(request_id = %id, "failed to send response for request");
}
} else {
warn!(request_id = %id, "received response for unknown request id");
}
}
JsonRpcMessage::Notification(notification) => {
let _ = notification_tx.send(notification);
}
JsonRpcMessage::Request(request) => {
if request_tx.send(request).is_err() {
warn!("failed to forward JSON-RPC request, channel closed");
}
}
},
Ok(None) => {
break;
}
Err(e) => {
error!(error = %e, "error reading from CLI");
break;
}
}
}
let mut pending = pending_requests.write();
if !pending.is_empty() {
warn!(
count = pending.len(),
"draining pending requests after read loop exit"
);
pending.clear();
}
}
async fn read_message(
reader: &mut BufReader<impl AsyncRead + Unpin>,
) -> Result<Option<JsonRpcMessage>, Error> {
let mut line = String::new();
let mut content_length = None;
loop {
line.clear();
if reader.read_line(&mut line).await? == 0 {
return Ok(None);
}
let trimmed = line.trim();
if trimmed.is_empty() {
break;
}
if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
content_length = Some(value.trim().parse::<usize>().map_err(|_| {
Error::Protocol(ProtocolError::InvalidContentLength(
value.trim().to_string(),
))
})?);
}
}
let Some(length) = content_length else {
return Err(Error::Protocol(ProtocolError::MissingContentLength));
};
let mut body = vec![0u8; length];
reader.read_exact(&mut body).await?;
let message: JsonRpcMessage = serde_json::from_slice(&body)?;
Ok(Some(message))
}
pub async fn send_request(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<JsonRpcResponse, Error> {
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
let request = JsonRpcRequest::new(id, method, params);
let (tx, rx) = oneshot::channel();
self.pending_requests.write().insert(id, tx);
let mut guard = PendingGuard {
map: &self.pending_requests,
id,
armed: true,
};
self.write(&request).await?;
let response = rx
.await
.map_err(|_| Error::Protocol(ProtocolError::RequestCancelled))?;
guard.disarm();
Ok(response)
}
pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
let body = serde_json::to_vec(message)?;
let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
frame.extend_from_slice(body.len().to_string().as_bytes());
frame.extend_from_slice(b"\r\n\r\n");
frame.extend_from_slice(&body);
let (ack_tx, ack_rx) = oneshot::channel();
self.write_tx
.send(WriteCommand { frame, ack: ack_tx })
.map_err(|_| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"writer actor has shut down",
))
})?;
match ack_rx.await {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(Error::Io(e)),
Err(_) => Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"writer actor dropped ack without responding",
))),
}
}
}
struct PendingGuard<'a> {
map: &'a RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>,
id: u64,
armed: bool,
}
impl PendingGuard<'_> {
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for PendingGuard<'_> {
fn drop(&mut self) {
if self.armed {
self.map.write().remove(&self.id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_notification() {
let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
}
#[test]
fn deserialize_request() {
let json =
r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
assert!(
matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
);
}
#[test]
fn deserialize_response_with_result() {
let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
}
#[test]
fn deserialize_error_response() {
let json =
r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
match msg {
JsonRpcMessage::Response(r) => {
assert!(r.is_error());
let err = r.error.unwrap();
assert_eq!(err.code, -32600);
assert_eq!(err.message, "Invalid Request");
}
other => panic!("expected Response, got {other:?}"),
}
}
#[test]
fn deserialize_rejects_non_object() {
let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
assert!(result.is_err());
}
#[test]
fn request_new_sets_version() {
let req = JsonRpcRequest::new(42, "test.method", None);
assert_eq!(req.jsonrpc, "2.0");
assert_eq!(req.id, 42);
assert_eq!(req.method, "test.method");
assert!(req.params.is_none());
}
#[test]
fn request_serializes_camel_case() {
let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains(r#""jsonrpc":"2.0""#));
assert!(json.contains(r#""id":1"#));
assert!(json.contains(r#""method":"ping""#));
}
#[test]
fn notification_without_params_omits_field() {
let n = JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "ping".into(),
params: None,
};
let json = serde_json::to_string(&n).unwrap();
assert!(!json.contains("params"));
}
#[test]
fn response_without_error_omits_field() {
let r = JsonRpcResponse {
jsonrpc: "2.0".into(),
id: 1,
result: Some(serde_json::json!(true)),
error: None,
};
let json = serde_json::to_string(&r).unwrap();
assert!(!json.contains("error"));
}
}