agent_client_protocol_schema/
rpc.rs1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_with::skip_serializing_none;
7
8#[derive(
18 Debug,
19 PartialEq,
20 Clone,
21 Hash,
22 Eq,
23 Deserialize,
24 Serialize,
25 PartialOrd,
26 Ord,
27 Display,
28 JsonSchema,
29 From,
30)]
31#[serde(untagged)]
32#[allow(
33 clippy::exhaustive_enums,
34 reason = "This comes from the JSON-RPC specification itself"
35)]
36#[from(String, i64)]
37pub enum RequestId {
38 #[display("null")]
39 Null,
40 Number(i64),
41 Str(String),
42}
43
44#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
45#[allow(
46 clippy::exhaustive_structs,
47 reason = "This comes from the JSON-RPC specification itself"
48)]
49#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
50#[skip_serializing_none]
51pub struct Request<Params> {
52 pub id: RequestId,
53 pub method: Arc<str>,
54 pub params: Option<Params>,
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
58#[allow(
59 clippy::exhaustive_enums,
60 reason = "This comes from the JSON-RPC specification itself"
61)]
62#[serde(untagged)]
63#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
64pub enum Response<Result, Error> {
65 Result { id: RequestId, result: Result },
66 Error { id: RequestId, error: Error },
67}
68
69impl<R, E> Response<R, E> {
70 #[must_use]
71 pub fn new(id: impl Into<RequestId>, result: std::result::Result<R, E>) -> Self {
72 match result {
73 Ok(result) => Self::Result {
74 id: id.into(),
75 result,
76 },
77 Err(error) => Self::Error {
78 id: id.into(),
79 error,
80 },
81 }
82 }
83}
84
85#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
86#[allow(
87 clippy::exhaustive_structs,
88 reason = "This comes from the JSON-RPC specification itself"
89)]
90#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
91#[skip_serializing_none]
92pub struct Notification<Params> {
93 pub method: Arc<str>,
94 pub params: Option<Params>,
95}
96
97#[derive(Debug, Serialize, Deserialize, JsonSchema)]
98#[schemars(inline)]
99enum JsonRpcVersion {
100 #[serde(rename = "2.0")]
101 V2,
102}
103
104#[derive(Debug, Serialize, Deserialize, JsonSchema)]
109#[schemars(inline)]
110pub struct JsonRpcMessage<M> {
111 jsonrpc: JsonRpcVersion,
112 #[serde(flatten)]
113 message: M,
114}
115
116impl<M> JsonRpcMessage<M> {
117 #[must_use]
119 pub fn wrap(message: M) -> Self {
120 Self {
121 jsonrpc: JsonRpcVersion::V2,
122 message,
123 }
124 }
125
126 #[must_use]
128 pub fn into_inner(self) -> M {
129 self.message
130 }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
134#[display("JSON-RPC batch must contain at least one message")]
135#[non_exhaustive]
136pub struct EmptyJsonRpcBatch;
137
138impl std::error::Error for EmptyJsonRpcBatch {}
139
140#[derive(Debug, Serialize, JsonSchema)]
142#[schemars(inline)]
143#[serde(transparent)]
144#[allow(
145 clippy::exhaustive_structs,
146 reason = "This comes from the JSON-RPC specification itself"
147)]
148pub struct JsonRpcBatch<M>(#[schemars(length(min = 1))] Vec<JsonRpcMessage<M>>);
149
150impl<M> JsonRpcBatch<M> {
151 pub fn new(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, EmptyJsonRpcBatch> {
160 if messages.is_empty() {
161 Err(EmptyJsonRpcBatch)
162 } else {
163 Ok(Self(messages))
164 }
165 }
166
167 #[must_use]
169 pub fn as_slice(&self) -> &[JsonRpcMessage<M>] {
170 &self.0
171 }
172
173 #[must_use]
175 pub fn into_vec(self) -> Vec<JsonRpcMessage<M>> {
176 self.0
177 }
178}
179
180impl<M> TryFrom<Vec<JsonRpcMessage<M>>> for JsonRpcBatch<M> {
181 type Error = EmptyJsonRpcBatch;
182
183 fn try_from(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, Self::Error> {
184 Self::new(messages)
185 }
186}
187
188impl<'de, M> Deserialize<'de> for JsonRpcBatch<M>
189where
190 M: Deserialize<'de>,
191{
192 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193 where
194 D: serde::Deserializer<'de>,
195 {
196 let messages = Vec::<JsonRpcMessage<M>>::deserialize(deserializer)?;
197 Self::new(messages).map_err(serde::de::Error::custom)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 use crate::{
206 AgentNotification, CancelNotification, ClientNotification, ContentBlock, ContentChunk,
207 SessionId, SessionNotification, SessionUpdate, TextContent,
208 };
209 use serde_json::{Number, Value, json};
210
211 #[test]
212 fn id_deserialization() {
213 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
214 assert_eq!(id, RequestId::Null);
215
216 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
217 .unwrap();
218 assert_eq!(id, RequestId::Number(1));
219
220 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
221 .unwrap();
222 assert_eq!(id, RequestId::Number(-1));
223
224 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
225 assert_eq!(id, RequestId::Str("id".to_owned()));
226 }
227
228 #[test]
229 fn id_serialization() {
230 let id = serde_json::to_value(RequestId::Null).unwrap();
231 assert_eq!(id, Value::Null);
232
233 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
234 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
235
236 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
237 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
238
239 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
240 assert_eq!(id, Value::String("id".to_owned()));
241 }
242
243 #[test]
244 fn id_display() {
245 let id = RequestId::Null;
246 assert_eq!(id.to_string(), "null");
247
248 let id = RequestId::Number(1);
249 assert_eq!(id.to_string(), "1");
250
251 let id = RequestId::Number(-1);
252 assert_eq!(id.to_string(), "-1");
253
254 let id = RequestId::Str("id".to_owned());
255 assert_eq!(id.to_string(), "id");
256 }
257
258 #[test]
259 fn batch_deserialization_requires_at_least_one_message() {
260 let err = serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(
261 Value::Array(Vec::new()),
262 )
263 .unwrap_err();
264 assert!(err.to_string().contains("at least one message"));
265 }
266
267 #[test]
268 fn batch_serialization_round_trips_non_empty_messages() {
269 let notification = JsonRpcMessage::wrap(Notification {
270 method: "cancel".into(),
271 params: Some(ClientNotification::CancelNotification(CancelNotification {
272 session_id: SessionId("test-123".into()),
273 meta: None,
274 })),
275 });
276
277 let batch = JsonRpcBatch::new(vec![notification]).unwrap();
278 let serialized = serde_json::to_value(&batch).unwrap();
279 assert_eq!(
280 serialized,
281 json!([{
282 "jsonrpc": "2.0",
283 "method": "cancel",
284 "params": {
285 "sessionId": "test-123"
286 },
287 }])
288 );
289
290 let deserialized =
291 serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(serialized)
292 .unwrap();
293 assert_eq!(deserialized.as_slice().len(), 1);
294 }
295
296 #[test]
297 fn notification_wire_format() {
298 let outgoing_msg = JsonRpcMessage::wrap(Notification {
300 method: "cancel".into(),
301 params: Some(ClientNotification::CancelNotification(CancelNotification {
302 session_id: SessionId("test-123".into()),
303 meta: None,
304 })),
305 });
306
307 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
308 assert_eq!(
309 serialized,
310 json!({
311 "jsonrpc": "2.0",
312 "method": "cancel",
313 "params": {
314 "sessionId": "test-123"
315 },
316 })
317 );
318
319 let outgoing_msg = JsonRpcMessage::wrap(Notification {
321 method: "sessionUpdate".into(),
322 params: Some(AgentNotification::SessionNotification(
323 SessionNotification {
324 session_id: SessionId("test-456".into()),
325 update: SessionUpdate::AgentMessageChunk(ContentChunk {
326 content: ContentBlock::Text(TextContent {
327 annotations: None,
328 text: "Hello".to_string(),
329 meta: None,
330 }),
331 message_id: None,
332 meta: None,
333 }),
334 meta: None,
335 },
336 )),
337 });
338
339 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
340 assert_eq!(
341 serialized,
342 json!({
343 "jsonrpc": "2.0",
344 "method": "sessionUpdate",
345 "params": {
346 "sessionId": "test-456",
347 "update": {
348 "sessionUpdate": "agent_message_chunk",
349 "content": {
350 "type": "text",
351 "text": "Hello"
352 }
353 }
354 }
355 })
356 );
357 }
358}