1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Instant;
5
6use parking_lot::{Mutex, RwLock};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::sync::{broadcast, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tracing::{Instrument, debug, error, warn};
13
14use crate::{Error, ProtocolError};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct JsonRpcRequest {
20 pub jsonrpc: String,
22 pub id: u64,
24 pub method: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub params: Option<Value>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(rename_all = "camelCase")]
34pub struct JsonRpcResponse {
35 pub jsonrpc: String,
37 pub id: u64,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub result: Option<Value>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub error: Option<JsonRpcError>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct JsonRpcError {
50 pub code: i32,
52 pub message: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub data: Option<Value>,
57}
58
59pub mod error_codes {
61 pub const METHOD_NOT_FOUND: i32 = -32601;
63 pub const INVALID_PARAMS: i32 = -32602;
65 #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
67 pub const INTERNAL_ERROR: i32 = -32603;
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub struct JsonRpcNotification {
74 pub jsonrpc: String,
76 pub method: String,
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub params: Option<Value>,
81}
82
83#[derive(Debug, Clone, Serialize)]
85pub enum JsonRpcMessage {
86 Request(JsonRpcRequest),
88 Response(JsonRpcResponse),
90 Notification(JsonRpcNotification),
92}
93
94impl<'de> Deserialize<'de> for JsonRpcMessage {
103 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104 where
105 D: serde::Deserializer<'de>,
106 {
107 let value = Value::deserialize(deserializer)?;
108 let obj = value
109 .as_object()
110 .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
111
112 let has_id = obj.contains_key("id");
113 let has_method = obj.contains_key("method");
114
115 if has_id && has_method {
116 JsonRpcRequest::deserialize(value)
117 .map(JsonRpcMessage::Request)
118 .map_err(serde::de::Error::custom)
119 } else if has_id {
120 JsonRpcResponse::deserialize(value)
121 .map(JsonRpcMessage::Response)
122 .map_err(serde::de::Error::custom)
123 } else {
124 JsonRpcNotification::deserialize(value)
125 .map(JsonRpcMessage::Notification)
126 .map_err(serde::de::Error::custom)
127 }
128 }
129}
130
131impl JsonRpcRequest {
132 pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
134 Self {
135 jsonrpc: "2.0".to_string(),
136 id,
137 method: method.to_string(),
138 params,
139 }
140 }
141}
142
143impl JsonRpcResponse {
144 #[allow(dead_code)]
146 pub fn is_error(&self) -> bool {
147 self.error.is_some()
148 }
149}
150
151const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
152
153struct WriteCommand {
162 frame: Vec<u8>,
163 ack: oneshot::Sender<Result<(), std::io::Error>>,
164}
165
166pub struct JsonRpcClient {
177 request_id: AtomicU64,
178 write_tx: mpsc::UnboundedSender<WriteCommand>,
185 pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
186 notification_tx: broadcast::Sender<JsonRpcNotification>,
187 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
188 read_task: Mutex<Option<JoinHandle<()>>>,
189 write_task: Mutex<Option<JoinHandle<()>>>,
190}
191
192impl JsonRpcClient {
193 pub fn new(
200 writer: impl AsyncWrite + Unpin + Send + 'static,
201 reader: impl AsyncRead + Unpin + Send + 'static,
202 notification_tx: broadcast::Sender<JsonRpcNotification>,
203 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
204 ) -> Self {
205 let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
206
207 let writer_span = tracing::error_span!("jsonrpc_write_loop");
208 let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
209
210 let client = Self {
211 request_id: AtomicU64::new(1),
212 write_tx,
213 pending_requests: Arc::new(RwLock::new(HashMap::new())),
214 notification_tx,
215 request_tx,
216 read_task: Mutex::new(None),
217 write_task: Mutex::new(Some(write_task)),
218 };
219
220 let pending_requests = client.pending_requests.clone();
221 let notification_tx_clone = client.notification_tx.clone();
222 let request_tx_clone = client.request_tx.clone();
223 let reader_span = tracing::error_span!("jsonrpc_read_loop");
224
225 let read_task = tokio::spawn(
226 async move {
227 Self::read_loop(
228 reader,
229 pending_requests,
230 notification_tx_clone,
231 request_tx_clone,
232 )
233 .await;
234 }
235 .instrument(reader_span),
236 );
237 *client.read_task.lock() = Some(read_task);
238
239 client
240 }
241
242 pub(crate) fn force_close(&self) {
243 if let Some(task) = self.read_task.lock().take() {
244 task.abort();
245 }
246 if let Some(task) = self.write_task.lock().take() {
247 task.abort();
248 }
249 self.pending_requests.write().clear();
250 }
251
252 async fn write_loop(
265 mut writer: impl AsyncWrite + Unpin + Send + 'static,
266 mut rx: mpsc::UnboundedReceiver<WriteCommand>,
267 ) {
268 while let Some(WriteCommand { frame, ack }) = rx.recv().await {
269 let result = async {
270 writer.write_all(&frame).await?;
271 writer.flush().await?;
272 Ok::<_, std::io::Error>(())
273 }
274 .await;
275
276 let _ = ack.send(result);
280 }
281 }
282
283 async fn read_loop(
284 reader: impl AsyncRead + Unpin + Send,
285 pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
286 notification_tx: broadcast::Sender<JsonRpcNotification>,
287 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
288 ) {
289 let mut reader = BufReader::new(reader);
290
291 loop {
292 match Self::read_message(&mut reader).await {
293 Ok(Some(message)) => match message {
294 JsonRpcMessage::Response(response) => {
295 let id = response.id;
296 let tx = pending_requests.write().remove(&id);
297 if let Some(tx) = tx {
298 if tx.send(response).is_err() {
299 warn!(request_id = %id, "failed to send response for request");
300 }
301 } else {
302 warn!(request_id = %id, "received response for unknown request id");
303 }
304 }
305 JsonRpcMessage::Notification(notification) => {
306 let _ = notification_tx.send(notification);
307 }
308 JsonRpcMessage::Request(request) => {
309 if request_tx.send(request).is_err() {
310 warn!("failed to forward JSON-RPC request, channel closed");
311 }
312 }
313 },
314 Ok(None) => {
315 break;
316 }
317 Err(e) => {
318 error!(error = %e, "error reading from CLI");
319 break;
320 }
321 }
322 }
323
324 let mut pending = pending_requests.write();
327 if !pending.is_empty() {
328 warn!(
329 count = pending.len(),
330 "draining pending requests after read loop exit"
331 );
332 pending.clear();
333 }
334 }
335
336 async fn read_message(
337 reader: &mut BufReader<impl AsyncRead + Unpin>,
338 ) -> Result<Option<JsonRpcMessage>, Error> {
339 let mut line = String::new();
340 let mut content_length = None;
341
342 loop {
343 line.clear();
344 if reader.read_line(&mut line).await? == 0 {
345 return Ok(None);
346 }
347
348 let trimmed = line.trim();
349 if trimmed.is_empty() {
350 break;
351 }
352
353 if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
354 content_length = Some(value.trim().parse::<usize>().map_err(|_| {
355 Error::Protocol(ProtocolError::InvalidContentLength(
356 value.trim().to_string(),
357 ))
358 })?);
359 }
360 }
361
362 let Some(length) = content_length else {
363 return Err(Error::Protocol(ProtocolError::MissingContentLength));
364 };
365
366 let mut body = vec![0u8; length];
367 reader.read_exact(&mut body).await?;
368
369 let message: JsonRpcMessage = serde_json::from_slice(&body)?;
370 Ok(Some(message))
371 }
372
373 pub async fn send_request(
384 &self,
385 method: &str,
386 params: Option<serde_json::Value>,
387 ) -> Result<JsonRpcResponse, Error> {
388 let request_start = Instant::now();
389 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
390 let request = JsonRpcRequest::new(id, method, params);
391
392 let (tx, rx) = oneshot::channel();
393 self.pending_requests.write().insert(id, tx);
394
395 let mut guard = PendingGuard {
400 map: &self.pending_requests,
401 id,
402 armed: true,
403 };
404
405 if let Err(error) = self.write(&request).await {
409 warn!(
410 elapsed_ms = request_start.elapsed().as_millis(),
411 method = %method,
412 request_id = id,
413 status = "failed",
414 error = %error,
415 "JsonRpcClient::send_request JSON-RPC request finished"
416 );
417 return Err(error);
418 }
419
420 let response = match rx.await {
421 Ok(response) => response,
422 Err(_) => {
423 let error = Error::Protocol(ProtocolError::RequestCancelled);
424 warn!(
425 elapsed_ms = request_start.elapsed().as_millis(),
426 method = %method,
427 request_id = id,
428 status = "failed",
429 error = %error,
430 "JsonRpcClient::send_request JSON-RPC request finished"
431 );
432 return Err(error);
433 }
434 };
435 guard.disarm();
436 if let Some(error) = &response.error {
437 warn!(
438 elapsed_ms = request_start.elapsed().as_millis(),
439 method = %method,
440 request_id = id,
441 status = "failed",
442 code = error.code,
443 error = %error.message,
444 "JsonRpcClient::send_request JSON-RPC request finished"
445 );
446 } else {
447 debug!(
448 elapsed_ms = request_start.elapsed().as_millis(),
449 method = %method,
450 request_id = id,
451 status = "succeeded",
452 "JsonRpcClient::send_request JSON-RPC request finished"
453 );
454 }
455 Ok(response)
456 }
457
458 pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
467 let body = serde_json::to_vec(message)?;
468 let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
469 frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
470 frame.extend_from_slice(body.len().to_string().as_bytes());
471 frame.extend_from_slice(b"\r\n\r\n");
472 frame.extend_from_slice(&body);
473
474 let (ack_tx, ack_rx) = oneshot::channel();
475 self.write_tx
476 .send(WriteCommand { frame, ack: ack_tx })
477 .map_err(|_| {
478 Error::Io(std::io::Error::new(
479 std::io::ErrorKind::BrokenPipe,
480 "writer actor has shut down",
481 ))
482 })?;
483
484 match ack_rx.await {
485 Ok(Ok(())) => Ok(()),
486 Ok(Err(e)) => Err(Error::Io(e)),
487 Err(_) => Err(Error::Io(std::io::Error::new(
488 std::io::ErrorKind::BrokenPipe,
489 "writer actor dropped ack without responding",
490 ))),
491 }
492 }
493}
494
495struct PendingGuard<'a> {
499 map: &'a RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>,
500 id: u64,
501 armed: bool,
502}
503
504impl PendingGuard<'_> {
505 fn disarm(&mut self) {
506 self.armed = false;
507 }
508}
509
510impl Drop for PendingGuard<'_> {
511 fn drop(&mut self) {
512 if self.armed {
513 self.map.write().remove(&self.id);
514 }
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn deserialize_notification() {
524 let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
525 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
526 assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
527 }
528
529 #[test]
530 fn deserialize_request() {
531 let json =
532 r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
533 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
534 assert!(
535 matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
536 );
537 }
538
539 #[test]
540 fn deserialize_response_with_result() {
541 let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
542 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
543 assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
544 }
545
546 #[test]
547 fn deserialize_error_response() {
548 let json =
549 r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
550 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
551 match msg {
552 JsonRpcMessage::Response(r) => {
553 assert!(r.is_error());
554 let err = r.error.unwrap();
555 assert_eq!(err.code, -32600);
556 assert_eq!(err.message, "Invalid Request");
557 }
558 other => panic!("expected Response, got {other:?}"),
559 }
560 }
561
562 #[test]
563 fn deserialize_rejects_non_object() {
564 let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
565 assert!(result.is_err());
566 }
567
568 #[test]
569 fn request_new_sets_version() {
570 let req = JsonRpcRequest::new(42, "test.method", None);
571 assert_eq!(req.jsonrpc, "2.0");
572 assert_eq!(req.id, 42);
573 assert_eq!(req.method, "test.method");
574 assert!(req.params.is_none());
575 }
576
577 #[test]
578 fn request_serializes_camel_case() {
579 let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
580 let json = serde_json::to_string(&req).unwrap();
581 assert!(json.contains(r#""jsonrpc":"2.0""#));
582 assert!(json.contains(r#""id":1"#));
583 assert!(json.contains(r#""method":"ping""#));
584 }
585
586 #[test]
587 fn notification_without_params_omits_field() {
588 let n = JsonRpcNotification {
589 jsonrpc: "2.0".into(),
590 method: "ping".into(),
591 params: None,
592 };
593 let json = serde_json::to_string(&n).unwrap();
594 assert!(!json.contains("params"));
595 }
596
597 #[test]
598 fn response_without_error_omits_field() {
599 let r = JsonRpcResponse {
600 jsonrpc: "2.0".into(),
601 id: 1,
602 result: Some(serde_json::json!(true)),
603 error: None,
604 };
605 let json = serde_json::to_string(&r).unwrap();
606 assert!(!json.contains("error"));
607 }
608}