Skip to main content

zeph_a2a/
jsonrpc.rs

1use serde::{Deserialize, Serialize, de::DeserializeOwned};
2
3use crate::types::Message;
4
5pub const METHOD_SEND_MESSAGE: &str = "message/send";
6pub const METHOD_SEND_STREAMING_MESSAGE: &str = "message/stream";
7pub const METHOD_GET_TASK: &str = "tasks/get";
8pub const METHOD_CANCEL_TASK: &str = "tasks/cancel";
9
10pub const ERR_TASK_NOT_FOUND: i32 = -32001;
11pub const ERR_TASK_NOT_CANCELABLE: i32 = -32002;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct JsonRpcRequest<P> {
15    pub jsonrpc: String,
16    pub id: serde_json::Value,
17    pub method: String,
18    pub params: P,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(bound(deserialize = "R: Deserialize<'de>"))]
23pub struct JsonRpcResponse<R> {
24    pub jsonrpc: String,
25    pub id: serde_json::Value,
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub result: Option<R>,
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub error: Option<JsonRpcError>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct JsonRpcError {
34    pub code: i32,
35    pub message: String,
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub data: Option<serde_json::Value>,
38}
39
40impl std::fmt::Display for JsonRpcError {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "JSON-RPC error {}: {}", self.code, self.message)
43    }
44}
45
46impl std::error::Error for JsonRpcError {}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(rename_all = "camelCase")]
50pub struct SendMessageParams {
51    pub message: Message,
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub configuration: Option<TaskConfiguration>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct TaskConfiguration {
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub blocking: Option<bool>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(rename_all = "camelCase")]
65pub struct TaskIdParams {
66    pub id: String,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub history_length: Option<u32>,
69}
70
71impl<P: Serialize> JsonRpcRequest<P> {
72    #[must_use]
73    pub fn new(method: &str, params: P) -> Self {
74        Self {
75            jsonrpc: "2.0".into(),
76            id: serde_json::Value::String(uuid::Uuid::new_v4().to_string()),
77            method: method.into(),
78            params,
79        }
80    }
81}
82
83impl<R: DeserializeOwned> JsonRpcResponse<R> {
84    /// # Errors
85    /// Returns `JsonRpcError` if the response contains an error or no result.
86    pub fn into_result(self) -> Result<R, JsonRpcError> {
87        if let Some(err) = self.error {
88            return Err(err);
89        }
90        self.result.ok_or_else(|| JsonRpcError {
91            code: -32603,
92            message: "response contains neither result nor error".into(),
93            data: None,
94        })
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn request_new_sets_jsonrpc_and_uuid_id() {
104        let req = JsonRpcRequest::new(
105            METHOD_SEND_MESSAGE,
106            TaskIdParams {
107                id: "task-1".into(),
108                history_length: None,
109            },
110        );
111        assert_eq!(req.jsonrpc, "2.0");
112        assert_eq!(req.method, "message/send");
113        let id_str = req.id.as_str().unwrap();
114        assert!(uuid::Uuid::parse_str(id_str).is_ok());
115    }
116
117    #[test]
118    fn request_serde_round_trip() {
119        let req = JsonRpcRequest::new(
120            METHOD_GET_TASK,
121            TaskIdParams {
122                id: "t-1".into(),
123                history_length: Some(10),
124            },
125        );
126        let json = serde_json::to_string(&req).unwrap();
127        let back: JsonRpcRequest<TaskIdParams> = serde_json::from_str(&json).unwrap();
128        assert_eq!(back.method, METHOD_GET_TASK);
129        assert_eq!(back.params.id, "t-1");
130        assert_eq!(back.params.history_length, Some(10));
131    }
132
133    #[test]
134    fn response_into_result_ok() {
135        let resp = JsonRpcResponse {
136            jsonrpc: "2.0".into(),
137            id: serde_json::Value::String("1".into()),
138            result: Some(serde_json::json!({"id": "task-1"})),
139            error: None,
140        };
141        let val: serde_json::Value = resp.into_result().unwrap();
142        assert_eq!(val["id"], "task-1");
143    }
144
145    #[test]
146    fn response_into_result_error() {
147        let resp: JsonRpcResponse<serde_json::Value> = JsonRpcResponse {
148            jsonrpc: "2.0".into(),
149            id: serde_json::Value::String("1".into()),
150            result: None,
151            error: Some(JsonRpcError {
152                code: ERR_TASK_NOT_FOUND,
153                message: "task not found".into(),
154                data: None,
155            }),
156        };
157        let err = resp.into_result().unwrap_err();
158        assert_eq!(err.code, ERR_TASK_NOT_FOUND);
159    }
160
161    #[test]
162    fn response_into_result_neither() {
163        let resp: JsonRpcResponse<serde_json::Value> = JsonRpcResponse {
164            jsonrpc: "2.0".into(),
165            id: serde_json::Value::String("1".into()),
166            result: None,
167            error: None,
168        };
169        let err = resp.into_result().unwrap_err();
170        assert_eq!(err.code, -32603);
171    }
172
173    #[test]
174    fn send_message_params_serde() {
175        let params = SendMessageParams {
176            message: Message::user_text("hello"),
177            configuration: Some(TaskConfiguration {
178                blocking: Some(true),
179            }),
180        };
181        let json = serde_json::to_string(&params).unwrap();
182        let back: SendMessageParams = serde_json::from_str(&json).unwrap();
183        assert_eq!(back.message.text_content(), Some("hello"));
184        assert_eq!(back.configuration.unwrap().blocking, Some(true));
185    }
186
187    #[test]
188    fn task_id_params_skips_none() {
189        let params = TaskIdParams {
190            id: "t-1".into(),
191            history_length: None,
192        };
193        let json = serde_json::to_string(&params).unwrap();
194        assert!(!json.contains("historyLength"));
195    }
196
197    #[test]
198    fn jsonrpc_error_display() {
199        let err = JsonRpcError {
200            code: -32001,
201            message: "not found".into(),
202            data: None,
203        };
204        assert_eq!(err.to_string(), "JSON-RPC error -32001: not found");
205    }
206}