nip_55/json_rpc/
mod.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use crate::{stream_helper::map_sender, uds_req_res::UdsResponse};
7use futures::StreamExt;
8use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
9
10pub trait JsonRpcServerTransport<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>>>:
11    futures::Stream<
12    Item = (
13        SingleOrBatchRequest,
14        futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponse>>,
15    ),
16>
17{
18}
19
20#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
21#[serde(untagged)]
22pub enum SingleOrBatch<T> {
23    Single(T),
24    Batch(Vec<T>),
25}
26
27impl<T> SingleOrBatch<T> {
28    pub fn map<TOut, MapFn: Fn(T) -> TOut>(self, map_fn: MapFn) -> SingleOrBatch<TOut> {
29        match self {
30            Self::Single(request) => SingleOrBatch::Single(map_fn(request)),
31            Self::Batch(requests) => {
32                SingleOrBatch::Batch(requests.into_iter().map(map_fn).collect())
33            }
34        }
35    }
36}
37
38impl<T> UdsResponse for SingleOrBatch<T>
39where
40    T: Serialize + DeserializeOwned + Send + 'static,
41{
42    fn request_parse_error_response() -> Self {
43        // TODO: Implement this.
44        panic!()
45    }
46}
47
48pub struct JsonRpcServerStream<
49    SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static,
50> {
51    #[allow(clippy::type_complexity)]
52    stream: Pin<
53        Box<
54            dyn futures::Stream<
55                    Item = (
56                        SingleOrBatchRequest,
57                        futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponseData>>,
58                    ),
59                > + Send,
60        >,
61    >,
62}
63
64impl<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static>
65    futures::Stream for JsonRpcServerStream<SingleOrBatchRequest>
66{
67    type Item = (
68        SingleOrBatchRequest,
69        futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponseData>>,
70    );
71
72    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73        self.stream
74            .poll_next_unpin(cx)
75            .map(|next_item_or| next_item_or)
76    }
77}
78
79impl<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static>
80    JsonRpcServerStream<SingleOrBatchRequest>
81{
82    // TODO: Completely clean up this function. It's a mess.
83    pub fn start(
84        transport: impl JsonRpcServerTransport<SingleOrBatchRequest> + Send + 'static,
85    ) -> Self {
86        Self {
87            stream: Box::pin(transport.map(|(request, response_sender)| {
88                let single_request_id_or = match request.as_ref() {
89                    SingleOrBatch::Single(request) => Some(request.id().clone()),
90                    SingleOrBatch::Batch(_requests) => None,
91                };
92                let batch_request_ids_or: Option<Vec<JsonRpcId>> = match request.as_ref() {
93                    SingleOrBatch::Single(_request) => None,
94                    SingleOrBatch::Batch(requests) => Some(
95                        requests
96                            .iter()
97                            .map(|request| request.id().clone())
98                            .collect(),
99                    ),
100                };
101
102                let response_sender = map_sender(response_sender, |response| match response {
103                    SingleOrBatch::Single(response_data) => {
104                        let Some(request_id) = single_request_id_or else {
105                            panic!("Expected a single request, but got a batch of requests",)
106                        };
107                        SingleOrBatch::Single(JsonRpcResponse::new(response_data, request_id))
108                    }
109                    SingleOrBatch::Batch(responses) => {
110                        let Some(request_ids) = batch_request_ids_or else {
111                            panic!("Expected a batch of requests, but got a single request")
112                        };
113                        SingleOrBatch::Batch(
114                            responses
115                                .into_iter()
116                                .enumerate()
117                                .map(|(i, response_data)| {
118                                    let Some(request_id) = request_ids.get(i) else {
119                                        panic!("Expected a request at index {i}")
120                                    };
121                                    JsonRpcResponse::new(response_data, request_id.clone())
122                                })
123                                .collect(),
124                        )
125                    }
126                });
127
128                (request, response_sender)
129            })),
130        }
131    }
132}
133
134// TODO: Uncomment this and have Nip55Client implement it.
135// #[async_trait]
136// pub trait JsonRpcClientTransport<E> {
137//     async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, E>;
138
139//     async fn send_batch_request(
140//         &self,
141//         requests: Vec<JsonRpcRequest>,
142//     ) -> Result<Vec<JsonRpcResponse>, E>;
143// }
144
145#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
146enum JsonRpcVersion {
147    #[serde(rename = "2.0")]
148    V2,
149}
150
151#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
152pub struct JsonRpcRequest {
153    jsonrpc: JsonRpcVersion,
154    method: String,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    params: Option<JsonRpcStructuredValue>,
157    id: JsonRpcId,
158}
159
160impl AsRef<Self> for JsonRpcRequest {
161    fn as_ref(&self) -> &Self {
162        self
163    }
164}
165
166impl JsonRpcRequest {
167    pub const fn new(
168        method: String,
169        params: Option<JsonRpcStructuredValue>,
170        id: JsonRpcId,
171    ) -> Self {
172        Self {
173            jsonrpc: JsonRpcVersion::V2,
174            method,
175            params,
176            id,
177        }
178    }
179
180    pub fn method(&self) -> &str {
181        &self.method
182    }
183
184    pub const fn params(&self) -> Option<&JsonRpcStructuredValue> {
185        self.params.as_ref()
186    }
187
188    pub const fn id(&self) -> &JsonRpcId {
189        &self.id
190    }
191}
192
193// TODO: Rename to `JsonRpcRequestId`.
194#[derive(PartialEq, Eq, Debug, Clone)]
195pub enum JsonRpcId {
196    Number(i32),
197    String(String),
198    Null,
199}
200
201impl JsonRpcId {
202    fn to_json_value(&self) -> serde_json::Value {
203        match self {
204            Self::Number(n) => serde_json::Value::Number((*n).into()),
205            Self::String(s) => serde_json::Value::String(s.clone()),
206            Self::Null => serde_json::Value::Null,
207        }
208    }
209}
210
211impl serde::Serialize for JsonRpcId {
212    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
213    where
214        S: Serializer,
215    {
216        self.to_json_value().serialize(serializer)
217    }
218}
219
220impl<'de> Deserialize<'de> for JsonRpcId {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        serde_json::Value::deserialize(deserializer).and_then(|value| {
226            if value.is_i64() {
227                Ok(Self::Number(
228                    i32::try_from(value.as_i64().unwrap()).unwrap(),
229                ))
230            } else if value.is_string() {
231                Ok(Self::String(value.as_str().unwrap().to_string()))
232            } else if value.is_null() {
233                Ok(Self::Null)
234            } else {
235                Err(serde::de::Error::custom("Invalid JSON-RPC ID"))
236            }
237        })
238    }
239}
240
241#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
242#[serde(untagged)]
243pub enum JsonRpcStructuredValue {
244    Object(serde_json::Map<String, serde_json::Value>),
245    Array(Vec<serde_json::Value>),
246}
247
248impl JsonRpcStructuredValue {
249    pub fn into_value(self) -> serde_json::Value {
250        match self {
251            Self::Object(object) => serde_json::Value::Object(object),
252            Self::Array(array) => serde_json::Value::Array(array),
253        }
254    }
255}
256
257#[derive(Serialize, Deserialize, PartialEq, Debug)]
258pub struct JsonRpcResponse {
259    jsonrpc: JsonRpcVersion,
260    #[serde(flatten)]
261    data: JsonRpcResponseData,
262    id: JsonRpcId,
263}
264
265impl JsonRpcResponse {
266    pub const fn new(data: JsonRpcResponseData, id: JsonRpcId) -> Self {
267        Self {
268            jsonrpc: JsonRpcVersion::V2,
269            data,
270            id,
271        }
272    }
273
274    pub const fn data(&self) -> &JsonRpcResponseData {
275        &self.data
276    }
277}
278
279#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
280#[serde(untagged)]
281pub enum JsonRpcResponseData {
282    Success { result: serde_json::Value },
283    Error { error: JsonRpcError },
284}
285
286// TODO: Make these fields private.
287#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
288pub struct JsonRpcError {
289    code: JsonRpcErrorCode,
290    message: String,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    data: Option<serde_json::Value>,
293}
294
295impl JsonRpcError {
296    pub const fn new(
297        code: JsonRpcErrorCode,
298        message: String,
299        data: Option<serde_json::Value>,
300    ) -> Self {
301        Self {
302            code,
303            message,
304            data,
305        }
306    }
307
308    pub const fn code(&self) -> JsonRpcErrorCode {
309        self.code
310    }
311}
312
313#[derive(PartialEq, Eq, Debug, Copy, Clone)]
314pub enum JsonRpcErrorCode {
315    ParseError,
316    InvalidRequest,
317    MethodNotFound,
318    InvalidParams,
319    InternalError,
320    Custom(i32), // TODO: Make it so that this can only be used for custom error codes, not the standard ones above.
321}
322
323impl Serialize for JsonRpcErrorCode {
324    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
325    where
326        S: Serializer,
327    {
328        let code = match *self {
329            Self::ParseError => -32700,
330            Self::InvalidRequest => -32600,
331            Self::MethodNotFound => -32601,
332            Self::InvalidParams => -32602,
333            Self::InternalError => -32603,
334            Self::Custom(c) => c,
335        };
336        serializer.serialize_i32(code)
337    }
338}
339
340impl<'de> Deserialize<'de> for JsonRpcErrorCode {
341    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
342    where
343        D: serde::Deserializer<'de>,
344    {
345        let code = i32::deserialize(deserializer)?;
346        match code {
347            -32700 => Ok(Self::ParseError),
348            -32600 => Ok(Self::InvalidRequest),
349            -32601 => Ok(Self::MethodNotFound),
350            -32602 => Ok(Self::InvalidParams),
351            -32603 => Ok(Self::InternalError),
352            _ => Ok(Self::Custom(code)),
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    fn assert_json_serialization<
362        'a,
363        T: Serialize + Deserialize<'a> + PartialEq + std::fmt::Debug,
364    >(
365        value: T,
366        json_string: &'a str,
367    ) {
368        assert_eq!(serde_json::from_str::<T>(json_string).unwrap(), value);
369        assert_eq!(serde_json::to_string(&value).unwrap(), json_string);
370    }
371
372    #[test]
373    fn serialize_and_deserialize_json_rpc_request() {
374        // Test with no parameters and null ID.
375        assert_json_serialization(
376            JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Null),
377            "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null}",
378        );
379
380        // Test with object parameters.
381        assert_json_serialization(
382            JsonRpcRequest::new(
383                "get_public_key".to_string(),
384                Some(JsonRpcStructuredValue::Object(serde_json::from_str("{\"key_type\":\"rsa\"}").unwrap())),
385                JsonRpcId::Null),
386            "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"params\":{\"key_type\":\"rsa\"},\"id\":null}"
387        );
388
389        // Test with array parameters.
390        assert_json_serialization(
391            JsonRpcRequest::new(
392                "fetch_values".to_string(),
393                Some(JsonRpcStructuredValue::Array(vec![
394                    serde_json::from_str("1").unwrap(),
395                    serde_json::from_str("\"2\"").unwrap(),
396                    serde_json::from_str("{\"3\":true}").unwrap(),
397                ])),
398                JsonRpcId::Null,
399            ),
400            "{\"jsonrpc\":\"2.0\",\"method\":\"fetch_values\",\"params\":[1,\"2\",{\"3\":true}],\"id\":null}",
401        );
402
403        // Test with number ID.
404        assert_json_serialization(
405            JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Number(1234)),
406            "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":1234}",
407        );
408
409        // Test with number ID.
410        assert_json_serialization(
411            JsonRpcRequest::new(
412                "get_foo_string".to_string(),
413                None,
414                JsonRpcId::String("foo".to_string()),
415            ),
416            "{\"jsonrpc\":\"2.0\",\"method\":\"get_foo_string\",\"id\":\"foo\"}",
417        );
418    }
419
420    #[test]
421    fn serialize_and_deserialize_json_rpc_response() {
422        // Test with result and null ID.
423        assert_json_serialization(
424            JsonRpcResponse::new(
425                JsonRpcResponseData::Success {
426                    result: serde_json::from_str("\"foo\"").unwrap(),
427                },
428                JsonRpcId::Null,
429            ),
430            "{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null}",
431        );
432
433        // Test with error (no data).
434        assert_json_serialization(
435            JsonRpcResponse::new(
436                JsonRpcResponseData::Error {
437                    error: JsonRpcError {
438                        code: JsonRpcErrorCode::InternalError,
439                        message: "foo".to_string(),
440                        data: None,
441                    },
442                },
443                JsonRpcId::Null,
444            ),
445            "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\"},\"id\":null}",
446        );
447
448        // Test with error (with data).
449        assert_json_serialization(
450            JsonRpcResponse::new(
451                JsonRpcResponseData::Error {
452                    error: JsonRpcError {
453                        code: JsonRpcErrorCode::InternalError,
454                        message: "foo".to_string(),
455                        data: Some(serde_json::from_str("\"bar\"").unwrap()),
456                    },
457                },
458                JsonRpcId::Null,
459            ),
460            "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\",\"data\":\"bar\"},\"id\":null}",
461        );
462    }
463
464    #[test]
465    fn serialize_deserialize_json_rpc_request_batch() {
466        // Test with single request.
467        assert_json_serialization(
468            SingleOrBatch::Single(JsonRpcRequest::new(
469                "get_public_key".to_string(),
470                None,
471                JsonRpcId::Null,
472            )),
473            "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null}",
474        );
475
476        // Test with batch request.
477        assert_json_serialization(
478            SingleOrBatch::Batch(vec![
479                JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Null),
480                JsonRpcRequest::new("get_foo_string".to_string(), None, JsonRpcId::String("foo".to_string())),
481            ]),
482            "[{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null},{\"jsonrpc\":\"2.0\",\"method\":\"get_foo_string\",\"id\":\"foo\"}]",
483        );
484    }
485
486    #[test]
487    fn serialize_deserialize_json_rpc_response_batch() {
488        // Test with single response.
489        assert_json_serialization(
490            SingleOrBatch::Single(JsonRpcResponse::new(
491                JsonRpcResponseData::Success {
492                    result: serde_json::from_str("\"foo\"").unwrap(),
493                },
494                JsonRpcId::Null,
495            )),
496            "{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null}",
497        );
498
499        // Test with batch response.
500        assert_json_serialization(
501            SingleOrBatch::Batch(vec![
502                JsonRpcResponse::new(
503                    JsonRpcResponseData::Success {
504                        result: serde_json::from_str("\"foo\"").unwrap(),
505                    },
506                    JsonRpcId::Null,
507                ),
508                JsonRpcResponse::new(
509                    JsonRpcResponseData::Success {
510                        result: serde_json::from_str("\"bar\"").unwrap(),
511                    },
512                    JsonRpcId::String("foo".to_string()),
513                ),
514            ]),
515            "[{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null},{\"jsonrpc\":\"2.0\",\"result\":\"bar\",\"id\":\"foo\"}]",
516        );
517    }
518
519    #[test]
520    fn serialize_and_deserialize_id() {
521        // Test with number ID.
522        assert_json_serialization(JsonRpcId::Number(1234), "1234");
523
524        // Test with string ID.
525        assert_json_serialization(JsonRpcId::String("foo".to_string()), "\"foo\"");
526
527        // Test with null ID.
528        assert_json_serialization(JsonRpcId::Null, "null");
529    }
530
531    #[test]
532    fn serialize_and_deserialize_error_code() {
533        // Test with ParseError.
534        assert_json_serialization(JsonRpcErrorCode::ParseError, "-32700");
535
536        // Test with InvalidRequest.
537        assert_json_serialization(JsonRpcErrorCode::InvalidRequest, "-32600");
538
539        // Test with MethodNotFound.
540        assert_json_serialization(JsonRpcErrorCode::MethodNotFound, "-32601");
541
542        // Test with InvalidParams.
543        assert_json_serialization(JsonRpcErrorCode::InvalidParams, "-32602");
544
545        // Test with InternalError.
546        assert_json_serialization(JsonRpcErrorCode::InternalError, "-32603");
547
548        // Test with Custom.
549        assert_json_serialization(JsonRpcErrorCode::Custom(1234), "1234");
550    }
551}