1use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
2use burn_tensor::backend::Backend;
3use core::marker::PhantomData;
4use flate2::{Compression, read::GzDecoder, write::GzEncoder};
5use serde::{Serialize, de::DeserializeOwned};
6use std::io::{BufReader, BufWriter};
7use std::{fs::File, path::PathBuf};
8
9pub trait FileRecorder<B: Backend>:
11 Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
12{
13 fn file_extension() -> &'static str;
15}
16
17pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
19
20#[derive(new, Debug, Default, Clone)]
22pub struct BinFileRecorder<S: PrecisionSettings> {
23 _settings: PhantomData<S>,
24}
25
26#[derive(new, Debug, Default, Clone)]
28pub struct BinGzFileRecorder<S: PrecisionSettings> {
29 _settings: PhantomData<S>,
30}
31
32#[derive(new, Debug, Default, Clone)]
34pub struct JsonGzFileRecorder<S: PrecisionSettings> {
35 _settings: PhantomData<S>,
36}
37
38#[derive(new, Debug, Default, Clone)]
40pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
41 _settings: PhantomData<S>,
42}
43
44#[derive(new, Debug, Default, Clone)]
46pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
47 _settings: PhantomData<S>,
48}
49
50#[derive(new, Debug, Default, Clone)]
52pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
53 _settings: PhantomData<S>,
54}
55
56impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
57 fn file_extension() -> &'static str {
58 "bin.gz"
59 }
60}
61impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
62 fn file_extension() -> &'static str {
63 "bin"
64 }
65}
66impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
67 fn file_extension() -> &'static str {
68 "json.gz"
69 }
70}
71impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
72 fn file_extension() -> &'static str {
73 "json"
74 }
75}
76
77impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
78 fn file_extension() -> &'static str {
79 "mpk.gz"
80 }
81}
82
83impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
84 fn file_extension() -> &'static str {
85 "mpk"
86 }
87}
88
89macro_rules! str2reader {
90 (
91 $file:expr
92 ) => {{
93 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
94 let path = $file.as_path();
95
96 File::open(path)
97 .map_err(|err| match err.kind() {
98 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
99 _ => RecorderError::Unknown(err.to_string()),
100 })
101 .map(|file| BufReader::new(file))
102 }};
103}
104
105macro_rules! str2writer {
106 (
107 $file:expr
108 ) => {{
109 $file.set_extension(<Self as FileRecorder<B>>::file_extension());
110 let path = $file.as_path();
111
112 if let Some(parent) = path.parent() {
114 std::fs::create_dir_all(parent).ok();
115 }
116
117 if path.exists() {
118 log::info!("File exists, replacing");
119 std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
120 }
121
122 File::create(path)
123 .map_err(|err| match err.kind() {
124 std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
125 _ => RecorderError::Unknown(err.to_string()),
126 })
127 .map(|file| BufWriter::new(file))
128 }};
129}
130
131impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
132 type Settings = S;
133 type RecordArgs = PathBuf;
134 type RecordOutput = ();
135 type LoadArgs = PathBuf;
136
137 fn save_item<I: Serialize>(
138 &self,
139 item: I,
140 mut file: Self::RecordArgs,
141 ) -> Result<(), RecorderError> {
142 let config = bin_config();
143 let writer = str2writer!(file)?;
144 let mut writer = GzEncoder::new(writer, Compression::default());
145
146 bincode::serde::encode_into_std_write(&item, &mut writer, config)
147 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
148
149 Ok(())
150 }
151
152 fn load_item<I: DeserializeOwned>(
153 &self,
154 file: &mut Self::LoadArgs,
155 ) -> Result<I, RecorderError> {
156 let reader = str2reader!(file)?;
157 let mut reader = GzDecoder::new(reader);
158 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
159 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
160
161 Ok(state)
162 }
163}
164
165impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
166 type Settings = S;
167 type RecordArgs = PathBuf;
168 type RecordOutput = ();
169 type LoadArgs = PathBuf;
170
171 fn save_item<I: Serialize>(
172 &self,
173 item: I,
174 mut file: Self::RecordArgs,
175 ) -> Result<(), RecorderError> {
176 let config = bin_config();
177 let mut writer = str2writer!(file)?;
178 bincode::serde::encode_into_std_write(&item, &mut writer, config)
179 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
180 Ok(())
181 }
182
183 fn load_item<I: DeserializeOwned>(
184 &self,
185 file: &mut Self::LoadArgs,
186 ) -> Result<I, RecorderError> {
187 let mut reader = str2reader!(file)?;
188 let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
189 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
190 Ok(state)
191 }
192}
193
194impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
195 type Settings = S;
196 type RecordArgs = PathBuf;
197 type RecordOutput = ();
198 type LoadArgs = PathBuf;
199
200 fn save_item<I: Serialize>(
201 &self,
202 item: I,
203 mut file: Self::RecordArgs,
204 ) -> Result<(), RecorderError> {
205 let writer = str2writer!(file)?;
206 let writer = GzEncoder::new(writer, Compression::default());
207 serde_json::to_writer(writer, &item)
208 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
209
210 Ok(())
211 }
212
213 fn load_item<I: DeserializeOwned>(
214 &self,
215 file: &mut Self::LoadArgs,
216 ) -> Result<I, RecorderError> {
217 let reader = str2reader!(file)?;
218 let reader = GzDecoder::new(reader);
219 let state = serde_json::from_reader(reader)
220 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
221
222 Ok(state)
223 }
224}
225
226impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
227 type Settings = S;
228 type RecordArgs = PathBuf;
229 type RecordOutput = ();
230 type LoadArgs = PathBuf;
231
232 fn save_item<I: Serialize>(
233 &self,
234 item: I,
235 mut file: Self::RecordArgs,
236 ) -> Result<(), RecorderError> {
237 let writer = str2writer!(file)?;
238 serde_json::to_writer_pretty(writer, &item)
239 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
240 Ok(())
241 }
242
243 fn load_item<I: DeserializeOwned>(
244 &self,
245 file: &mut Self::LoadArgs,
246 ) -> Result<I, RecorderError> {
247 let reader = str2reader!(file)?;
248 let state = serde_json::from_reader(reader)
249 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
250
251 Ok(state)
252 }
253}
254
255impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
256 type Settings = S;
257 type RecordArgs = PathBuf;
258 type RecordOutput = ();
259 type LoadArgs = PathBuf;
260
261 fn save_item<I: Serialize>(
262 &self,
263 item: I,
264 mut file: Self::RecordArgs,
265 ) -> Result<(), RecorderError> {
266 let writer = str2writer!(file)?;
267 let mut writer = GzEncoder::new(writer, Compression::default());
268 rmp_serde::encode::write_named(&mut writer, &item)
269 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
270
271 Ok(())
272 }
273
274 fn load_item<I: DeserializeOwned>(
275 &self,
276 file: &mut Self::LoadArgs,
277 ) -> Result<I, RecorderError> {
278 let reader = str2reader!(file)?;
279 let reader = GzDecoder::new(reader);
280 let state = rmp_serde::decode::from_read(reader)
281 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
282
283 Ok(state)
284 }
285}
286
287impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
288 type Settings = S;
289 type RecordArgs = PathBuf;
290 type RecordOutput = ();
291 type LoadArgs = PathBuf;
292
293 fn save_item<I: Serialize>(
294 &self,
295 item: I,
296 mut file: Self::RecordArgs,
297 ) -> Result<(), RecorderError> {
298 let mut writer = str2writer!(file)?;
299
300 rmp_serde::encode::write_named(&mut writer, &item)
301 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
302
303 Ok(())
304 }
305
306 fn load_item<I: DeserializeOwned>(
307 &self,
308 file: &mut Self::LoadArgs,
309 ) -> Result<I, RecorderError> {
310 let reader = str2reader!(file)?;
311 let state = rmp_serde::decode::from_read(reader)
312 .map_err(|err| RecorderError::Unknown(err.to_string()))?;
313
314 Ok(state)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320
321 use burn_tensor::backend::Backend;
322
323 use super::*;
324 use crate::{
325 TestBackend,
326 module::Module,
327 nn::{
328 Linear, LinearConfig,
329 conv::{Conv2d, Conv2dConfig},
330 },
331 record::{BinBytesRecorder, FullPrecisionSettings},
332 };
333
334 use crate as burn;
335
336 #[inline(always)]
337 fn file_path() -> PathBuf {
338 std::env::temp_dir()
339 .as_path()
340 .join("burn_test_file_recorder")
341 }
342
343 #[test]
344 fn test_can_save_and_load_jsongz_format() {
345 test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
346 }
347
348 #[test]
349 fn test_can_save_and_load_bin_format() {
350 test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
351 }
352
353 #[test]
354 fn test_can_save_and_load_bingz_format() {
355 test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
356 }
357
358 #[test]
359 fn test_can_save_and_load_pretty_json_format() {
360 test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
361 }
362
363 #[test]
364 fn test_can_save_and_load_mpkgz_format() {
365 test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
366 }
367
368 #[test]
369 fn test_can_save_and_load_mpk_format() {
370 test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
371 }
372
373 fn test_can_save_and_load<Recorder>(recorder: Recorder)
374 where
375 Recorder: FileRecorder<TestBackend>,
376 {
377 let device = Default::default();
378 let model_before = create_model(&device);
379 recorder
380 .record(model_before.clone().into_record(), file_path())
381 .unwrap();
382
383 let model_after =
384 create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
385
386 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
387 let model_bytes_before = byte_recorder
388 .record(model_before.into_record(), ())
389 .unwrap();
390 let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
391
392 assert_eq!(model_bytes_after, model_bytes_before);
393 }
394
395 #[derive(Module, Debug)]
396 pub struct Model<B: Backend> {
397 conv2d1: Conv2d<B>,
398 linear1: Linear<B>,
399 phantom: core::marker::PhantomData<B>,
400 }
401
402 pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
403 let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(device);
404
405 let linear1 = LinearConfig::new(32, 32).with_bias(true).init(device);
406
407 Model {
408 conv2d1,
409 linear1,
410 phantom: core::marker::PhantomData,
411 }
412 }
413}