1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use crate::{stream_helper::map_sender, uds_req_res::UdsResponse};
7use futures::StreamExt;
8use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
9
10pub trait JsonRpcServerTransport<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>>>:
11 futures::Stream<
12 Item = (
13 SingleOrBatchRequest,
14 futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponse>>,
15 ),
16>
17{
18}
19
20#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
21#[serde(untagged)]
22pub enum SingleOrBatch<T> {
23 Single(T),
24 Batch(Vec<T>),
25}
26
27impl<T> SingleOrBatch<T> {
28 pub fn map<TOut, MapFn: Fn(T) -> TOut>(self, map_fn: MapFn) -> SingleOrBatch<TOut> {
29 match self {
30 Self::Single(request) => SingleOrBatch::Single(map_fn(request)),
31 Self::Batch(requests) => {
32 SingleOrBatch::Batch(requests.into_iter().map(map_fn).collect())
33 }
34 }
35 }
36}
37
38impl<T> UdsResponse for SingleOrBatch<T>
39where
40 T: Serialize + DeserializeOwned + Send + 'static,
41{
42 fn request_parse_error_response() -> Self {
43 panic!()
45 }
46}
47
48pub struct JsonRpcServerStream<
49 SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static,
50> {
51 #[allow(clippy::type_complexity)]
52 stream: Pin<
53 Box<
54 dyn futures::Stream<
55 Item = (
56 SingleOrBatchRequest,
57 futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponseData>>,
58 ),
59 > + Send,
60 >,
61 >,
62}
63
64impl<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static>
65 futures::Stream for JsonRpcServerStream<SingleOrBatchRequest>
66{
67 type Item = (
68 SingleOrBatchRequest,
69 futures::channel::oneshot::Sender<SingleOrBatch<JsonRpcResponseData>>,
70 );
71
72 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73 self.stream
74 .poll_next_unpin(cx)
75 .map(|next_item_or| next_item_or)
76 }
77}
78
79impl<SingleOrBatchRequest: AsRef<SingleOrBatch<JsonRpcRequest>> + Send + Sync + 'static>
80 JsonRpcServerStream<SingleOrBatchRequest>
81{
82 pub fn start(
84 transport: impl JsonRpcServerTransport<SingleOrBatchRequest> + Send + 'static,
85 ) -> Self {
86 Self {
87 stream: Box::pin(transport.map(|(request, response_sender)| {
88 let single_request_id_or = match request.as_ref() {
89 SingleOrBatch::Single(request) => Some(request.id().clone()),
90 SingleOrBatch::Batch(_requests) => None,
91 };
92 let batch_request_ids_or: Option<Vec<JsonRpcId>> = match request.as_ref() {
93 SingleOrBatch::Single(_request) => None,
94 SingleOrBatch::Batch(requests) => Some(
95 requests
96 .iter()
97 .map(|request| request.id().clone())
98 .collect(),
99 ),
100 };
101
102 let response_sender = map_sender(response_sender, |response| match response {
103 SingleOrBatch::Single(response_data) => {
104 let Some(request_id) = single_request_id_or else {
105 panic!("Expected a single request, but got a batch of requests",)
106 };
107 SingleOrBatch::Single(JsonRpcResponse::new(response_data, request_id))
108 }
109 SingleOrBatch::Batch(responses) => {
110 let Some(request_ids) = batch_request_ids_or else {
111 panic!("Expected a batch of requests, but got a single request")
112 };
113 SingleOrBatch::Batch(
114 responses
115 .into_iter()
116 .enumerate()
117 .map(|(i, response_data)| {
118 let Some(request_id) = request_ids.get(i) else {
119 panic!("Expected a request at index {i}")
120 };
121 JsonRpcResponse::new(response_data, request_id.clone())
122 })
123 .collect(),
124 )
125 }
126 });
127
128 (request, response_sender)
129 })),
130 }
131 }
132}
133
134#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
146enum JsonRpcVersion {
147 #[serde(rename = "2.0")]
148 V2,
149}
150
151#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
152pub struct JsonRpcRequest {
153 jsonrpc: JsonRpcVersion,
154 method: String,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 params: Option<JsonRpcStructuredValue>,
157 id: JsonRpcId,
158}
159
160impl AsRef<Self> for JsonRpcRequest {
161 fn as_ref(&self) -> &Self {
162 self
163 }
164}
165
166impl JsonRpcRequest {
167 pub const fn new(
168 method: String,
169 params: Option<JsonRpcStructuredValue>,
170 id: JsonRpcId,
171 ) -> Self {
172 Self {
173 jsonrpc: JsonRpcVersion::V2,
174 method,
175 params,
176 id,
177 }
178 }
179
180 pub fn method(&self) -> &str {
181 &self.method
182 }
183
184 pub const fn params(&self) -> Option<&JsonRpcStructuredValue> {
185 self.params.as_ref()
186 }
187
188 pub const fn id(&self) -> &JsonRpcId {
189 &self.id
190 }
191}
192
193#[derive(PartialEq, Eq, Debug, Clone)]
195pub enum JsonRpcId {
196 Number(i32),
197 String(String),
198 Null,
199}
200
201impl JsonRpcId {
202 fn to_json_value(&self) -> serde_json::Value {
203 match self {
204 Self::Number(n) => serde_json::Value::Number((*n).into()),
205 Self::String(s) => serde_json::Value::String(s.clone()),
206 Self::Null => serde_json::Value::Null,
207 }
208 }
209}
210
211impl serde::Serialize for JsonRpcId {
212 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
213 where
214 S: Serializer,
215 {
216 self.to_json_value().serialize(serializer)
217 }
218}
219
220impl<'de> Deserialize<'de> for JsonRpcId {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 serde_json::Value::deserialize(deserializer).and_then(|value| {
226 if value.is_i64() {
227 Ok(Self::Number(
228 i32::try_from(value.as_i64().unwrap()).unwrap(),
229 ))
230 } else if value.is_string() {
231 Ok(Self::String(value.as_str().unwrap().to_string()))
232 } else if value.is_null() {
233 Ok(Self::Null)
234 } else {
235 Err(serde::de::Error::custom("Invalid JSON-RPC ID"))
236 }
237 })
238 }
239}
240
241#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
242#[serde(untagged)]
243pub enum JsonRpcStructuredValue {
244 Object(serde_json::Map<String, serde_json::Value>),
245 Array(Vec<serde_json::Value>),
246}
247
248impl JsonRpcStructuredValue {
249 pub fn into_value(self) -> serde_json::Value {
250 match self {
251 Self::Object(object) => serde_json::Value::Object(object),
252 Self::Array(array) => serde_json::Value::Array(array),
253 }
254 }
255}
256
257#[derive(Serialize, Deserialize, PartialEq, Debug)]
258pub struct JsonRpcResponse {
259 jsonrpc: JsonRpcVersion,
260 #[serde(flatten)]
261 data: JsonRpcResponseData,
262 id: JsonRpcId,
263}
264
265impl JsonRpcResponse {
266 pub const fn new(data: JsonRpcResponseData, id: JsonRpcId) -> Self {
267 Self {
268 jsonrpc: JsonRpcVersion::V2,
269 data,
270 id,
271 }
272 }
273
274 pub const fn data(&self) -> &JsonRpcResponseData {
275 &self.data
276 }
277}
278
279#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
280#[serde(untagged)]
281pub enum JsonRpcResponseData {
282 Success { result: serde_json::Value },
283 Error { error: JsonRpcError },
284}
285
286#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
288pub struct JsonRpcError {
289 code: JsonRpcErrorCode,
290 message: String,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 data: Option<serde_json::Value>,
293}
294
295impl JsonRpcError {
296 pub const fn new(
297 code: JsonRpcErrorCode,
298 message: String,
299 data: Option<serde_json::Value>,
300 ) -> Self {
301 Self {
302 code,
303 message,
304 data,
305 }
306 }
307
308 pub const fn code(&self) -> JsonRpcErrorCode {
309 self.code
310 }
311}
312
313#[derive(PartialEq, Eq, Debug, Copy, Clone)]
314pub enum JsonRpcErrorCode {
315 ParseError,
316 InvalidRequest,
317 MethodNotFound,
318 InvalidParams,
319 InternalError,
320 Custom(i32), }
322
323impl Serialize for JsonRpcErrorCode {
324 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
325 where
326 S: Serializer,
327 {
328 let code = match *self {
329 Self::ParseError => -32700,
330 Self::InvalidRequest => -32600,
331 Self::MethodNotFound => -32601,
332 Self::InvalidParams => -32602,
333 Self::InternalError => -32603,
334 Self::Custom(c) => c,
335 };
336 serializer.serialize_i32(code)
337 }
338}
339
340impl<'de> Deserialize<'de> for JsonRpcErrorCode {
341 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
342 where
343 D: serde::Deserializer<'de>,
344 {
345 let code = i32::deserialize(deserializer)?;
346 match code {
347 -32700 => Ok(Self::ParseError),
348 -32600 => Ok(Self::InvalidRequest),
349 -32601 => Ok(Self::MethodNotFound),
350 -32602 => Ok(Self::InvalidParams),
351 -32603 => Ok(Self::InternalError),
352 _ => Ok(Self::Custom(code)),
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 fn assert_json_serialization<
362 'a,
363 T: Serialize + Deserialize<'a> + PartialEq + std::fmt::Debug,
364 >(
365 value: T,
366 json_string: &'a str,
367 ) {
368 assert_eq!(serde_json::from_str::<T>(json_string).unwrap(), value);
369 assert_eq!(serde_json::to_string(&value).unwrap(), json_string);
370 }
371
372 #[test]
373 fn serialize_and_deserialize_json_rpc_request() {
374 assert_json_serialization(
376 JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Null),
377 "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null}",
378 );
379
380 assert_json_serialization(
382 JsonRpcRequest::new(
383 "get_public_key".to_string(),
384 Some(JsonRpcStructuredValue::Object(serde_json::from_str("{\"key_type\":\"rsa\"}").unwrap())),
385 JsonRpcId::Null),
386 "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"params\":{\"key_type\":\"rsa\"},\"id\":null}"
387 );
388
389 assert_json_serialization(
391 JsonRpcRequest::new(
392 "fetch_values".to_string(),
393 Some(JsonRpcStructuredValue::Array(vec![
394 serde_json::from_str("1").unwrap(),
395 serde_json::from_str("\"2\"").unwrap(),
396 serde_json::from_str("{\"3\":true}").unwrap(),
397 ])),
398 JsonRpcId::Null,
399 ),
400 "{\"jsonrpc\":\"2.0\",\"method\":\"fetch_values\",\"params\":[1,\"2\",{\"3\":true}],\"id\":null}",
401 );
402
403 assert_json_serialization(
405 JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Number(1234)),
406 "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":1234}",
407 );
408
409 assert_json_serialization(
411 JsonRpcRequest::new(
412 "get_foo_string".to_string(),
413 None,
414 JsonRpcId::String("foo".to_string()),
415 ),
416 "{\"jsonrpc\":\"2.0\",\"method\":\"get_foo_string\",\"id\":\"foo\"}",
417 );
418 }
419
420 #[test]
421 fn serialize_and_deserialize_json_rpc_response() {
422 assert_json_serialization(
424 JsonRpcResponse::new(
425 JsonRpcResponseData::Success {
426 result: serde_json::from_str("\"foo\"").unwrap(),
427 },
428 JsonRpcId::Null,
429 ),
430 "{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null}",
431 );
432
433 assert_json_serialization(
435 JsonRpcResponse::new(
436 JsonRpcResponseData::Error {
437 error: JsonRpcError {
438 code: JsonRpcErrorCode::InternalError,
439 message: "foo".to_string(),
440 data: None,
441 },
442 },
443 JsonRpcId::Null,
444 ),
445 "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\"},\"id\":null}",
446 );
447
448 assert_json_serialization(
450 JsonRpcResponse::new(
451 JsonRpcResponseData::Error {
452 error: JsonRpcError {
453 code: JsonRpcErrorCode::InternalError,
454 message: "foo".to_string(),
455 data: Some(serde_json::from_str("\"bar\"").unwrap()),
456 },
457 },
458 JsonRpcId::Null,
459 ),
460 "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32603,\"message\":\"foo\",\"data\":\"bar\"},\"id\":null}",
461 );
462 }
463
464 #[test]
465 fn serialize_deserialize_json_rpc_request_batch() {
466 assert_json_serialization(
468 SingleOrBatch::Single(JsonRpcRequest::new(
469 "get_public_key".to_string(),
470 None,
471 JsonRpcId::Null,
472 )),
473 "{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null}",
474 );
475
476 assert_json_serialization(
478 SingleOrBatch::Batch(vec![
479 JsonRpcRequest::new("get_public_key".to_string(), None, JsonRpcId::Null),
480 JsonRpcRequest::new("get_foo_string".to_string(), None, JsonRpcId::String("foo".to_string())),
481 ]),
482 "[{\"jsonrpc\":\"2.0\",\"method\":\"get_public_key\",\"id\":null},{\"jsonrpc\":\"2.0\",\"method\":\"get_foo_string\",\"id\":\"foo\"}]",
483 );
484 }
485
486 #[test]
487 fn serialize_deserialize_json_rpc_response_batch() {
488 assert_json_serialization(
490 SingleOrBatch::Single(JsonRpcResponse::new(
491 JsonRpcResponseData::Success {
492 result: serde_json::from_str("\"foo\"").unwrap(),
493 },
494 JsonRpcId::Null,
495 )),
496 "{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null}",
497 );
498
499 assert_json_serialization(
501 SingleOrBatch::Batch(vec![
502 JsonRpcResponse::new(
503 JsonRpcResponseData::Success {
504 result: serde_json::from_str("\"foo\"").unwrap(),
505 },
506 JsonRpcId::Null,
507 ),
508 JsonRpcResponse::new(
509 JsonRpcResponseData::Success {
510 result: serde_json::from_str("\"bar\"").unwrap(),
511 },
512 JsonRpcId::String("foo".to_string()),
513 ),
514 ]),
515 "[{\"jsonrpc\":\"2.0\",\"result\":\"foo\",\"id\":null},{\"jsonrpc\":\"2.0\",\"result\":\"bar\",\"id\":\"foo\"}]",
516 );
517 }
518
519 #[test]
520 fn serialize_and_deserialize_id() {
521 assert_json_serialization(JsonRpcId::Number(1234), "1234");
523
524 assert_json_serialization(JsonRpcId::String("foo".to_string()), "\"foo\"");
526
527 assert_json_serialization(JsonRpcId::Null, "null");
529 }
530
531 #[test]
532 fn serialize_and_deserialize_error_code() {
533 assert_json_serialization(JsonRpcErrorCode::ParseError, "-32700");
535
536 assert_json_serialization(JsonRpcErrorCode::InvalidRequest, "-32600");
538
539 assert_json_serialization(JsonRpcErrorCode::MethodNotFound, "-32601");
541
542 assert_json_serialization(JsonRpcErrorCode::InvalidParams, "-32602");
544
545 assert_json_serialization(JsonRpcErrorCode::InternalError, "-32603");
547
548 assert_json_serialization(JsonRpcErrorCode::Custom(1234), "1234");
550 }
551}