1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
use core::any::type_name;

use alloc::format;
use alloc::string::{String, ToString};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
#[cfg(feature = "std")]
use super::{
    BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
    PrettyJsonFileRecorder,
};

/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone {
    /// Type of the settings used by the recorder.
    type Settings: PrecisionSettings;

    /// Arguments used to record objects.
    type RecordArgs: Clone;

    /// Record output type.
    type RecordOutput;

    /// Arguments used to load recorded objects.
    type LoadArgs: Clone;

    /// Records an item.
    ///
    /// # Arguments
    ///
    /// * `record` - The item to record.
    /// * `args` - Arguments used to record the item.
    ///
    /// # Returns
    ///
    /// The output of the recording.
    fn record<R: Record>(
        &self,
        record: R,
        args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError> {
        let item = record.into_item::<Self::Settings>();
        let item = BurnRecord::new::<Self>(item);

        self.save_item(item, args)
    }

    /// Load an item from the given arguments.
    fn load<R: Record>(&self, args: Self::LoadArgs) -> Result<R, RecorderError> {
        let item: BurnRecord<R::Item<Self::Settings>> =
            self.load_item(args.clone()).map_err(|err| {
                if let Ok(record) = self.load_item::<BurnRecordNoItem>(args.clone()) {
                    let mut message = "Unable to load record.".to_string();
                    let metadata = recorder_metadata::<Self>();
                    if metadata.float != record.metadata.float {
                        message += format!(
                            "\nMetadata has a different float type: Actual {:?}, Expected {:?}",
                            record.metadata.float, metadata.float
                        )
                        .as_str();
                    }
                    if metadata.int != record.metadata.int {
                        message += format!(
                            "\nMetadata has a different int type: Actual {:?}, Expected {:?}",
                            record.metadata.int, metadata.int
                        )
                        .as_str();
                    }
                    if metadata.format != record.metadata.format {
                        message += format!(
                            "\nMetadata has a different format: Actual {:?}, Expected {:?}",
                            record.metadata.format, metadata.format
                        )
                        .as_str();
                    }
                    if metadata.version != record.metadata.version {
                        message += format!(
                            "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}",
                            record.metadata.version, metadata.version
                        )
                        .as_str();
                    }

                    message += format!("\nError: {:?}", err).as_str();

                    return RecorderError::Unknown(message);
                }

                err
            })?;

        Ok(R::from_item(item.item))
    }

    /// Saves an item.
    ///
    /// This method is used by [record](Recorder::record) to save the item.
    ///
    /// # Arguments
    ///
    /// * `item` - Item to save.
    /// * `args` - Arguments to use to save the item.
    ///
    /// # Returns
    ///
    /// The output of the save operation.
    fn save_item<I: Serialize>(
        &self,
        item: I,
        args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError>;

    /// Loads an item.
    ///
    /// This method is used by [load](Recorder::load) to load the item.
    ///
    /// # Arguments
    ///
    /// * `args` - Arguments to use to load the item.
    ///
    /// # Returns
    ///
    /// The loaded item.
    fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>;
}

fn recorder_metadata<R: Recorder>() -> BurnMetadata {
    BurnMetadata::new(
        type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
        type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
        type_name::<R>().to_string(),
        env!("CARGO_PKG_VERSION").to_string(),
        format!("{:?}", R::Settings::default()),
    )
}

/// Error that can occur when using a [Recorder](Recorder).
#[derive(Debug)]
pub enum RecorderError {
    /// File not found.
    FileNotFound(String),

    /// Other error.
    Unknown(String),
}

impl core::fmt::Display for RecorderError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.write_str(format!("{self:?}").as_str())
    }
}

// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
#[cfg(feature = "std")]
impl std::error::Error for RecorderError {}

pub(crate) fn bin_config() -> bincode::config::Configuration {
    bincode::config::standard()
}

/// Metadata of a record.
#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct BurnMetadata {
    /// Float type used to record the item.
    pub float: String,

    /// Int type used to record the item.
    pub int: String,

    /// Format used to record the item.
    pub format: String,

    /// Burn record version used to record the item.
    pub version: String,

    /// Settings used to record the item.
    pub settings: String,
}

/// Record that can be saved by a [Recorder](Recorder).
#[derive(Serialize, Deserialize)]
pub struct BurnRecord<I> {
    /// Metadata of the record.
    pub metadata: BurnMetadata,

    /// Item to record.
    pub item: I,
}

impl<I> BurnRecord<I> {
    /// Creates a new record.
    ///
    /// # Arguments
    ///
    /// * `item` - Item to record.
    ///
    /// # Returns
    ///
    /// The new record.
    pub fn new<R: Recorder>(item: I) -> Self {
        let metadata = recorder_metadata::<R>();

        Self { metadata, item }
    }
}

/// Record that can be saved by a [Recorder](Recorder) without the item.
#[derive(new, Debug, Serialize, Deserialize)]
pub struct BurnRecordNoItem {
    /// Metadata of the record.
    pub metadata: BurnMetadata,
}

/// Default recorder.
///
/// It uses the [named msgpack](rmp_serde) format for serialization with full precision.
#[cfg(feature = "std")]
pub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;

/// Recorder optimized for compactness.
///
/// It uses the [named msgpack](rmp_serde) format for serialization with half precision.
/// If you are looking for the recorder that offers the smallest file size, have a look at
/// [sensitive compact recorder](SensitiveCompactRecorder).
#[cfg(feature = "std")]
pub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;

/// Recorder optimized for compactness making it a good choice for model deployment.
///
/// It uses the [bincode](bincode) format for serialization and half precision.
/// This format is not resilient to type changes since no metadata is encoded.
/// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder)
/// for long term data storage.
#[cfg(feature = "std")]
pub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;

/// Training recorder compatible with no-std inference.
#[cfg(feature = "std")]
pub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;

/// Inference recorder compatible with no-std.
pub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings>;

/// Debug recorder.
///
/// It uses the [pretty json](serde_json) format for serialization with full precision making it
/// human readable.
#[cfg(feature = "std")]
pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;

#[cfg(all(test, feature = "std"))]
mod tests {
    static FILE_PATH: &str = "/tmp/burn_test_record";

    use super::*;
    use burn_tensor::ElementConversion;

    #[test]
    #[should_panic]
    fn err_when_invalid_item() {
        #[derive(new, Serialize, Deserialize)]
        struct Item<S: PrecisionSettings> {
            value: S::FloatElem,
        }

        impl<D: PrecisionSettings> Record for Item<D> {
            type Item<S: PrecisionSettings> = Item<S>;

            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
                Item {
                    value: self.value.elem(),
                }
            }

            fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
                Item {
                    value: item.value.elem(),
                }
            }
        }

        let item = Item::<FullPrecisionSettings>::new(16.elem());

        // Serialize in f32.
        let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
        recorder.record(item, FILE_PATH.into()).unwrap();

        // Can't deserialize f32 into f16.
        let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
        recorder
            .load::<Item<FullPrecisionSettings>>(FILE_PATH.into())
            .unwrap();
    }
}