burn_core/record/
file.rs

1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{Compression, read::GzDecoder, write::GzEncoder};
5use serde::{Serialize, de::DeserializeOwned};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9/// Recorder trait specialized to save and load data to and from files.
10pub trait FileRecorder<B: Backend>:
11    Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13    /// File extension of the format used by the recorder.
14    fn file_extension() -> &'static str;
15}
16
17/// Default [file recorder](FileRecorder).
18pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20/// File recorder using the [bincode format](bincode).
21#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23    _settings: PhantomData<S>,
24}
25
26/// File recorder using the [bincode format](bincode) compressed with gzip.
27#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29    _settings: PhantomData<S>,
30}
31
32/// File recorder using the [json format](serde_json) compressed with gzip.
33#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35    _settings: PhantomData<S>,
36}
37
38/// File recorder using [pretty json format](serde_json) for easy readability.
39#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41    _settings: PhantomData<S>,
42}
43
44/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.
45#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47    _settings: PhantomData<S>,
48}
49
50/// File recorder using the [named msgpack](rmp_serde) format.
51#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53    _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57    fn file_extension() -> &'static str {
58        "bin.gz"
59    }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62    fn file_extension() -> &'static str {
63        "bin"
64    }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67    fn file_extension() -> &'static str {
68        "json.gz"
69    }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72    fn file_extension() -> &'static str {
73        "json"
74    }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78    fn file_extension() -> &'static str {
79        "mpk.gz"
80    }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84    fn file_extension() -> &'static str {
85        "mpk"
86    }
87}
88
89macro_rules! str2reader {
90    (
91        $file:expr
92    ) => {{
93        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94        let path = $file.as_path();
95
96        File::open(path)
97            .map_err(|err| match err.kind() {
98                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99                _ => RecorderError::Unknown(err.to_string()),
100            })
101            .map(|file| BufReader::new(file))
102    }};
103}
104
105macro_rules! str2writer {
106    (
107        $file:expr
108    ) => {{
109        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110        let path = $file.as_path();
111
112        // Add parent directories if they don't exist
113        if let Some(parent) = path.parent() {
114            std::fs::create_dir_all(parent).ok();
115        }
116
117        if path.exists() {
118            log::info!("File exists, replacing");
119            std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
120        }
121
122        File::create(path)
123            .map_err(|err| match err.kind() {
124                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
125                _ => RecorderError::Unknown(err.to_string()),
126            })
127            .map(|file| BufWriter::new(file))
128    }};
129}
130
131impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
132    type Settings = S;
133    type RecordArgs = PathBuf;
134    type RecordOutput = ();
135    type LoadArgs = PathBuf;
136
137    fn save_item<I: Serialize>(
138        &self,
139        item: I,
140        mut file: Self::RecordArgs,
141    ) -> Result<(), RecorderError> {
142        let config = bin_config();
143        let writer = str2writer!(file)?;
144        let mut writer = GzEncoder::new(writer, Compression::default());
145
146        bincode::serde::encode_into_std_write(&item, &mut writer, config)
147            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
148
149        Ok(())
150    }
151
152    fn load_item<I: DeserializeOwned>(
153        &self,
154        file: &mut Self::LoadArgs,
155    ) -> Result<I, RecorderError> {
156        let reader = str2reader!(file)?;
157        let mut reader = GzDecoder::new(reader);
158        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
159            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
160
161        Ok(state)
162    }
163}
164
165impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
166    type Settings = S;
167    type RecordArgs = PathBuf;
168    type RecordOutput = ();
169    type LoadArgs = PathBuf;
170
171    fn save_item<I: Serialize>(
172        &self,
173        item: I,
174        mut file: Self::RecordArgs,
175    ) -> Result<(), RecorderError> {
176        let config = bin_config();
177        let mut writer = str2writer!(file)?;
178        bincode::serde::encode_into_std_write(&item, &mut writer, config)
179            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
180        Ok(())
181    }
182
183    fn load_item<I: DeserializeOwned>(
184        &self,
185        file: &mut Self::LoadArgs,
186    ) -> Result<I, RecorderError> {
187        let mut reader = str2reader!(file)?;
188        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
189            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
190        Ok(state)
191    }
192}
193
194impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
195    type Settings = S;
196    type RecordArgs = PathBuf;
197    type RecordOutput = ();
198    type LoadArgs = PathBuf;
199
200    fn save_item<I: Serialize>(
201        &self,
202        item: I,
203        mut file: Self::RecordArgs,
204    ) -> Result<(), RecorderError> {
205        let writer = str2writer!(file)?;
206        let writer = GzEncoder::new(writer, Compression::default());
207        serde_json::to_writer(writer, &item)
208            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
209
210        Ok(())
211    }
212
213    fn load_item<I: DeserializeOwned>(
214        &self,
215        file: &mut Self::LoadArgs,
216    ) -> Result<I, RecorderError> {
217        let reader = str2reader!(file)?;
218        let reader = GzDecoder::new(reader);
219        let state = serde_json::from_reader(reader)
220            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
221
222        Ok(state)
223    }
224}
225
226impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
227    type Settings = S;
228    type RecordArgs = PathBuf;
229    type RecordOutput = ();
230    type LoadArgs = PathBuf;
231
232    fn save_item<I: Serialize>(
233        &self,
234        item: I,
235        mut file: Self::RecordArgs,
236    ) -> Result<(), RecorderError> {
237        let writer = str2writer!(file)?;
238        serde_json::to_writer_pretty(writer, &item)
239            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
240        Ok(())
241    }
242
243    fn load_item<I: DeserializeOwned>(
244        &self,
245        file: &mut Self::LoadArgs,
246    ) -> Result<I, RecorderError> {
247        let reader = str2reader!(file)?;
248        let state = serde_json::from_reader(reader)
249            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
250
251        Ok(state)
252    }
253}
254
255impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
256    type Settings = S;
257    type RecordArgs = PathBuf;
258    type RecordOutput = ();
259    type LoadArgs = PathBuf;
260
261    fn save_item<I: Serialize>(
262        &self,
263        item: I,
264        mut file: Self::RecordArgs,
265    ) -> Result<(), RecorderError> {
266        let writer = str2writer!(file)?;
267        let mut writer = GzEncoder::new(writer, Compression::default());
268        rmp_serde::encode::write_named(&mut writer, &item)
269            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
270
271        Ok(())
272    }
273
274    fn load_item<I: DeserializeOwned>(
275        &self,
276        file: &mut Self::LoadArgs,
277    ) -> Result<I, RecorderError> {
278        let reader = str2reader!(file)?;
279        let reader = GzDecoder::new(reader);
280        let state = rmp_serde::decode::from_read(reader)
281            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
282
283        Ok(state)
284    }
285}
286
287impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
288    type Settings = S;
289    type RecordArgs = PathBuf;
290    type RecordOutput = ();
291    type LoadArgs = PathBuf;
292
293    fn save_item<I: Serialize>(
294        &self,
295        item: I,
296        mut file: Self::RecordArgs,
297    ) -> Result<(), RecorderError> {
298        let mut writer = str2writer!(file)?;
299
300        rmp_serde::encode::write_named(&mut writer, &item)
301            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
302
303        Ok(())
304    }
305
306    fn load_item<I: DeserializeOwned>(
307        &self,
308        file: &mut Self::LoadArgs,
309    ) -> Result<I, RecorderError> {
310        let reader = str2reader!(file)?;
311        let state = rmp_serde::decode::from_read(reader)
312            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
313
314        Ok(state)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320
321    use burn_tensor::backend::Backend;
322
323    use super::*;
324    use crate::{
325        TestBackend,
326        module::Module,
327        nn::{
328            Linear, LinearConfig,
329            conv::{Conv2d, Conv2dConfig},
330        },
331        record::{BinBytesRecorder, FullPrecisionSettings},
332    };
333
334    use crate as burn;
335
336    #[inline(always)]
337    fn file_path() -> PathBuf {
338        std::env::temp_dir()
339            .as_path()
340            .join("burn_test_file_recorder")
341    }
342
343    #[test]
344    fn test_can_save_and_load_jsongz_format() {
345        test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
346    }
347
348    #[test]
349    fn test_can_save_and_load_bin_format() {
350        test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
351    }
352
353    #[test]
354    fn test_can_save_and_load_bingz_format() {
355        test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
356    }
357
358    #[test]
359    fn test_can_save_and_load_pretty_json_format() {
360        test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
361    }
362
363    #[test]
364    fn test_can_save_and_load_mpkgz_format() {
365        test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
366    }
367
368    #[test]
369    fn test_can_save_and_load_mpk_format() {
370        test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
371    }
372
373    fn test_can_save_and_load<Recorder>(recorder: Recorder)
374    where
375        Recorder: FileRecorder<TestBackend>,
376    {
377        let device = Default::default();
378        let model_before = create_model(&device);
379        recorder
380            .record(model_before.clone().into_record(), file_path())
381            .unwrap();
382
383        let model_after =
384            create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
385
386        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
387        let model_bytes_before = byte_recorder
388            .record(model_before.into_record(), ())
389            .unwrap();
390        let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
391
392        assert_eq!(model_bytes_after, model_bytes_before);
393    }
394
395    #[derive(Module, Debug)]
396    pub struct Model<B: Backend> {
397        conv2d1: Conv2d<B>,
398        linear1: Linear<B>,
399        phantom: core::marker::PhantomData<B>,
400    }
401
402    pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
403        let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(device);
404
405        let linear1 = LinearConfig::new(32, 32).with_bias(true).init(device);
406
407        Model {
408            conv2d1,
409            linear1,
410            phantom: core::marker::PhantomData,
411        }
412    }
413}