agent_client_protocol_schema/
rpc.rs1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(untagged)]
26#[allow(
27 clippy::exhaustive_enums,
28 reason = "This comes from the JSON-RPC specification itself"
29)]
30pub enum RequestId {
31 #[display("null")]
32 Null,
33 Number(i64),
34 Str(String),
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
38#[allow(
39 clippy::exhaustive_structs,
40 reason = "This comes from the JSON-RPC specification itself"
41)]
42#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
43pub struct Request<Params> {
44 pub id: RequestId,
45 pub method: Arc<str>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub params: Option<Params>,
48}
49
50#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
51#[allow(
52 clippy::exhaustive_enums,
53 reason = "This comes from the JSON-RPC specification itself"
54)]
55#[serde(untagged)]
56#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
57pub enum Response<Result> {
58 Result { id: RequestId, result: Result },
59 Error { id: RequestId, error: Error },
60}
61
62impl<R> Response<R> {
63 pub fn new(id: RequestId, result: Result<R>) -> Self {
64 match result {
65 Ok(result) => Self::Result { id, result },
66 Err(error) => Self::Error { id, error },
67 }
68 }
69}
70
71#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
72#[allow(
73 clippy::exhaustive_structs,
74 reason = "This comes from the JSON-RPC specification itself"
75)]
76#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
77pub struct Notification<Params> {
78 pub method: Arc<str>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub params: Option<Params>,
81}
82
83#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
84#[serde(untagged)]
85#[schemars(inline)]
86#[allow(
87 clippy::exhaustive_enums,
88 reason = "This comes from the JSON-RPC specification itself"
89)]
90pub enum OutgoingMessage<Local: Side, Remote: Side> {
91 Request(Request<Remote::InRequest>),
92 Response(Response<Local::OutResponse>),
93 Notification(Notification<Remote::InNotification>),
94}
95
96#[derive(Debug, Serialize, Deserialize, JsonSchema)]
97#[schemars(inline)]
98enum JsonRpcVersion {
99 #[serde(rename = "2.0")]
100 V2,
101}
102
103#[derive(Debug, Serialize, Deserialize, JsonSchema)]
108#[schemars(inline)]
109pub struct JsonRpcMessage<M> {
110 jsonrpc: JsonRpcVersion,
111 #[serde(flatten)]
112 message: M,
113}
114
115impl<M> JsonRpcMessage<M> {
116 #[must_use]
119 pub fn wrap(message: M) -> Self {
120 Self {
121 jsonrpc: JsonRpcVersion::V2,
122 message,
123 }
124 }
125}
126
127pub trait Side: Clone {
128 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
129 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
130 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
131
132 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
140
141 fn decode_notification(method: &str, params: Option<&RawValue>)
149 -> Result<Self::InNotification>;
150}
151
152#[derive(Clone, Default, Debug, JsonSchema)]
159#[non_exhaustive]
160pub struct ClientSide;
161
162impl Side for ClientSide {
163 type InRequest = AgentRequest;
164 type InNotification = AgentNotification;
165 type OutResponse = ClientResponse;
166
167 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
168 let params = params.ok_or_else(Error::invalid_params)?;
169
170 match method {
171 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
172 serde_json::from_str(params.get())
173 .map(AgentRequest::RequestPermissionRequest)
174 .map_err(Into::into)
175 }
176 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
177 .map(AgentRequest::WriteTextFileRequest)
178 .map_err(Into::into),
179 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
180 .map(AgentRequest::ReadTextFileRequest)
181 .map_err(Into::into),
182 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
183 .map(AgentRequest::CreateTerminalRequest)
184 .map_err(Into::into),
185 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
186 .map(AgentRequest::TerminalOutputRequest)
187 .map_err(Into::into),
188 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
189 .map(AgentRequest::KillTerminalCommandRequest)
190 .map_err(Into::into),
191 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
192 .map(AgentRequest::ReleaseTerminalRequest)
193 .map_err(Into::into),
194 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
195 serde_json::from_str(params.get())
196 .map(AgentRequest::WaitForTerminalExitRequest)
197 .map_err(Into::into)
198 }
199 _ => {
200 if let Some(custom_method) = method.strip_prefix('_') {
201 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
202 method: custom_method.into(),
203 params: params.to_owned().into(),
204 }))
205 } else {
206 Err(Error::method_not_found())
207 }
208 }
209 }
210 }
211
212 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
213 let params = params.ok_or_else(Error::invalid_params)?;
214
215 match method {
216 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
217 .map(AgentNotification::SessionNotification)
218 .map_err(Into::into),
219 _ => {
220 if let Some(custom_method) = method.strip_prefix('_') {
221 Ok(AgentNotification::ExtNotification(ExtNotification {
222 method: custom_method.into(),
223 params: RawValue::from_string(params.get().to_string())?.into(),
224 }))
225 } else {
226 Err(Error::method_not_found())
227 }
228 }
229 }
230 }
231}
232
233#[derive(Clone, Default, Debug, JsonSchema)]
240#[non_exhaustive]
241pub struct AgentSide;
242
243impl Side for AgentSide {
244 type InRequest = ClientRequest;
245 type InNotification = ClientNotification;
246 type OutResponse = AgentResponse;
247
248 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
249 let params = params.ok_or_else(Error::invalid_params)?;
250
251 match method {
252 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
253 .map(ClientRequest::InitializeRequest)
254 .map_err(Into::into),
255 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
256 .map(ClientRequest::AuthenticateRequest)
257 .map_err(Into::into),
258 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
259 .map(ClientRequest::NewSessionRequest)
260 .map_err(Into::into),
261 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
262 .map(ClientRequest::LoadSessionRequest)
263 .map_err(Into::into),
264 #[cfg(feature = "unstable_session_list")]
265 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
266 .map(ClientRequest::ListSessionsRequest)
267 .map_err(Into::into),
268 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
269 .map(ClientRequest::SetSessionModeRequest)
270 .map_err(Into::into),
271 #[cfg(feature = "unstable_session_model")]
272 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
273 .map(ClientRequest::SetSessionModelRequest)
274 .map_err(Into::into),
275 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
276 .map(ClientRequest::PromptRequest)
277 .map_err(Into::into),
278 _ => {
279 if let Some(custom_method) = method.strip_prefix('_') {
280 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
281 method: custom_method.into(),
282 params: params.to_owned().into(),
283 }))
284 } else {
285 Err(Error::method_not_found())
286 }
287 }
288 }
289 }
290
291 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
292 let params = params.ok_or_else(Error::invalid_params)?;
293
294 match method {
295 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
296 .map(ClientNotification::CancelNotification)
297 .map_err(Into::into),
298 _ => {
299 if let Some(custom_method) = method.strip_prefix('_') {
300 Ok(ClientNotification::ExtNotification(ExtNotification {
301 method: custom_method.into(),
302 params: RawValue::from_string(params.get().to_string())?.into(),
303 }))
304 } else {
305 Err(Error::method_not_found())
306 }
307 }
308 }
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 use serde_json::{Number, Value};
317
318 #[test]
319 fn id_deserialization() {
320 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
321 assert_eq!(id, RequestId::Null);
322
323 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
324 .unwrap();
325 assert_eq!(id, RequestId::Number(1));
326
327 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
328 .unwrap();
329 assert_eq!(id, RequestId::Number(-1));
330
331 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
332 assert_eq!(id, RequestId::Str("id".to_owned()));
333 }
334
335 #[test]
336 fn id_serialization() {
337 let id = serde_json::to_value(RequestId::Null).unwrap();
338 assert_eq!(id, Value::Null);
339
340 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
341 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
342
343 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
344 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
345
346 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
347 assert_eq!(id, Value::String("id".to_owned()));
348 }
349
350 #[test]
351 fn id_display() {
352 let id = RequestId::Null;
353 assert_eq!(id.to_string(), "null");
354
355 let id = RequestId::Number(1);
356 assert_eq!(id.to_string(), "1");
357
358 let id = RequestId::Number(-1);
359 assert_eq!(id.to_string(), "-1");
360
361 let id = RequestId::Str("id".to_owned());
362 assert_eq!(id.to_string(), "id");
363 }
364}
365
366#[test]
367fn test_notification_wire_format() {
368 use super::*;
369
370 use serde_json::{Value, json};
371
372 let outgoing_msg = JsonRpcMessage::wrap(
374 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
375 method: "cancel".into(),
376 params: Some(ClientNotification::CancelNotification(CancelNotification {
377 session_id: SessionId("test-123".into()),
378 meta: None,
379 })),
380 }),
381 );
382
383 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
384 assert_eq!(
385 serialized,
386 json!({
387 "jsonrpc": "2.0",
388 "method": "cancel",
389 "params": {
390 "sessionId": "test-123"
391 },
392 })
393 );
394
395 let outgoing_msg = JsonRpcMessage::wrap(
397 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
398 method: "sessionUpdate".into(),
399 params: Some(AgentNotification::SessionNotification(
400 SessionNotification {
401 session_id: SessionId("test-456".into()),
402 update: SessionUpdate::AgentMessageChunk(ContentChunk {
403 content: ContentBlock::Text(TextContent {
404 annotations: None,
405 text: "Hello".to_string(),
406 meta: None,
407 }),
408 meta: None,
409 }),
410 meta: None,
411 },
412 )),
413 }),
414 );
415
416 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
417 assert_eq!(
418 serialized,
419 json!({
420 "jsonrpc": "2.0",
421 "method": "sessionUpdate",
422 "params": {
423 "sessionId": "test-456",
424 "update": {
425 "sessionUpdate": "agent_message_chunk",
426 "content": {
427 "type": "text",
428 "text": "Hello"
429 }
430 }
431 }
432 })
433 );
434}