Skip to main content

agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug,
24    PartialEq,
25    Clone,
26    Hash,
27    Eq,
28    Deserialize,
29    Serialize,
30    PartialOrd,
31    Ord,
32    Display,
33    JsonSchema,
34    From,
35)]
36#[serde(untagged)]
37#[allow(
38    clippy::exhaustive_enums,
39    reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43    #[display("null")]
44    Null,
45    Number(i64),
46    Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51    clippy::exhaustive_structs,
52    reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56    pub id: RequestId,
57    pub method: Arc<str>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64    clippy::exhaustive_enums,
65    reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70    Result { id: RequestId, result: Result },
71    Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75    #[must_use]
76    pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77        match result {
78            Ok(result) => Self::Result {
79                id: id.into(),
80                result,
81            },
82            Err(error) => Self::Error {
83                id: id.into(),
84                error,
85            },
86        }
87    }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92    clippy::exhaustive_structs,
93    reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97    pub method: Arc<str>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106    clippy::exhaustive_enums,
107    reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110    Request(Request<Remote::InRequest>),
111    Response(Response<Local::OutResponse>),
112    Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118    #[serde(rename = "2.0")]
119    V2,
120}
121
122/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
123/// [required by JSON-RPC 2.0 Specification][1].
124///
125/// [1]: https://www.jsonrpc.org/specification#compatibility
126#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129    jsonrpc: JsonRpcVersion,
130    #[serde(flatten)]
131    message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
136    /// [`JsonRpcMessage`].
137    #[must_use]
138    pub fn wrap(message: M) -> Self {
139        Self {
140            jsonrpc: JsonRpcVersion::V2,
141            message,
142        }
143    }
144}
145
146pub trait Side: Clone {
147    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
152    /// serialization struct to use for each method.
153    ///
154    /// # Errors
155    ///
156    /// This function will return an error if the method is not recognized or if the parameters
157    /// cannot be deserialized into the expected type.
158    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
161    /// serialization struct to use for each method.
162    ///
163    /// # Errors
164    ///
165    /// This function will return an error if the method is not recognized or if the parameters
166    /// cannot be deserialized into the expected type.
167    fn decode_notification(method: &str, params: Option<&RawValue>)
168    -> Result<Self::InNotification>;
169}
170
171/// Marker type representing the client side of an ACP connection.
172///
173/// This type is used by the RPC layer to determine which messages
174/// are incoming vs outgoing from the client's perspective.
175///
176/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
177#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182    type InRequest = AgentRequest;
183    type InNotification = AgentNotification;
184    type OutResponse = ClientResponse;
185
186    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187        let params = params.ok_or_else(Error::invalid_params)?;
188
189        match method {
190            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191                serde_json::from_str(params.get())
192                    .map(AgentRequest::RequestPermissionRequest)
193                    .map_err(Into::into)
194            }
195            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196                .map(AgentRequest::WriteTextFileRequest)
197                .map_err(Into::into),
198            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199                .map(AgentRequest::ReadTextFileRequest)
200                .map_err(Into::into),
201            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202                .map(AgentRequest::CreateTerminalRequest)
203                .map_err(Into::into),
204            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205                .map(AgentRequest::TerminalOutputRequest)
206                .map_err(Into::into),
207            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208                .map(AgentRequest::KillTerminalRequest)
209                .map_err(Into::into),
210            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211                .map(AgentRequest::ReleaseTerminalRequest)
212                .map_err(Into::into),
213            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214                serde_json::from_str(params.get())
215                    .map(AgentRequest::WaitForTerminalExitRequest)
216                    .map_err(Into::into)
217            }
218            _ => {
219                if let Some(custom_method) = method.strip_prefix('_') {
220                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
221                        method: custom_method.into(),
222                        params: params.to_owned().into(),
223                    }))
224                } else {
225                    Err(Error::method_not_found())
226                }
227            }
228        }
229    }
230
231    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
232        let params = params.ok_or_else(Error::invalid_params)?;
233
234        match method {
235            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
236                .map(AgentNotification::SessionNotification)
237                .map_err(Into::into),
238            _ => {
239                if let Some(custom_method) = method.strip_prefix('_') {
240                    Ok(AgentNotification::ExtNotification(ExtNotification {
241                        method: custom_method.into(),
242                        params: params.to_owned().into(),
243                    }))
244                } else {
245                    Err(Error::method_not_found())
246                }
247            }
248        }
249    }
250}
251
252/// Marker type representing the agent side of an ACP connection.
253///
254/// This type is used by the RPC layer to determine which messages
255/// are incoming vs outgoing from the agent's perspective.
256///
257/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
258#[derive(Clone, Default, Debug, JsonSchema)]
259#[non_exhaustive]
260pub struct AgentSide;
261
262impl Side for AgentSide {
263    type InRequest = ClientRequest;
264    type InNotification = ClientNotification;
265    type OutResponse = AgentResponse;
266
267    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
268        let params = params.ok_or_else(Error::invalid_params)?;
269
270        match method {
271            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
272                .map(ClientRequest::InitializeRequest)
273                .map_err(Into::into),
274            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
275                .map(ClientRequest::AuthenticateRequest)
276                .map_err(Into::into),
277            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
278                .map(ClientRequest::NewSessionRequest)
279                .map_err(Into::into),
280            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
281                .map(ClientRequest::LoadSessionRequest)
282                .map_err(Into::into),
283            #[cfg(feature = "unstable_session_list")]
284            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
285                .map(ClientRequest::ListSessionsRequest)
286                .map_err(Into::into),
287            #[cfg(feature = "unstable_session_fork")]
288            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
289                .map(ClientRequest::ForkSessionRequest)
290                .map_err(Into::into),
291            #[cfg(feature = "unstable_session_resume")]
292            m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
293                .map(ClientRequest::ResumeSessionRequest)
294                .map_err(Into::into),
295            #[cfg(feature = "unstable_session_stop")]
296            m if m == AGENT_METHOD_NAMES.session_stop => serde_json::from_str(params.get())
297                .map(ClientRequest::StopSessionRequest)
298                .map_err(Into::into),
299            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
300                .map(ClientRequest::SetSessionModeRequest)
301                .map_err(Into::into),
302            m if m == AGENT_METHOD_NAMES.session_set_config_option => {
303                serde_json::from_str(params.get())
304                    .map(ClientRequest::SetSessionConfigOptionRequest)
305                    .map_err(Into::into)
306            }
307            #[cfg(feature = "unstable_session_model")]
308            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
309                .map(ClientRequest::SetSessionModelRequest)
310                .map_err(Into::into),
311            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
312                .map(ClientRequest::PromptRequest)
313                .map_err(Into::into),
314            _ => {
315                if let Some(custom_method) = method.strip_prefix('_') {
316                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
317                        method: custom_method.into(),
318                        params: params.to_owned().into(),
319                    }))
320                } else {
321                    Err(Error::method_not_found())
322                }
323            }
324        }
325    }
326
327    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
328        let params = params.ok_or_else(Error::invalid_params)?;
329
330        match method {
331            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
332                .map(ClientNotification::CancelNotification)
333                .map_err(Into::into),
334            _ => {
335                if let Some(custom_method) = method.strip_prefix('_') {
336                    Ok(ClientNotification::ExtNotification(ExtNotification {
337                        method: custom_method.into(),
338                        params: params.to_owned().into(),
339                    }))
340                } else {
341                    Err(Error::method_not_found())
342                }
343            }
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    use serde_json::{Number, Value};
353
354    #[test]
355    fn id_deserialization() {
356        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
357        assert_eq!(id, RequestId::Null);
358
359        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
360            .unwrap();
361        assert_eq!(id, RequestId::Number(1));
362
363        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
364            .unwrap();
365        assert_eq!(id, RequestId::Number(-1));
366
367        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
368        assert_eq!(id, RequestId::Str("id".to_owned()));
369    }
370
371    #[test]
372    fn id_serialization() {
373        let id = serde_json::to_value(RequestId::Null).unwrap();
374        assert_eq!(id, Value::Null);
375
376        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
377        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
378
379        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
380        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
381
382        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
383        assert_eq!(id, Value::String("id".to_owned()));
384    }
385
386    #[test]
387    fn id_display() {
388        let id = RequestId::Null;
389        assert_eq!(id.to_string(), "null");
390
391        let id = RequestId::Number(1);
392        assert_eq!(id.to_string(), "1");
393
394        let id = RequestId::Number(-1);
395        assert_eq!(id.to_string(), "-1");
396
397        let id = RequestId::Str("id".to_owned());
398        assert_eq!(id.to_string(), "id");
399    }
400}
401
402#[test]
403fn test_notification_wire_format() {
404    use super::*;
405
406    use serde_json::{Value, json};
407
408    // Test client -> agent notification wire format
409    let outgoing_msg = JsonRpcMessage::wrap(
410        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
411            method: "cancel".into(),
412            params: Some(ClientNotification::CancelNotification(CancelNotification {
413                session_id: SessionId("test-123".into()),
414                meta: None,
415            })),
416        }),
417    );
418
419    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
420    assert_eq!(
421        serialized,
422        json!({
423            "jsonrpc": "2.0",
424            "method": "cancel",
425            "params": {
426                "sessionId": "test-123"
427            },
428        })
429    );
430
431    // Test agent -> client notification wire format
432    let outgoing_msg = JsonRpcMessage::wrap(
433        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
434            method: "sessionUpdate".into(),
435            params: Some(AgentNotification::SessionNotification(
436                SessionNotification {
437                    session_id: SessionId("test-456".into()),
438                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
439                        content: ContentBlock::Text(TextContent {
440                            annotations: None,
441                            text: "Hello".to_string(),
442                            meta: None,
443                        }),
444                        #[cfg(feature = "unstable_message_id")]
445                        message_id: None,
446                        meta: None,
447                    }),
448                    meta: None,
449                },
450            )),
451        }),
452    );
453
454    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
455    assert_eq!(
456        serialized,
457        json!({
458            "jsonrpc": "2.0",
459            "method": "sessionUpdate",
460            "params": {
461                "sessionId": "test-456",
462                "update": {
463                    "sessionUpdate": "agent_message_chunk",
464                    "content": {
465                        "type": "text",
466                        "text": "Hello"
467                    }
468                }
469            }
470        })
471    );
472}