agent_client_protocol_schema/
rpc.rs

1use std::sync::Arc;
2
3use derive_more::Display;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9    AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10    ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13/// JSON RPC Request Id
14///
15/// An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]
16///
17/// The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.
18///
19/// [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.
20///
21/// [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions.
22#[derive(
23    Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display, JsonSchema,
24)]
25#[serde(untagged)]
26#[allow(
27    clippy::exhaustive_enums,
28    reason = "This comes from the JSON-RPC specification itself"
29)]
30pub enum RequestId {
31    #[display("null")]
32    Null,
33    Number(i64),
34    Str(String),
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
38#[allow(
39    clippy::exhaustive_structs,
40    reason = "This comes from the JSON-RPC specification itself"
41)]
42#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
43pub struct Request<Params> {
44    pub id: RequestId,
45    pub method: Arc<str>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub params: Option<Params>,
48}
49
50#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
51#[allow(
52    clippy::exhaustive_enums,
53    reason = "This comes from the JSON-RPC specification itself"
54)]
55#[serde(untagged)]
56#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
57pub enum Response<Result> {
58    Result { id: RequestId, result: Result },
59    Error { id: RequestId, error: Error },
60}
61
62impl<R> Response<R> {
63    pub fn new(id: RequestId, result: Result<R>) -> Self {
64        match result {
65            Ok(result) => Self::Result { id, result },
66            Err(error) => Self::Error { id, error },
67        }
68    }
69}
70
71#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
72#[allow(
73    clippy::exhaustive_structs,
74    reason = "This comes from the JSON-RPC specification itself"
75)]
76#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
77pub struct Notification<Params> {
78    pub method: Arc<str>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub params: Option<Params>,
81}
82
83#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
84#[serde(untagged)]
85#[schemars(inline)]
86#[allow(
87    clippy::exhaustive_enums,
88    reason = "This comes from the JSON-RPC specification itself"
89)]
90pub enum OutgoingMessage<Local: Side, Remote: Side> {
91    Request(Request<Remote::InRequest>),
92    Response(Response<Local::OutResponse>),
93    Notification(Notification<Remote::InNotification>),
94}
95
96#[derive(Debug, Serialize, Deserialize, JsonSchema)]
97#[schemars(inline)]
98enum JsonRpcVersion {
99    #[serde(rename = "2.0")]
100    V2,
101}
102
103/// A message (request, response, or notification) with `"jsonrpc": "2.0"` specified as
104/// [required by JSON-RPC 2.0 Specification][1].
105///
106/// [1]: https://www.jsonrpc.org/specification#compatibility
107#[derive(Debug, Serialize, Deserialize, JsonSchema)]
108#[schemars(inline)]
109pub struct JsonRpcMessage<M> {
110    jsonrpc: JsonRpcVersion,
111    #[serde(flatten)]
112    message: M,
113}
114
115impl<M> JsonRpcMessage<M> {
116    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
117    /// [`JsonRpcMessage`].
118    #[must_use]
119    pub fn wrap(message: M) -> Self {
120        Self {
121            jsonrpc: JsonRpcVersion::V2,
122            message,
123        }
124    }
125}
126
127pub trait Side: Clone {
128    type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
129    type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
130    type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
131
132    /// Decode a request for a given method. This will encapsulate the knowledge of mapping which
133    /// serialization struct to use for each method.
134    ///
135    /// # Errors
136    ///
137    /// This function will return an error if the method is not recognized or if the parameters
138    /// cannot be deserialized into the expected type.
139    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
140
141    /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which
142    /// serialization struct to use for each method.
143    ///
144    /// # Errors
145    ///
146    /// This function will return an error if the method is not recognized or if the parameters
147    /// cannot be deserialized into the expected type.
148    fn decode_notification(method: &str, params: Option<&RawValue>)
149    -> Result<Self::InNotification>;
150}
151
152/// Marker type representing the client side of an ACP connection.
153///
154/// This type is used by the RPC layer to determine which messages
155/// are incoming vs outgoing from the client's perspective.
156///
157/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
158#[derive(Clone, Default, Debug, JsonSchema)]
159#[non_exhaustive]
160pub struct ClientSide;
161
162impl Side for ClientSide {
163    type InRequest = AgentRequest;
164    type InNotification = AgentNotification;
165    type OutResponse = ClientResponse;
166
167    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
168        let params = params.ok_or_else(Error::invalid_params)?;
169
170        match method {
171            m if m == CLIENT_METHOD_NAMES.session_request_permission => {
172                serde_json::from_str(params.get())
173                    .map(AgentRequest::RequestPermissionRequest)
174                    .map_err(Into::into)
175            }
176            m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
177                .map(AgentRequest::WriteTextFileRequest)
178                .map_err(Into::into),
179            m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
180                .map(AgentRequest::ReadTextFileRequest)
181                .map_err(Into::into),
182            m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
183                .map(AgentRequest::CreateTerminalRequest)
184                .map_err(Into::into),
185            m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
186                .map(AgentRequest::TerminalOutputRequest)
187                .map_err(Into::into),
188            m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
189                .map(AgentRequest::KillTerminalCommandRequest)
190                .map_err(Into::into),
191            m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
192                .map(AgentRequest::ReleaseTerminalRequest)
193                .map_err(Into::into),
194            m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
195                serde_json::from_str(params.get())
196                    .map(AgentRequest::WaitForTerminalExitRequest)
197                    .map_err(Into::into)
198            }
199            _ => {
200                if let Some(custom_method) = method.strip_prefix('_') {
201                    Ok(AgentRequest::ExtMethodRequest(ExtRequest {
202                        method: custom_method.into(),
203                        params: params.to_owned().into(),
204                    }))
205                } else {
206                    Err(Error::method_not_found())
207                }
208            }
209        }
210    }
211
212    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
213        let params = params.ok_or_else(Error::invalid_params)?;
214
215        match method {
216            m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
217                .map(AgentNotification::SessionNotification)
218                .map_err(Into::into),
219            _ => {
220                if let Some(custom_method) = method.strip_prefix('_') {
221                    Ok(AgentNotification::ExtNotification(ExtNotification {
222                        method: custom_method.into(),
223                        params: RawValue::from_string(params.get().to_string())?.into(),
224                    }))
225                } else {
226                    Err(Error::method_not_found())
227                }
228            }
229        }
230    }
231}
232
233/// Marker type representing the agent side of an ACP connection.
234///
235/// This type is used by the RPC layer to determine which messages
236/// are incoming vs outgoing from the agent's perspective.
237///
238/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model)
239#[derive(Clone, Default, Debug, JsonSchema)]
240#[non_exhaustive]
241pub struct AgentSide;
242
243impl Side for AgentSide {
244    type InRequest = ClientRequest;
245    type InNotification = ClientNotification;
246    type OutResponse = AgentResponse;
247
248    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
249        let params = params.ok_or_else(Error::invalid_params)?;
250
251        match method {
252            m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
253                .map(ClientRequest::InitializeRequest)
254                .map_err(Into::into),
255            m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
256                .map(ClientRequest::AuthenticateRequest)
257                .map_err(Into::into),
258            m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
259                .map(ClientRequest::NewSessionRequest)
260                .map_err(Into::into),
261            m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
262                .map(ClientRequest::LoadSessionRequest)
263                .map_err(Into::into),
264            #[cfg(feature = "unstable_session_list")]
265            m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
266                .map(ClientRequest::ListSessionsRequest)
267                .map_err(Into::into),
268            m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
269                .map(ClientRequest::SetSessionModeRequest)
270                .map_err(Into::into),
271            #[cfg(feature = "unstable_session_model")]
272            m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
273                .map(ClientRequest::SetSessionModelRequest)
274                .map_err(Into::into),
275            m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
276                .map(ClientRequest::PromptRequest)
277                .map_err(Into::into),
278            _ => {
279                if let Some(custom_method) = method.strip_prefix('_') {
280                    Ok(ClientRequest::ExtMethodRequest(ExtRequest {
281                        method: custom_method.into(),
282                        params: params.to_owned().into(),
283                    }))
284                } else {
285                    Err(Error::method_not_found())
286                }
287            }
288        }
289    }
290
291    fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
292        let params = params.ok_or_else(Error::invalid_params)?;
293
294        match method {
295            m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
296                .map(ClientNotification::CancelNotification)
297                .map_err(Into::into),
298            _ => {
299                if let Some(custom_method) = method.strip_prefix('_') {
300                    Ok(ClientNotification::ExtNotification(ExtNotification {
301                        method: custom_method.into(),
302                        params: RawValue::from_string(params.get().to_string())?.into(),
303                    }))
304                } else {
305                    Err(Error::method_not_found())
306                }
307            }
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    use serde_json::{Number, Value};
317
318    #[test]
319    fn id_deserialization() {
320        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
321        assert_eq!(id, RequestId::Null);
322
323        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
324            .unwrap();
325        assert_eq!(id, RequestId::Number(1));
326
327        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
328            .unwrap();
329        assert_eq!(id, RequestId::Number(-1));
330
331        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
332        assert_eq!(id, RequestId::Str("id".to_owned()));
333    }
334
335    #[test]
336    fn id_serialization() {
337        let id = serde_json::to_value(RequestId::Null).unwrap();
338        assert_eq!(id, Value::Null);
339
340        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
341        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
342
343        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
344        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
345
346        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
347        assert_eq!(id, Value::String("id".to_owned()));
348    }
349
350    #[test]
351    fn id_display() {
352        let id = RequestId::Null;
353        assert_eq!(id.to_string(), "null");
354
355        let id = RequestId::Number(1);
356        assert_eq!(id.to_string(), "1");
357
358        let id = RequestId::Number(-1);
359        assert_eq!(id.to_string(), "-1");
360
361        let id = RequestId::Str("id".to_owned());
362        assert_eq!(id.to_string(), "id");
363    }
364}
365
366#[test]
367fn test_notification_wire_format() {
368    use super::*;
369
370    use serde_json::{Value, json};
371
372    // Test client -> agent notification wire format
373    let outgoing_msg = JsonRpcMessage::wrap(
374        OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
375            method: "cancel".into(),
376            params: Some(ClientNotification::CancelNotification(CancelNotification {
377                session_id: SessionId("test-123".into()),
378                meta: None,
379            })),
380        }),
381    );
382
383    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
384    assert_eq!(
385        serialized,
386        json!({
387            "jsonrpc": "2.0",
388            "method": "cancel",
389            "params": {
390                "sessionId": "test-123"
391            },
392        })
393    );
394
395    // Test agent -> client notification wire format
396    let outgoing_msg = JsonRpcMessage::wrap(
397        OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
398            method: "sessionUpdate".into(),
399            params: Some(AgentNotification::SessionNotification(
400                SessionNotification {
401                    session_id: SessionId("test-456".into()),
402                    update: SessionUpdate::AgentMessageChunk(ContentChunk {
403                        content: ContentBlock::Text(TextContent {
404                            annotations: None,
405                            text: "Hello".to_string(),
406                            meta: None,
407                        }),
408                        meta: None,
409                    }),
410                    meta: None,
411                },
412            )),
413        }),
414    );
415
416    let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
417    assert_eq!(
418        serialized,
419        json!({
420            "jsonrpc": "2.0",
421            "method": "sessionUpdate",
422            "params": {
423                "sessionId": "test-456",
424                "update": {
425                    "sessionUpdate": "agent_message_chunk",
426                    "content": {
427                        "type": "text",
428                        "text": "Hello"
429                    }
430                }
431            }
432        })
433    );
434}