Skip to main content

agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_with::skip_serializing_none;
7
8/// JSON RPC Request Id
9///
10/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null \[1\] and Numbers SHOULD NOT contain fractional parts \[2\]
11///
12/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
13///
14/// \[1\] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
15///
16/// \[2\] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
17#[derive(
18    Debug,
19    PartialEq,
20    Clone,
21    Hash,
22    Eq,
23    Deserialize,
24    Serialize,
25    PartialOrd,
26    Ord,
27    Display,
28    JsonSchema,
29    From,
30)]
31#[serde(untagged)]
32#[allow(
33    clippy::exhaustive_enums,
34    reason = "This comes from the JSON-RPC specification itself"
35)]
36#[from(String, i64)]
37pub enum RequestId {
38    #[display("null")]
39    Null,
40    Number(i64),
41    Str(String),
42}
43
44#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
45#[allow(
46    clippy::exhaustive_structs,
47    reason = "This comes from the JSON-RPC specification itself"
48)]
49#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
50#[skip_serializing_none]
51pub struct Request<Params> {
52    pub id: RequestId,
53    pub method: Arc<str>,
54    pub params: Option<Params>,
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
58#[allow(
59    clippy::exhaustive_enums,
60    reason = "This comes from the JSON-RPC specification itself"
61)]
62#[serde(untagged)]
63#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
64pub enum Response<Result, Error> {
65    Result { id: RequestId, result: Result },
66    Error { id: RequestId, error: Error },
67}
68
69impl<R, E> Response<R, E> {
70    #[must_use]
71    pub fn new(id: impl Into<RequestId>, result: std::result::Result<R, E>) -> Self {
72        match result {
73            Ok(result) => Self::Result {
74                id: id.into(),
75                result,
76            },
77            Err(error) => Self::Error {
78                id: id.into(),
79                error,
80            },
81        }
82    }
83}
84
85#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
86#[allow(
87    clippy::exhaustive_structs,
88    reason = "This comes from the JSON-RPC specification itself"
89)]
90#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
91#[skip_serializing_none]
92pub struct Notification<Params> {
93    pub method: Arc<str>,
94    pub params: Option<Params>,
95}
96
97#[derive(Debug, Serialize, Deserialize, JsonSchema)]
98#[schemars(inline)]
99enum JsonRpcVersion {
100    #[serde(rename = "2.0")]
101    V2,
102}
103
104/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
105/// [required by JSON-RPC 2.0 Specification][1].
106///
107/// [1]: https://www.jsonrpc.org/specification#compatibility
108#[derive(Debug, Serialize, Deserialize, JsonSchema)]
109#[schemars(inline)]
110pub struct JsonRpcMessage<M> {
111    jsonrpc: JsonRpcVersion,
112    #[serde(flatten)]
113    message: M,
114}
115
116impl<M> JsonRpcMessage<M> {
117    /// Wraps the provided message into a versioned [`JsonRpcMessage`].
118    #[must_use]
119    pub fn wrap(message: M) -> Self {
120        Self {
121            jsonrpc: JsonRpcVersion::V2,
122            message,
123        }
124    }
125
126    /// Unwraps the contained message.
127    #[must_use]
128    pub fn into_inner(self) -> M {
129        self.message
130    }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
134#[display("JSON-RPC batch must contain at least one message")]
135#[non_exhaustive]
136pub struct EmptyJsonRpcBatch;
137
138impl std::error::Error for EmptyJsonRpcBatch {}
139
140/// A non-empty JSON-RPC 2.0 batch message.
141#[derive(Debug, Serialize, JsonSchema)]
142#[schemars(inline)]
143#[serde(transparent)]
144#[allow(
145    clippy::exhaustive_structs,
146    reason = "This comes from the JSON-RPC specification itself"
147)]
148pub struct JsonRpcBatch<M>(#[schemars(length(min = 1))] Vec<JsonRpcMessage<M>>);
149
150impl<M> JsonRpcBatch<M> {
151    /// Creates a non-empty JSON-RPC batch.
152    ///
153    /// Returns an error if `messages` is empty, because JSON-RPC 2.0 treats an
154    /// empty batch array as an invalid request.
155    ///
156    /// # Errors
157    ///
158    /// Returns [`EmptyJsonRpcBatch`] when `messages` is empty.
159    pub fn new(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, EmptyJsonRpcBatch> {
160        if messages.is_empty() {
161            Err(EmptyJsonRpcBatch)
162        } else {
163            Ok(Self(messages))
164        }
165    }
166
167    /// Returns the messages in this batch.
168    #[must_use]
169    pub fn as_slice(&self) -> &[JsonRpcMessage<M>] {
170        &self.0
171    }
172
173    /// Consumes this batch and returns its messages.
174    #[must_use]
175    pub fn into_vec(self) -> Vec<JsonRpcMessage<M>> {
176        self.0
177    }
178}
179
180impl<M> TryFrom<Vec<JsonRpcMessage<M>>> for JsonRpcBatch<M> {
181    type Error = EmptyJsonRpcBatch;
182
183    fn try_from(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, Self::Error> {
184        Self::new(messages)
185    }
186}
187
188impl<'de, M> Deserialize<'de> for JsonRpcBatch<M>
189where
190    M: Deserialize<'de>,
191{
192    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193    where
194        D: serde::Deserializer<'de>,
195    {
196        let messages = Vec::<JsonRpcMessage<M>>::deserialize(deserializer)?;
197        Self::new(messages).map_err(serde::de::Error::custom)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    use crate::{
206        AgentNotification, CancelNotification, ClientNotification, ContentBlock, ContentChunk,
207        SessionId, SessionNotification, SessionUpdate, TextContent,
208    };
209    use serde_json::{Number, Value, json};
210
211    #[test]
212    fn id_deserialization() {
213        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
214        assert_eq!(id, RequestId::Null);
215
216        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
217            .unwrap();
218        assert_eq!(id, RequestId::Number(1));
219
220        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
221            .unwrap();
222        assert_eq!(id, RequestId::Number(-1));
223
224        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
225        assert_eq!(id, RequestId::Str("id".to_owned()));
226    }
227
228    #[test]
229    fn id_serialization() {
230        let id = serde_json::to_value(RequestId::Null).unwrap();
231        assert_eq!(id, Value::Null);
232
233        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
234        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
235
236        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
237        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
238
239        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
240        assert_eq!(id, Value::String("id".to_owned()));
241    }
242
243    #[test]
244    fn id_display() {
245        let id = RequestId::Null;
246        assert_eq!(id.to_string(), "null");
247
248        let id = RequestId::Number(1);
249        assert_eq!(id.to_string(), "1");
250
251        let id = RequestId::Number(-1);
252        assert_eq!(id.to_string(), "-1");
253
254        let id = RequestId::Str("id".to_owned());
255        assert_eq!(id.to_string(), "id");
256    }
257
258    #[test]
259    fn batch_deserialization_requires_at_least_one_message() {
260        let err = serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(
261            Value::Array(Vec::new()),
262        )
263        .unwrap_err();
264        assert!(err.to_string().contains("at least one message"));
265    }
266
267    #[test]
268    fn batch_serialization_round_trips_non_empty_messages() {
269        let notification = JsonRpcMessage::wrap(Notification {
270            method: "cancel".into(),
271            params: Some(ClientNotification::CancelNotification(CancelNotification {
272                session_id: SessionId("test-123".into()),
273                meta: None,
274            })),
275        });
276
277        let batch = JsonRpcBatch::new(vec![notification]).unwrap();
278        let serialized = serde_json::to_value(&batch).unwrap();
279        assert_eq!(
280            serialized,
281            json!([{
282                "jsonrpc": "2.0",
283                "method": "cancel",
284                "params": {
285                    "sessionId": "test-123"
286                },
287            }])
288        );
289
290        let deserialized =
291            serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(serialized)
292                .unwrap();
293        assert_eq!(deserialized.as_slice().len(), 1);
294    }
295
296    #[test]
297    fn notification_wire_format() {
298        // Test client -> agent notification wire format
299        let outgoing_msg = JsonRpcMessage::wrap(Notification {
300            method: "cancel".into(),
301            params: Some(ClientNotification::CancelNotification(CancelNotification {
302                session_id: SessionId("test-123".into()),
303                meta: None,
304            })),
305        });
306
307        let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
308        assert_eq!(
309            serialized,
310            json!({
311                "jsonrpc": "2.0",
312                "method": "cancel",
313                "params": {
314                    "sessionId": "test-123"
315                },
316            })
317        );
318
319        // Test agent -> client notification wire format
320        let outgoing_msg = JsonRpcMessage::wrap(Notification {
321            method: "sessionUpdate".into(),
322            params: Some(AgentNotification::SessionNotification(
323                SessionNotification {
324                    session_id: SessionId("test-456".into()),
325                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
326                        content: ContentBlock::Text(TextContent {
327                            annotations: None,
328                            text: "Hello".to_string(),
329                            meta: None,
330                        }),
331                        message_id: None,
332                        meta: None,
333                    }),
334                    meta: None,
335                },
336            )),
337        });
338
339        let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
340        assert_eq!(
341            serialized,
342            json!({
343                "jsonrpc": "2.0",
344                "method": "sessionUpdate",
345                "params": {
346                    "sessionId": "test-456",
347                    "update": {
348                        "sessionUpdate": "agent_message_chunk",
349                        "content": {
350                            "type": "text",
351                            "text": "Hello"
352                        }
353                    }
354                }
355            })
356        );
357    }
358}