1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
use alloc::vec::Vec;
use serde::{de::DeserializeOwned, Serialize};

/// Recorder trait specialized to save and load data to and from bytes.
///
/// # Notes
///
/// This is especially useful in no_std environment where weights are stored directly in
/// compiled binaries.
pub trait BytesRecorder:
    Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
{
}

/// In memory recorder using the [bincode format](bincode).
#[derive(new, Debug, Default, Clone)]
pub struct BinBytesRecorder<S: PrecisionSettings> {
    _settings: core::marker::PhantomData<S>,
}

impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}

impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
    type Settings = S;
    type RecordArgs = ();
    type RecordOutput = Vec<u8>;
    type LoadArgs = Vec<u8>;

    fn save_item<I: Serialize>(
        &self,
        item: I,
        _args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError> {
        Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
    }
    fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError> {
        let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap();
        Ok(state)
    }
}

#[cfg(feature = "std")]
/// In memory recorder using the [Named MessagePack](rmp_serde).
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
    _settings: core::marker::PhantomData<S>,
}

#[cfg(feature = "std")]
impl<S: PrecisionSettings> BytesRecorder for NamedMpkBytesRecorder<S> {}

#[cfg(feature = "std")]
impl<S: PrecisionSettings> Recorder for NamedMpkBytesRecorder<S> {
    type Settings = S;
    type RecordArgs = ();
    type RecordOutput = Vec<u8>;
    type LoadArgs = Vec<u8>;

    fn save_item<I: Serialize>(
        &self,
        item: I,
        _args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError> {
        rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
    }
    fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError> {
        rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend};

    #[test]
    fn test_can_save_and_load_bin_format() {
        test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
    }

    #[cfg(feature = "std")]
    #[test]
    fn test_can_save_and_load_named_mpk_format() {
        test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
    }

    fn test_can_save_and_load<Recorder: BytesRecorder>(recorder: Recorder) {
        let model1 = create_model();
        let model2 = create_model();
        let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
        let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();

        let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap());
        let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();

        assert_ne!(bytes1, bytes2);
        assert_eq!(bytes1, bytes2_after);
    }

    pub fn create_model() -> nn::Linear<TestBackend> {
        nn::LinearConfig::new(32, 32).with_bias(true).init()
    }
}