rpc_it/
codecs.rs

1pub mod framing {
2    #[cfg(feature = "delim-framing")]
3    pub mod delim {
4        use memchr::memmem;
5
6        use crate::codec::{self, Framing};
7
8        /// Splits a buffer into frames by byte sequence delimeter
9        #[derive(Debug)]
10        struct DelimeterFraming {
11            finder: memmem::Finder<'static>,
12            cursor: usize,
13        }
14
15        impl Framing for DelimeterFraming {
16            fn try_framing(
17                &mut self,
18                buffer: &[u8],
19            ) -> Result<Option<codec::FramingAdvanceResult>, codec::FramingError> {
20                let buf = &buffer[self.cursor..];
21                if let Some(pos) = self.finder.find(buf) {
22                    let valid_data_end = self.cursor + pos;
23                    let next_frame_start = valid_data_end + self.finder.needle().len();
24                    self.cursor = next_frame_start;
25                    Ok(Some(codec::FramingAdvanceResult { valid_data_end, next_frame_start }))
26                } else {
27                    // Remain some margin to not miss the delimeter
28                    self.cursor += buffer.len().saturating_sub(self.finder.needle().len());
29                    Ok(None)
30                }
31            }
32
33            fn advance(&mut self) {
34                self.cursor = 0;
35            }
36
37            fn next_buffer_size(&self) -> Option<std::num::NonZeroUsize> {
38                None
39            }
40        }
41
42        pub fn by_delim(delim: &[u8]) -> impl Framing {
43            DelimeterFraming { cursor: 0, finder: memmem::Finder::new(delim).into_owned() }
44        }
45    }
46}
47
48#[cfg(feature = "msgpack-rpc")]
49pub mod msgpack_rpc {
50    use std::{borrow::Cow, num::NonZeroU64};
51
52    use ::bytes::Buf;
53    use bytes::{BufMut, BytesMut};
54    use derive_setters::Setters;
55    use serde::Deserialize;
56
57    use crate::codec::{self, DecodeError::InvalidFormat, EncodeError};
58
59    #[derive(Setters, Debug, Default)]
60    #[setters(prefix = "with_")]
61    pub struct Codec {
62        /// If specified, the codec will wrap the provided parameter to array automatically.
63        /// Otherwise, the caller should wrap the parameter within array manually.
64        ///
65        /// > [Reference](https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md#params)
66        auto_wrapping: bool,
67
68        /// On deseiralization, if the parameter is an array with single element, the codec will
69        /// unwrap the array and use the element as the parameter. Otherwise, the parameter will be
70        /// deserialized as an array.
71        unwrap_mono_param: bool,
72    }
73
74    impl codec::Codec for Codec {
75        fn encode_notify(
76            &self,
77            method: &str,
78            params: &dyn erased_serde::Serialize,
79            write: &mut BytesMut,
80        ) -> Result<(), EncodeError> {
81            use rmp::encode::*;
82
83            let mut write = write.writer();
84            write_array_len(&mut write, 3).unwrap();
85            write_uint(&mut write, 2).unwrap();
86            write_str(&mut write, method).unwrap();
87
88            if self.auto_wrapping {
89                write_array_len(&mut write, 1).unwrap();
90            }
91
92            params
93                .erased_serialize(&mut <dyn erased_serde::Serializer>::erase(
94                    &mut rmp_serde::Serializer::new(write).with_struct_map(),
95                ))
96                .unwrap();
97
98            Ok(())
99        }
100
101        fn encode_request(
102            &self,
103            method: &str,
104            req_id_hint: NonZeroU64,
105            params: &dyn erased_serde::Serialize,
106            write: &mut BytesMut,
107        ) -> Result<std::num::NonZeroU64, EncodeError> {
108            use rmp::encode::*;
109            let mut write = write.writer();
110            let write = &mut write;
111
112            write_array_len(write, 4).unwrap();
113            write_uint(write, 0).unwrap();
114
115            let as_32bit = req_id_hint.get() as u32;
116            write_u32(write, as_32bit).unwrap();
117            write_str(write, method).unwrap();
118
119            if self.auto_wrapping {
120                write_array_len(write, 1).unwrap();
121            }
122
123            params
124                .erased_serialize(&mut <dyn erased_serde::Serializer>::erase(
125                    &mut rmp_serde::Serializer::new(write).with_struct_map(),
126                ))
127                .unwrap();
128
129            Ok((as_32bit as u64).try_into().unwrap())
130        }
131
132        fn encode_response(
133            &self,
134            req_id: codec::ReqIdRef,
135            encode_as_error: bool,
136            response: &dyn erased_serde::Serialize,
137            write: &mut BytesMut,
138        ) -> Result<(), EncodeError> {
139            use rmp::encode::*;
140            let mut write = write.writer();
141            let write = &mut write;
142            write_array_len(write, 4).unwrap();
143
144            write_uint(write, 1).unwrap();
145            write_uint(
146                write,
147                *req_id
148                    .as_u64()
149                    .ok_or(EncodeError::UnsupportedDataFormat("unsupported non-integer".into()))?
150                    as u32 as _,
151            )
152            .unwrap();
153
154            let serialize = |v: &mut dyn std::io::Write| {
155                response
156                    .erased_serialize(&mut <dyn erased_serde::Serializer>::erase(
157                        &mut rmp_serde::Serializer::new(v).with_struct_map(),
158                    ))
159                    .unwrap();
160            };
161
162            if encode_as_error {
163                serialize(write);
164                write_nil(write).unwrap();
165            } else {
166                write_nil(write).unwrap();
167                serialize(write);
168            }
169
170            Ok(())
171        }
172
173        fn decode_inbound(
174            &self,
175            data: &[u8],
176        ) -> Result<(codec::InboundFrameType, std::ops::Range<usize>), codec::DecodeError> {
177            use rmp::decode::*;
178            let mut rd = data;
179
180            fn efmt<T>(e: impl Into<Cow<'static, str>>) -> impl FnOnce(T) -> codec::DecodeError {
181                |_| InvalidFormat(e.into())
182            }
183
184            let arr_len = read_array_len(&mut rd).map_err(efmt("Non-msgpack array format"))?;
185            if arr_len < 2 || arr_len > 4 {
186                return Err(InvalidFormat(format!("Invalid array length {arr_len}").into()));
187            }
188
189            let offset_of = |s: &[u8]| s.as_ptr() as usize - data.as_ptr() as usize;
190            let skip_single_value = |rd: &mut &[u8]| {
191                serde::de::IgnoredAny::deserialize(&mut rmp_serde::Deserializer::new(rd))
192                    .map_err(efmt("parameter read failed"))
193            };
194
195            let msg_type: u32 = read_int(&mut rd).map_err(efmt("Non-msgpack integer format"))?;
196            match (arr_len, msg_type) {
197                // Request
198                (4, 0) | (3, 2) => {
199                    let mut req_id = None;
200                    if msg_type == 0 {
201                        req_id = Some(read_int::<u32, _>(&mut rd).map_err(efmt("rd: not req_id"))?);
202                    };
203
204                    let method_len = read_str_len(&mut rd).map_err(efmt("rd: not method"))?;
205                    let method_offset = offset_of(rd);
206                    rd.advance(method_len as _);
207
208                    // Now we're reading the payload ..
209                    if self.unwrap_mono_param {
210                        if 1 == read_array_len(&mut rd.clone())
211                            .map_err(efmt("rd: non-array param"))?
212                        {
213                            // Advance the cursor by array marker, to unwrap payload.
214                            read_array_len(&mut rd).ok();
215                        }
216                    }
217
218                    let (obj_begin, _, obj_end) =
219                        (offset_of(rd), skip_single_value(&mut rd)?, offset_of(rd));
220
221                    Ok((
222                        if let Some(req_id) = req_id {
223                            codec::InboundFrameType::Request {
224                                method: method_offset..method_offset + method_len as usize,
225                                req_id: codec::ReqId::U64(req_id as _),
226                            }
227                        } else {
228                            codec::InboundFrameType::Notify {
229                                method: method_offset..method_offset + method_len as usize,
230                            }
231                        },
232                        obj_begin..obj_end,
233                    ))
234                }
235
236                // Response
237                (4, 1) => {
238                    let req_id = read_int::<u32, _>(&mut rd).map_err(efmt("req_id error"))?;
239                    let is_error = if read_nil(&mut (rd.clone())).is_ok() {
240                        // Error was nil, so it's a success response.
241                        rd = &rd[1..];
242                        false
243                    } else {
244                        true
245                    };
246
247                    let (obj_begin, _, obj_end) =
248                        (offset_of(rd), skip_single_value(&mut rd)?, offset_of(rd));
249
250                    Ok((
251                        codec::InboundFrameType::Response {
252                            req_id: codec::ReqId::U64(req_id as _),
253                            req_id_hash: req_id as _,
254                            is_error,
255                        },
256                        obj_begin..obj_end,
257                    ))
258                }
259
260                (al, msg) => {
261                    return Err(InvalidFormat(
262                        format!("Invalid message type {msg}, with {al} args").into(),
263                    ));
264                }
265            }
266        }
267
268        fn decode_payload<'a>(
269            &self,
270            payload: &'a [u8],
271            decode: &mut dyn FnMut(
272                &mut dyn erased_serde::Deserializer<'a>,
273            ) -> Result<(), erased_serde::Error>,
274        ) -> Result<(), codec::DecodeError> {
275            decode(&mut <dyn erased_serde::Deserializer>::erase(
276                &mut rmp_serde::Deserializer::new(payload),
277            ))?;
278            Ok(())
279        }
280    }
281}
282
283#[cfg(feature = "jsonrpc")]
284pub mod jsonrpc {
285    use std::num::NonZeroU64;
286
287    use bytes::{BufMut, BytesMut};
288    use serde_json::value::RawValue;
289
290    use crate::codec::{self, InboundFrameType, ReqId, ReqIdRef};
291
292    #[derive(Debug, Default)]
293    pub struct Codec {}
294
295    #[derive(serde::Serialize, serde::Deserialize)]
296    #[serde(untagged)]
297    enum MsgId<'a> {
298        Int(u64),
299        Str(&'a str),
300        Null,
301    }
302
303    #[derive(serde::Serialize)]
304    struct SerMsg<'a, T: serde::Serialize + ?Sized> {
305        jsonrpc: JsonRpcTag,
306
307        #[serde(skip_serializing_if = "Option::is_none")]
308        method: Option<&'a str>,
309
310        #[serde(rename = "id", skip_serializing_if = "Option::is_none")]
311        id: Option<MsgId<'a>>,
312
313        #[serde(skip_serializing_if = "Option::is_none")]
314        params: Option<&'a T>,
315
316        #[serde(skip_serializing_if = "Option::is_none")]
317        error: Option<SerErrObj<'a, T>>,
318
319        #[serde(skip_serializing_if = "Option::is_none")]
320        result: Option<&'a T>,
321    }
322
323    #[derive(serde::Serialize)]
324    struct SerErrObj<'a, T: serde::Serialize + ?Sized> {
325        code: i64,
326        message: &'a str,
327        data: &'a T,
328    }
329
330    impl<'a, T: serde::Serialize + ?Sized> Default for SerMsg<'a, T> {
331        fn default() -> Self {
332            Self {
333                jsonrpc: Default::default(),
334                method: Default::default(),
335                id: Default::default(),
336                params: Default::default(),
337                error: Default::default(),
338                result: Default::default(),
339            }
340        }
341    }
342
343    #[derive(Default)]
344    struct JsonRpcTag;
345
346    impl serde::Serialize for JsonRpcTag {
347        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
348        where
349            S: serde::Serializer,
350        {
351            serializer.serialize_str("2.0")
352        }
353    }
354
355    impl<'de> serde::Deserialize<'de> for JsonRpcTag {
356        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
357        where
358            D: serde::Deserializer<'de>,
359        {
360            <&str>::deserialize(deserializer).and_then(|x| {
361                if x == "2.0" {
362                    Ok(JsonRpcTag)
363                } else {
364                    Err(serde::de::Error::custom("Invalid JSON-RPC version"))
365                }
366            })
367        }
368    }
369
370    #[derive(Default, serde::Deserialize)]
371    struct DeMsgFrame<'a> {
372        #[serde(rename = "jsonrpc")]
373        _jsonrpc: JsonRpcTag,
374
375        #[serde(borrow, default)]
376        method: Option<&'a str>,
377
378        #[serde(borrow, default)]
379        id: Option<MsgId<'a>>,
380
381        #[serde(borrow, default)]
382        params: Option<&'a RawValue>,
383
384        #[serde(borrow, default)]
385        error: Option<&'a RawValue>,
386
387        #[serde(borrow, default)]
388        result: Option<&'a RawValue>,
389    }
390
391    impl codec::Codec for Codec {
392        fn encode_notify(
393            &self,
394            method: &str,
395            params: &dyn erased_serde::Serialize,
396            write: &mut BytesMut,
397        ) -> Result<(), codec::EncodeError> {
398            serde_json::to_writer(
399                write.writer(),
400                &SerMsg { method: Some(method), params: Some(params), ..Default::default() },
401            )
402            .map_err(|e| codec::EncodeError::SerializeError(e.into()))?;
403            Ok(())
404        }
405
406        fn encode_request(
407            &self,
408            method: &str,
409            req_id_hint: NonZeroU64,
410            params: &dyn erased_serde::Serialize,
411            write: &mut BytesMut,
412        ) -> Result<std::num::NonZeroU64, codec::EncodeError> {
413            // Make sure the request ID rotate within 53 bits. (JS's max safe integer)
414            let req_id = req_id_hint.get() & ((1 << 53) - 1);
415
416            serde_json::to_writer(
417                write.writer(),
418                &SerMsg {
419                    method: Some(method),
420                    id: Some(MsgId::Int(req_id)),
421                    params: Some(params),
422                    ..Default::default()
423                },
424            )
425            .map_err(|e| codec::EncodeError::SerializeError(e.into()))?;
426            Ok(req_id.try_into().unwrap())
427        }
428
429        fn encode_response(
430            &self,
431            req_id: ReqIdRef,
432            encode_as_error: bool,
433            response: &dyn erased_serde::Serialize,
434            write: &mut BytesMut,
435        ) -> Result<(), codec::EncodeError> {
436            serde_json::to_writer(
437                write.writer(),
438                &SerMsg {
439                    id: Some(match req_id {
440                        ReqIdRef::U64(value) => MsgId::Int(value),
441                        ReqIdRef::Bytes(value) => {
442                            std::str::from_utf8(value).map_or(MsgId::Null, MsgId::Str)
443                        }
444                    }),
445                    error: {
446                        (encode_as_error == true).then_some(SerErrObj {
447                            code: -1,
448                            message: "Error from 'rpc_it::codecs::jsonrpc'",
449                            data: response,
450                        })
451                    },
452                    result: (encode_as_error == false).then_some(response),
453                    ..Default::default()
454                },
455            )
456            .map_err(|e| codec::EncodeError::SerializeError(e.into()))?;
457            Ok(())
458        }
459
460        fn encode_response_predefined(
461            &self,
462            req_id: ReqIdRef,
463            response: &codec::PredefinedResponseError,
464            write: &mut BytesMut,
465        ) -> Result<(), codec::EncodeError> {
466            // XXX: New type for predefined response error?
467            self.encode_response(req_id, true, response, write)
468        }
469
470        fn try_decode_predef_error<'a>(
471            &self,
472            payload: &'a [u8],
473        ) -> Option<codec::PredefinedResponseError> {
474            // TODO: Support predefined error decoding
475            let _ = payload;
476            None
477        }
478
479        fn decode_inbound(
480            &self,
481            data: &[u8],
482        ) -> Result<(InboundFrameType, std::ops::Range<usize>), codec::DecodeError> {
483            let f = serde_json::from_slice::<DeMsgFrame>(data)
484                .map_err(|e| codec::DecodeError::Other(e.into()))?;
485
486            let data_range_of = |x: &[u8]| {
487                let offset = x.as_ptr() as usize - data.as_ptr() as usize;
488                offset..offset + x.len()
489            };
490
491            let method_range = f.method.map(|x| data_range_of(x.as_bytes()));
492            let req_id = match &f.id {
493                Some(MsgId::Int(x)) => Some(ReqId::U64(*x)),
494                Some(MsgId::Str(x)) => Some(ReqId::Bytes(data_range_of(x.as_bytes()))),
495                Some(MsgId::Null) => {
496                    return Err(codec::DecodeError::InvalidFormat(
497                        "Null request ID returned".into(),
498                    ))
499                }
500                None => None,
501            };
502
503            Ok(match (&f.id, f.method, f.params, f.error, f.result) {
504                // ID, Method, (Params) => Request
505                (Some(_id), Some(_), payload, None, None) => (
506                    InboundFrameType::Request {
507                        method: method_range.unwrap(),
508                        req_id: req_id.unwrap(),
509                    },
510                    payload.map(|value| data_range_of(value.get().as_bytes())).unwrap_or(0..0),
511                ),
512
513                // Method, (Params) => Notify
514                (None, Some(_), payload, None, None) => (
515                    InboundFrameType::Notify { method: method_range.unwrap() },
516                    payload.map(|value| data_range_of(value.get().as_bytes())).unwrap_or(0..0),
517                ),
518
519                // ID, (Error | Result) => Response
520                (Some(_id), None, None, e, r) if e.is_some() ^ r.is_some() => {
521                    let MsgId::Int(req_id) = f.id.unwrap() else {
522                        return Err(codec::DecodeError::InvalidFormat(
523                            "We don't use string request ID types.".into(),
524                        ));
525                    };
526
527                    (
528                        InboundFrameType::Response {
529                            req_id: ReqId::U64(req_id),
530                            req_id_hash: req_id,
531                            is_error: e.is_some(),
532                        },
533                        data_range_of(e.or(r).unwrap().get().as_bytes()),
534                    )
535                }
536
537                _ => {
538                    return Err(codec::DecodeError::InvalidFormat(
539                        format!(
540                            "Invalid json-rpc with fields: \
541							 [id?={},method?={},params?={},error?={},result?={}]",
542                            f.id.is_some(),
543                            f.method.is_some(),
544                            f.params.is_some(),
545                            f.error.is_some(),
546                            f.result.is_some()
547                        )
548                        .into(),
549                    ))
550                }
551            })
552        }
553
554        fn decode_payload<'a>(
555            &self,
556            payload: &'a [u8],
557            decode: &mut dyn FnMut(
558                &mut dyn erased_serde::Deserializer<'a>,
559            ) -> Result<(), erased_serde::Error>,
560        ) -> Result<(), codec::DecodeError> {
561            decode(&mut <dyn erased_serde::Deserializer>::erase(
562                &mut serde_json::Deserializer::from_slice(payload),
563            ))?;
564            Ok(())
565        }
566    }
567}