burn_core/record/
recorder.rs

1use core::any::type_name;
2use core::marker::PhantomData;
3
4use alloc::format;
5use alloc::string::{String, ToString};
6use burn_tensor::backend::Backend;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8
9use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
10
11#[cfg(feature = "std")]
12use super::{
13    BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
14    PrettyJsonFileRecorder,
15};
16
17/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
18pub trait Recorder<B: Backend>:
19    Send + Sync + core::default::Default + core::fmt::Debug + Clone
20{
21    /// Type of the settings used by the recorder.
22    type Settings: PrecisionSettings;
23
24    /// Arguments used to record objects.
25    type RecordArgs: Clone;
26
27    /// Record output type.
28    type RecordOutput;
29
30    /// Arguments used to load recorded objects.
31    type LoadArgs;
32
33    /// Records an item.
34    ///
35    /// # Arguments
36    ///
37    /// * `record` - The item to record.
38    /// * `args` - Arguments used to record the item.
39    ///
40    /// # Returns
41    ///
42    /// The output of the recording.
43    fn record<R>(
44        &self,
45        record: R,
46        args: Self::RecordArgs,
47    ) -> Result<Self::RecordOutput, RecorderError>
48    where
49        R: Record<B>,
50    {
51        let item = record.into_item::<Self::Settings>();
52        let item = BurnRecord::new::<Self>(item);
53
54        self.save_item(item, args)
55    }
56
57    /// Load an item from the given arguments.
58    fn load<R>(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
59    where
60        R: Record<B>,
61    {
62        let item: BurnRecord<R::Item<Self::Settings>, B> =
63            self.load_item(&mut args).map_err(|err| {
64                if let Ok(record) = self.load_item::<BurnRecordNoItem>(&mut args) {
65                    let mut message = "Unable to load record.".to_string();
66                    let metadata = recorder_metadata::<Self, B>();
67                    if metadata.float != record.metadata.float {
68                        message += format!(
69                            "\nMetadata has a different float type: Actual {:?}, Expected {:?}",
70                            record.metadata.float, metadata.float
71                        )
72                        .as_str();
73                    }
74                    if metadata.int != record.metadata.int {
75                        message += format!(
76                            "\nMetadata has a different int type: Actual {:?}, Expected {:?}",
77                            record.metadata.int, metadata.int
78                        )
79                        .as_str();
80                    }
81                    if metadata.format != record.metadata.format {
82                        message += format!(
83                            "\nMetadata has a different format: Actual {:?}, Expected {:?}",
84                            record.metadata.format, metadata.format
85                        )
86                        .as_str();
87                    }
88                    if metadata.version != record.metadata.version {
89                        message += format!(
90                            "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}",
91                            record.metadata.version, metadata.version
92                        )
93                        .as_str();
94                    }
95
96                    message += format!("\nError: {err:?}").as_str();
97
98                    return RecorderError::Unknown(message);
99                }
100
101                err
102            })?;
103
104        Ok(R::from_item(item.item, device))
105    }
106
107    /// Saves an item.
108    ///
109    /// This method is used by [record](Recorder::record) to save the item.
110    ///
111    /// # Arguments
112    ///
113    /// * `item` - Item to save.
114    /// * `args` - Arguments to use to save the item.
115    ///
116    /// # Returns
117    ///
118    /// The output of the save operation.
119    fn save_item<I: Serialize>(
120        &self,
121        item: I,
122        args: Self::RecordArgs,
123    ) -> Result<Self::RecordOutput, RecorderError>;
124
125    /// Loads an item.
126    ///
127    /// This method is used by [load](Recorder::load) to load the item.
128    ///
129    /// # Arguments
130    ///
131    /// * `args` - Arguments to use to load the item.
132    ///
133    /// # Returns
134    ///
135    /// The loaded item.
136    fn load_item<I>(&self, args: &mut Self::LoadArgs) -> Result<I, RecorderError>
137    where
138        I: DeserializeOwned;
139}
140
141fn recorder_metadata<R, B>() -> BurnMetadata
142where
143    R: Recorder<B>,
144    B: Backend,
145{
146    BurnMetadata::new(
147        type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
148        type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
149        type_name::<R>().to_string(),
150        env!("CARGO_PKG_VERSION").to_string(),
151        format!("{:?}", R::Settings::default()),
152    )
153}
154
155/// Error that can occur when using a [Recorder](Recorder).
156#[derive(Debug)]
157pub enum RecorderError {
158    /// File not found.
159    FileNotFound(String),
160
161    /// Failed to read file.
162    DeserializeError(String),
163
164    /// Other error.
165    Unknown(String),
166}
167
168impl core::fmt::Display for RecorderError {
169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170        f.write_str(format!("{self:?}").as_str())
171    }
172}
173
174impl core::error::Error for RecorderError {}
175
176pub(crate) fn bin_config() -> bincode::config::Configuration {
177    bincode::config::standard()
178}
179
180/// Metadata of a record.
181#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
182pub struct BurnMetadata {
183    /// Float type used to record the item.
184    pub float: String,
185
186    /// Int type used to record the item.
187    pub int: String,
188
189    /// Format used to record the item.
190    pub format: String,
191
192    /// Burn record version used to record the item.
193    pub version: String,
194
195    /// Settings used to record the item.
196    pub settings: String,
197}
198
199/// Record that can be saved by a [Recorder](Recorder).
200#[derive(Serialize, Deserialize, Debug)]
201pub struct BurnRecord<I, B: Backend> {
202    /// Metadata of the record.
203    pub metadata: BurnMetadata,
204
205    /// Item to record.
206    pub item: I,
207
208    _b: PhantomData<B>,
209}
210
211impl<I, B: Backend> BurnRecord<I, B> {
212    /// Creates a new record.
213    ///
214    /// # Arguments
215    ///
216    /// * `item` - Item to record.
217    ///
218    /// # Returns
219    ///
220    /// The new record.
221    pub fn new<R: Recorder<B>>(item: I) -> Self {
222        let metadata = recorder_metadata::<R, B>();
223
224        Self {
225            metadata,
226            item,
227            _b: PhantomData,
228        }
229    }
230}
231
232/// Record that can be saved by a [Recorder](Recorder) without the item.
233#[derive(new, Debug, Serialize, Deserialize)]
234pub struct BurnRecordNoItem {
235    /// Metadata of the record.
236    pub metadata: BurnMetadata,
237}
238
239/// Default recorder.
240///
241/// It uses the [named msgpack](rmp_serde) format for serialization with full precision.
242#[cfg(feature = "std")]
243pub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;
244
245/// Recorder optimized for compactness.
246///
247/// It uses the [named msgpack](rmp_serde) format for serialization with half precision.
248/// If you are looking for the recorder that offers the smallest file size, have a look at
249/// [sensitive compact recorder](SensitiveCompactRecorder).
250#[cfg(feature = "std")]
251pub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;
252
253/// Recorder optimized for compactness making it a good choice for model deployment.
254///
255/// It uses the [bincode](bincode) format for serialization and half precision.
256/// This format is not resilient to type changes since no metadata is encoded.
257/// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder)
258/// for long term data storage.
259#[cfg(feature = "std")]
260pub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;
261
262/// Training recorder compatible with no-std inference.
263#[cfg(feature = "std")]
264pub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;
265
266/// Inference recorder compatible with no-std.
267pub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings, &'static [u8]>;
268
269/// Debug recorder.
270///
271/// It uses the [pretty json](serde_json) format for serialization with full precision making it
272/// human readable.
273#[cfg(feature = "std")]
274pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
275
276#[cfg(all(test, feature = "std"))]
277mod tests {
278    static FILE_PATH: &str = "/tmp/burn_test_record";
279
280    use crate::TestBackend;
281
282    use super::*;
283    use burn_tensor::{Device, ElementConversion};
284
285    #[test]
286    #[should_panic]
287    fn err_when_invalid_item() {
288        #[derive(new, Serialize, Deserialize)]
289        struct Item<S: PrecisionSettings> {
290            value: S::FloatElem,
291        }
292
293        impl<D, B> Record<B> for Item<D>
294        where
295            D: PrecisionSettings,
296            B: Backend,
297        {
298            type Item<S: PrecisionSettings> = Item<S>;
299
300            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
301                Item {
302                    value: self.value.elem(),
303                }
304            }
305
306            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
307                Item {
308                    value: item.value.elem(),
309                }
310            }
311        }
312
313        let item = Item::<FullPrecisionSettings>::new(16.elem());
314        let device: Device<TestBackend> = Default::default();
315
316        // Serialize in f32.
317        let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
318        Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
319
320        // Can't deserialize f32 into f16.
321        let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
322        Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
323            &recorder,
324            FILE_PATH.into(),
325            &device,
326        )
327        .unwrap();
328    }
329}