burn_core/record/
file.rs

1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{Compression, read::GzDecoder, write::GzEncoder};
5use serde::{Serialize, de::DeserializeOwned};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9/// Recorder trait specialized to save and load data to and from files.
10pub trait FileRecorder<B: Backend>:
11    Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13    /// File extension of the format used by the recorder.
14    fn file_extension() -> &'static str;
15}
16
17/// Default [file recorder](FileRecorder).
18pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20/// File recorder using the [bincode format](bincode).
21#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23    _settings: PhantomData<S>,
24}
25
26/// File recorder using the [bincode format](bincode) compressed with gzip.
27#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29    _settings: PhantomData<S>,
30}
31
32/// File recorder using the [json format](serde_json) compressed with gzip.
33#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35    _settings: PhantomData<S>,
36}
37
38/// File recorder using [pretty json format](serde_json) for easy readability.
39#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41    _settings: PhantomData<S>,
42}
43
44/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.
45#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47    _settings: PhantomData<S>,
48}
49
50/// File recorder using the [named msgpack](rmp_serde) format.
51#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53    _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57    fn file_extension() -> &'static str {
58        "bin.gz"
59    }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62    fn file_extension() -> &'static str {
63        "bin"
64    }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67    fn file_extension() -> &'static str {
68        "json.gz"
69    }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72    fn file_extension() -> &'static str {
73        "json"
74    }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78    fn file_extension() -> &'static str {
79        "mpk.gz"
80    }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84    fn file_extension() -> &'static str {
85        "mpk"
86    }
87}
88
89macro_rules! str2reader {
90    (
91        $file:expr
92    ) => {{
93        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94        let path = $file.as_path();
95
96        File::open(path)
97            .map_err(|err| match err.kind() {
98                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99                _ => RecorderError::Unknown(err.to_string()),
100            })
101            .map(|file| BufReader::new(file))
102    }};
103}
104
105macro_rules! str2writer {
106    (
107        $file:expr
108    ) => {{
109        $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110        let path = $file.as_path();
111
112        log::debug!("Writing to file: {:?}", path);
113
114        // Add parent directories if they don't exist
115        if let Some(parent) = path.parent() {
116            std::fs::create_dir_all(parent).ok();
117        }
118
119        if path.exists() {
120            log::warn!("File exists, replacing");
121            std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
122        }
123
124        File::create(path)
125            .map_err(|err| match err.kind() {
126                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
127                _ => RecorderError::Unknown(err.to_string()),
128            })
129            .map(|file| BufWriter::new(file))
130    }};
131}
132
133impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
134    type Settings = S;
135    type RecordArgs = PathBuf;
136    type RecordOutput = ();
137    type LoadArgs = PathBuf;
138
139    fn save_item<I: Serialize>(
140        &self,
141        item: I,
142        mut file: Self::RecordArgs,
143    ) -> Result<(), RecorderError> {
144        let config = bin_config();
145        let writer = str2writer!(file)?;
146        let mut writer = GzEncoder::new(writer, Compression::default());
147
148        bincode::serde::encode_into_std_write(&item, &mut writer, config)
149            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
150
151        Ok(())
152    }
153
154    fn load_item<I: DeserializeOwned>(
155        &self,
156        file: &mut Self::LoadArgs,
157    ) -> Result<I, RecorderError> {
158        let reader = str2reader!(file)?;
159        let mut reader = GzDecoder::new(reader);
160        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
161            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
162
163        Ok(state)
164    }
165}
166
167impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
168    type Settings = S;
169    type RecordArgs = PathBuf;
170    type RecordOutput = ();
171    type LoadArgs = PathBuf;
172
173    fn save_item<I: Serialize>(
174        &self,
175        item: I,
176        mut file: Self::RecordArgs,
177    ) -> Result<(), RecorderError> {
178        let config = bin_config();
179        let mut writer = str2writer!(file)?;
180        bincode::serde::encode_into_std_write(&item, &mut writer, config)
181            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
182        Ok(())
183    }
184
185    fn load_item<I: DeserializeOwned>(
186        &self,
187        file: &mut Self::LoadArgs,
188    ) -> Result<I, RecorderError> {
189        let mut reader = str2reader!(file)?;
190        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
191            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
192        Ok(state)
193    }
194}
195
196impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
197    type Settings = S;
198    type RecordArgs = PathBuf;
199    type RecordOutput = ();
200    type LoadArgs = PathBuf;
201
202    fn save_item<I: Serialize>(
203        &self,
204        item: I,
205        mut file: Self::RecordArgs,
206    ) -> Result<(), RecorderError> {
207        let writer = str2writer!(file)?;
208        let writer = GzEncoder::new(writer, Compression::default());
209        serde_json::to_writer(writer, &item)
210            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
211
212        Ok(())
213    }
214
215    fn load_item<I: DeserializeOwned>(
216        &self,
217        file: &mut Self::LoadArgs,
218    ) -> Result<I, RecorderError> {
219        let reader = str2reader!(file)?;
220        let reader = GzDecoder::new(reader);
221        let state = serde_json::from_reader(reader)
222            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
223
224        Ok(state)
225    }
226}
227
228impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
229    type Settings = S;
230    type RecordArgs = PathBuf;
231    type RecordOutput = ();
232    type LoadArgs = PathBuf;
233
234    fn save_item<I: Serialize>(
235        &self,
236        item: I,
237        mut file: Self::RecordArgs,
238    ) -> Result<(), RecorderError> {
239        let writer = str2writer!(file)?;
240        serde_json::to_writer_pretty(writer, &item)
241            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
242        Ok(())
243    }
244
245    fn load_item<I: DeserializeOwned>(
246        &self,
247        file: &mut Self::LoadArgs,
248    ) -> Result<I, RecorderError> {
249        let reader = str2reader!(file)?;
250        let state = serde_json::from_reader(reader)
251            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
252
253        Ok(state)
254    }
255}
256
257impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
258    type Settings = S;
259    type RecordArgs = PathBuf;
260    type RecordOutput = ();
261    type LoadArgs = PathBuf;
262
263    fn save_item<I: Serialize>(
264        &self,
265        item: I,
266        mut file: Self::RecordArgs,
267    ) -> Result<(), RecorderError> {
268        let writer = str2writer!(file)?;
269        let mut writer = GzEncoder::new(writer, Compression::default());
270        rmp_serde::encode::write_named(&mut writer, &item)
271            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
272
273        Ok(())
274    }
275
276    fn load_item<I: DeserializeOwned>(
277        &self,
278        file: &mut Self::LoadArgs,
279    ) -> Result<I, RecorderError> {
280        let reader = str2reader!(file)?;
281        let reader = GzDecoder::new(reader);
282        let state = rmp_serde::decode::from_read(reader)
283            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
284
285        Ok(state)
286    }
287}
288
289impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
290    type Settings = S;
291    type RecordArgs = PathBuf;
292    type RecordOutput = ();
293    type LoadArgs = PathBuf;
294
295    fn save_item<I: Serialize>(
296        &self,
297        item: I,
298        mut file: Self::RecordArgs,
299    ) -> Result<(), RecorderError> {
300        let mut writer = str2writer!(file)?;
301
302        rmp_serde::encode::write_named(&mut writer, &item)
303            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
304
305        Ok(())
306    }
307
308    fn load_item<I: DeserializeOwned>(
309        &self,
310        file: &mut Self::LoadArgs,
311    ) -> Result<I, RecorderError> {
312        let reader = str2reader!(file)?;
313        let state = rmp_serde::decode::from_read(reader)
314            .map_err(|err| RecorderError::Unknown(err.to_string()))?;
315
316        Ok(state)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate as burn;
324    use crate::config::Config;
325    use crate::module::Ignored;
326    use crate::test_utils::SimpleLinear;
327    use crate::{
328        TestBackend,
329        module::Module,
330        record::{BinBytesRecorder, FullPrecisionSettings},
331    };
332    use burn_tensor::backend::Backend;
333
334    #[inline(always)]
335    fn file_path() -> PathBuf {
336        std::env::temp_dir()
337            .as_path()
338            .join("burn_test_file_recorder")
339    }
340
341    #[test]
342    fn test_can_save_and_load_jsongz_format() {
343        test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
344    }
345
346    #[test]
347    fn test_can_save_and_load_bin_format() {
348        test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
349    }
350
351    #[test]
352    fn test_can_save_and_load_bingz_format() {
353        test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
354    }
355
356    #[test]
357    fn test_can_save_and_load_pretty_json_format() {
358        test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
359    }
360
361    #[test]
362    fn test_can_save_and_load_mpkgz_format() {
363        test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
364    }
365
366    #[test]
367    fn test_can_save_and_load_mpk_format() {
368        test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
369    }
370
371    fn test_can_save_and_load<Recorder>(recorder: Recorder)
372    where
373        Recorder: FileRecorder<TestBackend>,
374    {
375        let device = Default::default();
376        let model_before = create_model(&device);
377        recorder
378            .record(model_before.clone().into_record(), file_path())
379            .unwrap();
380
381        let model_after =
382            create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
383
384        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
385        let model_bytes_before = byte_recorder
386            .record(model_before.into_record(), ())
387            .unwrap();
388        let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
389
390        assert_eq!(model_bytes_after, model_bytes_before);
391    }
392
393    #[derive(Config, Debug)]
394    pub enum PaddingConfig2d {
395        Same,
396        Valid,
397        Explicit(usize, usize),
398    }
399
400    // Dummy model with different record types
401    #[derive(Module, Debug)]
402    pub struct Model<B: Backend> {
403        linear1: SimpleLinear<B>,
404        phantom: PhantomData<B>,
405        arr: [usize; 2],
406        int: usize,
407        ignore: Ignored<PaddingConfig2d>,
408    }
409
410    pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
411        let linear1 = SimpleLinear::new(32, 32, device);
412
413        Model {
414            linear1,
415            phantom: PhantomData,
416            arr: [2, 2],
417            int: 0,
418            ignore: Ignored(PaddingConfig2d::Same),
419        }
420    }
421}