1use crate::{common::Id, RpcSend};
2use serde::{
3    de::{DeserializeOwned, MapAccess, Visitor},
4    ser::SerializeMap,
5    Deserialize, Deserializer, Serialize,
6};
7use serde_json::value::RawValue;
8use std::{
9    borrow::{Borrow, Cow},
10    fmt,
11    marker::PhantomData,
12};
13
14mod error;
15pub use error::{BorrowedErrorPayload, ErrorPayload};
16
17mod payload;
18pub use payload::{BorrowedResponsePayload, ResponsePayload};
19
20#[derive(Clone, Debug)]
27pub struct Response<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
28    pub id: Id,
30    pub payload: ResponsePayload<Payload, ErrData>,
32}
33
34pub type BorrowedResponse<'a> = Response<&'a RawValue, &'a RawValue>;
42
43impl BorrowedResponse<'_> {
44    pub fn into_owned(self) -> Response {
47        Response { id: self.id.clone(), payload: self.payload.into_owned() }
48    }
49}
50
51impl<Payload, ErrData> Response<Payload, ErrData> {
52    pub const fn parse_error(id: Id) -> Self {
54        Self { id, payload: ResponsePayload::parse_error() }
55    }
56
57    pub const fn invalid_request(id: Id) -> Self {
59        Self { id, payload: ResponsePayload::invalid_request() }
60    }
61
62    pub const fn method_not_found(id: Id) -> Self {
64        Self { id, payload: ResponsePayload::method_not_found() }
65    }
66
67    pub const fn invalid_params(id: Id) -> Self {
69        Self { id, payload: ResponsePayload::invalid_params() }
70    }
71
72    pub const fn internal_error(id: Id) -> Self {
74        Self { id, payload: ResponsePayload::internal_error() }
75    }
76
77    pub const fn internal_error_message(id: Id, message: Cow<'static, str>) -> Self {
79        Self {
80            id,
81            payload: ResponsePayload::Failure(ErrorPayload::internal_error_message(message)),
82        }
83    }
84
85    pub const fn payload(&self) -> &ResponsePayload<Payload, ErrData> {
87        &self.payload
88    }
89
90    pub const fn internal_error_with_obj(id: Id, data: ErrData) -> Self
92    where
93        ErrData: RpcSend,
94    {
95        Self { id, payload: ResponsePayload::Failure(ErrorPayload::internal_error_with_obj(data)) }
96    }
97
98    pub const fn internal_error_with_message_and_obj(
101        id: Id,
102        message: Cow<'static, str>,
103        data: ErrData,
104    ) -> Self
105    where
106        ErrData: RpcSend,
107    {
108        Self {
109            id,
110            payload: ResponsePayload::Failure(ErrorPayload::internal_error_with_message_and_obj(
111                message, data,
112            )),
113        }
114    }
115
116    pub const fn is_success(&self) -> bool {
118        self.payload.is_success()
119    }
120
121    pub const fn is_error(&self) -> bool {
123        self.payload.is_error()
124    }
125
126    pub fn error_code(&self) -> Option<i64> {
128        self.payload().error_code()
129    }
130}
131
132impl<Payload, ErrData> Response<Payload, ErrData>
133where
134    Payload: RpcSend,
135    ErrData: RpcSend,
136{
137    pub fn serialize_payload(&self) -> serde_json::Result<Response> {
139        self.payload.serialize_payload().map(|payload| Response { id: self.id.clone(), payload })
140    }
141}
142
143impl<'a, Payload, ErrData> Response<Payload, ErrData>
144where
145    Payload: AsRef<RawValue> + 'a,
146{
147    pub fn try_success_as<T: Deserialize<'a>>(&'a self) -> Option<serde_json::Result<T>> {
152        self.payload.try_success_as()
153    }
154
155    pub fn deser_success<T: DeserializeOwned>(self) -> Result<Response<T, ErrData>, Self> {
163        match self.payload.deserialize_success() {
164            Ok(payload) => Ok(Response { id: self.id, payload }),
165            Err(payload) => Err(Self { id: self.id, payload }),
166        }
167    }
168}
169
170impl<'a, Payload, ErrData> Response<Payload, ErrData>
171where
172    ErrData: Borrow<RawValue> + 'a,
173{
174    pub fn try_error_as<T: Deserialize<'a>>(&'a self) -> Option<serde_json::Result<T>> {
179        self.payload.try_error_as()
180    }
181
182    pub fn deser_err<T: DeserializeOwned>(self) -> Result<Response<Payload, T>, Self> {
190        match self.payload.deserialize_error() {
191            Ok(payload) => Ok(Response { id: self.id, payload }),
192            Err(payload) => Err(Self { id: self.id, payload }),
193        }
194    }
195}
196
197impl<'de, Payload, ErrData> Deserialize<'de> for Response<Payload, ErrData>
198where
199    Payload: Deserialize<'de>,
200    ErrData: Deserialize<'de>,
201{
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        enum Field {
207            Result,
208            Error,
209            Id,
210            Unknown,
211        }
212
213        impl<'de> Deserialize<'de> for Field {
214            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215            where
216                D: Deserializer<'de>,
217            {
218                struct FieldVisitor;
219
220                impl serde::de::Visitor<'_> for FieldVisitor {
221                    type Value = Field;
222
223                    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
224                        formatter.write_str("`result`, `error` and `id`")
225                    }
226
227                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
228                    where
229                        E: serde::de::Error,
230                    {
231                        match value {
232                            "result" => Ok(Field::Result),
233                            "error" => Ok(Field::Error),
234                            "id" => Ok(Field::Id),
235                            _ => Ok(Field::Unknown),
236                        }
237                    }
238                }
239                deserializer.deserialize_identifier(FieldVisitor)
240            }
241        }
242
243        struct JsonRpcResponseVisitor<T>(PhantomData<T>);
244
245        impl<'de, Payload, ErrData> Visitor<'de> for JsonRpcResponseVisitor<fn() -> (Payload, ErrData)>
246        where
247            Payload: Deserialize<'de>,
248            ErrData: Deserialize<'de>,
249        {
250            type Value = Response<Payload, ErrData>;
251
252            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253                formatter.write_str(
254                    "a JSON-RPC response object, consisting of either a result or an error",
255                )
256            }
257
258            fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
259            where
260                M: MapAccess<'de>,
261            {
262                let mut result = None;
263                let mut error = None;
264                let mut id: Option<Id> = None;
265
266                while let Some(key) = map.next_key()? {
267                    match key {
268                        Field::Result => {
269                            if result.is_some() {
270                                return Err(serde::de::Error::duplicate_field("result"));
271                            }
272                            result = Some(map.next_value()?);
273                        }
274                        Field::Error => {
275                            if error.is_some() {
276                                return Err(serde::de::Error::duplicate_field("error"));
277                            }
278                            error = Some(map.next_value()?);
279                        }
280                        Field::Id => {
281                            if id.is_some() {
282                                return Err(serde::de::Error::duplicate_field("id"));
283                            }
284                            id = Some(map.next_value()?);
285                        }
286                        Field::Unknown => {
287                            let _: serde::de::IgnoredAny = map.next_value()?; }
289                    }
290                }
291                let id = id.unwrap_or(Id::None);
292
293                match (result, error) {
294                    (Some(result), None) => {
295                        Ok(Response { id, payload: ResponsePayload::Success(result) })
296                    }
297                    (None, Some(error)) => {
298                        Ok(Response { id, payload: ResponsePayload::Failure(error) })
299                    }
300                    (None, None) => Err(serde::de::Error::missing_field("result or error")),
301                    (Some(_), Some(_)) => {
302                        Err(serde::de::Error::custom("result and error are mutually exclusive"))
303                    }
304                }
305            }
306        }
307
308        deserializer.deserialize_map(JsonRpcResponseVisitor(PhantomData))
309    }
310}
311
312impl<Payload, ErrData> Serialize for Response<Payload, ErrData>
313where
314    Payload: Serialize,
315    ErrData: Serialize,
316{
317    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318    where
319        S: serde::Serializer,
320    {
321        let mut map = serializer.serialize_map(Some(3))?;
322        map.serialize_entry("jsonrpc", "2.0")?;
323        map.serialize_entry("id", &self.id)?;
324        match &self.payload {
325            ResponsePayload::Success(result) => {
326                map.serialize_entry("result", result)?;
327            }
328            ResponsePayload::Failure(error) => {
329                map.serialize_entry("error", error)?;
330            }
331        }
332        map.end()
333    }
334}
335
336#[cfg(test)]
337mod test {
338    #[test]
339    fn deser_success() {
340        let response = r#"{
341            "jsonrpc": "2.0",
342            "result": "california",
343            "id": 1
344        }"#;
345        let response: super::Response = serde_json::from_str(response).unwrap();
346        assert_eq!(response.id, super::Id::Number(1));
347        assert!(matches!(response.payload, super::ResponsePayload::Success(_)));
348    }
349
350    #[test]
351    fn deser_err() {
352        let response = r#"{
353            "jsonrpc": "2.0",
354            "error": {
355                "code": -32600,
356                "message": "Invalid Request"
357            },
358            "id": null
359        }"#;
360        let response: super::Response = serde_json::from_str(response).unwrap();
361        assert_eq!(response.id, super::Id::None);
362        assert!(matches!(response.payload, super::ResponsePayload::Failure(_)));
363    }
364
365    #[test]
366    fn deser_complex_success() {
367        let response = r#"{
368            "result": {
369                "name": "california",
370                "population": 39250000,
371                "cities": [
372                    "los angeles",
373                    "san francisco"
374                ]
375            }
376        }"#;
377        let response: super::Response = serde_json::from_str(response).unwrap();
378        assert_eq!(response.id, super::Id::None);
379        assert!(matches!(response.payload, super::ResponsePayload::Success(_)));
380    }
381}
382
383