agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(deny_unknown_fields)]
26#[serde(untagged)]
27#[schemars(inline)]
28pub enum RequestId {
29    #[display("null")]
30    #[schemars(title = "null")]
31    Null,
32    #[schemars(title = "number")]
33    Number(i64),
34    #[schemars(title = "string")]
35    Str(String),
36}
37
38#[derive(Serialize, Deserialize, Clone, JsonSchema)]
39#[serde(untagged)]
40#[schemars(inline)]
41pub enum OutgoingMessage<Local: Side, Remote: Side> {
42    Request {
43        id: RequestId,
44        method: Arc<str>,
45        #[serde(skip_serializing_if = "Option::is_none")]
46        params: Option<Remote::InRequest>,
47    },
48    Response {
49        id: RequestId,
50        #[serde(flatten)]
51        result: ResponseResult<Local::OutResponse>,
52    },
53    Notification {
54        method: Arc<str>,
55        #[serde(skip_serializing_if = "Option::is_none")]
56        params: Option<Remote::InNotification>,
57    },
58}
59
60#[derive(Debug, Serialize, Deserialize, JsonSchema)]
61#[schemars(inline)]
62enum JsonRpcVersion {
63    #[serde(rename = "2.0")]
64    V2,
65}
66
67/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
68/// [required by JSON-RPC 2.0 Specification][1].
69///
70/// [1]: https://www.jsonrpc.org/specification#compatibility
71#[derive(Debug, Serialize, Deserialize, JsonSchema)]
72#[schemars(inline)]
73pub struct JsonRpcMessage<M> {
74    jsonrpc: JsonRpcVersion,
75    #[serde(flatten)]
76    message: M,
77}
78
79impl<M> JsonRpcMessage<M> {
80    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
81    /// [`JsonRpcMessage`].
82    #[must_use]
83    pub fn wrap(message: M) -> Self {
84        Self {
85            jsonrpc: JsonRpcVersion::V2,
86            message,
87        }
88    }
89}
90
91#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
92#[serde(rename_all = "snake_case")]
93pub enum ResponseResult<Res> {
94    Result(Res),
95    Error(Error),
96}
97
98impl<T> From<Result<T>> for ResponseResult<T> {
99    fn from(result: Result<T>) -> Self {
100        match result {
101            Ok(value) => ResponseResult::Result(value),
102            Err(error) => ResponseResult::Error(error),
103        }
104    }
105}
106
107pub trait Side: Clone {
108    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
109    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
110    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
111
112    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
113
114    fn decode_notification(method: &str, params: Option<&RawValue>)
115    -> Result<Self::InNotification>;
116}
117
118/// Marker type representing the client side of an ACP connection.
119///
120/// This type is used by the RPC layer to determine which messages
121/// are incoming vs outgoing from the client's perspective.
122///
123/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
124#[derive(Clone, JsonSchema)]
125pub struct ClientSide;
126
127impl Side for ClientSide {
128    type InRequest = AgentRequest;
129    type InNotification = AgentNotification;
130    type OutResponse = ClientResponse;
131
132    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
133        let params = params.ok_or_else(Error::invalid_params)?;
134
135        match method {
136            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
137                serde_json::from_str(params.get())
138                    .map(AgentRequest::RequestPermissionRequest)
139                    .map_err(Into::into)
140            }
141            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
142                .map(AgentRequest::WriteTextFileRequest)
143                .map_err(Into::into),
144            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
145                .map(AgentRequest::ReadTextFileRequest)
146                .map_err(Into::into),
147            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
148                .map(AgentRequest::CreateTerminalRequest)
149                .map_err(Into::into),
150            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
151                .map(AgentRequest::TerminalOutputRequest)
152                .map_err(Into::into),
153            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
154                .map(AgentRequest::KillTerminalCommandRequest)
155                .map_err(Into::into),
156            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
157                .map(AgentRequest::ReleaseTerminalRequest)
158                .map_err(Into::into),
159            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
160                serde_json::from_str(params.get())
161                    .map(AgentRequest::WaitForTerminalExitRequest)
162                    .map_err(Into::into)
163            }
164            _ => {
165                if let Some(custom_method) = method.strip_prefix('_') {
166                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
167                        method: custom_method.into(),
168                        params: params.to_owned().into(),
169                    }))
170                } else {
171                    Err(Error::method_not_found())
172                }
173            }
174        }
175    }
176
177    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
178        let params = params.ok_or_else(Error::invalid_params)?;
179
180        match method {
181            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
182                .map(AgentNotification::SessionNotification)
183                .map_err(Into::into),
184            _ => {
185                if let Some(custom_method) = method.strip_prefix('_') {
186                    Ok(AgentNotification::ExtNotification(ExtNotification {
187                        method: custom_method.into(),
188                        params: RawValue::from_string(params.get().to_string())?.into(),
189                    }))
190                } else {
191                    Err(Error::method_not_found())
192                }
193            }
194        }
195    }
196}
197
198/// Marker type representing the agent side of an ACP connection.
199///
200/// This type is used by the RPC layer to determine which messages
201/// are incoming vs outgoing from the agent's perspective.
202///
203/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
204#[derive(Clone, JsonSchema)]
205pub struct AgentSide;
206
207impl Side for AgentSide {
208    type InRequest = ClientRequest;
209    type InNotification = ClientNotification;
210    type OutResponse = AgentResponse;
211
212    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
213        let params = params.ok_or_else(Error::invalid_params)?;
214
215        match method {
216            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
217                .map(ClientRequest::InitializeRequest)
218                .map_err(Into::into),
219            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
220                .map(ClientRequest::AuthenticateRequest)
221                .map_err(Into::into),
222            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
223                .map(ClientRequest::NewSessionRequest)
224                .map_err(Into::into),
225            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
226                .map(ClientRequest::LoadSessionRequest)
227                .map_err(Into::into),
228            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
229                .map(ClientRequest::SetSessionModeRequest)
230                .map_err(Into::into),
231            #[cfg(feature = "unstable")]
232            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
233                .map(ClientRequest::SetSessionModelRequest)
234                .map_err(Into::into),
235            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
236                .map(ClientRequest::PromptRequest)
237                .map_err(Into::into),
238            _ => {
239                if let Some(custom_method) = method.strip_prefix('_') {
240                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
241                        method: custom_method.into(),
242                        params: params.to_owned().into(),
243                    }))
244                } else {
245                    Err(Error::method_not_found())
246                }
247            }
248        }
249    }
250
251    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
252        let params = params.ok_or_else(Error::invalid_params)?;
253
254        match method {
255            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
256                .map(ClientNotification::CancelNotification)
257                .map_err(Into::into),
258            _ => {
259                if let Some(custom_method) = method.strip_prefix('_') {
260                    Ok(ClientNotification::ExtNotification(ExtNotification {
261                        method: custom_method.into(),
262                        params: RawValue::from_string(params.get().to_string())?.into(),
263                    }))
264                } else {
265                    Err(Error::method_not_found())
266                }
267            }
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    use serde_json::{Number, Value};
277
278    #[test]
279    fn id_deserialization() {
280        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
281        assert_eq!(id, RequestId::Null);
282
283        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
284            .unwrap();
285        assert_eq!(id, RequestId::Number(1));
286
287        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
288            .unwrap();
289        assert_eq!(id, RequestId::Number(-1));
290
291        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
292        assert_eq!(id, RequestId::Str("id".to_owned()));
293    }
294
295    #[test]
296    fn id_serialization() {
297        let id = serde_json::to_value(RequestId::Null).unwrap();
298        assert_eq!(id, Value::Null);
299
300        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
301        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
302
303        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
304        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
305
306        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
307        assert_eq!(id, Value::String("id".to_owned()));
308    }
309
310    #[test]
311    fn id_display() {
312        let id = RequestId::Null;
313        assert_eq!(id.to_string(), "null");
314
315        let id = RequestId::Number(1);
316        assert_eq!(id.to_string(), "1");
317
318        let id = RequestId::Number(-1);
319        assert_eq!(id.to_string(), "-1");
320
321        let id = RequestId::Str("id".to_owned());
322        assert_eq!(id.to_string(), "id");
323    }
324}
325
326#[test]
327fn test_notification_wire_format() {
328    use super::*;
329
330    use serde_json::{Value, json};
331
332    // Test client -> agent notification wire format
333    let outgoing_msg =
334        JsonRpcMessage::wrap(OutgoingMessage::<ClientSide, AgentSide>::Notification {
335            method: "cancel".into(),
336            params: Some(ClientNotification::CancelNotification(CancelNotification {
337                session_id: SessionId("test-123".into()),
338                meta: None,
339            })),
340        });
341
342    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
343    assert_eq!(
344        serialized,
345        json!({
346            "jsonrpc": "2.0",
347            "method": "cancel",
348            "params": {
349                "sessionId": "test-123"
350            },
351        })
352    );
353
354    // Test agent -> client notification wire format
355    let outgoing_msg =
356        JsonRpcMessage::wrap(OutgoingMessage::<AgentSide, ClientSide>::Notification {
357            method: "sessionUpdate".into(),
358            params: Some(AgentNotification::SessionNotification(
359                SessionNotification {
360                    session_id: SessionId("test-456".into()),
361                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
362                        content: ContentBlock::Text(TextContent {
363                            annotations: None,
364                            text: "Hello".to_string(),
365                            meta: None,
366                        }),
367                        meta: None,
368                    }),
369                    meta: None,
370                },
371            )),
372        });
373
374    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
375    assert_eq!(
376        serialized,
377        json!({
378            "jsonrpc": "2.0",
379            "method": "sessionUpdate",
380            "params": {
381                "sessionId": "test-456",
382                "update": {
383                    "sessionUpdate": "agent_message_chunk",
384                    "content": {
385                        "type": "text",
386                        "text": "Hello"
387                    }
388                }
389            }
390        })
391    );
392}