agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(untagged)]
26#[schemars(inline)]
27#[allow(
28    clippy::exhaustive_enums,
29    reason = "This comes from the JSON-RPC specification itself"
30)]
31pub enum RequestId {
32    #[display("null")]
33    Null,
34    Number(i64),
35    Str(String),
36}
37
38#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
39#[serde(untagged)]
40#[schemars(inline)]
41#[non_exhaustive]
42pub enum OutgoingMessage<Local: Side, Remote: Side> {
43    Request {
44        id: RequestId,
45        method: Arc<str>,
46        #[serde(skip_serializing_if = "Option::is_none")]
47        params: Option<Remote::InRequest>,
48    },
49    Response {
50        id: RequestId,
51        #[serde(flatten)]
52        result: ResponseResult<Local::OutResponse>,
53    },
54    Notification {
55        method: Arc<str>,
56        #[serde(skip_serializing_if = "Option::is_none")]
57        params: Option<Remote::InNotification>,
58    },
59}
60
61#[derive(Debug, Serialize, Deserialize, JsonSchema)]
62#[schemars(inline)]
63enum JsonRpcVersion {
64    #[serde(rename = "2.0")]
65    V2,
66}
67
68/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
69/// [required by JSON-RPC 2.0 Specification][1].
70///
71/// [1]: https://www.jsonrpc.org/specification#compatibility
72#[derive(Debug, Serialize, Deserialize, JsonSchema)]
73#[schemars(inline)]
74pub struct JsonRpcMessage<M> {
75    jsonrpc: JsonRpcVersion,
76    #[serde(flatten)]
77    message: M,
78}
79
80impl<M> JsonRpcMessage<M> {
81    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
82    /// [`JsonRpcMessage`].
83    #[must_use]
84    pub fn wrap(message: M) -> Self {
85        Self {
86            jsonrpc: JsonRpcVersion::V2,
87            message,
88        }
89    }
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
93#[serde(rename_all = "snake_case")]
94#[non_exhaustive]
95pub enum ResponseResult<Res> {
96    Result(Res),
97    Error(Error),
98}
99
100impl<T> From<Result<T>> for ResponseResult<T> {
101    fn from(result: Result<T>) -> Self {
102        match result {
103            Ok(value) => ResponseResult::Result(value),
104            Err(error) => ResponseResult::Error(error),
105        }
106    }
107}
108
109pub trait Side: Clone {
110    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
111    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
112    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
113
114    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
115    /// serialization struct to use for each method.
116    ///
117    /// # Errors
118    ///
119    /// This function will return an error if the method is not recognized or if the parameters
120    /// cannot be deserialized into the expected type.
121    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
122
123    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
124    /// serialization struct to use for each method.
125    ///
126    /// # Errors
127    ///
128    /// This function will return an error if the method is not recognized or if the parameters
129    /// cannot be deserialized into the expected type.
130    fn decode_notification(method: &str, params: Option<&RawValue>)
131    -> Result<Self::InNotification>;
132}
133
134/// Marker type representing the client side of an ACP connection.
135///
136/// This type is used by the RPC layer to determine which messages
137/// are incoming vs outgoing from the client's perspective.
138///
139/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
140#[derive(Clone, Default, Debug, JsonSchema)]
141#[non_exhaustive]
142pub struct ClientSide;
143
144impl Side for ClientSide {
145    type InRequest = AgentRequest;
146    type InNotification = AgentNotification;
147    type OutResponse = ClientResponse;
148
149    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
150        let params = params.ok_or_else(Error::invalid_params)?;
151
152        match method {
153            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
154                serde_json::from_str(params.get())
155                    .map(AgentRequest::RequestPermissionRequest)
156                    .map_err(Into::into)
157            }
158            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
159                .map(AgentRequest::WriteTextFileRequest)
160                .map_err(Into::into),
161            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
162                .map(AgentRequest::ReadTextFileRequest)
163                .map_err(Into::into),
164            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
165                .map(AgentRequest::CreateTerminalRequest)
166                .map_err(Into::into),
167            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
168                .map(AgentRequest::TerminalOutputRequest)
169                .map_err(Into::into),
170            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
171                .map(AgentRequest::KillTerminalCommandRequest)
172                .map_err(Into::into),
173            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
174                .map(AgentRequest::ReleaseTerminalRequest)
175                .map_err(Into::into),
176            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
177                serde_json::from_str(params.get())
178                    .map(AgentRequest::WaitForTerminalExitRequest)
179                    .map_err(Into::into)
180            }
181            _ => {
182                if let Some(custom_method) = method.strip_prefix('_') {
183                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
184                        method: custom_method.into(),
185                        params: params.to_owned().into(),
186                    }))
187                } else {
188                    Err(Error::method_not_found())
189                }
190            }
191        }
192    }
193
194    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
195        let params = params.ok_or_else(Error::invalid_params)?;
196
197        match method {
198            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
199                .map(AgentNotification::SessionNotification)
200                .map_err(Into::into),
201            _ => {
202                if let Some(custom_method) = method.strip_prefix('_') {
203                    Ok(AgentNotification::ExtNotification(ExtNotification {
204                        method: custom_method.into(),
205                        params: RawValue::from_string(params.get().to_string())?.into(),
206                    }))
207                } else {
208                    Err(Error::method_not_found())
209                }
210            }
211        }
212    }
213}
214
215/// Marker type representing the agent side of an ACP connection.
216///
217/// This type is used by the RPC layer to determine which messages
218/// are incoming vs outgoing from the agent's perspective.
219///
220/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
221#[derive(Clone, Default, Debug, JsonSchema)]
222#[non_exhaustive]
223pub struct AgentSide;
224
225impl Side for AgentSide {
226    type InRequest = ClientRequest;
227    type InNotification = ClientNotification;
228    type OutResponse = AgentResponse;
229
230    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
231        let params = params.ok_or_else(Error::invalid_params)?;
232
233        match method {
234            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
235                .map(ClientRequest::InitializeRequest)
236                .map_err(Into::into),
237            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
238                .map(ClientRequest::AuthenticateRequest)
239                .map_err(Into::into),
240            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
241                .map(ClientRequest::NewSessionRequest)
242                .map_err(Into::into),
243            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
244                .map(ClientRequest::LoadSessionRequest)
245                .map_err(Into::into),
246            #[cfg(feature = "unstable_session_list")]
247            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
248                .map(ClientRequest::ListSessionsRequest)
249                .map_err(Into::into),
250            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
251                .map(ClientRequest::SetSessionModeRequest)
252                .map_err(Into::into),
253            #[cfg(feature = "unstable_session_model")]
254            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
255                .map(ClientRequest::SetSessionModelRequest)
256                .map_err(Into::into),
257            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
258                .map(ClientRequest::PromptRequest)
259                .map_err(Into::into),
260            _ => {
261                if let Some(custom_method) = method.strip_prefix('_') {
262                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
263                        method: custom_method.into(),
264                        params: params.to_owned().into(),
265                    }))
266                } else {
267                    Err(Error::method_not_found())
268                }
269            }
270        }
271    }
272
273    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
274        let params = params.ok_or_else(Error::invalid_params)?;
275
276        match method {
277            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
278                .map(ClientNotification::CancelNotification)
279                .map_err(Into::into),
280            _ => {
281                if let Some(custom_method) = method.strip_prefix('_') {
282                    Ok(ClientNotification::ExtNotification(ExtNotification {
283                        method: custom_method.into(),
284                        params: RawValue::from_string(params.get().to_string())?.into(),
285                    }))
286                } else {
287                    Err(Error::method_not_found())
288                }
289            }
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    use serde_json::{Number, Value};
299
300    #[test]
301    fn id_deserialization() {
302        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
303        assert_eq!(id, RequestId::Null);
304
305        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
306            .unwrap();
307        assert_eq!(id, RequestId::Number(1));
308
309        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
310            .unwrap();
311        assert_eq!(id, RequestId::Number(-1));
312
313        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
314        assert_eq!(id, RequestId::Str("id".to_owned()));
315    }
316
317    #[test]
318    fn id_serialization() {
319        let id = serde_json::to_value(RequestId::Null).unwrap();
320        assert_eq!(id, Value::Null);
321
322        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
323        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
324
325        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
326        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
327
328        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
329        assert_eq!(id, Value::String("id".to_owned()));
330    }
331
332    #[test]
333    fn id_display() {
334        let id = RequestId::Null;
335        assert_eq!(id.to_string(), "null");
336
337        let id = RequestId::Number(1);
338        assert_eq!(id.to_string(), "1");
339
340        let id = RequestId::Number(-1);
341        assert_eq!(id.to_string(), "-1");
342
343        let id = RequestId::Str("id".to_owned());
344        assert_eq!(id.to_string(), "id");
345    }
346}
347
348#[test]
349fn test_notification_wire_format() {
350    use super::*;
351
352    use serde_json::{Value, json};
353
354    // Test client -> agent notification wire format
355    let outgoing_msg =
356        JsonRpcMessage::wrap(OutgoingMessage::<ClientSide, AgentSide>::Notification {
357            method: "cancel".into(),
358            params: Some(ClientNotification::CancelNotification(CancelNotification {
359                session_id: SessionId("test-123".into()),
360                meta: None,
361            })),
362        });
363
364    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
365    assert_eq!(
366        serialized,
367        json!({
368            "jsonrpc": "2.0",
369            "method": "cancel",
370            "params": {
371                "sessionId": "test-123"
372            },
373        })
374    );
375
376    // Test agent -> client notification wire format
377    let outgoing_msg =
378        JsonRpcMessage::wrap(OutgoingMessage::<AgentSide, ClientSide>::Notification {
379            method: "sessionUpdate".into(),
380            params: Some(AgentNotification::SessionNotification(
381                SessionNotification {
382                    session_id: SessionId("test-456".into()),
383                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
384                        content: ContentBlock::Text(TextContent {
385                            annotations: None,
386                            text: "Hello".to_string(),
387                            meta: None,
388                        }),
389                        meta: None,
390                    }),
391                    meta: None,
392                },
393            )),
394        });
395
396    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
397    assert_eq!(
398        serialized,
399        json!({
400            "jsonrpc": "2.0",
401            "method": "sessionUpdate",
402            "params": {
403                "sessionId": "test-456",
404                "update": {
405                    "sessionUpdate": "agent_message_chunk",
406                    "content": {
407                        "type": "text",
408                        "text": "Hello"
409                    }
410                }
411            }
412        })
413    );
414}