agent_client_protocol_schema/
rpc.rs1use std::sync::Arc;
7
8use derive_more::{Display, From};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use serde_with::skip_serializing_none;
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
45 Null,
46 Number(i64),
48 Str(String),
50}
51
52#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
54#[allow(
55 clippy::exhaustive_structs,
56 reason = "This comes from the JSON-RPC specification itself"
57)]
58#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
59#[skip_serializing_none]
60pub struct Request<Params> {
61 pub id: RequestId,
63 pub method: Arc<str>,
65 pub params: Option<Params>,
67}
68
69#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
71#[allow(
72 clippy::exhaustive_enums,
73 reason = "This comes from the JSON-RPC specification itself"
74)]
75#[serde(untagged)]
76#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
77pub enum Response<Result, Error> {
78 Result {
80 id: RequestId,
82 result: Result,
84 },
85 Error {
87 id: RequestId,
89 error: Error,
91 },
92}
93
94impl<R, E> Response<R, E> {
95 #[must_use]
97 pub fn new(id: impl Into<RequestId>, result: std::result::Result<R, E>) -> Self {
98 match result {
99 Ok(result) => Self::Result {
100 id: id.into(),
101 result,
102 },
103 Err(error) => Self::Error {
104 id: id.into(),
105 error,
106 },
107 }
108 }
109}
110
111#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
113#[allow(
114 clippy::exhaustive_structs,
115 reason = "This comes from the JSON-RPC specification itself"
116)]
117#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
118#[skip_serializing_none]
119pub struct Notification<Params> {
120 pub method: Arc<str>,
122 pub params: Option<Params>,
124}
125
126#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128enum JsonRpcVersion {
129 #[serde(rename = "2.0")]
130 V2,
131}
132
133#[derive(Debug, Serialize, Deserialize, JsonSchema)]
138#[schemars(inline)]
139pub struct JsonRpcMessage<M> {
140 jsonrpc: JsonRpcVersion,
141 #[serde(flatten)]
142 message: M,
143}
144
145impl<M> JsonRpcMessage<M> {
146 #[must_use]
148 pub fn wrap(message: M) -> Self {
149 Self {
150 jsonrpc: JsonRpcVersion::V2,
151 message,
152 }
153 }
154
155 #[must_use]
157 pub fn inner(&self) -> &M {
158 &self.message
159 }
160
161 #[must_use]
163 pub fn into_inner(self) -> M {
164 self.message
165 }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
170#[display("JSON-RPC batch must contain at least one message")]
171#[non_exhaustive]
172pub struct EmptyJsonRpcBatch;
173
174impl std::error::Error for EmptyJsonRpcBatch {}
175
176#[derive(Debug, Serialize, JsonSchema)]
178#[schemars(inline)]
179#[serde(transparent)]
180#[allow(
181 clippy::exhaustive_structs,
182 reason = "This comes from the JSON-RPC specification itself"
183)]
184pub struct JsonRpcBatch<M>(#[schemars(length(min = 1))] Vec<JsonRpcMessage<M>>);
185
186impl<M> JsonRpcBatch<M> {
187 pub fn new(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, EmptyJsonRpcBatch> {
196 if messages.is_empty() {
197 Err(EmptyJsonRpcBatch)
198 } else {
199 Ok(Self(messages))
200 }
201 }
202
203 #[must_use]
205 pub fn as_slice(&self) -> &[JsonRpcMessage<M>] {
206 &self.0
207 }
208
209 #[must_use]
211 pub fn into_vec(self) -> Vec<JsonRpcMessage<M>> {
212 self.0
213 }
214}
215
216impl<M> TryFrom<Vec<JsonRpcMessage<M>>> for JsonRpcBatch<M> {
217 type Error = EmptyJsonRpcBatch;
218
219 fn try_from(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, Self::Error> {
220 Self::new(messages)
221 }
222}
223
224impl<'de, M> Deserialize<'de> for JsonRpcBatch<M>
225where
226 M: Deserialize<'de>,
227{
228 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229 where
230 D: serde::Deserializer<'de>,
231 {
232 let messages = Vec::<JsonRpcMessage<M>>::deserialize(deserializer)?;
233 Self::new(messages).map_err(serde::de::Error::custom)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 use crate::{
242 AgentNotification, CancelNotification, ClientNotification, ContentBlock, ContentChunk,
243 SessionId, SessionNotification, SessionUpdate, TextContent,
244 };
245 use serde_json::{Number, Value, json};
246
247 #[test]
248 fn id_deserialization() {
249 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
250 assert_eq!(id, RequestId::Null);
251
252 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
253 .unwrap();
254 assert_eq!(id, RequestId::Number(1));
255
256 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
257 .unwrap();
258 assert_eq!(id, RequestId::Number(-1));
259
260 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
261 assert_eq!(id, RequestId::Str("id".to_owned()));
262 }
263
264 #[test]
265 fn id_serialization() {
266 let id = serde_json::to_value(RequestId::Null).unwrap();
267 assert_eq!(id, Value::Null);
268
269 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
270 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
271
272 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
273 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
274
275 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
276 assert_eq!(id, Value::String("id".to_owned()));
277 }
278
279 #[test]
280 fn id_display() {
281 let id = RequestId::Null;
282 assert_eq!(id.to_string(), "null");
283
284 let id = RequestId::Number(1);
285 assert_eq!(id.to_string(), "1");
286
287 let id = RequestId::Number(-1);
288 assert_eq!(id.to_string(), "-1");
289
290 let id = RequestId::Str("id".to_owned());
291 assert_eq!(id.to_string(), "id");
292 }
293
294 #[test]
295 fn batch_deserialization_requires_at_least_one_message() {
296 let err = serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(
297 Value::Array(Vec::new()),
298 )
299 .unwrap_err();
300 assert!(err.to_string().contains("at least one message"));
301 }
302
303 #[test]
304 fn batch_serialization_round_trips_non_empty_messages() {
305 let notification = JsonRpcMessage::wrap(Notification {
306 method: "cancel".into(),
307 params: Some(ClientNotification::CancelNotification(CancelNotification {
308 session_id: SessionId("test-123".into()),
309 meta: None,
310 })),
311 });
312
313 let batch = JsonRpcBatch::new(vec![notification]).unwrap();
314 let serialized = serde_json::to_value(&batch).unwrap();
315 assert_eq!(
316 serialized,
317 json!([{
318 "jsonrpc": "2.0",
319 "method": "cancel",
320 "params": {
321 "sessionId": "test-123"
322 },
323 }])
324 );
325
326 let deserialized =
327 serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(serialized)
328 .unwrap();
329 assert_eq!(deserialized.as_slice().len(), 1);
330 assert_eq!(deserialized.as_slice()[0].inner().method.as_ref(), "cancel");
331 }
332
333 #[test]
334 fn notification_wire_format() {
335 let outgoing_msg = JsonRpcMessage::wrap(Notification {
337 method: "cancel".into(),
338 params: Some(ClientNotification::CancelNotification(CancelNotification {
339 session_id: SessionId("test-123".into()),
340 meta: None,
341 })),
342 });
343
344 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
345 assert_eq!(
346 serialized,
347 json!({
348 "jsonrpc": "2.0",
349 "method": "cancel",
350 "params": {
351 "sessionId": "test-123"
352 },
353 })
354 );
355
356 let outgoing_msg = JsonRpcMessage::wrap(Notification {
358 method: "sessionUpdate".into(),
359 params: Some(AgentNotification::SessionNotification(
360 SessionNotification {
361 session_id: SessionId("test-456".into()),
362 update: SessionUpdate::AgentMessageChunk(ContentChunk {
363 content: ContentBlock::Text(TextContent {
364 annotations: None,
365 text: "Hello".to_string(),
366 meta: None,
367 }),
368 message_id: None,
369 meta: None,
370 }),
371 meta: None,
372 },
373 )),
374 });
375
376 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
377 assert_eq!(
378 serialized,
379 json!({
380 "jsonrpc": "2.0",
381 "method": "sessionUpdate",
382 "params": {
383 "sessionId": "test-456",
384 "update": {
385 "sessionUpdate": "agent_message_chunk",
386 "content": {
387 "type": "text",
388 "text": "Hello"
389 }
390 }
391 }
392 })
393 );
394 }
395}