1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Instant;
5
6use parking_lot::{Mutex, RwLock};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::sync::{broadcast, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tracing::{Instrument, debug, error, warn};
13
14use crate::{Error, ErrorKind, ProtocolErrorKind};
15
16pub(crate) type InlineResponseCallback =
26 Box<dyn FnOnce(&JsonRpcResponse) -> Result<(), Error> + Send + Sync>;
27
28struct PendingRequest {
31 sender: oneshot::Sender<JsonRpcResponse>,
32 inline_callback: Option<InlineResponseCallback>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct JsonRpcRequest {
39 pub jsonrpc: String,
41 pub id: u64,
43 pub method: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub params: Option<Value>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct JsonRpcResponse {
54 pub jsonrpc: String,
56 pub id: u64,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub result: Option<Value>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub error: Option<JsonRpcError>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct JsonRpcError {
69 pub code: i32,
71 pub message: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub data: Option<Value>,
76}
77
78pub mod error_codes {
80 pub const METHOD_NOT_FOUND: i32 = -32601;
82 pub const INVALID_PARAMS: i32 = -32602;
84 #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
86 pub const INTERNAL_ERROR: i32 = -32603;
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct JsonRpcNotification {
93 pub jsonrpc: String,
95 pub method: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub params: Option<Value>,
100}
101
102#[derive(Debug, Clone, Serialize)]
104pub enum JsonRpcMessage {
105 Request(JsonRpcRequest),
107 Response(JsonRpcResponse),
109 Notification(JsonRpcNotification),
111}
112
113impl<'de> Deserialize<'de> for JsonRpcMessage {
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: serde::Deserializer<'de>,
125 {
126 let value = Value::deserialize(deserializer)?;
127 let obj = value
128 .as_object()
129 .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
130
131 let has_id = obj.contains_key("id");
132 let has_method = obj.contains_key("method");
133
134 if has_id && has_method {
135 JsonRpcRequest::deserialize(value)
136 .map(JsonRpcMessage::Request)
137 .map_err(serde::de::Error::custom)
138 } else if has_id {
139 JsonRpcResponse::deserialize(value)
140 .map(JsonRpcMessage::Response)
141 .map_err(serde::de::Error::custom)
142 } else {
143 JsonRpcNotification::deserialize(value)
144 .map(JsonRpcMessage::Notification)
145 .map_err(serde::de::Error::custom)
146 }
147 }
148}
149
150impl JsonRpcRequest {
151 pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
153 Self {
154 jsonrpc: "2.0".to_string(),
155 id,
156 method: method.to_string(),
157 params,
158 }
159 }
160}
161
162impl JsonRpcResponse {
163 #[allow(dead_code)]
165 pub fn is_error(&self) -> bool {
166 self.error.is_some()
167 }
168}
169
170const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
171
172struct WriteCommand {
181 frame: Vec<u8>,
182 ack: oneshot::Sender<Result<(), std::io::Error>>,
183}
184
185pub struct JsonRpcClient {
196 request_id: AtomicU64,
197 write_tx: mpsc::UnboundedSender<WriteCommand>,
204 pending_requests: Arc<RwLock<HashMap<u64, PendingRequest>>>,
205 notification_tx: broadcast::Sender<JsonRpcNotification>,
206 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
207 read_task: Mutex<Option<JoinHandle<()>>>,
208 write_task: Mutex<Option<JoinHandle<()>>>,
209}
210
211impl JsonRpcClient {
212 pub fn new(
219 writer: impl AsyncWrite + Unpin + Send + 'static,
220 reader: impl AsyncRead + Unpin + Send + 'static,
221 notification_tx: broadcast::Sender<JsonRpcNotification>,
222 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
223 ) -> Self {
224 let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
225
226 let writer_span = tracing::error_span!("jsonrpc_write_loop");
227 let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
228
229 let client = Self {
230 request_id: AtomicU64::new(1),
231 write_tx,
232 pending_requests: Arc::new(RwLock::new(HashMap::new())),
233 notification_tx,
234 request_tx,
235 read_task: Mutex::new(None),
236 write_task: Mutex::new(Some(write_task)),
237 };
238
239 let pending_requests = client.pending_requests.clone();
240 let notification_tx_clone = client.notification_tx.clone();
241 let request_tx_clone = client.request_tx.clone();
242 let reader_span = tracing::error_span!("jsonrpc_read_loop");
243
244 let read_task = tokio::spawn(
245 async move {
246 Self::read_loop(
247 reader,
248 pending_requests,
249 notification_tx_clone,
250 request_tx_clone,
251 )
252 .await;
253 }
254 .instrument(reader_span),
255 );
256 *client.read_task.lock() = Some(read_task);
257
258 client
259 }
260
261 pub(crate) fn force_close(&self) {
262 if let Some(task) = self.read_task.lock().take() {
263 task.abort();
264 }
265 if let Some(task) = self.write_task.lock().take() {
266 task.abort();
267 }
268 self.pending_requests.write().clear();
269 }
270
271 async fn write_loop(
284 mut writer: impl AsyncWrite + Unpin + Send + 'static,
285 mut rx: mpsc::UnboundedReceiver<WriteCommand>,
286 ) {
287 while let Some(WriteCommand { frame, ack }) = rx.recv().await {
288 let result = async {
289 writer.write_all(&frame).await?;
290 writer.flush().await?;
291 Ok::<_, std::io::Error>(())
292 }
293 .await;
294
295 let _ = ack.send(result);
299 }
300 }
301
302 async fn read_loop(
303 reader: impl AsyncRead + Unpin + Send,
304 pending_requests: Arc<RwLock<HashMap<u64, PendingRequest>>>,
305 notification_tx: broadcast::Sender<JsonRpcNotification>,
306 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
307 ) {
308 let mut reader = BufReader::new(reader);
309
310 loop {
311 match Self::read_message(&mut reader).await {
312 Ok(Some(message)) => match message {
313 JsonRpcMessage::Response(mut response) => {
314 let id = response.id;
315 let pending = pending_requests.write().remove(&id);
316 if let Some(PendingRequest {
317 sender,
318 inline_callback,
319 }) = pending
320 {
321 if let Some(cb) = inline_callback
327 && response.error.is_none()
328 {
329 let cb_outcome =
330 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
331 cb(&response)
332 }));
333 match cb_outcome {
334 Ok(Ok(())) => {}
335 Ok(Err(error)) => {
336 response.result = None;
337 response.error = Some(JsonRpcError {
338 code: -32603,
339 message: error.to_string(),
340 data: None,
341 });
342 }
343 Err(panic) => {
344 let message = panic
345 .downcast_ref::<&'static str>()
346 .map(|s| (*s).to_string())
347 .or_else(|| panic.downcast_ref::<String>().cloned())
348 .unwrap_or_else(|| {
349 "inline response callback panicked".to_string()
350 });
351 response.result = None;
352 response.error = Some(JsonRpcError {
353 code: -32603,
354 message,
355 data: None,
356 });
357 }
358 }
359 }
360 if sender.send(response).is_err() {
361 warn!(request_id = %id, "failed to send response for request");
362 }
363 } else {
364 warn!(request_id = %id, "received response for unknown request id");
365 }
366 }
367 JsonRpcMessage::Notification(notification) => {
368 let _ = notification_tx.send(notification);
369 }
370 JsonRpcMessage::Request(request) => {
371 if request_tx.send(request).is_err() {
372 warn!("failed to forward JSON-RPC request, channel closed");
373 }
374 }
375 },
376 Ok(None) => {
377 break;
378 }
379 Err(e) => {
380 error!(error = %e, "error reading from CLI");
381 break;
382 }
383 }
384 }
385
386 let mut pending = pending_requests.write();
389 if !pending.is_empty() {
390 warn!(
391 count = pending.len(),
392 "draining pending requests after read loop exit"
393 );
394 pending.clear();
395 }
396 }
397
398 async fn read_message(
399 reader: &mut BufReader<impl AsyncRead + Unpin>,
400 ) -> Result<Option<JsonRpcMessage>, Error> {
401 let mut line = String::new();
402 let mut content_length = None;
403
404 loop {
405 line.clear();
406 if reader.read_line(&mut line).await? == 0 {
407 return Ok(None);
408 }
409
410 let trimmed = line.trim();
411 if trimmed.is_empty() {
412 break;
413 }
414
415 if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
416 content_length = Some(value.trim().parse::<usize>().map_err(|_| {
417 Error::from(ErrorKind::Protocol(
418 ProtocolErrorKind::InvalidContentLength(value.trim().to_string()),
419 ))
420 })?);
421 }
422 }
423
424 let Some(length) = content_length else {
425 return Err(ErrorKind::Protocol(ProtocolErrorKind::MissingContentLength).into());
426 };
427
428 let mut body = vec![0u8; length];
429 reader.read_exact(&mut body).await?;
430
431 let message: JsonRpcMessage = serde_json::from_slice(&body)?;
432 Ok(Some(message))
433 }
434
435 #[allow(dead_code, reason = "public API exported via crate::JsonRpcClient")]
446 pub async fn send_request(
447 &self,
448 method: &str,
449 params: Option<serde_json::Value>,
450 ) -> Result<JsonRpcResponse, Error> {
451 self.send_request_with_inline_callback(method, params, None)
452 .await
453 }
454
455 pub(crate) async fn send_request_with_inline_callback(
472 &self,
473 method: &str,
474 params: Option<serde_json::Value>,
475 inline_callback: Option<InlineResponseCallback>,
476 ) -> Result<JsonRpcResponse, Error> {
477 let request_start = Instant::now();
478 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
479 let request = JsonRpcRequest::new(id, method, params);
480
481 let (tx, rx) = oneshot::channel();
482 self.pending_requests.write().insert(
483 id,
484 PendingRequest {
485 sender: tx,
486 inline_callback,
487 },
488 );
489
490 let mut guard = PendingGuard {
495 map: &self.pending_requests,
496 id,
497 armed: true,
498 };
499
500 if let Err(error) = self.write(&request).await {
504 warn!(
505 elapsed_ms = request_start.elapsed().as_millis(),
506 method = %method,
507 request_id = id,
508 status = "failed",
509 error = %error,
510 "JsonRpcClient::send_request JSON-RPC request finished"
511 );
512 return Err(error);
513 }
514
515 let response = match rx.await {
516 Ok(response) => response,
517 Err(_) => {
518 let error = ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled).into();
519 warn!(
520 elapsed_ms = request_start.elapsed().as_millis(),
521 method = %method,
522 request_id = id,
523 status = "failed",
524 error = %error,
525 "JsonRpcClient::send_request JSON-RPC request finished"
526 );
527 return Err(error);
528 }
529 };
530 guard.disarm();
531 if let Some(error) = &response.error {
532 warn!(
533 elapsed_ms = request_start.elapsed().as_millis(),
534 method = %method,
535 request_id = id,
536 status = "failed",
537 code = error.code,
538 error = %error.message,
539 "JsonRpcClient::send_request JSON-RPC request finished"
540 );
541 } else {
542 debug!(
543 elapsed_ms = request_start.elapsed().as_millis(),
544 method = %method,
545 request_id = id,
546 status = "succeeded",
547 "JsonRpcClient::send_request JSON-RPC request finished"
548 );
549 }
550 Ok(response)
551 }
552
553 pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
562 let body = serde_json::to_vec(message)?;
563 let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
564 frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
565 frame.extend_from_slice(body.len().to_string().as_bytes());
566 frame.extend_from_slice(b"\r\n\r\n");
567 frame.extend_from_slice(&body);
568
569 let (ack_tx, ack_rx) = oneshot::channel();
570 self.write_tx
571 .send(WriteCommand { frame, ack: ack_tx })
572 .map_err(|_| {
573 Error::from(std::io::Error::new(
574 std::io::ErrorKind::BrokenPipe,
575 "writer actor has shut down",
576 ))
577 })?;
578
579 match ack_rx.await {
580 Ok(Ok(())) => Ok(()),
581 Ok(Err(e)) => Err(Error::from(e)),
582 Err(_) => Err(Error::from(std::io::Error::new(
583 std::io::ErrorKind::BrokenPipe,
584 "writer actor dropped ack without responding",
585 ))),
586 }
587 }
588}
589
590struct PendingGuard<'a> {
594 map: &'a RwLock<HashMap<u64, PendingRequest>>,
595 id: u64,
596 armed: bool,
597}
598
599impl PendingGuard<'_> {
600 fn disarm(&mut self) {
601 self.armed = false;
602 }
603}
604
605impl Drop for PendingGuard<'_> {
606 fn drop(&mut self) {
607 if self.armed {
608 self.map.write().remove(&self.id);
609 }
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn deserialize_notification() {
619 let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
620 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
621 assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
622 }
623
624 #[test]
625 fn deserialize_request() {
626 let json =
627 r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
628 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
629 assert!(
630 matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
631 );
632 }
633
634 #[test]
635 fn deserialize_response_with_result() {
636 let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
637 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
638 assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
639 }
640
641 #[test]
642 fn deserialize_error_response() {
643 let json =
644 r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
645 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
646 match msg {
647 JsonRpcMessage::Response(r) => {
648 assert!(r.is_error());
649 let err = r.error.unwrap();
650 assert_eq!(err.code, -32600);
651 assert_eq!(err.message, "Invalid Request");
652 }
653 other => panic!("expected Response, got {other:?}"),
654 }
655 }
656
657 #[test]
658 fn deserialize_rejects_non_object() {
659 let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
660 assert!(result.is_err());
661 }
662
663 #[test]
664 fn request_new_sets_version() {
665 let req = JsonRpcRequest::new(42, "test.method", None);
666 assert_eq!(req.jsonrpc, "2.0");
667 assert_eq!(req.id, 42);
668 assert_eq!(req.method, "test.method");
669 assert!(req.params.is_none());
670 }
671
672 #[test]
673 fn request_serializes_camel_case() {
674 let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
675 let json = serde_json::to_string(&req).unwrap();
676 assert!(json.contains(r#""jsonrpc":"2.0""#));
677 assert!(json.contains(r#""id":1"#));
678 assert!(json.contains(r#""method":"ping""#));
679 }
680
681 #[test]
682 fn notification_without_params_omits_field() {
683 let n = JsonRpcNotification {
684 jsonrpc: "2.0".into(),
685 method: "ping".into(),
686 params: None,
687 };
688 let json = serde_json::to_string(&n).unwrap();
689 assert!(!json.contains("params"));
690 }
691
692 #[test]
693 fn response_without_error_omits_field() {
694 let r = JsonRpcResponse {
695 jsonrpc: "2.0".into(),
696 id: 1,
697 result: Some(serde_json::json!(true)),
698 error: None,
699 };
700 let json = serde_json::to_string(&r).unwrap();
701 assert!(!json.contains("error"));
702 }
703}