1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tracing::{Instrument, error, warn};
11
12use crate::{Error, ProtocolError};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct JsonRpcRequest {
18 pub jsonrpc: String,
20 pub id: u64,
22 pub method: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub params: Option<Value>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct JsonRpcResponse {
33 pub jsonrpc: String,
35 pub id: u64,
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub result: Option<Value>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub error: Option<JsonRpcError>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct JsonRpcError {
48 pub code: i32,
50 pub message: String,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub data: Option<Value>,
55}
56
57pub mod error_codes {
59 pub const METHOD_NOT_FOUND: i32 = -32601;
61 pub const INVALID_PARAMS: i32 = -32602;
63 #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
65 pub const INTERNAL_ERROR: i32 = -32603;
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct JsonRpcNotification {
72 pub jsonrpc: String,
74 pub method: String,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub params: Option<Value>,
79}
80
81#[derive(Debug, Clone, Serialize)]
83pub enum JsonRpcMessage {
84 Request(JsonRpcRequest),
86 Response(JsonRpcResponse),
88 Notification(JsonRpcNotification),
90}
91
92impl<'de> Deserialize<'de> for JsonRpcMessage {
101 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102 where
103 D: serde::Deserializer<'de>,
104 {
105 let value = Value::deserialize(deserializer)?;
106 let obj = value
107 .as_object()
108 .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
109
110 let has_id = obj.contains_key("id");
111 let has_method = obj.contains_key("method");
112
113 if has_id && has_method {
114 JsonRpcRequest::deserialize(value)
115 .map(JsonRpcMessage::Request)
116 .map_err(serde::de::Error::custom)
117 } else if has_id {
118 JsonRpcResponse::deserialize(value)
119 .map(JsonRpcMessage::Response)
120 .map_err(serde::de::Error::custom)
121 } else {
122 JsonRpcNotification::deserialize(value)
123 .map(JsonRpcMessage::Notification)
124 .map_err(serde::de::Error::custom)
125 }
126 }
127}
128
129impl JsonRpcRequest {
130 pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
132 Self {
133 jsonrpc: "2.0".to_string(),
134 id,
135 method: method.to_string(),
136 params,
137 }
138 }
139}
140
141impl JsonRpcResponse {
142 #[allow(dead_code)]
144 pub fn is_error(&self) -> bool {
145 self.error.is_some()
146 }
147}
148
149const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
150
151struct WriteCommand {
160 frame: Vec<u8>,
161 ack: oneshot::Sender<Result<(), std::io::Error>>,
162}
163
164pub struct JsonRpcClient {
175 request_id: AtomicU64,
176 write_tx: mpsc::UnboundedSender<WriteCommand>,
183 pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
184 notification_tx: broadcast::Sender<JsonRpcNotification>,
185 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
186}
187
188impl JsonRpcClient {
189 pub fn new(
196 writer: impl AsyncWrite + Unpin + Send + 'static,
197 reader: impl AsyncRead + Unpin + Send + 'static,
198 notification_tx: broadcast::Sender<JsonRpcNotification>,
199 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
200 ) -> Self {
201 let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
202
203 let writer_span = tracing::error_span!("jsonrpc_write_loop");
204 tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
205
206 let client = Self {
207 request_id: AtomicU64::new(1),
208 write_tx,
209 pending_requests: Arc::new(RwLock::new(HashMap::new())),
210 notification_tx,
211 request_tx,
212 };
213
214 let pending_requests = client.pending_requests.clone();
215 let notification_tx_clone = client.notification_tx.clone();
216 let request_tx_clone = client.request_tx.clone();
217 let reader_span = tracing::error_span!("jsonrpc_read_loop");
218
219 tokio::spawn(
220 async move {
221 Self::read_loop(
222 reader,
223 pending_requests,
224 notification_tx_clone,
225 request_tx_clone,
226 )
227 .await;
228 }
229 .instrument(reader_span),
230 );
231
232 client
233 }
234
235 async fn write_loop(
248 mut writer: impl AsyncWrite + Unpin + Send + 'static,
249 mut rx: mpsc::UnboundedReceiver<WriteCommand>,
250 ) {
251 while let Some(WriteCommand { frame, ack }) = rx.recv().await {
252 let result = async {
253 writer.write_all(&frame).await?;
254 writer.flush().await?;
255 Ok::<_, std::io::Error>(())
256 }
257 .await;
258
259 let _ = ack.send(result);
263 }
264 }
265
266 async fn read_loop(
267 reader: impl AsyncRead + Unpin + Send,
268 pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
269 notification_tx: broadcast::Sender<JsonRpcNotification>,
270 request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
271 ) {
272 let mut reader = BufReader::new(reader);
273
274 loop {
275 match Self::read_message(&mut reader).await {
276 Ok(Some(message)) => match message {
277 JsonRpcMessage::Response(response) => {
278 let id = response.id;
279 let tx = pending_requests.write().remove(&id);
280 if let Some(tx) = tx {
281 if tx.send(response).is_err() {
282 warn!(request_id = %id, "failed to send response for request");
283 }
284 } else {
285 warn!(request_id = %id, "received response for unknown request id");
286 }
287 }
288 JsonRpcMessage::Notification(notification) => {
289 let _ = notification_tx.send(notification);
290 }
291 JsonRpcMessage::Request(request) => {
292 if request_tx.send(request).is_err() {
293 warn!("failed to forward JSON-RPC request, channel closed");
294 }
295 }
296 },
297 Ok(None) => {
298 break;
299 }
300 Err(e) => {
301 error!(error = %e, "error reading from CLI");
302 break;
303 }
304 }
305 }
306
307 let mut pending = pending_requests.write();
310 if !pending.is_empty() {
311 warn!(
312 count = pending.len(),
313 "draining pending requests after read loop exit"
314 );
315 pending.clear();
316 }
317 }
318
319 async fn read_message(
320 reader: &mut BufReader<impl AsyncRead + Unpin>,
321 ) -> Result<Option<JsonRpcMessage>, Error> {
322 let mut line = String::new();
323 let mut content_length = None;
324
325 loop {
326 line.clear();
327 if reader.read_line(&mut line).await? == 0 {
328 return Ok(None);
329 }
330
331 let trimmed = line.trim();
332 if trimmed.is_empty() {
333 break;
334 }
335
336 if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
337 content_length = Some(value.trim().parse::<usize>().map_err(|_| {
338 Error::Protocol(ProtocolError::InvalidContentLength(
339 value.trim().to_string(),
340 ))
341 })?);
342 }
343 }
344
345 let Some(length) = content_length else {
346 return Err(Error::Protocol(ProtocolError::MissingContentLength));
347 };
348
349 let mut body = vec![0u8; length];
350 reader.read_exact(&mut body).await?;
351
352 let message: JsonRpcMessage = serde_json::from_slice(&body)?;
353 Ok(Some(message))
354 }
355
356 pub async fn send_request(
367 &self,
368 method: &str,
369 params: Option<serde_json::Value>,
370 ) -> Result<JsonRpcResponse, Error> {
371 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
372 let request = JsonRpcRequest::new(id, method, params);
373
374 let (tx, rx) = oneshot::channel();
375 self.pending_requests.write().insert(id, tx);
376
377 let mut guard = PendingGuard {
382 map: &self.pending_requests,
383 id,
384 armed: true,
385 };
386
387 self.write(&request).await?;
391
392 let response = rx
393 .await
394 .map_err(|_| Error::Protocol(ProtocolError::RequestCancelled))?;
395 guard.disarm();
396 Ok(response)
397 }
398
399 pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
408 let body = serde_json::to_vec(message)?;
409 let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
410 frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
411 frame.extend_from_slice(body.len().to_string().as_bytes());
412 frame.extend_from_slice(b"\r\n\r\n");
413 frame.extend_from_slice(&body);
414
415 let (ack_tx, ack_rx) = oneshot::channel();
416 self.write_tx
417 .send(WriteCommand { frame, ack: ack_tx })
418 .map_err(|_| {
419 Error::Io(std::io::Error::new(
420 std::io::ErrorKind::BrokenPipe,
421 "writer actor has shut down",
422 ))
423 })?;
424
425 match ack_rx.await {
426 Ok(Ok(())) => Ok(()),
427 Ok(Err(e)) => Err(Error::Io(e)),
428 Err(_) => Err(Error::Io(std::io::Error::new(
429 std::io::ErrorKind::BrokenPipe,
430 "writer actor dropped ack without responding",
431 ))),
432 }
433 }
434}
435
436struct PendingGuard<'a> {
440 map: &'a RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>,
441 id: u64,
442 armed: bool,
443}
444
445impl PendingGuard<'_> {
446 fn disarm(&mut self) {
447 self.armed = false;
448 }
449}
450
451impl Drop for PendingGuard<'_> {
452 fn drop(&mut self) {
453 if self.armed {
454 self.map.write().remove(&self.id);
455 }
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn deserialize_notification() {
465 let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
466 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
467 assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
468 }
469
470 #[test]
471 fn deserialize_request() {
472 let json =
473 r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
474 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
475 assert!(
476 matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
477 );
478 }
479
480 #[test]
481 fn deserialize_response_with_result() {
482 let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
483 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
484 assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
485 }
486
487 #[test]
488 fn deserialize_error_response() {
489 let json =
490 r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
491 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
492 match msg {
493 JsonRpcMessage::Response(r) => {
494 assert!(r.is_error());
495 let err = r.error.unwrap();
496 assert_eq!(err.code, -32600);
497 assert_eq!(err.message, "Invalid Request");
498 }
499 other => panic!("expected Response, got {other:?}"),
500 }
501 }
502
503 #[test]
504 fn deserialize_rejects_non_object() {
505 let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
506 assert!(result.is_err());
507 }
508
509 #[test]
510 fn request_new_sets_version() {
511 let req = JsonRpcRequest::new(42, "test.method", None);
512 assert_eq!(req.jsonrpc, "2.0");
513 assert_eq!(req.id, 42);
514 assert_eq!(req.method, "test.method");
515 assert!(req.params.is_none());
516 }
517
518 #[test]
519 fn request_serializes_camel_case() {
520 let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
521 let json = serde_json::to_string(&req).unwrap();
522 assert!(json.contains(r#""jsonrpc":"2.0""#));
523 assert!(json.contains(r#""id":1"#));
524 assert!(json.contains(r#""method":"ping""#));
525 }
526
527 #[test]
528 fn notification_without_params_omits_field() {
529 let n = JsonRpcNotification {
530 jsonrpc: "2.0".into(),
531 method: "ping".into(),
532 params: None,
533 };
534 let json = serde_json::to_string(&n).unwrap();
535 assert!(!json.contains("params"));
536 }
537
538 #[test]
539 fn response_without_error_omits_field() {
540 let r = JsonRpcResponse {
541 jsonrpc: "2.0".into(),
542 id: 1,
543 result: Some(serde_json::json!(true)),
544 error: None,
545 };
546 let json = serde_json::to_string(&r).unwrap();
547 assert!(!json.contains("error"));
548 }
549}