use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::{Mutex, mpsc, oneshot};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Request {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<RpcError>,
pub id: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RpcError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl std::fmt::Display for RpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RPC error {}: {}", self.code, self.message)
}
}
impl std::error::Error for RpcError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IncomingMessage {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<RpcError>,
}
impl IncomingMessage {
pub fn is_response(&self) -> bool {
self.method.is_none() && (self.result.is_some() || self.error.is_some())
}
pub fn is_notification(&self) -> bool {
self.method.is_some() && self.id.is_none()
}
pub fn is_rpc_call(&self) -> bool {
self.method.is_some() && self.id.is_some()
}
pub fn into_response(self) -> Response {
Response {
jsonrpc: self.jsonrpc,
result: self.result,
error: self.error,
id: self.id,
}
}
}
pub struct JsonRpcClient {
writer: Arc<Mutex<ChildStdin>>,
next_id: Arc<AtomicU64>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
}
impl std::fmt::Debug for JsonRpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonRpcClient").finish_non_exhaustive()
}
}
impl JsonRpcClient {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>> =
Arc::new(Mutex::new(HashMap::new()));
let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<IncomingMessage>();
let reader_pending = Arc::clone(&pending);
let reader_tx = incoming_tx;
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
break;
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let msg: IncomingMessage = match serde_json::from_str(trimmed) {
Ok(m) => m,
Err(e) => {
log::error!("Failed to parse JSON-RPC message: {e}");
continue;
}
};
if msg.is_response() {
if let Some(id) = msg.id {
let mut map = reader_pending.lock().await;
if let Some(tx) = map.remove(&id) {
let _ = tx.send(msg.into_response());
} else {
log::error!("Received response for unknown request id {id}");
}
} else {
log::error!("Received response without id: {trimmed}");
}
} else {
if reader_tx.send(msg).is_err() {
break;
}
}
}
Err(e) => {
log::error!("Error reading from child stdout: {e}");
break;
}
}
}
let mut map = reader_pending.lock().await;
for (id, tx) in map.drain() {
let _ = tx.send(Response {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(RpcError {
code: -32003,
message: "Agent process exited".to_string(),
data: None,
}),
id: Some(id),
});
}
});
Self {
writer: Arc::new(Mutex::new(stdin)),
next_id: Arc::new(AtomicU64::new(1)),
pending,
incoming_rx: Some(incoming_rx),
}
}
pub fn take_incoming(&mut self) -> Option<mpsc::UnboundedReceiver<IncomingMessage>> {
self.incoming_rx.take()
}
pub async fn request(
&self,
method: &str,
params: Option<Value>,
) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let req = Request {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params,
id: Some(id),
};
let (tx, rx) = oneshot::channel::<Response>();
{
let mut map = self.pending.lock().await;
map.insert(id, tx);
}
let json = serde_json::to_string(&req)?;
{
let mut writer = self.writer.lock().await;
writer.write_all(format!("{json}\n").as_bytes()).await?;
writer.flush().await?;
}
let response = rx.await?;
Ok(response)
}
pub async fn notify(
&self,
method: &str,
params: Option<Value>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let req = Request {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params,
id: None,
};
let json = serde_json::to_string(&req)?;
let mut writer = self.writer.lock().await;
writer.write_all(format!("{json}\n").as_bytes()).await?;
writer.flush().await?;
Ok(())
}
pub async fn respond(
&self,
id: u64,
result: Option<Value>,
error: Option<RpcError>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let resp = Response {
jsonrpc: "2.0".to_string(),
result,
error,
id: Some(id),
};
let json = serde_json::to_string(&resp)?;
log::info!("ACP WIRE OUT: {json}");
let mut writer = self.writer.lock().await;
writer.write_all(format!("{json}\n").as_bytes()).await?;
writer.flush().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_incoming_message_classification() {
let msg: IncomingMessage =
serde_json::from_str(r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#).unwrap();
assert!(msg.is_response());
assert!(!msg.is_notification());
assert!(!msg.is_rpc_call());
let msg: IncomingMessage =
serde_json::from_str(r#"{"jsonrpc":"2.0","method":"session/update","params":{}}"#)
.unwrap();
assert!(!msg.is_response());
assert!(msg.is_notification());
assert!(!msg.is_rpc_call());
let msg: IncomingMessage = serde_json::from_str(
r#"{"jsonrpc":"2.0","id":5,"method":"session/request_permission","params":{}}"#,
)
.unwrap();
assert!(!msg.is_response());
assert!(!msg.is_notification());
assert!(msg.is_rpc_call());
}
#[test]
fn test_request_serialization() {
let req = Request {
jsonrpc: "2.0".to_string(),
method: "initialize".to_string(),
params: Some(serde_json::json!({"protocolVersion": 1})),
id: Some(1),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("initialize"));
assert!(json.contains("protocolVersion"));
}
#[test]
fn test_notification_has_no_id() {
let req = Request {
jsonrpc: "2.0".to_string(),
method: "session/update".to_string(),
params: Some(serde_json::json!({"status": "active"})),
id: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"id\""));
}
#[test]
fn test_response_serialization() {
let resp = Response {
jsonrpc: "2.0".to_string(),
result: Some(serde_json::json!({"capabilities": {}})),
error: None,
id: Some(1),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("capabilities"));
assert!(!json.contains("error"));
}
#[test]
fn test_rpc_error_display() {
let err = RpcError {
code: -32600,
message: "Invalid Request".to_string(),
data: None,
};
assert_eq!(format!("{err}"), "RPC error -32600: Invalid Request");
}
#[test]
fn test_incoming_into_response() {
let msg: IncomingMessage =
serde_json::from_str(r#"{"jsonrpc":"2.0","id":42,"result":{"data":"hello"}}"#).unwrap();
assert!(msg.is_response());
let resp = msg.into_response();
assert_eq!(resp.id, Some(42));
assert!(resp.result.is_some());
assert!(resp.error.is_none());
}
}