agent_client_protocol_schema/
rpc.rs1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(deny_unknown_fields)]
26#[serde(untagged)]
27#[schemars(inline)]
28pub enum RequestId {
29 #[display("null")]
30 #[schemars(title = "null")]
31 Null,
32 #[schemars(title = "number")]
33 Number(i64),
34 #[schemars(title = "string")]
35 Str(String),
36}
37
38#[derive(Serialize, Deserialize, Clone, JsonSchema)]
39#[serde(untagged)]
40#[schemars(inline)]
41pub enum OutgoingMessage<Local: Side, Remote: Side> {
42 Request {
43 id: RequestId,
44 method: Arc<str>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 params: Option<Remote::InRequest>,
47 },
48 Response {
49 id: RequestId,
50 #[serde(flatten)]
51 result: ResponseResult<Local::OutResponse>,
52 },
53 Notification {
54 method: Arc<str>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 params: Option<Remote::InNotification>,
57 },
58}
59
60#[derive(Debug, Serialize, Deserialize, JsonSchema)]
61#[schemars(inline)]
62enum JsonRpcVersion {
63 #[serde(rename = "2.0")]
64 V2,
65}
66
67#[derive(Debug, Serialize, Deserialize, JsonSchema)]
72#[schemars(inline)]
73pub struct JsonRpcMessage<M> {
74 jsonrpc: JsonRpcVersion,
75 #[serde(flatten)]
76 message: M,
77}
78
79impl<M> JsonRpcMessage<M> {
80 #[must_use]
83 pub fn wrap(message: M) -> Self {
84 Self {
85 jsonrpc: JsonRpcVersion::V2,
86 message,
87 }
88 }
89}
90
91#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
92#[serde(rename_all = "snake_case")]
93pub enum ResponseResult<Res> {
94 Result(Res),
95 Error(Error),
96}
97
98impl<T> From<Result<T>> for ResponseResult<T> {
99 fn from(result: Result<T>) -> Self {
100 match result {
101 Ok(value) => ResponseResult::Result(value),
102 Err(error) => ResponseResult::Error(error),
103 }
104 }
105}
106
107pub trait Side: Clone {
108 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
109 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
110 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
111
112 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
113
114 fn decode_notification(method: &str, params: Option<&RawValue>)
115 -> Result<Self::InNotification>;
116}
117
118#[derive(Clone, JsonSchema)]
125pub struct ClientSide;
126
127impl Side for ClientSide {
128 type InRequest = AgentRequest;
129 type InNotification = AgentNotification;
130 type OutResponse = ClientResponse;
131
132 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
133 let params = params.ok_or_else(Error::invalid_params)?;
134
135 match method {
136 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
137 serde_json::from_str(params.get())
138 .map(AgentRequest::RequestPermissionRequest)
139 .map_err(Into::into)
140 }
141 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
142 .map(AgentRequest::WriteTextFileRequest)
143 .map_err(Into::into),
144 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
145 .map(AgentRequest::ReadTextFileRequest)
146 .map_err(Into::into),
147 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
148 .map(AgentRequest::CreateTerminalRequest)
149 .map_err(Into::into),
150 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
151 .map(AgentRequest::TerminalOutputRequest)
152 .map_err(Into::into),
153 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
154 .map(AgentRequest::KillTerminalCommandRequest)
155 .map_err(Into::into),
156 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
157 .map(AgentRequest::ReleaseTerminalRequest)
158 .map_err(Into::into),
159 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
160 serde_json::from_str(params.get())
161 .map(AgentRequest::WaitForTerminalExitRequest)
162 .map_err(Into::into)
163 }
164 _ => {
165 if let Some(custom_method) = method.strip_prefix('_') {
166 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
167 method: custom_method.into(),
168 params: params.to_owned().into(),
169 }))
170 } else {
171 Err(Error::method_not_found())
172 }
173 }
174 }
175 }
176
177 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
178 let params = params.ok_or_else(Error::invalid_params)?;
179
180 match method {
181 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
182 .map(AgentNotification::SessionNotification)
183 .map_err(Into::into),
184 _ => {
185 if let Some(custom_method) = method.strip_prefix('_') {
186 Ok(AgentNotification::ExtNotification(ExtNotification {
187 method: custom_method.into(),
188 params: RawValue::from_string(params.get().to_string())?.into(),
189 }))
190 } else {
191 Err(Error::method_not_found())
192 }
193 }
194 }
195 }
196}
197
198#[derive(Clone, JsonSchema)]
205pub struct AgentSide;
206
207impl Side for AgentSide {
208 type InRequest = ClientRequest;
209 type InNotification = ClientNotification;
210 type OutResponse = AgentResponse;
211
212 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
213 let params = params.ok_or_else(Error::invalid_params)?;
214
215 match method {
216 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
217 .map(ClientRequest::InitializeRequest)
218 .map_err(Into::into),
219 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
220 .map(ClientRequest::AuthenticateRequest)
221 .map_err(Into::into),
222 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
223 .map(ClientRequest::NewSessionRequest)
224 .map_err(Into::into),
225 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
226 .map(ClientRequest::LoadSessionRequest)
227 .map_err(Into::into),
228 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
229 .map(ClientRequest::SetSessionModeRequest)
230 .map_err(Into::into),
231 #[cfg(feature = "unstable")]
232 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
233 .map(ClientRequest::SetSessionModelRequest)
234 .map_err(Into::into),
235 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
236 .map(ClientRequest::PromptRequest)
237 .map_err(Into::into),
238 _ => {
239 if let Some(custom_method) = method.strip_prefix('_') {
240 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
241 method: custom_method.into(),
242 params: params.to_owned().into(),
243 }))
244 } else {
245 Err(Error::method_not_found())
246 }
247 }
248 }
249 }
250
251 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
252 let params = params.ok_or_else(Error::invalid_params)?;
253
254 match method {
255 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
256 .map(ClientNotification::CancelNotification)
257 .map_err(Into::into),
258 _ => {
259 if let Some(custom_method) = method.strip_prefix('_') {
260 Ok(ClientNotification::ExtNotification(ExtNotification {
261 method: custom_method.into(),
262 params: RawValue::from_string(params.get().to_string())?.into(),
263 }))
264 } else {
265 Err(Error::method_not_found())
266 }
267 }
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 use serde_json::{Number, Value};
277
278 #[test]
279 fn id_deserialization() {
280 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
281 assert_eq!(id, RequestId::Null);
282
283 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
284 .unwrap();
285 assert_eq!(id, RequestId::Number(1));
286
287 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
288 .unwrap();
289 assert_eq!(id, RequestId::Number(-1));
290
291 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
292 assert_eq!(id, RequestId::Str("id".to_owned()));
293 }
294
295 #[test]
296 fn id_serialization() {
297 let id = serde_json::to_value(RequestId::Null).unwrap();
298 assert_eq!(id, Value::Null);
299
300 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
301 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
302
303 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
304 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
305
306 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
307 assert_eq!(id, Value::String("id".to_owned()));
308 }
309
310 #[test]
311 fn id_display() {
312 let id = RequestId::Null;
313 assert_eq!(id.to_string(), "null");
314
315 let id = RequestId::Number(1);
316 assert_eq!(id.to_string(), "1");
317
318 let id = RequestId::Number(-1);
319 assert_eq!(id.to_string(), "-1");
320
321 let id = RequestId::Str("id".to_owned());
322 assert_eq!(id.to_string(), "id");
323 }
324}
325
326#[test]
327fn test_notification_wire_format() {
328 use super::*;
329
330 use serde_json::{Value, json};
331
332 let outgoing_msg =
334 JsonRpcMessage::wrap(OutgoingMessage::<ClientSide, AgentSide>::Notification {
335 method: "cancel".into(),
336 params: Some(ClientNotification::CancelNotification(CancelNotification {
337 session_id: SessionId("test-123".into()),
338 meta: None,
339 })),
340 });
341
342 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
343 assert_eq!(
344 serialized,
345 json!({
346 "jsonrpc": "2.0",
347 "method": "cancel",
348 "params": {
349 "sessionId": "test-123"
350 },
351 })
352 );
353
354 let outgoing_msg =
356 JsonRpcMessage::wrap(OutgoingMessage::<AgentSide, ClientSide>::Notification {
357 method: "sessionUpdate".into(),
358 params: Some(AgentNotification::SessionNotification(
359 SessionNotification {
360 session_id: SessionId("test-456".into()),
361 update: SessionUpdate::AgentMessageChunk(ContentChunk {
362 content: ContentBlock::Text(TextContent {
363 annotations: None,
364 text: "Hello".to_string(),
365 meta: None,
366 }),
367 meta: None,
368 }),
369 meta: None,
370 },
371 )),
372 });
373
374 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
375 assert_eq!(
376 serialized,
377 json!({
378 "jsonrpc": "2.0",
379 "method": "sessionUpdate",
380 "params": {
381 "sessionId": "test-456",
382 "update": {
383 "sessionUpdate": "agent_message_chunk",
384 "content": {
385 "type": "text",
386 "text": "Hello"
387 }
388 }
389 }
390 })
391 );
392}