1use core::any::type_name;
2use core::marker::PhantomData;
3
4use alloc::format;
5use alloc::string::{String, ToString};
6use burn_tensor::backend::Backend;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8
9use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
10
11#[cfg(feature = "std")]
12use super::{
13 BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
14 PrettyJsonFileRecorder,
15};
16
17pub trait Recorder<B: Backend>:
19 Send + Sync + core::default::Default + core::fmt::Debug + Clone
20{
21 type Settings: PrecisionSettings;
23
24 type RecordArgs: Clone;
26
27 type RecordOutput;
29
30 type LoadArgs;
32
33 fn record<R>(
44 &self,
45 record: R,
46 args: Self::RecordArgs,
47 ) -> Result<Self::RecordOutput, RecorderError>
48 where
49 R: Record<B>,
50 {
51 let item = record.into_item::<Self::Settings>();
52 let item = BurnRecord::new::<Self>(item);
53
54 self.save_item(item, args)
55 }
56
57 fn load<R>(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
59 where
60 R: Record<B>,
61 {
62 let item: BurnRecord<R::Item<Self::Settings>, B> =
63 self.load_item(&mut args).map_err(|err| {
64 if let Ok(record) = self.load_item::<BurnRecordNoItem>(&mut args) {
65 let mut message = "Unable to load record.".to_string();
66 let metadata = recorder_metadata::<Self, B>();
67 if metadata.float != record.metadata.float {
68 message += format!(
69 "\nMetadata has a different float type: Actual {:?}, Expected {:?}",
70 record.metadata.float, metadata.float
71 )
72 .as_str();
73 }
74 if metadata.int != record.metadata.int {
75 message += format!(
76 "\nMetadata has a different int type: Actual {:?}, Expected {:?}",
77 record.metadata.int, metadata.int
78 )
79 .as_str();
80 }
81 if metadata.format != record.metadata.format {
82 message += format!(
83 "\nMetadata has a different format: Actual {:?}, Expected {:?}",
84 record.metadata.format, metadata.format
85 )
86 .as_str();
87 }
88 if metadata.version != record.metadata.version {
89 message += format!(
90 "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}",
91 record.metadata.version, metadata.version
92 )
93 .as_str();
94 }
95
96 message += format!("\nError: {err:?}").as_str();
97
98 return RecorderError::Unknown(message);
99 }
100
101 err
102 })?;
103
104 Ok(R::from_item(item.item, device))
105 }
106
107 fn save_item<I: Serialize>(
120 &self,
121 item: I,
122 args: Self::RecordArgs,
123 ) -> Result<Self::RecordOutput, RecorderError>;
124
125 fn load_item<I>(&self, args: &mut Self::LoadArgs) -> Result<I, RecorderError>
137 where
138 I: DeserializeOwned;
139}
140
141fn recorder_metadata<R, B>() -> BurnMetadata
142where
143 R: Recorder<B>,
144 B: Backend,
145{
146 BurnMetadata::new(
147 type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
148 type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
149 type_name::<R>().to_string(),
150 env!("CARGO_PKG_VERSION").to_string(),
151 format!("{:?}", R::Settings::default()),
152 )
153}
154
155#[derive(Debug)]
157pub enum RecorderError {
158 FileNotFound(String),
160
161 DeserializeError(String),
163
164 Unknown(String),
166}
167
168impl core::fmt::Display for RecorderError {
169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170 f.write_str(format!("{self:?}").as_str())
171 }
172}
173
174impl core::error::Error for RecorderError {}
175
176pub(crate) fn bin_config() -> bincode::config::Configuration {
177 bincode::config::standard()
178}
179
180#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
182pub struct BurnMetadata {
183 pub float: String,
185
186 pub int: String,
188
189 pub format: String,
191
192 pub version: String,
194
195 pub settings: String,
197}
198
199#[derive(Serialize, Deserialize, Debug)]
201pub struct BurnRecord<I, B: Backend> {
202 pub metadata: BurnMetadata,
204
205 pub item: I,
207
208 _b: PhantomData<B>,
209}
210
211impl<I, B: Backend> BurnRecord<I, B> {
212 pub fn new<R: Recorder<B>>(item: I) -> Self {
222 let metadata = recorder_metadata::<R, B>();
223
224 Self {
225 metadata,
226 item,
227 _b: PhantomData,
228 }
229 }
230}
231
232#[derive(new, Debug, Serialize, Deserialize)]
234pub struct BurnRecordNoItem {
235 pub metadata: BurnMetadata,
237}
238
239#[cfg(feature = "std")]
243pub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;
244
245#[cfg(feature = "std")]
251pub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;
252
253#[cfg(feature = "std")]
260pub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;
261
262#[cfg(feature = "std")]
264pub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;
265
266pub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings, &'static [u8]>;
268
269#[cfg(feature = "std")]
274pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
275
276#[cfg(all(test, feature = "std"))]
277mod tests {
278 static FILE_PATH: &str = "/tmp/burn_test_record";
279
280 use crate::TestBackend;
281
282 use super::*;
283 use burn_tensor::{Device, ElementConversion};
284
285 #[test]
286 #[should_panic]
287 fn err_when_invalid_item() {
288 #[derive(new, Serialize, Deserialize)]
289 struct Item<S: PrecisionSettings> {
290 value: S::FloatElem,
291 }
292
293 impl<D, B> Record<B> for Item<D>
294 where
295 D: PrecisionSettings,
296 B: Backend,
297 {
298 type Item<S: PrecisionSettings> = Item<S>;
299
300 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
301 Item {
302 value: self.value.elem(),
303 }
304 }
305
306 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
307 Item {
308 value: item.value.elem(),
309 }
310 }
311 }
312
313 let item = Item::<FullPrecisionSettings>::new(16.elem());
314 let device: Device<TestBackend> = Default::default();
315
316 let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
318 Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
319
320 let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
322 Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
323 &recorder,
324 FILE_PATH.into(),
325 &device,
326 )
327 .unwrap();
328 }
329}