1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use crate::{Decoder, Encoder};
use rkyv::de::deserializers::SharedDeserializeMap;
use rkyv::ser::serializers::AllocSerializer;
use rkyv::validation::validators::DefaultValidator;
use rkyv::{Archive, CheckBytes, Deserialize, Fallible, Serialize};
use std::error::Error;
use std::sync::Arc;

/// A codec that relies on `rkyv` to encode data in the msgpack format.
///
/// This is only available with the **`rkyv` feature** enabled.
pub struct RkyvCodec;

impl<T> Encoder<T> for RkyvCodec
where
    T: Serialize<AllocSerializer<1024>>,
{
    type Error = <AllocSerializer<1024> as Fallible>::Error;
    type Encoded = Vec<u8>;

    fn encode(val: &T) -> Result<Self::Encoded, Self::Error> {
        Ok(rkyv::to_bytes::<T, 1024>(val)?.to_vec())
    }
}

impl<T> Decoder<T> for RkyvCodec
where
    T: Archive,
    for<'a> T::Archived:
        'a + CheckBytes<DefaultValidator<'a>> + Deserialize<T, SharedDeserializeMap>,
{
    type Error = Arc<dyn Error>;
    type Encoded = [u8];

    fn decode(val: &Self::Encoded) -> Result<T, Self::Error> {
        rkyv::from_bytes::<T>(val).map_err(|e| Arc::new(e) as Arc<dyn Error>)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rkyv_codec() {
        #[derive(Clone, Debug, PartialEq, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
        #[archive(check_bytes)]
        struct Test {
            s: String,
            i: i32,
        }
        let t = Test {
            s: String::from("party time 🎉"),
            i: 42,
        };
        let enc = RkyvCodec::encode(&t).unwrap();
        let dec: Test = RkyvCodec::decode(&enc).unwrap();
        assert_eq!(dec, t);
    }
}