Skip to main content

agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug,
24    PartialEq,
25    Clone,
26    Hash,
27    Eq,
28    Deserialize,
29    Serialize,
30    PartialOrd,
31    Ord,
32    Display,
33    JsonSchema,
34    From,
35)]
36#[serde(untagged)]
37#[allow(
38    clippy::exhaustive_enums,
39    reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43    #[display("null")]
44    Null,
45    Number(i64),
46    Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51    clippy::exhaustive_structs,
52    reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56    pub id: RequestId,
57    pub method: Arc<str>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64    clippy::exhaustive_enums,
65    reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70    Result { id: RequestId, result: Result },
71    Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75    #[must_use]
76    pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77        match result {
78            Ok(result) => Self::Result {
79                id: id.into(),
80                result,
81            },
82            Err(error) => Self::Error {
83                id: id.into(),
84                error,
85            },
86        }
87    }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92    clippy::exhaustive_structs,
93    reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97    pub method: Arc<str>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106    clippy::exhaustive_enums,
107    reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110    Request(Request<Remote::InRequest>),
111    Response(Response<Local::OutResponse>),
112    Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118    #[serde(rename = "2.0")]
119    V2,
120}
121
122/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
123/// [required by JSON-RPC 2.0 Specification][1].
124///
125/// [1]: https://www.jsonrpc.org/specification#compatibility
126#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129    jsonrpc: JsonRpcVersion,
130    #[serde(flatten)]
131    message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
136    /// [`JsonRpcMessage`].
137    #[must_use]
138    pub fn wrap(message: M) -> Self {
139        Self {
140            jsonrpc: JsonRpcVersion::V2,
141            message,
142        }
143    }
144}
145
146pub trait Side: Clone {
147    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
152    /// serialization struct to use for each method.
153    ///
154    /// # Errors
155    ///
156    /// This function will return an error if the method is not recognized or if the parameters
157    /// cannot be deserialized into the expected type.
158    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
161    /// serialization struct to use for each method.
162    ///
163    /// # Errors
164    ///
165    /// This function will return an error if the method is not recognized or if the parameters
166    /// cannot be deserialized into the expected type.
167    fn decode_notification(method: &str, params: Option<&RawValue>)
168    -> Result<Self::InNotification>;
169}
170
171/// Marker type representing the client side of an ACP connection.
172///
173/// This type is used by the RPC layer to determine which messages
174/// are incoming vs outgoing from the client's perspective.
175///
176/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
177#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182    type InRequest = AgentRequest;
183    type InNotification = AgentNotification;
184    type OutResponse = ClientResponse;
185
186    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187        let params = params.ok_or_else(Error::invalid_params)?;
188
189        match method {
190            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191                serde_json::from_str(params.get())
192                    .map(AgentRequest::RequestPermissionRequest)
193                    .map_err(Into::into)
194            }
195            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196                .map(AgentRequest::WriteTextFileRequest)
197                .map_err(Into::into),
198            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199                .map(AgentRequest::ReadTextFileRequest)
200                .map_err(Into::into),
201            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202                .map(AgentRequest::CreateTerminalRequest)
203                .map_err(Into::into),
204            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205                .map(AgentRequest::TerminalOutputRequest)
206                .map_err(Into::into),
207            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208                .map(AgentRequest::KillTerminalRequest)
209                .map_err(Into::into),
210            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211                .map(AgentRequest::ReleaseTerminalRequest)
212                .map_err(Into::into),
213            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214                serde_json::from_str(params.get())
215                    .map(AgentRequest::WaitForTerminalExitRequest)
216                    .map_err(Into::into)
217            }
218            #[cfg(feature = "unstable_elicitation")]
219            m if m == CLIENT_METHOD_NAMES.session_elicitation => serde_json::from_str(params.get())
220                .map(AgentRequest::ElicitationRequest)
221                .map_err(Into::into),
222            _ => {
223                if let Some(custom_method) = method.strip_prefix('_') {
224                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
225                        method: custom_method.into(),
226                        params: params.to_owned().into(),
227                    }))
228                } else {
229                    Err(Error::method_not_found())
230                }
231            }
232        }
233    }
234
235    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
236        let params = params.ok_or_else(Error::invalid_params)?;
237
238        match method {
239            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
240                .map(AgentNotification::SessionNotification)
241                .map_err(Into::into),
242            #[cfg(feature = "unstable_elicitation")]
243            m if m == CLIENT_METHOD_NAMES.session_elicitation_complete => {
244                serde_json::from_str(params.get())
245                    .map(AgentNotification::ElicitationCompleteNotification)
246                    .map_err(Into::into)
247            }
248            _ => {
249                if let Some(custom_method) = method.strip_prefix('_') {
250                    Ok(AgentNotification::ExtNotification(ExtNotification {
251                        method: custom_method.into(),
252                        params: params.to_owned().into(),
253                    }))
254                } else {
255                    Err(Error::method_not_found())
256                }
257            }
258        }
259    }
260}
261
262/// Marker type representing the agent side of an ACP connection.
263///
264/// This type is used by the RPC layer to determine which messages
265/// are incoming vs outgoing from the agent's perspective.
266///
267/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
268#[derive(Clone, Default, Debug, JsonSchema)]
269#[non_exhaustive]
270pub struct AgentSide;
271
272impl Side for AgentSide {
273    type InRequest = ClientRequest;
274    type InNotification = ClientNotification;
275    type OutResponse = AgentResponse;
276
277    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
278        let params = params.ok_or_else(Error::invalid_params)?;
279
280        match method {
281            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
282                .map(ClientRequest::InitializeRequest)
283                .map_err(Into::into),
284            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
285                .map(ClientRequest::AuthenticateRequest)
286                .map_err(Into::into),
287            #[cfg(feature = "unstable_logout")]
288            m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
289                .map(ClientRequest::LogoutRequest)
290                .map_err(Into::into),
291            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
292                .map(ClientRequest::NewSessionRequest)
293                .map_err(Into::into),
294            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
295                .map(ClientRequest::LoadSessionRequest)
296                .map_err(Into::into),
297            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
298                .map(ClientRequest::ListSessionsRequest)
299                .map_err(Into::into),
300            #[cfg(feature = "unstable_session_fork")]
301            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
302                .map(ClientRequest::ForkSessionRequest)
303                .map_err(Into::into),
304            #[cfg(feature = "unstable_session_resume")]
305            m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
306                .map(ClientRequest::ResumeSessionRequest)
307                .map_err(Into::into),
308            #[cfg(feature = "unstable_session_close")]
309            m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
310                .map(ClientRequest::CloseSessionRequest)
311                .map_err(Into::into),
312            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
313                .map(ClientRequest::SetSessionModeRequest)
314                .map_err(Into::into),
315            m if m == AGENT_METHOD_NAMES.session_set_config_option => {
316                serde_json::from_str(params.get())
317                    .map(ClientRequest::SetSessionConfigOptionRequest)
318                    .map_err(Into::into)
319            }
320            #[cfg(feature = "unstable_session_model")]
321            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
322                .map(ClientRequest::SetSessionModelRequest)
323                .map_err(Into::into),
324            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
325                .map(ClientRequest::PromptRequest)
326                .map_err(Into::into),
327            _ => {
328                if let Some(custom_method) = method.strip_prefix('_') {
329                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
330                        method: custom_method.into(),
331                        params: params.to_owned().into(),
332                    }))
333                } else {
334                    Err(Error::method_not_found())
335                }
336            }
337        }
338    }
339
340    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
341        let params = params.ok_or_else(Error::invalid_params)?;
342
343        match method {
344            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
345                .map(ClientNotification::CancelNotification)
346                .map_err(Into::into),
347            _ => {
348                if let Some(custom_method) = method.strip_prefix('_') {
349                    Ok(ClientNotification::ExtNotification(ExtNotification {
350                        method: custom_method.into(),
351                        params: params.to_owned().into(),
352                    }))
353                } else {
354                    Err(Error::method_not_found())
355                }
356            }
357        }
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    use serde_json::{Number, Value};
366
367    #[test]
368    fn id_deserialization() {
369        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
370        assert_eq!(id, RequestId::Null);
371
372        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
373            .unwrap();
374        assert_eq!(id, RequestId::Number(1));
375
376        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
377            .unwrap();
378        assert_eq!(id, RequestId::Number(-1));
379
380        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
381        assert_eq!(id, RequestId::Str("id".to_owned()));
382    }
383
384    #[test]
385    fn id_serialization() {
386        let id = serde_json::to_value(RequestId::Null).unwrap();
387        assert_eq!(id, Value::Null);
388
389        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
390        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
391
392        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
393        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
394
395        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
396        assert_eq!(id, Value::String("id".to_owned()));
397    }
398
399    #[test]
400    fn id_display() {
401        let id = RequestId::Null;
402        assert_eq!(id.to_string(), "null");
403
404        let id = RequestId::Number(1);
405        assert_eq!(id.to_string(), "1");
406
407        let id = RequestId::Number(-1);
408        assert_eq!(id.to_string(), "-1");
409
410        let id = RequestId::Str("id".to_owned());
411        assert_eq!(id.to_string(), "id");
412    }
413}
414
415#[test]
416fn test_notification_wire_format() {
417    use super::*;
418
419    use serde_json::{Value, json};
420
421    // Test client -> agent notification wire format
422    let outgoing_msg = JsonRpcMessage::wrap(
423        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
424            method: "cancel".into(),
425            params: Some(ClientNotification::CancelNotification(CancelNotification {
426                session_id: SessionId("test-123".into()),
427                meta: None,
428            })),
429        }),
430    );
431
432    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
433    assert_eq!(
434        serialized,
435        json!({
436            "jsonrpc": "2.0",
437            "method": "cancel",
438            "params": {
439                "sessionId": "test-123"
440            },
441        })
442    );
443
444    // Test agent -> client notification wire format
445    let outgoing_msg = JsonRpcMessage::wrap(
446        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
447            method: "sessionUpdate".into(),
448            params: Some(AgentNotification::SessionNotification(
449                SessionNotification {
450                    session_id: SessionId("test-456".into()),
451                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
452                        content: ContentBlock::Text(TextContent {
453                            annotations: None,
454                            text: "Hello".to_string(),
455                            meta: None,
456                        }),
457                        #[cfg(feature = "unstable_message_id")]
458                        message_id: None,
459                        meta: None,
460                    }),
461                    meta: None,
462                },
463            )),
464        }),
465    );
466
467    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
468    assert_eq!(
469        serialized,
470        json!({
471            "jsonrpc": "2.0",
472            "method": "sessionUpdate",
473            "params": {
474                "sessionId": "test-456",
475                "update": {
476                    "sessionUpdate": "agent_message_chunk",
477                    "content": {
478                        "type": "text",
479                        "text": "Hello"
480                    }
481                }
482            }
483        })
484    );
485}