Skip to main content

a2a/
jsonrpc.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::Value;
5
6// ---------------------------------------------------------------------------
7// JSON-RPC 2.0 types
8// ---------------------------------------------------------------------------
9
10/// JSON-RPC 2.0 request.
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct JsonRpcRequest {
13    pub jsonrpc: String,
14    pub id: JsonRpcId,
15    pub method: String,
16    #[serde(default, skip_serializing_if = "Option::is_none")]
17    pub params: Option<Value>,
18}
19
20impl JsonRpcRequest {
21    pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<Value>) -> Self {
22        JsonRpcRequest {
23            jsonrpc: "2.0".to_string(),
24            id,
25            method: method.into(),
26            params,
27        }
28    }
29}
30
31/// JSON-RPC 2.0 response.
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct JsonRpcResponse {
34    pub jsonrpc: String,
35    pub id: JsonRpcId,
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub result: Option<Value>,
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub error: Option<JsonRpcError>,
40}
41
42impl JsonRpcResponse {
43    pub fn success(id: JsonRpcId, result: Value) -> Self {
44        JsonRpcResponse {
45            jsonrpc: "2.0".to_string(),
46            id,
47            result: Some(result),
48            error: None,
49        }
50    }
51
52    pub fn error(id: JsonRpcId, error: JsonRpcError) -> Self {
53        JsonRpcResponse {
54            jsonrpc: "2.0".to_string(),
55            id,
56            result: None,
57            error: Some(error),
58        }
59    }
60}
61
62/// JSON-RPC 2.0 error object.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct JsonRpcError {
65    pub code: i32,
66    pub message: String,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub data: Option<Value>,
69}
70
71// ---------------------------------------------------------------------------
72// JsonRpcId — preserves string vs number
73// ---------------------------------------------------------------------------
74
75/// A JSON-RPC ID that can be a string, integer, or null.
76#[derive(Debug, Clone, PartialEq)]
77pub enum JsonRpcId {
78    String(String),
79    Number(i64),
80    Null,
81}
82
83impl Serialize for JsonRpcId {
84    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
85        match self {
86            JsonRpcId::String(s) => serializer.serialize_str(s),
87            JsonRpcId::Number(n) => serializer.serialize_i64(*n),
88            JsonRpcId::Null => serializer.serialize_none(),
89        }
90    }
91}
92
93impl<'de> Deserialize<'de> for JsonRpcId {
94    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
95        let v = Value::deserialize(deserializer)?;
96        match v {
97            Value::String(s) => Ok(JsonRpcId::String(s)),
98            Value::Number(n) => {
99                if let Some(i) = n.as_i64() {
100                    Ok(JsonRpcId::Number(i))
101                } else {
102                    Err(serde::de::Error::custom(
103                        "JSON-RPC id must be a non-fractional number",
104                    ))
105                }
106            }
107            Value::Null => Ok(JsonRpcId::Null),
108            _ => Err(serde::de::Error::custom(
109                "JSON-RPC id must be string, integer, or null",
110            )),
111        }
112    }
113}
114
115impl From<String> for JsonRpcId {
116    fn from(s: String) -> Self {
117        JsonRpcId::String(s)
118    }
119}
120
121impl From<&str> for JsonRpcId {
122    fn from(s: &str) -> Self {
123        JsonRpcId::String(s.to_string())
124    }
125}
126
127impl From<i64> for JsonRpcId {
128    fn from(n: i64) -> Self {
129        JsonRpcId::Number(n)
130    }
131}
132
133// ---------------------------------------------------------------------------
134// A2A JSON-RPC method names
135// ---------------------------------------------------------------------------
136
137pub mod methods {
138    pub const SEND_MESSAGE: &str = "SendMessage";
139    pub const SEND_STREAMING_MESSAGE: &str = "SendStreamingMessage";
140    pub const GET_TASK: &str = "GetTask";
141    pub const LIST_TASKS: &str = "ListTasks";
142    pub const CANCEL_TASK: &str = "CancelTask";
143    pub const SUBSCRIBE_TO_TASK: &str = "SubscribeToTask";
144    pub const CREATE_PUSH_CONFIG: &str = "CreateTaskPushNotificationConfig";
145    pub const GET_PUSH_CONFIG: &str = "GetTaskPushNotificationConfig";
146    pub const LIST_PUSH_CONFIGS: &str = "ListTaskPushNotificationConfigs";
147    pub const DELETE_PUSH_CONFIG: &str = "DeleteTaskPushNotificationConfig";
148    pub const GET_EXTENDED_AGENT_CARD: &str = "GetExtendedAgentCard";
149
150    pub fn is_streaming(method: &str) -> bool {
151        matches!(method, SEND_STREAMING_MESSAGE | SUBSCRIBE_TO_TASK)
152    }
153
154    pub fn is_valid(method: &str) -> bool {
155        matches!(
156            method,
157            SEND_MESSAGE
158                | SEND_STREAMING_MESSAGE
159                | GET_TASK
160                | LIST_TASKS
161                | CANCEL_TASK
162                | SUBSCRIBE_TO_TASK
163                | CREATE_PUSH_CONFIG
164                | GET_PUSH_CONFIG
165                | LIST_PUSH_CONFIGS
166                | DELETE_PUSH_CONFIG
167                | GET_EXTENDED_AGENT_CARD
168        )
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_jsonrpc_id_string() {
178        let id = JsonRpcId::String("abc".into());
179        let json = serde_json::to_string(&id).unwrap();
180        assert_eq!(json, r#""abc""#);
181        let back: JsonRpcId = serde_json::from_str(&json).unwrap();
182        assert_eq!(back, JsonRpcId::String("abc".into()));
183    }
184
185    #[test]
186    fn test_jsonrpc_id_number() {
187        let id = JsonRpcId::Number(42);
188        let json = serde_json::to_string(&id).unwrap();
189        assert_eq!(json, "42");
190        let back: JsonRpcId = serde_json::from_str(&json).unwrap();
191        assert_eq!(back, JsonRpcId::Number(42));
192    }
193
194    #[test]
195    fn test_jsonrpc_request_roundtrip() {
196        let req = JsonRpcRequest::new(
197            JsonRpcId::String("req-1".into()),
198            methods::SEND_MESSAGE,
199            Some(serde_json::json!({"key": "value"})),
200        );
201        let json = serde_json::to_string(&req).unwrap();
202        let back: JsonRpcRequest = serde_json::from_str(&json).unwrap();
203        assert_eq!(back.method, methods::SEND_MESSAGE);
204        assert_eq!(back.jsonrpc, "2.0");
205    }
206
207    #[test]
208    fn test_jsonrpc_id_null() {
209        let id = JsonRpcId::Null;
210        let json = serde_json::to_string(&id).unwrap();
211        assert_eq!(json, "null");
212        let back: JsonRpcId = serde_json::from_str(&json).unwrap();
213        assert_eq!(back, JsonRpcId::Null);
214    }
215
216    #[test]
217    fn test_jsonrpc_id_invalid_type() {
218        let result = serde_json::from_str::<JsonRpcId>(r#"[1,2]"#);
219        assert!(result.is_err());
220    }
221
222    #[test]
223    fn test_jsonrpc_id_fractional_number() {
224        let result = serde_json::from_str::<JsonRpcId>("3.14");
225        assert!(result.is_err());
226    }
227
228    #[test]
229    fn test_jsonrpc_id_from_string() {
230        let id: JsonRpcId = String::from("test").into();
231        assert_eq!(id, JsonRpcId::String("test".into()));
232    }
233
234    #[test]
235    fn test_jsonrpc_id_from_str() {
236        let id: JsonRpcId = "test".into();
237        assert_eq!(id, JsonRpcId::String("test".into()));
238    }
239
240    #[test]
241    fn test_jsonrpc_id_from_i64() {
242        let id: JsonRpcId = 42i64.into();
243        assert_eq!(id, JsonRpcId::Number(42));
244    }
245
246    #[test]
247    fn test_jsonrpc_response_success() {
248        let resp =
249            JsonRpcResponse::success(JsonRpcId::Number(1), serde_json::json!({"status": "ok"}));
250        assert_eq!(resp.jsonrpc, "2.0");
251        assert!(resp.result.is_some());
252        assert!(resp.error.is_none());
253        let json = serde_json::to_string(&resp).unwrap();
254        let back: JsonRpcResponse = serde_json::from_str(&json).unwrap();
255        assert_eq!(back.id, JsonRpcId::Number(1));
256    }
257
258    #[test]
259    fn test_jsonrpc_response_error() {
260        let err = JsonRpcError {
261            code: -32600,
262            message: "invalid".into(),
263            data: None,
264        };
265        let resp = JsonRpcResponse::error(JsonRpcId::String("e1".into()), err);
266        assert!(resp.result.is_none());
267        assert!(resp.error.is_some());
268        assert_eq!(resp.error.as_ref().unwrap().code, -32600);
269    }
270
271    #[test]
272    fn test_jsonrpc_error_with_data() {
273        let err = JsonRpcError {
274            code: -32000,
275            message: "custom".into(),
276            data: Some(serde_json::json!({"detail": "info"})),
277        };
278        let json = serde_json::to_string(&err).unwrap();
279        let back: JsonRpcError = serde_json::from_str(&json).unwrap();
280        assert!(back.data.is_some());
281    }
282
283    #[test]
284    fn test_jsonrpc_request_no_params() {
285        let req = JsonRpcRequest::new(JsonRpcId::Number(1), methods::GET_TASK, None);
286        let json = serde_json::to_string(&req).unwrap();
287        assert!(!json.contains("params"));
288    }
289
290    #[test]
291    fn test_methods_is_streaming() {
292        assert!(methods::is_streaming(methods::SEND_STREAMING_MESSAGE));
293        assert!(methods::is_streaming(methods::SUBSCRIBE_TO_TASK));
294        assert!(!methods::is_streaming("message.stream"));
295        assert!(!methods::is_streaming("tasks.resubscribe"));
296        assert!(!methods::is_streaming(methods::SEND_MESSAGE));
297        assert!(!methods::is_streaming(methods::GET_TASK));
298        assert!(!methods::is_streaming("unknown"));
299    }
300
301    #[test]
302    fn test_methods_is_valid() {
303        assert!(methods::is_valid(methods::SEND_MESSAGE));
304        assert!(methods::is_valid(methods::SEND_STREAMING_MESSAGE));
305        assert!(methods::is_valid(methods::GET_TASK));
306        assert!(methods::is_valid(methods::LIST_TASKS));
307        assert!(methods::is_valid(methods::CANCEL_TASK));
308        assert!(methods::is_valid(methods::SUBSCRIBE_TO_TASK));
309        assert!(methods::is_valid(methods::CREATE_PUSH_CONFIG));
310        assert!(methods::is_valid(methods::GET_PUSH_CONFIG));
311        assert!(methods::is_valid(methods::LIST_PUSH_CONFIGS));
312        assert!(methods::is_valid(methods::DELETE_PUSH_CONFIG));
313        assert!(methods::is_valid(methods::GET_EXTENDED_AGENT_CARD));
314        assert!(!methods::is_valid("message.send"));
315        assert!(!methods::is_valid("message.stream"));
316        assert!(!methods::is_valid("tasks.get"));
317        assert!(!methods::is_valid("tasks.list"));
318        assert!(!methods::is_valid("tasks.cancel"));
319        assert!(!methods::is_valid("tasks.resubscribe"));
320        assert!(!methods::is_valid("push-config.set"));
321        assert!(!methods::is_valid("push-config.get"));
322        assert!(!methods::is_valid("push-config.list"));
323        assert!(!methods::is_valid("push-config.delete"));
324        assert!(!methods::is_valid("agent-card.extended.get"));
325        assert!(!methods::is_valid("unknown.method"));
326    }
327}