burn_core/record/
file.rs

1use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{read::GzDecoder, write::GzEncoder, Compression};
5use serde::{de::DeserializeOwned, Serialize};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9/// Recorder trait specialized to save and load data to and from files.
10pub trait FileRecorder<B: Backend>:
11    Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13    /// File extension of the format used by the recorder.
14    fn file_extension() -> &'static str;
15}
16
17/// Default [file recorder](FileRecorder).
18pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20/// File recorder using the [bincode format](bincode).
21#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23    _settings: PhantomData<S>,
24}
25
26/// File recorder using the [bincode format](bincode) compressed with gzip.
27#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29    _settings: PhantomData<S>,
30}
31
32/// File recorder using the [json format](serde_json) compressed with gzip.
33#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35    _settings: PhantomData<S>,
36}
37
38/// File recorder using [pretty json format](serde_json) for easy readability.
39#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41    _settings: PhantomData<S>,
42}
43
44/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.
45#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47    _settings: PhantomData<S>,
48}
49
50/// File recorder using the [named msgpack](rmp_serde) format.
51#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53    _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57    fn file_extension() -> &'static str {
58        "bin.gz"
59    }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62    fn file_extension() -> &'static str {
63        "bin"
64    }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67    fn file_extension() -> &'static str {
68        "json.gz"
69    }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72    fn file_extension() -> &'static str {
73        "json"
74    }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78    fn file_extension() -> &'static str {
79        "mpk.gz"
80    }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84    fn file_extension() -> &'static str {
85        "mpk"
86    }
87}
88
89macro_rules! str2reader {
90    (
91        $file:expr
92    ) => {{
93        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94        let path = $file.as_path();
95
96        File::open(path)
97            .map_err(|err| match err.kind() {
98                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99                _ => RecorderError::Unknown(err.to_string()),
100            })
101            .map(|file| BufReader::new(file))
102    }};
103}
104
105macro_rules! str2writer {
106    (
107        $file:expr
108    ) => {{
109        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110        let path = $file.as_path();
111
112        // Add parent directories if they don't exist
113        if let Some(parent) = path.parent() {
114            std::fs::create_dir_all(parent).ok();
115        }
116
117        if path.exists() {
118            log::info!("File exists, replacing");
119            std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
120        }
121
122        File::create(path)
123            .map_err(|err| match err.kind() {
124                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
125                _ => RecorderError::Unknown(err.to_string()),
126            })
127            .map(|file| BufWriter::new(file))
128    }};
129}
130
131impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
132    type Settings = S;
133    type RecordArgs = PathBuf;
134    type RecordOutput = ();
135    type LoadArgs = PathBuf;
136
137    fn save_item<I: Serialize>(
138        &self,
139        item: I,
140        mut file: Self::RecordArgs,
141    ) -> Result<(), RecorderError> {
142        let config = bin_config();
143        let writer = str2writer!(file)?;
144        let mut writer = GzEncoder::new(writer, Compression::default());
145
146        bincode::serde::encode_into_std_write(&item, &mut writer, config)
147            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
148
149        Ok(())
150    }
151
152    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
153        let reader = str2reader!(file)?;
154        let mut reader = GzDecoder::new(reader);
155        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
156            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
157
158        Ok(state)
159    }
160}
161
162impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
163    type Settings = S;
164    type RecordArgs = PathBuf;
165    type RecordOutput = ();
166    type LoadArgs = PathBuf;
167
168    fn save_item<I: Serialize>(
169        &self,
170        item: I,
171        mut file: Self::RecordArgs,
172    ) -> Result<(), RecorderError> {
173        let config = bin_config();
174        let mut writer = str2writer!(file)?;
175        bincode::serde::encode_into_std_write(&item, &mut writer, config)
176            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
177        Ok(())
178    }
179
180    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
181        let mut reader = str2reader!(file)?;
182        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
183            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
184        Ok(state)
185    }
186}
187
188impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
189    type Settings = S;
190    type RecordArgs = PathBuf;
191    type RecordOutput = ();
192    type LoadArgs = PathBuf;
193
194    fn save_item<I: Serialize>(
195        &self,
196        item: I,
197        mut file: Self::RecordArgs,
198    ) -> Result<(), RecorderError> {
199        let writer = str2writer!(file)?;
200        let writer = GzEncoder::new(writer, Compression::default());
201        serde_json::to_writer(writer, &item)
202            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
203
204        Ok(())
205    }
206
207    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
208        let reader = str2reader!(file)?;
209        let reader = GzDecoder::new(reader);
210        let state = serde_json::from_reader(reader)
211            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
212
213        Ok(state)
214    }
215}
216
217impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
218    type Settings = S;
219    type RecordArgs = PathBuf;
220    type RecordOutput = ();
221    type LoadArgs = PathBuf;
222
223    fn save_item<I: Serialize>(
224        &self,
225        item: I,
226        mut file: Self::RecordArgs,
227    ) -> Result<(), RecorderError> {
228        let writer = str2writer!(file)?;
229        serde_json::to_writer_pretty(writer, &item)
230            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
231        Ok(())
232    }
233
234    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
235        let reader = str2reader!(file)?;
236        let state = serde_json::from_reader(reader)
237            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
238
239        Ok(state)
240    }
241}
242
243impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
244    type Settings = S;
245    type RecordArgs = PathBuf;
246    type RecordOutput = ();
247    type LoadArgs = PathBuf;
248
249    fn save_item<I: Serialize>(
250        &self,
251        item: I,
252        mut file: Self::RecordArgs,
253    ) -> Result<(), RecorderError> {
254        let writer = str2writer!(file)?;
255        let mut writer = GzEncoder::new(writer, Compression::default());
256        rmp_serde::encode::write_named(&mut writer, &item)
257            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
258
259        Ok(())
260    }
261
262    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
263        let reader = str2reader!(file)?;
264        let reader = GzDecoder::new(reader);
265        let state = rmp_serde::decode::from_read(reader)
266            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
267
268        Ok(state)
269    }
270}
271
272impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
273    type Settings = S;
274    type RecordArgs = PathBuf;
275    type RecordOutput = ();
276    type LoadArgs = PathBuf;
277
278    fn save_item<I: Serialize>(
279        &self,
280        item: I,
281        mut file: Self::RecordArgs,
282    ) -> Result<(), RecorderError> {
283        let mut writer = str2writer!(file)?;
284
285        rmp_serde::encode::write_named(&mut writer, &item)
286            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
287
288        Ok(())
289    }
290
291    fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
292        let reader = str2reader!(file)?;
293        let state = rmp_serde::decode::from_read(reader)
294            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
295
296        Ok(state)
297    }
298}
299
300#[cfg(test)]
301mod tests {
302
303    use burn_tensor::backend::Backend;
304
305    use super::*;
306    use crate::{
307        module::Module,
308        nn::{
309            conv::{Conv2d, Conv2dConfig},
310            Linear, LinearConfig,
311        },
312        record::{BinBytesRecorder, FullPrecisionSettings},
313        TestBackend,
314    };
315
316    use crate as burn;
317
318    #[inline(always)]
319    fn file_path() -> PathBuf {
320        std::env::temp_dir()
321            .as_path()
322            .join("burn_test_file_recorder")
323    }
324
325    #[test]
326    fn test_can_save_and_load_jsongz_format() {
327        test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
328    }
329
330    #[test]
331    fn test_can_save_and_load_bin_format() {
332        test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
333    }
334
335    #[test]
336    fn test_can_save_and_load_bingz_format() {
337        test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
338    }
339
340    #[test]
341    fn test_can_save_and_load_pretty_json_format() {
342        test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
343    }
344
345    #[test]
346    fn test_can_save_and_load_mpkgz_format() {
347        test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
348    }
349
350    #[test]
351    fn test_can_save_and_load_mpk_format() {
352        test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
353    }
354
355    fn test_can_save_and_load<Recorder>(recorder: Recorder)
356    where
357        Recorder: FileRecorder<TestBackend>,
358    {
359        let device = Default::default();
360        let model_before = create_model(&device);
361        recorder
362            .record(model_before.clone().into_record(), file_path())
363            .unwrap();
364
365        let model_after =
366            create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
367
368        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
369        let model_bytes_before = byte_recorder
370            .record(model_before.into_record(), ())
371            .unwrap();
372        let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
373
374        assert_eq!(model_bytes_after, model_bytes_before);
375    }
376
377    #[derive(Module, Debug)]
378    pub struct Model<B: Backend> {
379        conv2d1: Conv2d<B>,
380        linear1: Linear<B>,
381        phantom: core::marker::PhantomData<B>,
382    }
383
384    pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
385        let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(device);
386
387        let linear1 = LinearConfig::new(32, 32).with_bias(true).init(device);
388
389        Model {
390            conv2d1,
391            linear1,
392            phantom: core::marker::PhantomData,
393        }
394    }
395}