Skip to main content

agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug,
24    PartialEq,
25    Clone,
26    Hash,
27    Eq,
28    Deserialize,
29    Serialize,
30    PartialOrd,
31    Ord,
32    Display,
33    JsonSchema,
34    From,
35)]
36#[serde(untagged)]
37#[allow(
38    clippy::exhaustive_enums,
39    reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43    #[display("null")]
44    Null,
45    Number(i64),
46    Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51    clippy::exhaustive_structs,
52    reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56    pub id: RequestId,
57    pub method: Arc<str>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64    clippy::exhaustive_enums,
65    reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70    Result { id: RequestId, result: Result },
71    Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75    #[must_use]
76    pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77        match result {
78            Ok(result) => Self::Result {
79                id: id.into(),
80                result,
81            },
82            Err(error) => Self::Error {
83                id: id.into(),
84                error,
85            },
86        }
87    }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92    clippy::exhaustive_structs,
93    reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97    pub method: Arc<str>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106    clippy::exhaustive_enums,
107    reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110    Request(Request<Remote::InRequest>),
111    Response(Response<Local::OutResponse>),
112    Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118    #[serde(rename = "2.0")]
119    V2,
120}
121
122/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
123/// [required by JSON-RPC 2.0 Specification][1].
124///
125/// [1]: https://www.jsonrpc.org/specification#compatibility
126#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129    jsonrpc: JsonRpcVersion,
130    #[serde(flatten)]
131    message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
136    /// [`JsonRpcMessage`].
137    #[must_use]
138    pub fn wrap(message: M) -> Self {
139        Self {
140            jsonrpc: JsonRpcVersion::V2,
141            message,
142        }
143    }
144}
145
146pub trait Side: Clone {
147    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
152    /// serialization struct to use for each method.
153    ///
154    /// # Errors
155    ///
156    /// This function will return an error if the method is not recognized or if the parameters
157    /// cannot be deserialized into the expected type.
158    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
161    /// serialization struct to use for each method.
162    ///
163    /// # Errors
164    ///
165    /// This function will return an error if the method is not recognized or if the parameters
166    /// cannot be deserialized into the expected type.
167    fn decode_notification(method: &str, params: Option<&RawValue>)
168    -> Result<Self::InNotification>;
169}
170
171/// Marker type representing the client side of an ACP connection.
172///
173/// This type is used by the RPC layer to determine which messages
174/// are incoming vs outgoing from the client's perspective.
175///
176/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
177#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182    type InRequest = AgentRequest;
183    type InNotification = AgentNotification;
184    type OutResponse = ClientResponse;
185
186    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187        let params = params.ok_or_else(Error::invalid_params)?;
188
189        match method {
190            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191                serde_json::from_str(params.get())
192                    .map(AgentRequest::RequestPermissionRequest)
193                    .map_err(Into::into)
194            }
195            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196                .map(AgentRequest::WriteTextFileRequest)
197                .map_err(Into::into),
198            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199                .map(AgentRequest::ReadTextFileRequest)
200                .map_err(Into::into),
201            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202                .map(AgentRequest::CreateTerminalRequest)
203                .map_err(Into::into),
204            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205                .map(AgentRequest::TerminalOutputRequest)
206                .map_err(Into::into),
207            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208                .map(AgentRequest::KillTerminalRequest)
209                .map_err(Into::into),
210            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211                .map(AgentRequest::ReleaseTerminalRequest)
212                .map_err(Into::into),
213            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214                serde_json::from_str(params.get())
215                    .map(AgentRequest::WaitForTerminalExitRequest)
216                    .map_err(Into::into)
217            }
218            _ => {
219                if let Some(custom_method) = method.strip_prefix('_') {
220                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
221                        method: custom_method.into(),
222                        params: params.to_owned().into(),
223                    }))
224                } else {
225                    Err(Error::method_not_found())
226                }
227            }
228        }
229    }
230
231    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
232        let params = params.ok_or_else(Error::invalid_params)?;
233
234        match method {
235            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
236                .map(AgentNotification::SessionNotification)
237                .map_err(Into::into),
238            _ => {
239                if let Some(custom_method) = method.strip_prefix('_') {
240                    Ok(AgentNotification::ExtNotification(ExtNotification {
241                        method: custom_method.into(),
242                        params: params.to_owned().into(),
243                    }))
244                } else {
245                    Err(Error::method_not_found())
246                }
247            }
248        }
249    }
250}
251
252/// Marker type representing the agent side of an ACP connection.
253///
254/// This type is used by the RPC layer to determine which messages
255/// are incoming vs outgoing from the agent's perspective.
256///
257/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
258#[derive(Clone, Default, Debug, JsonSchema)]
259#[non_exhaustive]
260pub struct AgentSide;
261
262impl Side for AgentSide {
263    type InRequest = ClientRequest;
264    type InNotification = ClientNotification;
265    type OutResponse = AgentResponse;
266
267    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
268        let params = params.ok_or_else(Error::invalid_params)?;
269
270        match method {
271            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
272                .map(ClientRequest::InitializeRequest)
273                .map_err(Into::into),
274            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
275                .map(ClientRequest::AuthenticateRequest)
276                .map_err(Into::into),
277            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
278                .map(ClientRequest::NewSessionRequest)
279                .map_err(Into::into),
280            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
281                .map(ClientRequest::LoadSessionRequest)
282                .map_err(Into::into),
283            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
284                .map(ClientRequest::ListSessionsRequest)
285                .map_err(Into::into),
286            #[cfg(feature = "unstable_session_fork")]
287            m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
288                .map(ClientRequest::ForkSessionRequest)
289                .map_err(Into::into),
290            #[cfg(feature = "unstable_session_resume")]
291            m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
292                .map(ClientRequest::ResumeSessionRequest)
293                .map_err(Into::into),
294            #[cfg(feature = "unstable_session_close")]
295            m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
296                .map(ClientRequest::CloseSessionRequest)
297                .map_err(Into::into),
298            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
299                .map(ClientRequest::SetSessionModeRequest)
300                .map_err(Into::into),
301            m if m == AGENT_METHOD_NAMES.session_set_config_option => {
302                serde_json::from_str(params.get())
303                    .map(ClientRequest::SetSessionConfigOptionRequest)
304                    .map_err(Into::into)
305            }
306            #[cfg(feature = "unstable_session_model")]
307            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
308                .map(ClientRequest::SetSessionModelRequest)
309                .map_err(Into::into),
310            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
311                .map(ClientRequest::PromptRequest)
312                .map_err(Into::into),
313            _ => {
314                if let Some(custom_method) = method.strip_prefix('_') {
315                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
316                        method: custom_method.into(),
317                        params: params.to_owned().into(),
318                    }))
319                } else {
320                    Err(Error::method_not_found())
321                }
322            }
323        }
324    }
325
326    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
327        let params = params.ok_or_else(Error::invalid_params)?;
328
329        match method {
330            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
331                .map(ClientNotification::CancelNotification)
332                .map_err(Into::into),
333            _ => {
334                if let Some(custom_method) = method.strip_prefix('_') {
335                    Ok(ClientNotification::ExtNotification(ExtNotification {
336                        method: custom_method.into(),
337                        params: params.to_owned().into(),
338                    }))
339                } else {
340                    Err(Error::method_not_found())
341                }
342            }
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    use serde_json::{Number, Value};
352
353    #[test]
354    fn id_deserialization() {
355        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
356        assert_eq!(id, RequestId::Null);
357
358        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
359            .unwrap();
360        assert_eq!(id, RequestId::Number(1));
361
362        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
363            .unwrap();
364        assert_eq!(id, RequestId::Number(-1));
365
366        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
367        assert_eq!(id, RequestId::Str("id".to_owned()));
368    }
369
370    #[test]
371    fn id_serialization() {
372        let id = serde_json::to_value(RequestId::Null).unwrap();
373        assert_eq!(id, Value::Null);
374
375        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
376        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
377
378        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
379        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
380
381        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
382        assert_eq!(id, Value::String("id".to_owned()));
383    }
384
385    #[test]
386    fn id_display() {
387        let id = RequestId::Null;
388        assert_eq!(id.to_string(), "null");
389
390        let id = RequestId::Number(1);
391        assert_eq!(id.to_string(), "1");
392
393        let id = RequestId::Number(-1);
394        assert_eq!(id.to_string(), "-1");
395
396        let id = RequestId::Str("id".to_owned());
397        assert_eq!(id.to_string(), "id");
398    }
399}
400
401#[test]
402fn test_notification_wire_format() {
403    use super::*;
404
405    use serde_json::{Value, json};
406
407    // Test client -> agent notification wire format
408    let outgoing_msg = JsonRpcMessage::wrap(
409        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
410            method: "cancel".into(),
411            params: Some(ClientNotification::CancelNotification(CancelNotification {
412                session_id: SessionId("test-123".into()),
413                meta: None,
414            })),
415        }),
416    );
417
418    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
419    assert_eq!(
420        serialized,
421        json!({
422            "jsonrpc": "2.0",
423            "method": "cancel",
424            "params": {
425                "sessionId": "test-123"
426            },
427        })
428    );
429
430    // Test agent -> client notification wire format
431    let outgoing_msg = JsonRpcMessage::wrap(
432        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
433            method: "sessionUpdate".into(),
434            params: Some(AgentNotification::SessionNotification(
435                SessionNotification {
436                    session_id: SessionId("test-456".into()),
437                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
438                        content: ContentBlock::Text(TextContent {
439                            annotations: None,
440                            text: "Hello".to_string(),
441                            meta: None,
442                        }),
443                        #[cfg(feature = "unstable_message_id")]
444                        message_id: None,
445                        meta: None,
446                    }),
447                    meta: None,
448                },
449            )),
450        }),
451    );
452
453    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
454    assert_eq!(
455        serialized,
456        json!({
457            "jsonrpc": "2.0",
458            "method": "sessionUpdate",
459            "params": {
460                "sessionId": "test-456",
461                "update": {
462                    "sessionUpdate": "agent_message_chunk",
463                    "content": {
464                        "type": "text",
465                        "text": "Hello"
466                    }
467                }
468            }
469        })
470    );
471}