agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug,
24    PartialEq,
25    Clone,
26    Hash,
27    Eq,
28    Deserialize,
29    Serialize,
30    PartialOrd,
31    Ord,
32    Display,
33    JsonSchema,
34    From,
35)]
36#[serde(untagged)]
37#[allow(
38    clippy::exhaustive_enums,
39    reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43    #[display("null")]
44    Null,
45    Number(i64),
46    Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51    clippy::exhaustive_structs,
52    reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56    pub id: RequestId,
57    pub method: Arc<str>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64    clippy::exhaustive_enums,
65    reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70    Result { id: RequestId, result: Result },
71    Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75    pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
76        match result {
77            Ok(result) => Self::Result {
78                id: id.into(),
79                result,
80            },
81            Err(error) => Self::Error {
82                id: id.into(),
83                error,
84            },
85        }
86    }
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
90#[allow(
91    clippy::exhaustive_structs,
92    reason = "This comes from the JSON-RPC specification itself"
93)]
94#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
95pub struct Notification<Params> {
96    pub method: Arc<str>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub params: Option<Params>,
99}
100
101#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
102#[serde(untagged)]
103#[schemars(inline)]
104#[allow(
105    clippy::exhaustive_enums,
106    reason = "This comes from the JSON-RPC specification itself"
107)]
108pub enum OutgoingMessage<Local: Side, Remote: Side> {
109    Request(Request<Remote::InRequest>),
110    Response(Response<Local::OutResponse>),
111    Notification(Notification<Remote::InNotification>),
112}
113
114#[derive(Debug, Serialize, Deserialize, JsonSchema)]
115#[schemars(inline)]
116enum JsonRpcVersion {
117    #[serde(rename = "2.0")]
118    V2,
119}
120
121/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
122/// [required by JSON-RPC 2.0 Specification][1].
123///
124/// [1]: https://www.jsonrpc.org/specification#compatibility
125#[derive(Debug, Serialize, Deserialize, JsonSchema)]
126#[schemars(inline)]
127pub struct JsonRpcMessage<M> {
128    jsonrpc: JsonRpcVersion,
129    #[serde(flatten)]
130    message: M,
131}
132
133impl<M> JsonRpcMessage<M> {
134    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
135    /// [`JsonRpcMessage`].
136    #[must_use]
137    pub fn wrap(message: M) -> Self {
138        Self {
139            jsonrpc: JsonRpcVersion::V2,
140            message,
141        }
142    }
143}
144
145pub trait Side: Clone {
146    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
147    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149
150    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
151    /// serialization struct to use for each method.
152    ///
153    /// # Errors
154    ///
155    /// This function will return an error if the method is not recognized or if the parameters
156    /// cannot be deserialized into the expected type.
157    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
158
159    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
160    /// serialization struct to use for each method.
161    ///
162    /// # Errors
163    ///
164    /// This function will return an error if the method is not recognized or if the parameters
165    /// cannot be deserialized into the expected type.
166    fn decode_notification(method: &str, params: Option<&RawValue>)
167    -> Result<Self::InNotification>;
168}
169
170/// Marker type representing the client side of an ACP connection.
171///
172/// This type is used by the RPC layer to determine which messages
173/// are incoming vs outgoing from the client's perspective.
174///
175/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
176#[derive(Clone, Default, Debug, JsonSchema)]
177#[non_exhaustive]
178pub struct ClientSide;
179
180impl Side for ClientSide {
181    type InRequest = AgentRequest;
182    type InNotification = AgentNotification;
183    type OutResponse = ClientResponse;
184
185    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
186        let params = params.ok_or_else(Error::invalid_params)?;
187
188        match method {
189            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
190                serde_json::from_str(params.get())
191                    .map(AgentRequest::RequestPermissionRequest)
192                    .map_err(Into::into)
193            }
194            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
195                .map(AgentRequest::WriteTextFileRequest)
196                .map_err(Into::into),
197            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
198                .map(AgentRequest::ReadTextFileRequest)
199                .map_err(Into::into),
200            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
201                .map(AgentRequest::CreateTerminalRequest)
202                .map_err(Into::into),
203            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
204                .map(AgentRequest::TerminalOutputRequest)
205                .map_err(Into::into),
206            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
207                .map(AgentRequest::KillTerminalCommandRequest)
208                .map_err(Into::into),
209            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
210                .map(AgentRequest::ReleaseTerminalRequest)
211                .map_err(Into::into),
212            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
213                serde_json::from_str(params.get())
214                    .map(AgentRequest::WaitForTerminalExitRequest)
215                    .map_err(Into::into)
216            }
217            _ => {
218                if let Some(custom_method) = method.strip_prefix('_') {
219                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
220                        method: custom_method.into(),
221                        params: params.to_owned().into(),
222                    }))
223                } else {
224                    Err(Error::method_not_found())
225                }
226            }
227        }
228    }
229
230    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
231        let params = params.ok_or_else(Error::invalid_params)?;
232
233        match method {
234            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
235                .map(AgentNotification::SessionNotification)
236                .map_err(Into::into),
237            _ => {
238                if let Some(custom_method) = method.strip_prefix('_') {
239                    Ok(AgentNotification::ExtNotification(ExtNotification {
240                        method: custom_method.into(),
241                        params: RawValue::from_string(params.get().to_string())?.into(),
242                    }))
243                } else {
244                    Err(Error::method_not_found())
245                }
246            }
247        }
248    }
249}
250
251/// Marker type representing the agent side of an ACP connection.
252///
253/// This type is used by the RPC layer to determine which messages
254/// are incoming vs outgoing from the agent's perspective.
255///
256/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
257#[derive(Clone, Default, Debug, JsonSchema)]
258#[non_exhaustive]
259pub struct AgentSide;
260
261impl Side for AgentSide {
262    type InRequest = ClientRequest;
263    type InNotification = ClientNotification;
264    type OutResponse = AgentResponse;
265
266    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
267        let params = params.ok_or_else(Error::invalid_params)?;
268
269        match method {
270            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
271                .map(ClientRequest::InitializeRequest)
272                .map_err(Into::into),
273            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
274                .map(ClientRequest::AuthenticateRequest)
275                .map_err(Into::into),
276            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
277                .map(ClientRequest::NewSessionRequest)
278                .map_err(Into::into),
279            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
280                .map(ClientRequest::LoadSessionRequest)
281                .map_err(Into::into),
282            #[cfg(feature = "unstable_session_list")]
283            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284                .map(ClientRequest::ListSessionsRequest)
285                .map_err(Into::into),
286            #[cfg(feature = "unstable_session_fork")]
287            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288                .map(ClientRequest::ForkSessionRequest)
289                .map_err(Into::into),
290            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
291                .map(ClientRequest::SetSessionModeRequest)
292                .map_err(Into::into),
293            #[cfg(feature = "unstable_session_model")]
294            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
295                .map(ClientRequest::SetSessionModelRequest)
296                .map_err(Into::into),
297            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
298                .map(ClientRequest::PromptRequest)
299                .map_err(Into::into),
300            _ => {
301                if let Some(custom_method) = method.strip_prefix('_') {
302                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
303                        method: custom_method.into(),
304                        params: params.to_owned().into(),
305                    }))
306                } else {
307                    Err(Error::method_not_found())
308                }
309            }
310        }
311    }
312
313    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
314        let params = params.ok_or_else(Error::invalid_params)?;
315
316        match method {
317            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
318                .map(ClientNotification::CancelNotification)
319                .map_err(Into::into),
320            _ => {
321                if let Some(custom_method) = method.strip_prefix('_') {
322                    Ok(ClientNotification::ExtNotification(ExtNotification {
323                        method: custom_method.into(),
324                        params: RawValue::from_string(params.get().to_string())?.into(),
325                    }))
326                } else {
327                    Err(Error::method_not_found())
328                }
329            }
330        }
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    use serde_json::{Number, Value};
339
340    #[test]
341    fn id_deserialization() {
342        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
343        assert_eq!(id, RequestId::Null);
344
345        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
346            .unwrap();
347        assert_eq!(id, RequestId::Number(1));
348
349        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
350            .unwrap();
351        assert_eq!(id, RequestId::Number(-1));
352
353        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
354        assert_eq!(id, RequestId::Str("id".to_owned()));
355    }
356
357    #[test]
358    fn id_serialization() {
359        let id = serde_json::to_value(RequestId::Null).unwrap();
360        assert_eq!(id, Value::Null);
361
362        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
363        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
364
365        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
366        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
367
368        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
369        assert_eq!(id, Value::String("id".to_owned()));
370    }
371
372    #[test]
373    fn id_display() {
374        let id = RequestId::Null;
375        assert_eq!(id.to_string(), "null");
376
377        let id = RequestId::Number(1);
378        assert_eq!(id.to_string(), "1");
379
380        let id = RequestId::Number(-1);
381        assert_eq!(id.to_string(), "-1");
382
383        let id = RequestId::Str("id".to_owned());
384        assert_eq!(id.to_string(), "id");
385    }
386}
387
388#[test]
389fn test_notification_wire_format() {
390    use super::*;
391
392    use serde_json::{Value, json};
393
394    // Test client -> agent notification wire format
395    let outgoing_msg = JsonRpcMessage::wrap(
396        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
397            method: "cancel".into(),
398            params: Some(ClientNotification::CancelNotification(CancelNotification {
399                session_id: SessionId("test-123".into()),
400                meta: None,
401            })),
402        }),
403    );
404
405    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
406    assert_eq!(
407        serialized,
408        json!({
409            "jsonrpc": "2.0",
410            "method": "cancel",
411            "params": {
412                "sessionId": "test-123"
413            },
414        })
415    );
416
417    // Test agent -> client notification wire format
418    let outgoing_msg = JsonRpcMessage::wrap(
419        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
420            method: "sessionUpdate".into(),
421            params: Some(AgentNotification::SessionNotification(
422                SessionNotification {
423                    session_id: SessionId("test-456".into()),
424                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
425                        content: ContentBlock::Text(TextContent {
426                            annotations: None,
427                            text: "Hello".to_string(),
428                            meta: None,
429                        }),
430                        meta: None,
431                    }),
432                    meta: None,
433                },
434            )),
435        }),
436    );
437
438    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
439    assert_eq!(
440        serialized,
441        json!({
442            "jsonrpc": "2.0",
443            "method": "sessionUpdate",
444            "params": {
445                "sessionId": "test-456",
446                "update": {
447                    "sessionUpdate": "agent_message_chunk",
448                    "content": {
449                        "type": "text",
450                        "text": "Hello"
451                    }
452                }
453            }
454        })
455    );
456}