1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{Compression, read::GzDecoder, write::GzEncoder};
5use serde::{Serialize, de::DeserializeOwned};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9pub trait FileRecorder<B: Backend>:
11 Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13 fn file_extension() -> &'static str;
15}
16
17pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23 _settings: PhantomData<S>,
24}
25
26#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29 _settings: PhantomData<S>,
30}
31
32#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35 _settings: PhantomData<S>,
36}
37
38#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41 _settings: PhantomData<S>,
42}
43
44#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47 _settings: PhantomData<S>,
48}
49
50#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53 _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57 fn file_extension() -> &'static str {
58 "bin.gz"
59 }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62 fn file_extension() -> &'static str {
63 "bin"
64 }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67 fn file_extension() -> &'static str {
68 "json.gz"
69 }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72 fn file_extension() -> &'static str {
73 "json"
74 }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78 fn file_extension() -> &'static str {
79 "mpk.gz"
80 }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84 fn file_extension() -> &'static str {
85 "mpk"
86 }
87}
88
89macro_rules! str2reader {
90 (
91 $file:expr
92 ) => {{
93 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94 let path = $file.as_path();
95
96 File::open(path)
97 .map_err(|err| match err.kind() {
98 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99 _ => RecorderError::Unknown(err.to_string()),
100 })
101 .map(|file| BufReader::new(file))
102 }};
103}
104
105macro_rules! str2writer {
106 (
107 $file:expr
108 ) => {{
109 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110 let path = $file.as_path();
111
112 log::debug!("Writing to file: {:?}", path);
113
114 if let Some(parent) = path.parent() {
116 std::fs::create_dir_all(parent).ok();
117 }
118
119 if path.exists() {
120 log::warn!("File exists, replacing");
121 std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
122 }
123
124 File::create(path)
125 .map_err(|err| match err.kind() {
126 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
127 _ => RecorderError::Unknown(err.to_string()),
128 })
129 .map(|file| BufWriter::new(file))
130 }};
131}
132
133impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
134 type Settings = S;
135 type RecordArgs = PathBuf;
136 type RecordOutput = ();
137 type LoadArgs = PathBuf;
138
139 fn save_item<I: Serialize>(
140 &self,
141 item: I,
142 mut file: Self::RecordArgs,
143 ) -> Result<(), RecorderError> {
144 let config = bin_config();
145 let writer = str2writer!(file)?;
146 let mut writer = GzEncoder::new(writer, Compression::default());
147
148 bincode::serde::encode_into_std_write(&item, &mut writer, config)
149 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
150
151 Ok(())
152 }
153
154 fn load_item<I: DeserializeOwned>(
155 &self,
156 file: &mut Self::LoadArgs,
157 ) -> Result<I, RecorderError> {
158 let reader = str2reader!(file)?;
159 let mut reader = GzDecoder::new(reader);
160 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
161 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
162
163 Ok(state)
164 }
165}
166
167impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
168 type Settings = S;
169 type RecordArgs = PathBuf;
170 type RecordOutput = ();
171 type LoadArgs = PathBuf;
172
173 fn save_item<I: Serialize>(
174 &self,
175 item: I,
176 mut file: Self::RecordArgs,
177 ) -> Result<(), RecorderError> {
178 let config = bin_config();
179 let mut writer = str2writer!(file)?;
180 bincode::serde::encode_into_std_write(&item, &mut writer, config)
181 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
182 Ok(())
183 }
184
185 fn load_item<I: DeserializeOwned>(
186 &self,
187 file: &mut Self::LoadArgs,
188 ) -> Result<I, RecorderError> {
189 let mut reader = str2reader!(file)?;
190 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
191 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
192 Ok(state)
193 }
194}
195
196impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
197 type Settings = S;
198 type RecordArgs = PathBuf;
199 type RecordOutput = ();
200 type LoadArgs = PathBuf;
201
202 fn save_item<I: Serialize>(
203 &self,
204 item: I,
205 mut file: Self::RecordArgs,
206 ) -> Result<(), RecorderError> {
207 let writer = str2writer!(file)?;
208 let writer = GzEncoder::new(writer, Compression::default());
209 serde_json::to_writer(writer, &item)
210 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
211
212 Ok(())
213 }
214
215 fn load_item<I: DeserializeOwned>(
216 &self,
217 file: &mut Self::LoadArgs,
218 ) -> Result<I, RecorderError> {
219 let reader = str2reader!(file)?;
220 let reader = GzDecoder::new(reader);
221 let state = serde_json::from_reader(reader)
222 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
223
224 Ok(state)
225 }
226}
227
228impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
229 type Settings = S;
230 type RecordArgs = PathBuf;
231 type RecordOutput = ();
232 type LoadArgs = PathBuf;
233
234 fn save_item<I: Serialize>(
235 &self,
236 item: I,
237 mut file: Self::RecordArgs,
238 ) -> Result<(), RecorderError> {
239 let writer = str2writer!(file)?;
240 serde_json::to_writer_pretty(writer, &item)
241 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
242 Ok(())
243 }
244
245 fn load_item<I: DeserializeOwned>(
246 &self,
247 file: &mut Self::LoadArgs,
248 ) -> Result<I, RecorderError> {
249 let reader = str2reader!(file)?;
250 let state = serde_json::from_reader(reader)
251 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
252
253 Ok(state)
254 }
255}
256
257impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
258 type Settings = S;
259 type RecordArgs = PathBuf;
260 type RecordOutput = ();
261 type LoadArgs = PathBuf;
262
263 fn save_item<I: Serialize>(
264 &self,
265 item: I,
266 mut file: Self::RecordArgs,
267 ) -> Result<(), RecorderError> {
268 let writer = str2writer!(file)?;
269 let mut writer = GzEncoder::new(writer, Compression::default());
270 rmp_serde::encode::write_named(&mut writer, &item)
271 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
272
273 Ok(())
274 }
275
276 fn load_item<I: DeserializeOwned>(
277 &self,
278 file: &mut Self::LoadArgs,
279 ) -> Result<I, RecorderError> {
280 let reader = str2reader!(file)?;
281 let reader = GzDecoder::new(reader);
282 let state = rmp_serde::decode::from_read(reader)
283 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
284
285 Ok(state)
286 }
287}
288
289impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
290 type Settings = S;
291 type RecordArgs = PathBuf;
292 type RecordOutput = ();
293 type LoadArgs = PathBuf;
294
295 fn save_item<I: Serialize>(
296 &self,
297 item: I,
298 mut file: Self::RecordArgs,
299 ) -> Result<(), RecorderError> {
300 let mut writer = str2writer!(file)?;
301
302 rmp_serde::encode::write_named(&mut writer, &item)
303 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
304
305 Ok(())
306 }
307
308 fn load_item<I: DeserializeOwned>(
309 &self,
310 file: &mut Self::LoadArgs,
311 ) -> Result<I, RecorderError> {
312 let reader = str2reader!(file)?;
313 let state = rmp_serde::decode::from_read(reader)
314 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
315
316 Ok(state)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate as burn;
324 use crate::config::Config;
325 use crate::module::Ignored;
326 use crate::test_utils::SimpleLinear;
327 use crate::{
328 TestBackend,
329 module::Module,
330 record::{BinBytesRecorder, FullPrecisionSettings},
331 };
332 use burn_tensor::backend::Backend;
333
334 #[inline(always)]
335 fn file_path() -> PathBuf {
336 std::env::temp_dir()
337 .as_path()
338 .join("burn_test_file_recorder")
339 }
340
341 #[test]
342 fn test_can_save_and_load_jsongz_format() {
343 test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
344 }
345
346 #[test]
347 fn test_can_save_and_load_bin_format() {
348 test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
349 }
350
351 #[test]
352 fn test_can_save_and_load_bingz_format() {
353 test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
354 }
355
356 #[test]
357 fn test_can_save_and_load_pretty_json_format() {
358 test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
359 }
360
361 #[test]
362 fn test_can_save_and_load_mpkgz_format() {
363 test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
364 }
365
366 #[test]
367 fn test_can_save_and_load_mpk_format() {
368 test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
369 }
370
371 fn test_can_save_and_load<Recorder>(recorder: Recorder)
372 where
373 Recorder: FileRecorder<TestBackend>,
374 {
375 let device = Default::default();
376 let model_before = create_model(&device);
377 recorder
378 .record(model_before.clone().into_record(), file_path())
379 .unwrap();
380
381 let model_after =
382 create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
383
384 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
385 let model_bytes_before = byte_recorder
386 .record(model_before.into_record(), ())
387 .unwrap();
388 let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
389
390 assert_eq!(model_bytes_after, model_bytes_before);
391 }
392
393 #[derive(Config, Debug)]
394 pub enum PaddingConfig2d {
395 Same,
396 Valid,
397 Explicit(usize, usize),
398 }
399
400 #[derive(Module, Debug)]
402 pub struct Model<B: Backend> {
403 linear1: SimpleLinear<B>,
404 phantom: PhantomData<B>,
405 arr: [usize; 2],
406 int: usize,
407 ignore: Ignored<PaddingConfig2d>,
408 }
409
410 pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
411 let linear1 = SimpleLinear::new(32, 32, device);
412
413 Model {
414 linear1,
415 phantom: PhantomData,
416 arr: [2, 2],
417 int: 0,
418 ignore: Ignored(PaddingConfig2d::Same),
419 }
420 }
421}