Skip to main content

github_copilot_sdk/
jsonrpc.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tracing::{Instrument, error, warn};
11
12use crate::{Error, ProtocolError};
13
14/// A JSON-RPC 2.0 request message.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct JsonRpcRequest {
18    /// Protocol version (always `"2.0"`).
19    pub jsonrpc: String,
20    /// Request ID for correlating responses.
21    pub id: u64,
22    /// RPC method name.
23    pub method: String,
24    /// Optional method parameters.
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub params: Option<Value>,
27}
28
29/// A JSON-RPC 2.0 response message.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct JsonRpcResponse {
33    /// Protocol version (always `"2.0"`).
34    pub jsonrpc: String,
35    /// Request ID this response correlates to.
36    pub id: u64,
37    /// Success payload (mutually exclusive with `error`).
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub result: Option<Value>,
40    /// Error payload (mutually exclusive with `result`).
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub error: Option<JsonRpcError>,
43}
44
45/// A JSON-RPC 2.0 error object.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct JsonRpcError {
48    /// Numeric error code.
49    pub code: i32,
50    /// Human-readable error description.
51    pub message: String,
52    /// Optional structured error data.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub data: Option<Value>,
55}
56
57/// Standard JSON-RPC 2.0 error codes.
58pub mod error_codes {
59    /// Method not found (-32601).
60    pub const METHOD_NOT_FOUND: i32 = -32601;
61    /// Invalid method parameters (-32602).
62    pub const INVALID_PARAMS: i32 = -32602;
63    /// Internal server error (-32603).
64    #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
65    pub const INTERNAL_ERROR: i32 = -32603;
66}
67
68/// A JSON-RPC 2.0 notification (no `id`, no response expected).
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct JsonRpcNotification {
72    /// Protocol version (always `"2.0"`).
73    pub jsonrpc: String,
74    /// Notification method name.
75    pub method: String,
76    /// Optional notification parameters.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub params: Option<Value>,
79}
80
81/// A parsed JSON-RPC 2.0 message — request, response, or notification.
82#[derive(Debug, Clone, Serialize)]
83pub enum JsonRpcMessage {
84    /// An incoming or outgoing request.
85    Request(JsonRpcRequest),
86    /// A response to a previous request.
87    Response(JsonRpcResponse),
88    /// A fire-and-forget notification.
89    Notification(JsonRpcNotification),
90}
91
92/// Custom deserializer that dispatches based on field presence instead of
93/// `#[serde(untagged)]` which tries each variant sequentially (3× parse
94/// attempts for Notification — the hot-path streaming variant).
95///
96/// Dispatch logic:
97/// - has `id` + has `method` → Request
98/// - has `id` + no `method` → Response
99/// - no `id`                → Notification
100impl<'de> Deserialize<'de> for JsonRpcMessage {
101    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102    where
103        D: serde::Deserializer<'de>,
104    {
105        let value = Value::deserialize(deserializer)?;
106        let obj = value
107            .as_object()
108            .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?;
109
110        let has_id = obj.contains_key("id");
111        let has_method = obj.contains_key("method");
112
113        if has_id && has_method {
114            JsonRpcRequest::deserialize(value)
115                .map(JsonRpcMessage::Request)
116                .map_err(serde::de::Error::custom)
117        } else if has_id {
118            JsonRpcResponse::deserialize(value)
119                .map(JsonRpcMessage::Response)
120                .map_err(serde::de::Error::custom)
121        } else {
122            JsonRpcNotification::deserialize(value)
123                .map(JsonRpcMessage::Notification)
124                .map_err(serde::de::Error::custom)
125        }
126    }
127}
128
129impl JsonRpcRequest {
130    /// Create a new JSON-RPC request with the given ID, method, and params.
131    pub fn new(id: u64, method: &str, params: Option<Value>) -> Self {
132        Self {
133            jsonrpc: "2.0".to_string(),
134            id,
135            method: method.to_string(),
136            params,
137        }
138    }
139}
140
141impl JsonRpcResponse {
142    /// Returns `true` if this response contains an error.
143    #[allow(dead_code)]
144    pub fn is_error(&self) -> bool {
145        self.error.is_some()
146    }
147}
148
149const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
150
151/// One framed JSON-RPC message handed to the writer actor.
152///
153/// `frame` is the fully serialized bytes (header + body); the caller pays
154/// the serde cost synchronously before enqueueing so the actor never sees a
155/// `Result` from JSON encoding. `ack` resolves once the bytes have been
156/// fully written and flushed (or the underlying I/O reports an error). If
157/// the caller drops the `oneshot::Receiver`, the actor still completes the
158/// frame — caller cancellation cannot desync the wire.
159struct WriteCommand {
160    frame: Vec<u8>,
161    ack: oneshot::Sender<Result<(), std::io::Error>>,
162}
163
164/// Low-level JSON-RPC 2.0 client over Content-Length-framed streams.
165///
166/// # Cancel safety
167///
168/// All public methods (`write`, `send_request`) are **cancel-safe**: the
169/// actual bytes hit the wire on a dedicated background actor task, so
170/// dropping the caller's future after `await` returns `Pending` cannot
171/// produce a partial frame on the wire. Frames either land atomically or
172/// the underlying I/O fails. See `cancel-safety review` artifact for the
173/// full RFD-400 reasoning.
174pub struct JsonRpcClient {
175    request_id: AtomicU64,
176    /// Sender side of the writer actor's command queue. Public methods
177    /// pre-serialize their frames and enqueue here; the background actor
178    /// drains the queue and serializes writes onto the underlying
179    /// `AsyncWrite`. Unbounded by design — RFD 400 explicitly permits this
180    /// for cancel-safety, and JSON-RPC frames are small relative to the
181    /// natural request/response back-pressure of the wire.
182    write_tx: mpsc::UnboundedSender<WriteCommand>,
183    pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
184    notification_tx: broadcast::Sender<JsonRpcNotification>,
185    request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
186}
187
188impl JsonRpcClient {
189    /// Create a new client from async read/write streams.
190    ///
191    /// Spawns two background tasks: a reader that dispatches incoming
192    /// messages to pending request channels, the notification broadcast,
193    /// or the request-forwarding channel; and a writer actor that owns the
194    /// underlying `AsyncWrite` and serializes frames atomically.
195    pub fn new(
196        writer: impl AsyncWrite + Unpin + Send + 'static,
197        reader: impl AsyncRead + Unpin + Send + 'static,
198        notification_tx: broadcast::Sender<JsonRpcNotification>,
199        request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
200    ) -> Self {
201        let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();
202
203        let writer_span = tracing::error_span!("jsonrpc_write_loop");
204        tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
205
206        let client = Self {
207            request_id: AtomicU64::new(1),
208            write_tx,
209            pending_requests: Arc::new(RwLock::new(HashMap::new())),
210            notification_tx,
211            request_tx,
212        };
213
214        let pending_requests = client.pending_requests.clone();
215        let notification_tx_clone = client.notification_tx.clone();
216        let request_tx_clone = client.request_tx.clone();
217        let reader_span = tracing::error_span!("jsonrpc_read_loop");
218
219        tokio::spawn(
220            async move {
221                Self::read_loop(
222                    reader,
223                    pending_requests,
224                    notification_tx_clone,
225                    request_tx_clone,
226                )
227                .await;
228            }
229            .instrument(reader_span),
230        );
231
232        client
233    }
234
235    /// Writer-actor task. Owns the `AsyncWrite`, drains the command queue,
236    /// and writes each frame atomically (header + body + flush) before
237    /// signaling the ack.
238    ///
239    /// Caller-side cancellation cannot interrupt a write in progress:
240    /// dropping the ack `oneshot::Receiver` does not cancel the in-flight
241    /// I/O. Once `WriteCommand` is enqueued the frame is committed to land
242    /// on the wire (or surface an `io::Error` to the ack receiver if the
243    /// transport is broken).
244    ///
245    /// Exits cleanly when all senders drop (channel closes), flushing any
246    /// final buffered bytes.
247    async fn write_loop(
248        mut writer: impl AsyncWrite + Unpin + Send + 'static,
249        mut rx: mpsc::UnboundedReceiver<WriteCommand>,
250    ) {
251        while let Some(WriteCommand { frame, ack }) = rx.recv().await {
252            let result = async {
253                writer.write_all(&frame).await?;
254                writer.flush().await?;
255                Ok::<_, std::io::Error>(())
256            }
257            .await;
258
259            // Caller may have dropped the ack receiver (e.g. their
260            // `await` was cancelled); that's fine — we still completed
261            // the write, which was the whole point.
262            let _ = ack.send(result);
263        }
264    }
265
266    async fn read_loop(
267        reader: impl AsyncRead + Unpin + Send,
268        pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
269        notification_tx: broadcast::Sender<JsonRpcNotification>,
270        request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
271    ) {
272        let mut reader = BufReader::new(reader);
273
274        loop {
275            match Self::read_message(&mut reader).await {
276                Ok(Some(message)) => match message {
277                    JsonRpcMessage::Response(response) => {
278                        let id = response.id;
279                        let tx = pending_requests.write().remove(&id);
280                        if let Some(tx) = tx {
281                            if tx.send(response).is_err() {
282                                warn!(request_id = %id, "failed to send response for request");
283                            }
284                        } else {
285                            warn!(request_id = %id, "received response for unknown request id");
286                        }
287                    }
288                    JsonRpcMessage::Notification(notification) => {
289                        let _ = notification_tx.send(notification);
290                    }
291                    JsonRpcMessage::Request(request) => {
292                        if request_tx.send(request).is_err() {
293                            warn!("failed to forward JSON-RPC request, channel closed");
294                        }
295                    }
296                },
297                Ok(None) => {
298                    break;
299                }
300                Err(e) => {
301                    error!(error = %e, "error reading from CLI");
302                    break;
303                }
304            }
305        }
306
307        // Drain in-flight requests so callers observe cancellation
308        // instead of hanging on a oneshot receiver.
309        let mut pending = pending_requests.write();
310        if !pending.is_empty() {
311            warn!(
312                count = pending.len(),
313                "draining pending requests after read loop exit"
314            );
315            pending.clear();
316        }
317    }
318
319    async fn read_message(
320        reader: &mut BufReader<impl AsyncRead + Unpin>,
321    ) -> Result<Option<JsonRpcMessage>, Error> {
322        let mut line = String::new();
323        let mut content_length = None;
324
325        loop {
326            line.clear();
327            if reader.read_line(&mut line).await? == 0 {
328                return Ok(None);
329            }
330
331            let trimmed = line.trim();
332            if trimmed.is_empty() {
333                break;
334            }
335
336            if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) {
337                content_length = Some(value.trim().parse::<usize>().map_err(|_| {
338                    Error::Protocol(ProtocolError::InvalidContentLength(
339                        value.trim().to_string(),
340                    ))
341                })?);
342            }
343        }
344
345        let Some(length) = content_length else {
346            return Err(Error::Protocol(ProtocolError::MissingContentLength));
347        };
348
349        let mut body = vec![0u8; length];
350        reader.read_exact(&mut body).await?;
351
352        let message: JsonRpcMessage = serde_json::from_slice(&body)?;
353        Ok(Some(message))
354    }
355
356    /// Send a JSON-RPC request and wait for the matching response.
357    ///
358    /// # Cancel safety
359    ///
360    /// **Cancel-safe.** The frame is committed to the wire via the writer
361    /// actor before this future yields; cancelling the await drops the
362    /// response oneshot but does not desync the transport. The pending-
363    /// requests map is cleaned up automatically (the `PendingGuard` drop
364    /// removes the entry, and the read loop's response handling tolerates
365    /// a missing entry).
366    pub async fn send_request(
367        &self,
368        method: &str,
369        params: Option<serde_json::Value>,
370    ) -> Result<JsonRpcResponse, Error> {
371        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
372        let request = JsonRpcRequest::new(id, method, params);
373
374        let (tx, rx) = oneshot::channel();
375        self.pending_requests.write().insert(id, tx);
376
377        // RAII guard that removes the pending entry if this future is
378        // dropped before the response arrives. Disarmed below before the
379        // success return so the read loop owns the cleanup on the happy
380        // path.
381        let mut guard = PendingGuard {
382            map: &self.pending_requests,
383            id,
384            armed: true,
385        };
386
387        // The PendingGuard's drop removes the entry on every error path
388        // and on cancellation; disarmed below before the success return so
389        // the read loop owns the cleanup on the happy path.
390        self.write(&request).await?;
391
392        let response = rx
393            .await
394            .map_err(|_| Error::Protocol(ProtocolError::RequestCancelled))?;
395        guard.disarm();
396        Ok(response)
397    }
398
399    /// Write a Content-Length-framed JSON-RPC message to the transport.
400    ///
401    /// # Cancel safety
402    ///
403    /// **Cancel-safe.** Pre-serializes the body, enqueues it on the writer
404    /// actor's command channel, and awaits an ack. Caller cancellation
405    /// drops the ack receiver; the actor still completes the frame and
406    /// flushes. A partial frame can never appear on the wire.
407    pub async fn write<T: serde::Serialize>(&self, message: &T) -> Result<(), Error> {
408        let body = serde_json::to_vec(message)?;
409        let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4);
410        frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes());
411        frame.extend_from_slice(body.len().to_string().as_bytes());
412        frame.extend_from_slice(b"\r\n\r\n");
413        frame.extend_from_slice(&body);
414
415        let (ack_tx, ack_rx) = oneshot::channel();
416        self.write_tx
417            .send(WriteCommand { frame, ack: ack_tx })
418            .map_err(|_| {
419                Error::Io(std::io::Error::new(
420                    std::io::ErrorKind::BrokenPipe,
421                    "writer actor has shut down",
422                ))
423            })?;
424
425        match ack_rx.await {
426            Ok(Ok(())) => Ok(()),
427            Ok(Err(e)) => Err(Error::Io(e)),
428            Err(_) => Err(Error::Io(std::io::Error::new(
429                std::io::ErrorKind::BrokenPipe,
430                "writer actor dropped ack without responding",
431            ))),
432        }
433    }
434}
435
436/// RAII guard that removes a pending-request entry from the map if the
437/// owning future is dropped before the response arrives. Disarmed on the
438/// happy path so the read loop's response handling owns the cleanup.
439struct PendingGuard<'a> {
440    map: &'a RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>,
441    id: u64,
442    armed: bool,
443}
444
445impl PendingGuard<'_> {
446    fn disarm(&mut self) {
447        self.armed = false;
448    }
449}
450
451impl Drop for PendingGuard<'_> {
452    fn drop(&mut self) {
453        if self.armed {
454            self.map.write().remove(&self.id);
455        }
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn deserialize_notification() {
465        let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#;
466        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
467        assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event"));
468    }
469
470    #[test]
471    fn deserialize_request() {
472        let json =
473            r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#;
474        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
475        assert!(
476            matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request")
477        );
478    }
479
480    #[test]
481    fn deserialize_response_with_result() {
482        let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
483        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
484        assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error()));
485    }
486
487    #[test]
488    fn deserialize_error_response() {
489        let json =
490            r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#;
491        let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
492        match msg {
493            JsonRpcMessage::Response(r) => {
494                assert!(r.is_error());
495                let err = r.error.unwrap();
496                assert_eq!(err.code, -32600);
497                assert_eq!(err.message, "Invalid Request");
498            }
499            other => panic!("expected Response, got {other:?}"),
500        }
501    }
502
503    #[test]
504    fn deserialize_rejects_non_object() {
505        let result = serde_json::from_str::<JsonRpcMessage>(r#""not an object""#);
506        assert!(result.is_err());
507    }
508
509    #[test]
510    fn request_new_sets_version() {
511        let req = JsonRpcRequest::new(42, "test.method", None);
512        assert_eq!(req.jsonrpc, "2.0");
513        assert_eq!(req.id, 42);
514        assert_eq!(req.method, "test.method");
515        assert!(req.params.is_none());
516    }
517
518    #[test]
519    fn request_serializes_camel_case() {
520        let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({})));
521        let json = serde_json::to_string(&req).unwrap();
522        assert!(json.contains(r#""jsonrpc":"2.0""#));
523        assert!(json.contains(r#""id":1"#));
524        assert!(json.contains(r#""method":"ping""#));
525    }
526
527    #[test]
528    fn notification_without_params_omits_field() {
529        let n = JsonRpcNotification {
530            jsonrpc: "2.0".into(),
531            method: "ping".into(),
532            params: None,
533        };
534        let json = serde_json::to_string(&n).unwrap();
535        assert!(!json.contains("params"));
536    }
537
538    #[test]
539    fn response_without_error_omits_field() {
540        let r = JsonRpcResponse {
541            jsonrpc: "2.0".into(),
542            id: 1,
543            result: Some(serde_json::json!(true)),
544            error: None,
545        };
546        let json = serde_json::to_string(&r).unwrap();
547        assert!(!json.contains("error"));
548    }
549}