1use std::{fmt, marker::PhantomData};
2
3use serde::{
4 de::{self, DeserializeOwned},
5 Deserialize, Serialize,
6};
7use serde_json::Value;
8
9use crate::{
10 error::{Error, ErrorCode},
11 id::Id,
12};
13
14#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
16#[serde(deny_unknown_fields)]
17pub struct Output<T = Value> {
18 pub result: Option<T>,
20 pub error: Option<Error>,
22 pub id: Option<Id>,
29}
30
31impl<T: Serialize> fmt::Display for Output<T> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 let json = serde_json::to_string(self).expect("`Output` is serializable");
34 write!(f, "{}", json)
35 }
36}
37
38impl<'de, T: Deserialize<'de>> de::Deserialize<'de> for Output<T> {
39 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40 where
41 D: de::Deserializer<'de>,
42 {
43 use self::response_field::{Field, FIELDS};
44
45 struct Visitor<'de, T> {
46 marker: PhantomData<Output<T>>,
47 lifetime: PhantomData<&'de ()>,
48 }
49 impl<'de, T: Deserialize<'de>> de::Visitor<'de> for Visitor<'de, T> {
50 type Value = Output<T>;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
53 formatter.write_str("struct Output")
54 }
55
56 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
57 where
58 A: de::MapAccess<'de>,
59 {
60 let mut result = Option::<Option<T>>::None;
61 let mut error = Option::<Option<Error>>::None;
62 let mut id = Option::<Option<Id>>::None;
63
64 while let Some(key) = de::MapAccess::next_key::<Field>(&mut map)? {
65 match key {
66 Field::Result => {
67 if result.is_some() {
68 return Err(de::Error::duplicate_field("result"));
69 }
70 result = Some(de::MapAccess::next_value::<Option<T>>(&mut map)?)
71 }
72 Field::Error => {
73 if error.is_some() {
74 return Err(de::Error::duplicate_field("error"));
75 }
76 error = Some(de::MapAccess::next_value::<Option<Error>>(&mut map)?)
77 }
78 Field::Id => {
79 if id.is_some() {
80 return Err(de::Error::duplicate_field("id"));
81 }
82 id = Some(de::MapAccess::next_value::<Option<Id>>(&mut map)?)
83 }
84 }
85 }
86
87 let result = result.ok_or_else(|| de::Error::missing_field("result"))?;
88 let error = error.ok_or_else(|| de::Error::missing_field("error"))?;
89 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
90 let (result, error, id) = match (result, error, id) {
91 (Some(value), None, Some(id)) => (Some(value), None, Some(id)),
92 (None, Some(error), id) => (None, Some(error), id),
93 _ => return Err(de::Error::custom("Invalid JSON-RPC 1.0 response")),
94 };
95 Ok(Output { result, error, id })
96 }
97 }
98
99 de::Deserializer::deserialize_struct(
100 deserializer,
101 "Output",
102 FIELDS,
103 Visitor {
104 marker: PhantomData::<Output<T>>,
105 lifetime: PhantomData,
106 },
107 )
108 }
109}
110
111impl<T: Serialize + DeserializeOwned> Output<T> {
112 pub fn success(result: T, id: Id) -> Self {
114 Self {
115 result: Some(result),
116 error: None,
117 id: Some(id),
118 }
119 }
120
121 pub fn failure(error: Error, id: Option<Id>) -> Self {
123 Self {
124 result: None,
125 error: Some(error),
126 id,
127 }
128 }
129
130 pub fn invalid_request(id: Option<Id>) -> Self {
132 Output::failure(Error::new(ErrorCode::InvalidRequest), id)
133 }
134}
135
136impl<T: Serialize + DeserializeOwned> From<Output<T>> for Result<T, Error> {
137 fn from(output: Output<T>) -> Result<T, Error> {
140 match (output.result, output.error) {
141 (Some(result), None) => Ok(result),
142 (None, Some(error)) => Err(error),
143 _ => unreachable!("Invalid JSON-RPC 1.0 Response"),
144 }
145 }
146}
147
148#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
150#[serde(deny_unknown_fields)]
151#[serde(untagged)]
152pub enum Response<T = Value> {
153 Single(Output<T>),
155 Batch(Vec<Output<T>>),
157}
158
159impl<T: Serialize> fmt::Display for Response<T> {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 let json = serde_json::to_string(self).expect("`Response` is serializable");
162 write!(f, "{}", json)
163 }
164}
165
166mod response_field {
167 use super::*;
168
169 pub const FIELDS: &[&str] = &["result", "error", "id"];
170 pub enum Field {
171 Result,
172 Error,
173 Id,
174 }
175
176 impl<'de> de::Deserialize<'de> for Field {
177 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
178 where
179 D: de::Deserializer<'de>,
180 {
181 de::Deserializer::deserialize_identifier(deserializer, FieldVisitor)
182 }
183 }
184
185 struct FieldVisitor;
186 impl<'de> de::Visitor<'de> for FieldVisitor {
187 type Value = Field;
188
189 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
190 formatter.write_str("field identifier")
191 }
192
193 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
194 where
195 E: de::Error,
196 {
197 match v {
198 "result" => Ok(Field::Result),
199 "error" => Ok(Field::Error),
200 "id" => Ok(Field::Id),
201 _ => Err(de::Error::unknown_field(v, &FIELDS)),
202 }
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 fn response_output_cases() -> Vec<(Output, &'static str)> {
212 vec![
213 (
214 Output {
216 result: Some(Value::Bool(true)),
217 error: None,
218 id: Some(Id::Num(1)),
219 },
220 r#"{"result":true,"error":null,"id":1}"#,
221 ),
222 (
223 Output {
225 result: None,
226 error: Some(Error::parse_error()),
227 id: Some(Id::Num(1)),
228 },
229 r#"{"result":null,"error":{"code":-32700,"message":"Parse error"},"id":1}"#,
230 ),
231 (
232 Output {
234 result: None,
235 error: Some(Error::parse_error()),
236 id: None,
237 },
238 r#"{"result":null,"error":{"code":-32700,"message":"Parse error"},"id":null}"#,
239 ),
240 ]
241 }
242
243 #[test]
244 fn response_output_serialization() {
245 for (success_response, expect) in response_output_cases() {
246 let ser = serde_json::to_string(&success_response).unwrap();
247 assert_eq!(ser, expect);
248 let de = serde_json::from_str::<Output>(expect).unwrap();
249 assert_eq!(de, success_response);
250 }
251 }
252
253 #[test]
254 fn response_serialization() {
255 for (output, expect) in response_output_cases() {
256 let response = Response::Single(output);
257 assert_eq!(serde_json::to_string(&response).unwrap(), expect);
258 assert_eq!(serde_json::from_str::<Response>(expect).unwrap(), response);
259 }
260
261 let batch_response = Response::Batch(vec![
262 Output {
263 result: Some(Value::Bool(true)),
264 error: None,
265 id: Some(Id::Num(1)),
266 },
267 Output {
268 result: Some(Value::Bool(false)),
269 error: None,
270 id: Some(Id::Num(2)),
271 },
272 ]);
273 let batch_expect = r#"[{"result":true,"error":null,"id":1},{"result":false,"error":null,"id":2}]"#;
274 assert_eq!(serde_json::to_string(&batch_response).unwrap(), batch_expect);
275 assert_eq!(serde_json::from_str::<Response>(&batch_expect).unwrap(), batch_response);
276 }
277
278 #[test]
279 fn invalid_response() {
280 let cases = vec![
281 r#"{"result":true,"error":null,"id":1,unknown:[]}"#,
283 r#"{"result":true,"error":{"code": -32700,"message": "Parse error"},"id":1}"#,
284 r#"{"result":true,"error":{"code": -32700,"message": "Parse error"}}"#,
285 r#"{"result":true,"id":1}"#,
286 r#"{"error":{"code": -32700,"message": "Parse error"},"id":1}"#,
287 r#"{"unknown":[]}"#,
288 ];
289
290 for case in cases {
291 let response = serde_json::from_str::<Response>(case);
292 assert!(response.is_err());
293 }
294 }
295
296 #[test]
297 fn valid_response() {
298 let cases = vec![
299 r#"{"result":true,"error":null,"id":1}"#,
301 r#"{"result":null,"error":{"code": -32700,"message": "Parse error"},"id":1}"#,
302 ];
303
304 for case in cases {
305 let response = serde_json::from_str::<Response>(case);
306 assert!(response.is_ok());
307 }
308 }
309}