burn_core/record/
memory.rs

1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use alloc::vec::Vec;
3use burn_tensor::backend::Backend;
4use serde::{Serialize, de::DeserializeOwned};
5
6/// Recorder trait specialized to save and load data to and from bytes.
7///
8/// # Notes
9///
10/// This is especially useful in no_std environment where weights are stored directly in
11/// compiled binaries.
12pub trait BytesRecorder<
13    B: Backend,
14    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
15>: Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = L>
16{
17}
18
19/// In memory recorder using the [bincode format](bincode).
20#[derive(new, Debug, Default, Clone)]
21pub struct BinBytesRecorder<
22    S: PrecisionSettings,
23    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec<u8>,
24> {
25    _settings: core::marker::PhantomData<S>,
26    _loadargs: core::marker::PhantomData<L>,
27}
28
29impl<
30    S: PrecisionSettings,
31    B: Backend,
32    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
33> BytesRecorder<B, L> for BinBytesRecorder<S, L>
34{
35}
36
37impl<
38    S: PrecisionSettings,
39    B: Backend,
40    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
41> Recorder<B> for BinBytesRecorder<S, L>
42{
43    type Settings = S;
44    type RecordArgs = ();
45    type RecordOutput = Vec<u8>;
46    type LoadArgs = L;
47
48    fn save_item<I: Serialize>(
49        &self,
50        item: I,
51        _args: Self::RecordArgs,
52    ) -> Result<Self::RecordOutput, RecorderError> {
53        Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
54    }
55
56    fn load_item<I: DeserializeOwned>(
57        &self,
58        args: &mut Self::LoadArgs,
59    ) -> Result<I, RecorderError> {
60        let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat<I>, _>(
61            args.as_ref(),
62            bin_config(),
63        )
64        .unwrap()
65        .0;
66        Ok(state.0)
67    }
68}
69
70#[cfg(feature = "std")]
71/// In memory recorder using the [Named MessagePack](rmp_serde).
72#[derive(new, Debug, Default, Clone)]
73pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
74    _settings: core::marker::PhantomData<S>,
75}
76
77#[cfg(feature = "std")]
78impl<S: PrecisionSettings, B: Backend> BytesRecorder<B, Vec<u8>> for NamedMpkBytesRecorder<S> {}
79
80#[cfg(feature = "std")]
81impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
82    type Settings = S;
83    type RecordArgs = ();
84    type RecordOutput = Vec<u8>;
85    type LoadArgs = Vec<u8>;
86
87    fn save_item<I: Serialize>(
88        &self,
89        item: I,
90        _args: Self::RecordArgs,
91    ) -> Result<Self::RecordOutput, RecorderError> {
92        rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
93    }
94    fn load_item<I: DeserializeOwned>(
95        &self,
96        args: &mut Self::LoadArgs,
97    ) -> Result<I, RecorderError> {
98        rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string()))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::test_utils::SimpleLinear;
106    use crate::{
107        TestBackend, module::Module, record::FullPrecisionSettings, tensor::backend::Backend,
108    };
109
110    #[test]
111    fn test_can_save_and_load_bin_format() {
112        test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
113    }
114
115    #[cfg(feature = "std")]
116    #[test]
117    fn test_can_save_and_load_named_mpk_format() {
118        test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
119    }
120
121    fn test_can_save_and_load<Recorder>(recorder: Recorder)
122    where
123        Recorder: BytesRecorder<TestBackend, Vec<u8>>,
124    {
125        let device = Default::default();
126        let model1 = create_model::<TestBackend>(&device);
127        let model2 = create_model::<TestBackend>(&device);
128        let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
129        let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
130
131        let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
132        let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
133
134        assert_ne!(bytes1, bytes2);
135        assert_eq!(bytes1, bytes2_after);
136    }
137
138    pub fn create_model<B: Backend>(device: &B::Device) -> SimpleLinear<B> {
139        SimpleLinear::new(32, 32, device)
140    }
141}