burn_core/record/
memory.rs1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use alloc::vec::Vec;
3use burn_tensor::backend::Backend;
4use serde::{Serialize, de::DeserializeOwned};
5
6pub trait BytesRecorder<
13 B: Backend,
14 L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
15>: Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = L>
16{
17}
18
19#[derive(new, Debug, Default, Clone)]
21pub struct BinBytesRecorder<
22 S: PrecisionSettings,
23 L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec<u8>,
24> {
25 _settings: core::marker::PhantomData<S>,
26 _loadargs: core::marker::PhantomData<L>,
27}
28
29impl<
30 S: PrecisionSettings,
31 B: Backend,
32 L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
33> BytesRecorder<B, L> for BinBytesRecorder<S, L>
34{
35}
36
37impl<
38 S: PrecisionSettings,
39 B: Backend,
40 L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
41> Recorder<B> for BinBytesRecorder<S, L>
42{
43 type Settings = S;
44 type RecordArgs = ();
45 type RecordOutput = Vec<u8>;
46 type LoadArgs = L;
47
48 fn save_item<I: Serialize>(
49 &self,
50 item: I,
51 _args: Self::RecordArgs,
52 ) -> Result<Self::RecordOutput, RecorderError> {
53 Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
54 }
55
56 fn load_item<I: DeserializeOwned>(
57 &self,
58 args: &mut Self::LoadArgs,
59 ) -> Result<I, RecorderError> {
60 let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat<I>, _>(
61 args.as_ref(),
62 bin_config(),
63 )
64 .unwrap()
65 .0;
66 Ok(state.0)
67 }
68}
69
70#[cfg(feature = "std")]
71#[derive(new, Debug, Default, Clone)]
73pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
74 _settings: core::marker::PhantomData<S>,
75}
76
77#[cfg(feature = "std")]
78impl<S: PrecisionSettings, B: Backend> BytesRecorder<B, Vec<u8>> for NamedMpkBytesRecorder<S> {}
79
80#[cfg(feature = "std")]
81impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
82 type Settings = S;
83 type RecordArgs = ();
84 type RecordOutput = Vec<u8>;
85 type LoadArgs = Vec<u8>;
86
87 fn save_item<I: Serialize>(
88 &self,
89 item: I,
90 _args: Self::RecordArgs,
91 ) -> Result<Self::RecordOutput, RecorderError> {
92 rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
93 }
94 fn load_item<I: DeserializeOwned>(
95 &self,
96 args: &mut Self::LoadArgs,
97 ) -> Result<I, RecorderError> {
98 rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string()))
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::{
106 TestBackend, module::Module, nn, record::FullPrecisionSettings, tensor::backend::Backend,
107 };
108
109 #[test]
110 fn test_can_save_and_load_bin_format() {
111 test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
112 }
113
114 #[cfg(feature = "std")]
115 #[test]
116 fn test_can_save_and_load_named_mpk_format() {
117 test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
118 }
119
120 fn test_can_save_and_load<Recorder>(recorder: Recorder)
121 where
122 Recorder: BytesRecorder<TestBackend, Vec<u8>>,
123 {
124 let device = Default::default();
125 let model1 = create_model::<TestBackend>(&device);
126 let model2 = create_model::<TestBackend>(&device);
127 let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
128 let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
129
130 let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
131 let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
132
133 assert_ne!(bytes1, bytes2);
134 assert_eq!(bytes1, bytes2_after);
135 }
136
137 pub fn create_model<B: Backend>(device: &B::Device) -> nn::Linear<B> {
138 nn::LinearConfig::new(32, 32).with_bias(true).init(device)
139 }
140}