1use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{read::GzDecoder, write::GzEncoder, Compression};
5use serde::{de::DeserializeOwned, Serialize};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9pub trait FileRecorder<B: Backend>:
11 Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13 fn file_extension() -> &'static str;
15}
16
17pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23 _settings: PhantomData<S>,
24}
25
26#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29 _settings: PhantomData<S>,
30}
31
32#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35 _settings: PhantomData<S>,
36}
37
38#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41 _settings: PhantomData<S>,
42}
43
44#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47 _settings: PhantomData<S>,
48}
49
50#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53 _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57 fn file_extension() -> &'static str {
58 "bin.gz"
59 }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62 fn file_extension() -> &'static str {
63 "bin"
64 }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67 fn file_extension() -> &'static str {
68 "json.gz"
69 }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72 fn file_extension() -> &'static str {
73 "json"
74 }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78 fn file_extension() -> &'static str {
79 "mpk.gz"
80 }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84 fn file_extension() -> &'static str {
85 "mpk"
86 }
87}
88
89macro_rules! str2reader {
90 (
91 $file:expr
92 ) => {{
93 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94 let path = $file.as_path();
95
96 File::open(path)
97 .map_err(|err| match err.kind() {
98 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99 _ => RecorderError::Unknown(err.to_string()),
100 })
101 .map(|file| BufReader::new(file))
102 }};
103}
104
105macro_rules! str2writer {
106 (
107 $file:expr
108 ) => {{
109 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110 let path = $file.as_path();
111
112 if let Some(parent) = path.parent() {
114 std::fs::create_dir_all(parent).ok();
115 }
116
117 if path.exists() {
118 log::info!("File exists, replacing");
119 std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
120 }
121
122 File::create(path)
123 .map_err(|err| match err.kind() {
124 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
125 _ => RecorderError::Unknown(err.to_string()),
126 })
127 .map(|file| BufWriter::new(file))
128 }};
129}
130
131impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
132 type Settings = S;
133 type RecordArgs = PathBuf;
134 type RecordOutput = ();
135 type LoadArgs = PathBuf;
136
137 fn save_item<I: Serialize>(
138 &self,
139 item: I,
140 mut file: Self::RecordArgs,
141 ) -> Result<(), RecorderError> {
142 let config = bin_config();
143 let writer = str2writer!(file)?;
144 let mut writer = GzEncoder::new(writer, Compression::default());
145
146 bincode::serde::encode_into_std_write(&item, &mut writer, config)
147 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
148
149 Ok(())
150 }
151
152 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
153 let reader = str2reader!(file)?;
154 let mut reader = GzDecoder::new(reader);
155 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
156 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
157
158 Ok(state)
159 }
160}
161
162impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
163 type Settings = S;
164 type RecordArgs = PathBuf;
165 type RecordOutput = ();
166 type LoadArgs = PathBuf;
167
168 fn save_item<I: Serialize>(
169 &self,
170 item: I,
171 mut file: Self::RecordArgs,
172 ) -> Result<(), RecorderError> {
173 let config = bin_config();
174 let mut writer = str2writer!(file)?;
175 bincode::serde::encode_into_std_write(&item, &mut writer, config)
176 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
177 Ok(())
178 }
179
180 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
181 let mut reader = str2reader!(file)?;
182 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
183 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
184 Ok(state)
185 }
186}
187
188impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
189 type Settings = S;
190 type RecordArgs = PathBuf;
191 type RecordOutput = ();
192 type LoadArgs = PathBuf;
193
194 fn save_item<I: Serialize>(
195 &self,
196 item: I,
197 mut file: Self::RecordArgs,
198 ) -> Result<(), RecorderError> {
199 let writer = str2writer!(file)?;
200 let writer = GzEncoder::new(writer, Compression::default());
201 serde_json::to_writer(writer, &item)
202 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
203
204 Ok(())
205 }
206
207 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
208 let reader = str2reader!(file)?;
209 let reader = GzDecoder::new(reader);
210 let state = serde_json::from_reader(reader)
211 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
212
213 Ok(state)
214 }
215}
216
217impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
218 type Settings = S;
219 type RecordArgs = PathBuf;
220 type RecordOutput = ();
221 type LoadArgs = PathBuf;
222
223 fn save_item<I: Serialize>(
224 &self,
225 item: I,
226 mut file: Self::RecordArgs,
227 ) -> Result<(), RecorderError> {
228 let writer = str2writer!(file)?;
229 serde_json::to_writer_pretty(writer, &item)
230 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
231 Ok(())
232 }
233
234 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
235 let reader = str2reader!(file)?;
236 let state = serde_json::from_reader(reader)
237 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
238
239 Ok(state)
240 }
241}
242
243impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
244 type Settings = S;
245 type RecordArgs = PathBuf;
246 type RecordOutput = ();
247 type LoadArgs = PathBuf;
248
249 fn save_item<I: Serialize>(
250 &self,
251 item: I,
252 mut file: Self::RecordArgs,
253 ) -> Result<(), RecorderError> {
254 let writer = str2writer!(file)?;
255 let mut writer = GzEncoder::new(writer, Compression::default());
256 rmp_serde::encode::write_named(&mut writer, &item)
257 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
258
259 Ok(())
260 }
261
262 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
263 let reader = str2reader!(file)?;
264 let reader = GzDecoder::new(reader);
265 let state = rmp_serde::decode::from_read(reader)
266 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
267
268 Ok(state)
269 }
270}
271
272impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
273 type Settings = S;
274 type RecordArgs = PathBuf;
275 type RecordOutput = ();
276 type LoadArgs = PathBuf;
277
278 fn save_item<I: Serialize>(
279 &self,
280 item: I,
281 mut file: Self::RecordArgs,
282 ) -> Result<(), RecorderError> {
283 let mut writer = str2writer!(file)?;
284
285 rmp_serde::encode::write_named(&mut writer, &item)
286 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
287
288 Ok(())
289 }
290
291 fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
292 let reader = str2reader!(file)?;
293 let state = rmp_serde::decode::from_read(reader)
294 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
295
296 Ok(state)
297 }
298}
299
300#[cfg(test)]
301mod tests {
302
303 use burn_tensor::backend::Backend;
304
305 use super::*;
306 use crate::{
307 module::Module,
308 nn::{
309 conv::{Conv2d, Conv2dConfig},
310 Linear, LinearConfig,
311 },
312 record::{BinBytesRecorder, FullPrecisionSettings},
313 TestBackend,
314 };
315
316 use crate as burn;
317
318 #[inline(always)]
319 fn file_path() -> PathBuf {
320 std::env::temp_dir()
321 .as_path()
322 .join("burn_test_file_recorder")
323 }
324
325 #[test]
326 fn test_can_save_and_load_jsongz_format() {
327 test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
328 }
329
330 #[test]
331 fn test_can_save_and_load_bin_format() {
332 test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
333 }
334
335 #[test]
336 fn test_can_save_and_load_bingz_format() {
337 test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
338 }
339
340 #[test]
341 fn test_can_save_and_load_pretty_json_format() {
342 test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
343 }
344
345 #[test]
346 fn test_can_save_and_load_mpkgz_format() {
347 test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
348 }
349
350 #[test]
351 fn test_can_save_and_load_mpk_format() {
352 test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
353 }
354
355 fn test_can_save_and_load<Recorder>(recorder: Recorder)
356 where
357 Recorder: FileRecorder<TestBackend>,
358 {
359 let device = Default::default();
360 let model_before = create_model(&device);
361 recorder
362 .record(model_before.clone().into_record(), file_path())
363 .unwrap();
364
365 let model_after =
366 create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
367
368 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
369 let model_bytes_before = byte_recorder
370 .record(model_before.into_record(), ())
371 .unwrap();
372 let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
373
374 assert_eq!(model_bytes_after, model_bytes_before);
375 }
376
377 #[derive(Module, Debug)]
378 pub struct Model<B: Backend> {
379 conv2d1: Conv2d<B>,
380 linear1: Linear<B>,
381 phantom: core::marker::PhantomData<B>,
382 }
383
384 pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
385 let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(device);
386
387 let linear1 = LinearConfig::new(32, 32).with_bias(true).init(device);
388
389 Model {
390 conv2d1,
391 linear1,
392 phantom: core::marker::PhantomData,
393 }
394 }
395}