Skip to main content

github_copilot_sdk/
jsonrpc.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Instant;
5
6use parking_lot::{Mutex, RwLock};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::sync::{broadcast, mpsc, oneshot};
11use tokio::task::JoinHandle;
12use tracing::{Instrument, debug, error, warn};
13
14use crate::{Error, ProtocolError};
15
16/// A JSON-RPC 2.0 request message.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct JsonRpcRequest {
20    /// Protocol version (always `"2.0"`).
21    pub jsonrpc: String,
22    /// Request ID for correlating responses.
23    pub id: u64,
24    /// RPC method name.
25    pub method: String,
26    /// Optional method parameters.
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub params: Option<Value>,
29}
30
31/// A JSON-RPC 2.0 response message.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(rename_all = "camelCase")]
34pub struct JsonRpcResponse {
35    /// Protocol version (always `"2.0"`).
36    pub jsonrpc: String,
37    /// Request ID this response correlates to.
38    pub id: u64,
39    /// Success payload (mutually exclusive with `error`).
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub result: Option<Value>,
42    /// Error payload (mutually exclusive with `result`).
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub error: Option<JsonRpcError>,
45}
46
47/// A JSON-RPC 2.0 error object.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct JsonRpcError {
50    /// Numeric error code.
51    pub code: i32,
52    /// Human-readable error description.
53    pub message: String,
54    /// Optional structured error data.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub data: Option<Value>,
57}
58
59/// Standard JSON-RPC 2.0 error codes.
60pub mod error_codes {
61    /// Method not found (-32601).
62    pub const METHOD_NOT_FOUND: i32 = -32601;
63    /// Invalid method parameters (-32602).
64    pub const INVALID_PARAMS: i32 = -32602;
65    /// Internal server error (-32603).
66    #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
67    pub const INTERNAL_ERROR: i32 = -32603;
68}
69
70/// A JSON-RPC 2.0 notification (no `id`, no response expected).
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub struct JsonRpcNotification {
74    /// Protocol version (always `"2.0"`).
75    pub jsonrpc: String,
76    /// Notification method name.
77    pub method: String,
78    /// Optional notification parameters.
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub params: Option<Value>,
81}
82
83/// A parsed JSON-RPC 2.0 message — request, response, or notification.
84#[derive(Debug, Clone, Serialize)]
85pub enum JsonRpcMessage {
86    /// An incoming or outgoing request.
87    Request(JsonRpcRequest),
88    /// A response to a previous request.
89    Response(JsonRpcResponse),
90    /// A fire-and-forget notification.
91    Notification(JsonRpcNotification),
92}
93
94/// Custom deserializer that dispatches based on field presence instead of
95/// `#[serde(untagged)]` which tries each variant sequentially (3× parse
96/// attempts for Notification — the hot-path streaming variant).
97///
98/// Dispatch logic:
99/// - has `id` + has `method` → Request
100/// - has `id` + no `method` → Response
101/// - no `id`                → Notification
102impl<'de> Deserialize<'de> for JsonRpcMessage {
103    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104    where
105        D: serde::Deserializer<'de>,
106    {
107        let value = Value::deserialize(deserializer)?;
108        let obj = value
109            .as_object()
110            .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
111
112        let has_id = obj.contains_key("id");
113        let has_method = obj.contains_key("method");
114
115        if has_id && has_method {
116            JsonRpcRequest::deserialize(value)
117                .map(JsonRpcMessage::Request)
118                .map_err(serde::de::Error::custom)
119        } else if has_id {
120            JsonRpcResponse::deserialize(value)
121                .map(JsonRpcMessage::Response)
122                .map_err(serde::de::Error::custom)
123        } else {
124            JsonRpcNotification::deserialize(value)
125                .map(JsonRpcMessage::Notification)
126                .map_err(serde::de::Error::custom)
127        }
128    }
129}
130
131impl JsonRpcRequest {
132    /// Create a new JSON-RPC request with the given ID, method, and params.
133    pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
134        Self {
135            jsonrpc: "2.0".to_string(),
136            id,
137            method: method.to_string(),
138            params,
139        }
140    }
141}
142
143impl JsonRpcResponse {
144    /// Returns `true` if this response contains an error.
145    #[allow(dead_code)]
146    pub fn is_error(&self) -> bool {
147        self.error.is_some()
148    }
149}
150
151const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
152
153/// One framed JSON-RPC message handed to the writer actor.
154///
155/// `frame` is the fully serialized bytes (header + body); the caller pays
156/// the serde cost synchronously before enqueueing so the actor never sees a
157/// `Result` from JSON encoding. `ack` resolves once the bytes have been
158/// fully written and flushed (or the underlying I/O reports an error). If
159/// the caller drops the `oneshot::Receiver`, the actor still completes the
160/// frame — caller cancellation cannot desync the wire.
161struct WriteCommand {
162    frame: Vec<u8>,
163    ack: oneshot::Sender<Result<(), std::io::Error>>,
164}
165
166/// Low-level JSON-RPC 2.0 client over Content-Length-framed streams.
167///
168/// # Cancel safety
169///
170/// All public methods (`write`, `send_request`) are **cancel-safe**: the
171/// actual bytes hit the wire on a dedicated background actor task, so
172/// dropping the caller's future after `await` returns `Pending` cannot
173/// produce a partial frame on the wire. Frames either land atomically or
174/// the underlying I/O fails. See `cancel-safety review` artifact for the
175/// full RFD-400 reasoning.
176pub struct JsonRpcClient {
177    request_id: AtomicU64,
178    /// Sender side of the writer actor's command queue. Public methods
179    /// pre-serialize their frames and enqueue here; the background actor
180    /// drains the queue and serializes writes onto the underlying
181    /// `AsyncWrite`. Unbounded by design — RFD 400 explicitly permits this
182    /// for cancel-safety, and JSON-RPC frames are small relative to the
183    /// natural request/response back-pressure of the wire.
184    write_tx: mpsc::UnboundedSender<WriteCommand>,
185    pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
186    notification_tx: broadcast::Sender<JsonRpcNotification>,
187    request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
188    read_task: Mutex<Option<JoinHandle<()>>>,
189    write_task: Mutex<Option<JoinHandle<()>>>,
190}
191
192impl JsonRpcClient {
193    /// Create a new client from async read/write streams.
194    ///
195    /// Spawns two background tasks: a reader that dispatches incoming
196    /// messages to pending request channels, the notification broadcast,
197    /// or the request-forwarding channel; and a writer actor that owns the
198    /// underlying `AsyncWrite` and serializes frames atomically.
199    pub fn new(
200        writer: impl AsyncWrite + Unpin + Send + 'static,
201        reader: impl AsyncRead + Unpin + Send + 'static,
202        notification_tx: broadcast::Sender<JsonRpcNotification>,
203        request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
204    ) -> Self {
205        let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
206
207        let writer_span = tracing::error_span!("jsonrpc_write_loop");
208        let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
209
210        let client = Self {
211            request_id: AtomicU64::new(1),
212            write_tx,
213            pending_requests: Arc::new(RwLock::new(HashMap::new())),
214            notification_tx,
215            request_tx,
216            read_task: Mutex::new(None),
217            write_task: Mutex::new(Some(write_task)),
218        };
219
220        let pending_requests = client.pending_requests.clone();
221        let notification_tx_clone = client.notification_tx.clone();
222        let request_tx_clone = client.request_tx.clone();
223        let reader_span = tracing::error_span!("jsonrpc_read_loop");
224
225        let read_task = tokio::spawn(
226            async move {
227                Self::read_loop(
228                    reader,
229                    pending_requests,
230                    notification_tx_clone,
231                    request_tx_clone,
232                )
233                .await;
234            }
235            .instrument(reader_span),
236        );
237        *client.read_task.lock() = Some(read_task);
238
239        client
240    }
241
242    pub(crate) fn force_close(&self) {
243        if let Some(task) = self.read_task.lock().take() {
244            task.abort();
245        }
246        if let Some(task) = self.write_task.lock().take() {
247            task.abort();
248        }
249        self.pending_requests.write().clear();
250    }
251
252    /// Writer-actor task. Owns the `AsyncWrite`, drains the command queue,
253    /// and writes each frame atomically (header + body + flush) before
254    /// signaling the ack.
255    ///
256    /// Caller-side cancellation cannot interrupt a write in progress:
257    /// dropping the ack `oneshot::Receiver` does not cancel the in-flight
258    /// I/O. Once `WriteCommand` is enqueued the frame is committed to land
259    /// on the wire (or surface an `io::Error` to the ack receiver if the
260    /// transport is broken).
261    ///
262    /// Exits cleanly when all senders drop (channel closes), flushing any
263    /// final buffered bytes.
264    async fn write_loop(
265        mut writer: impl AsyncWrite + Unpin + Send + 'static,
266        mut rx: mpsc::UnboundedReceiver<WriteCommand>,
267    ) {
268        while let Some(WriteCommand { frame, ack }) = rx.recv().await {
269            let result = async {
270                writer.write_all(&frame).await?;
271                writer.flush().await?;
272                Ok::<_, std::io::Error>(())
273            }
274            .await;
275
276            // Caller may have dropped the ack receiver (e.g. their
277            // `await` was cancelled); that's fine — we still completed
278            // the write, which was the whole point.
279            let _ = ack.send(result);
280        }
281    }
282
283    async fn read_loop(
284        reader: impl AsyncRead + Unpin + Send,
285        pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
286        notification_tx: broadcast::Sender<JsonRpcNotification>,
287        request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
288    ) {
289        let mut reader = BufReader::new(reader);
290
291        loop {
292            match Self::read_message(&mut reader).await {
293                Ok(Some(message)) => match message {
294                    JsonRpcMessage::Response(response) => {
295                        let id = response.id;
296                        let tx = pending_requests.write().remove(&id);
297                        if let Some(tx) = tx {
298                            if tx.send(response).is_err() {
299                                warn!(request_id = %id, "failed to send response for request");
300                            }
301                        } else {
302                            warn!(request_id = %id, "received response for unknown request id");
303                        }
304                    }
305                    JsonRpcMessage::Notification(notification) => {
306                        let _ = notification_tx.send(notification);
307                    }
308                    JsonRpcMessage::Request(request) => {
309                        if request_tx.send(request).is_err() {
310                            warn!("failed to forward JSON-RPC request, channel closed");
311                        }
312                    }
313                },
314                Ok(None) => {
315                    break;
316                }
317                Err(e) => {
318                    error!(error = %e, "error reading from CLI");
319                    break;
320                }
321            }
322        }
323
324        // Drain in-flight requests so callers observe cancellation
325        // instead of hanging on a oneshot receiver.
326        let mut pending = pending_requests.write();
327        if !pending.is_empty() {
328            warn!(
329                count = pending.len(),
330                "draining pending requests after read loop exit"
331            );
332            pending.clear();
333        }
334    }
335
336    async fn read_message(
337        reader: &mut BufReader<impl AsyncRead + Unpin>,
338    ) -> Result<Option<JsonRpcMessage>, Error> {
339        let mut line = String::new();
340        let mut content_length = None;
341
342        loop {
343            line.clear();
344            if reader.read_line(&mut line).await? == 0 {
345                return Ok(None);
346            }
347
348            let trimmed = line.trim();
349            if trimmed.is_empty() {
350                break;
351            }
352
353            if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
354                content_length = Some(value.trim().parse::<usize>().map_err(|_| {
355                    Error::Protocol(ProtocolError::InvalidContentLength(
356                        value.trim().to_string(),
357                    ))
358                })?);
359            }
360        }
361
362        let Some(length) = content_length else {
363            return Err(Error::Protocol(ProtocolError::MissingContentLength));
364        };
365
366        let mut body = vec![0u8; length];
367        reader.read_exact(&mut body).await?;
368
369        let message: JsonRpcMessage = serde_json::from_slice(&body)?;
370        Ok(Some(message))
371    }
372
373    /// Send a JSON-RPC request and wait for the matching response.
374    ///
375    /// # Cancel safety
376    ///
377    /// **Cancel-safe.** The frame is committed to the wire via the writer
378    /// actor before this future yields; cancelling the await drops the
379    /// response oneshot but does not desync the transport. The pending-
380    /// requests map is cleaned up automatically (the `PendingGuard` drop
381    /// removes the entry, and the read loop's response handling tolerates
382    /// a missing entry).
383    pub async fn send_request(
384        &self,
385        method: &str,
386        params: Option<serde_json::Value>,
387    ) -> Result<JsonRpcResponse, Error> {
388        let request_start = Instant::now();
389        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
390        let request = JsonRpcRequest::new(id, method, params);
391
392        let (tx, rx) = oneshot::channel();
393        self.pending_requests.write().insert(id, tx);
394
395        // RAII guard that removes the pending entry if this future is
396        // dropped before the response arrives. Disarmed below before the
397        // success return so the read loop owns the cleanup on the happy
398        // path.
399        let mut guard = PendingGuard {
400            map: &self.pending_requests,
401            id,
402            armed: true,
403        };
404
405        // The PendingGuard's drop removes the entry on every error path
406        // and on cancellation; disarmed below before the success return so
407        // the read loop owns the cleanup on the happy path.
408        if let Err(error) = self.write(&request).await {
409            warn!(
410                elapsed_ms = request_start.elapsed().as_millis(),
411                method = %method,
412                request_id = id,
413                status = "failed",
414                error = %error,
415                "JsonRpcClient::send_request JSON-RPC request finished"
416            );
417            return Err(error);
418        }
419
420        let response = match rx.await {
421            Ok(response) => response,
422            Err(_) => {
423                let error = Error::Protocol(ProtocolError::RequestCancelled);
424                warn!(
425                    elapsed_ms = request_start.elapsed().as_millis(),
426                    method = %method,
427                    request_id = id,
428                    status = "failed",
429                    error = %error,
430                    "JsonRpcClient::send_request JSON-RPC request finished"
431                );
432                return Err(error);
433            }
434        };
435        guard.disarm();
436        if let Some(error) = &response.error {
437            warn!(
438                elapsed_ms = request_start.elapsed().as_millis(),
439                method = %method,
440                request_id = id,
441                status = "failed",
442                code = error.code,
443                error = %error.message,
444                "JsonRpcClient::send_request JSON-RPC request finished"
445            );
446        } else {
447            debug!(
448                elapsed_ms = request_start.elapsed().as_millis(),
449                method = %method,
450                request_id = id,
451                status = "succeeded",
452                "JsonRpcClient::send_request JSON-RPC request finished"
453            );
454        }
455        Ok(response)
456    }
457
458    /// Write a Content-Length-framed JSON-RPC message to the transport.
459    ///
460    /// # Cancel safety
461    ///
462    /// **Cancel-safe.** Pre-serializes the body, enqueues it on the writer
463    /// actor's command channel, and awaits an ack. Caller cancellation
464    /// drops the ack receiver; the actor still completes the frame and
465    /// flushes. A partial frame can never appear on the wire.
466    pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
467        let body = serde_json::to_vec(message)?;
468        let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
469        frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
470        frame.extend_from_slice(body.len().to_string().as_bytes());
471        frame.extend_from_slice(b"\r\n\r\n");
472        frame.extend_from_slice(&body);
473
474        let (ack_tx, ack_rx) = oneshot::channel();
475        self.write_tx
476            .send(WriteCommand { frame, ack: ack_tx })
477            .map_err(|_| {
478                Error::Io(std::io::Error::new(
479                    std::io::ErrorKind::BrokenPipe,
480                    "writer actor has shut down",
481                ))
482            })?;
483
484        match ack_rx.await {
485            Ok(Ok(())) => Ok(()),
486            Ok(Err(e)) => Err(Error::Io(e)),
487            Err(_) => Err(Error::Io(std::io::Error::new(
488                std::io::ErrorKind::BrokenPipe,
489                "writer actor dropped ack without responding",
490            ))),
491        }
492    }
493}
494
495/// RAII guard that removes a pending-request entry from the map if the
496/// owning future is dropped before the response arrives. Disarmed on the
497/// happy path so the read loop's response handling owns the cleanup.
498struct PendingGuard<'a> {
499    map: &'a RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>,
500    id: u64,
501    armed: bool,
502}
503
504impl PendingGuard<'_> {
505    fn disarm(&mut self) {
506        self.armed = false;
507    }
508}
509
510impl Drop for PendingGuard<'_> {
511    fn drop(&mut self) {
512        if self.armed {
513            self.map.write().remove(&self.id);
514        }
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn deserialize_notification() {
524        let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
525        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
526        assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
527    }
528
529    #[test]
530    fn deserialize_request() {
531        let json =
532            r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
533        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
534        assert!(
535            matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
536        );
537    }
538
539    #[test]
540    fn deserialize_response_with_result() {
541        let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
542        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
543        assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
544    }
545
546    #[test]
547    fn deserialize_error_response() {
548        let json =
549            r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
550        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
551        match msg {
552            JsonRpcMessage::Response(r) => {
553                assert!(r.is_error());
554                let err = r.error.unwrap();
555                assert_eq!(err.code, -32600);
556                assert_eq!(err.message, "Invalid Request");
557            }
558            other => panic!("expected Response, got {other:?}"),
559        }
560    }
561
562    #[test]
563    fn deserialize_rejects_non_object() {
564        let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
565        assert!(result.is_err());
566    }
567
568    #[test]
569    fn request_new_sets_version() {
570        let req = JsonRpcRequest::new(42, "test.method", None);
571        assert_eq!(req.jsonrpc, "2.0");
572        assert_eq!(req.id, 42);
573        assert_eq!(req.method, "test.method");
574        assert!(req.params.is_none());
575    }
576
577    #[test]
578    fn request_serializes_camel_case() {
579        let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
580        let json = serde_json::to_string(&req).unwrap();
581        assert!(json.contains(r#""jsonrpc":"2.0""#));
582        assert!(json.contains(r#""id":1"#));
583        assert!(json.contains(r#""method":"ping""#));
584    }
585
586    #[test]
587    fn notification_without_params_omits_field() {
588        let n = JsonRpcNotification {
589            jsonrpc: "2.0".into(),
590            method: "ping".into(),
591            params: None,
592        };
593        let json = serde_json::to_string(&n).unwrap();
594        assert!(!json.contains("params"));
595    }
596
597    #[test]
598    fn response_without_error_omits_field() {
599        let r = JsonRpcResponse {
600            jsonrpc: "2.0".into(),
601            id: 1,
602            result: Some(serde_json::json!(true)),
603            error: None,
604        };
605        let json = serde_json::to_string(&r).unwrap();
606        assert!(!json.contains("error"));
607    }
608}