agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug,
24    PartialEq,
25    Clone,
26    Hash,
27    Eq,
28    Deserialize,
29    Serialize,
30    PartialOrd,
31    Ord,
32    Display,
33    JsonSchema,
34    From,
35)]
36#[serde(untagged)]
37#[allow(
38    clippy::exhaustive_enums,
39    reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43    #[display("null")]
44    Null,
45    Number(i64),
46    Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51    clippy::exhaustive_structs,
52    reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56    pub id: RequestId,
57    pub method: Arc<str>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64    clippy::exhaustive_enums,
65    reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70    Result { id: RequestId, result: Result },
71    Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75    pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
76        match result {
77            Ok(result) => Self::Result {
78                id: id.into(),
79                result,
80            },
81            Err(error) => Self::Error {
82                id: id.into(),
83                error,
84            },
85        }
86    }
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
90#[allow(
91    clippy::exhaustive_structs,
92    reason = "This comes from the JSON-RPC specification itself"
93)]
94#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
95pub struct Notification<Params> {
96    pub method: Arc<str>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub params: Option<Params>,
99}
100
101#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
102#[serde(untagged)]
103#[schemars(inline)]
104#[allow(
105    clippy::exhaustive_enums,
106    reason = "This comes from the JSON-RPC specification itself"
107)]
108pub enum OutgoingMessage<Local: Side, Remote: Side> {
109    Request(Request<Remote::InRequest>),
110    Response(Response<Local::OutResponse>),
111    Notification(Notification<Remote::InNotification>),
112}
113
114#[derive(Debug, Serialize, Deserialize, JsonSchema)]
115#[schemars(inline)]
116enum JsonRpcVersion {
117    #[serde(rename = "2.0")]
118    V2,
119}
120
121/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
122/// [required by JSON-RPC 2.0 Specification][1].
123///
124/// [1]: https://www.jsonrpc.org/specification#compatibility
125#[derive(Debug, Serialize, Deserialize, JsonSchema)]
126#[schemars(inline)]
127pub struct JsonRpcMessage<M> {
128    jsonrpc: JsonRpcVersion,
129    #[serde(flatten)]
130    message: M,
131}
132
133impl<M> JsonRpcMessage<M> {
134    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
135    /// [`JsonRpcMessage`].
136    #[must_use]
137    pub fn wrap(message: M) -> Self {
138        Self {
139            jsonrpc: JsonRpcVersion::V2,
140            message,
141        }
142    }
143}
144
145pub trait Side: Clone {
146    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
147    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149
150    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
151    /// serialization struct to use for each method.
152    ///
153    /// # Errors
154    ///
155    /// This function will return an error if the method is not recognized or if the parameters
156    /// cannot be deserialized into the expected type.
157    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
158
159    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
160    /// serialization struct to use for each method.
161    ///
162    /// # Errors
163    ///
164    /// This function will return an error if the method is not recognized or if the parameters
165    /// cannot be deserialized into the expected type.
166    fn decode_notification(method: &str, params: Option<&RawValue>)
167    -> Result<Self::InNotification>;
168}
169
170/// Marker type representing the client side of an ACP connection.
171///
172/// This type is used by the RPC layer to determine which messages
173/// are incoming vs outgoing from the client's perspective.
174///
175/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
176#[derive(Clone, Default, Debug, JsonSchema)]
177#[non_exhaustive]
178pub struct ClientSide;
179
180impl Side for ClientSide {
181    type InRequest = AgentRequest;
182    type InNotification = AgentNotification;
183    type OutResponse = ClientResponse;
184
185    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
186        let params = params.ok_or_else(Error::invalid_params)?;
187
188        match method {
189            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
190                serde_json::from_str(params.get())
191                    .map(AgentRequest::RequestPermissionRequest)
192                    .map_err(Into::into)
193            }
194            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
195                .map(AgentRequest::WriteTextFileRequest)
196                .map_err(Into::into),
197            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
198                .map(AgentRequest::ReadTextFileRequest)
199                .map_err(Into::into),
200            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
201                .map(AgentRequest::CreateTerminalRequest)
202                .map_err(Into::into),
203            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
204                .map(AgentRequest::TerminalOutputRequest)
205                .map_err(Into::into),
206            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
207                .map(AgentRequest::KillTerminalCommandRequest)
208                .map_err(Into::into),
209            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
210                .map(AgentRequest::ReleaseTerminalRequest)
211                .map_err(Into::into),
212            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
213                serde_json::from_str(params.get())
214                    .map(AgentRequest::WaitForTerminalExitRequest)
215                    .map_err(Into::into)
216            }
217            _ => {
218                if let Some(custom_method) = method.strip_prefix('_') {
219                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
220                        method: custom_method.into(),
221                        params: params.to_owned().into(),
222                    }))
223                } else {
224                    Err(Error::method_not_found())
225                }
226            }
227        }
228    }
229
230    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
231        let params = params.ok_or_else(Error::invalid_params)?;
232
233        match method {
234            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
235                .map(AgentNotification::SessionNotification)
236                .map_err(Into::into),
237            _ => {
238                if let Some(custom_method) = method.strip_prefix('_') {
239                    Ok(AgentNotification::ExtNotification(ExtNotification {
240                        method: custom_method.into(),
241                        params: RawValue::from_string(params.get().to_string())?.into(),
242                    }))
243                } else {
244                    Err(Error::method_not_found())
245                }
246            }
247        }
248    }
249}
250
251/// Marker type representing the agent side of an ACP connection.
252///
253/// This type is used by the RPC layer to determine which messages
254/// are incoming vs outgoing from the agent's perspective.
255///
256/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
257#[derive(Clone, Default, Debug, JsonSchema)]
258#[non_exhaustive]
259pub struct AgentSide;
260
261impl Side for AgentSide {
262    type InRequest = ClientRequest;
263    type InNotification = ClientNotification;
264    type OutResponse = AgentResponse;
265
266    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
267        let params = params.ok_or_else(Error::invalid_params)?;
268
269        match method {
270            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
271                .map(ClientRequest::InitializeRequest)
272                .map_err(Into::into),
273            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
274                .map(ClientRequest::AuthenticateRequest)
275                .map_err(Into::into),
276            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
277                .map(ClientRequest::NewSessionRequest)
278                .map_err(Into::into),
279            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
280                .map(ClientRequest::LoadSessionRequest)
281                .map_err(Into::into),
282            #[cfg(feature = "unstable_session_list")]
283            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284                .map(ClientRequest::ListSessionsRequest)
285                .map_err(Into::into),
286            #[cfg(feature = "unstable_session_fork")]
287            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288                .map(ClientRequest::ForkSessionRequest)
289                .map_err(Into::into),
290            #[cfg(feature = "unstable_session_resume")]
291            m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
292                .map(ClientRequest::ResumeSessionRequest)
293                .map_err(Into::into),
294            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
295                .map(ClientRequest::SetSessionModeRequest)
296                .map_err(Into::into),
297            #[cfg(feature = "unstable_session_model")]
298            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
299                .map(ClientRequest::SetSessionModelRequest)
300                .map_err(Into::into),
301            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
302                .map(ClientRequest::PromptRequest)
303                .map_err(Into::into),
304            _ => {
305                if let Some(custom_method) = method.strip_prefix('_') {
306                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
307                        method: custom_method.into(),
308                        params: params.to_owned().into(),
309                    }))
310                } else {
311                    Err(Error::method_not_found())
312                }
313            }
314        }
315    }
316
317    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
318        let params = params.ok_or_else(Error::invalid_params)?;
319
320        match method {
321            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
322                .map(ClientNotification::CancelNotification)
323                .map_err(Into::into),
324            _ => {
325                if let Some(custom_method) = method.strip_prefix('_') {
326                    Ok(ClientNotification::ExtNotification(ExtNotification {
327                        method: custom_method.into(),
328                        params: RawValue::from_string(params.get().to_string())?.into(),
329                    }))
330                } else {
331                    Err(Error::method_not_found())
332                }
333            }
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    use serde_json::{Number, Value};
343
344    #[test]
345    fn id_deserialization() {
346        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
347        assert_eq!(id, RequestId::Null);
348
349        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
350            .unwrap();
351        assert_eq!(id, RequestId::Number(1));
352
353        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
354            .unwrap();
355        assert_eq!(id, RequestId::Number(-1));
356
357        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
358        assert_eq!(id, RequestId::Str("id".to_owned()));
359    }
360
361    #[test]
362    fn id_serialization() {
363        let id = serde_json::to_value(RequestId::Null).unwrap();
364        assert_eq!(id, Value::Null);
365
366        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
367        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
368
369        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
370        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
371
372        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
373        assert_eq!(id, Value::String("id".to_owned()));
374    }
375
376    #[test]
377    fn id_display() {
378        let id = RequestId::Null;
379        assert_eq!(id.to_string(), "null");
380
381        let id = RequestId::Number(1);
382        assert_eq!(id.to_string(), "1");
383
384        let id = RequestId::Number(-1);
385        assert_eq!(id.to_string(), "-1");
386
387        let id = RequestId::Str("id".to_owned());
388        assert_eq!(id.to_string(), "id");
389    }
390}
391
392#[test]
393fn test_notification_wire_format() {
394    use super::*;
395
396    use serde_json::{Value, json};
397
398    // Test client -> agent notification wire format
399    let outgoing_msg = JsonRpcMessage::wrap(
400        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
401            method: "cancel".into(),
402            params: Some(ClientNotification::CancelNotification(CancelNotification {
403                session_id: SessionId("test-123".into()),
404                meta: None,
405            })),
406        }),
407    );
408
409    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
410    assert_eq!(
411        serialized,
412        json!({
413            "jsonrpc": "2.0",
414            "method": "cancel",
415            "params": {
416                "sessionId": "test-123"
417            },
418        })
419    );
420
421    // Test agent -> client notification wire format
422    let outgoing_msg = JsonRpcMessage::wrap(
423        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
424            method: "sessionUpdate".into(),
425            params: Some(AgentNotification::SessionNotification(
426                SessionNotification {
427                    session_id: SessionId("test-456".into()),
428                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
429                        content: ContentBlock::Text(TextContent {
430                            annotations: None,
431                            text: "Hello".to_string(),
432                            meta: None,
433                        }),
434                        meta: None,
435                    }),
436                    meta: None,
437                },
438            )),
439        }),
440    );
441
442    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
443    assert_eq!(
444        serialized,
445        json!({
446            "jsonrpc": "2.0",
447            "method": "sessionUpdate",
448            "params": {
449                "sessionId": "test-456",
450                "update": {
451                    "sessionUpdate": "agent_message_chunk",
452                    "content": {
453                        "type": "text",
454                        "text": "Hello"
455                    }
456                }
457            }
458        })
459    );
460}