use core::marker::PhantomData;
use std::path::PathBuf;
use burn::{
record::{PrecisionSettings, Record, Recorder, RecorderError},
tensor::backend::Backend,
};
use regex::Regex;
use serde::{Serialize, de::DeserializeOwned};
use super::reader::from_file;
#[derive(new, Debug, Default, Clone)]
pub struct PyTorchFileRecorder<PS: PrecisionSettings> {
_settings: PhantomData<PS>,
}
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS> {
type Settings = PS;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = LoadArgs;
fn save_item<I: Serialize>(
&self,
_item: I,
_file: Self::RecordArgs,
) -> Result<(), RecorderError> {
unimplemented!("Save operations are not supported by PyTorchFileRecorder.")
}
fn load_item<I: DeserializeOwned>(
&self,
_file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
unimplemented!("load_item is not implemented for PyTorchFileRecorder; use load instead.")
}
fn load<R: Record<B>>(
&self,
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(
&args.file,
args.key_remap,
args.top_level_key.as_deref(), args.debug,
)?;
Ok(R::from_item(item, device))
}
}
#[derive(Debug, Clone)]
pub struct LoadArgs {
pub file: PathBuf,
pub key_remap: Vec<(Regex, String)>,
pub top_level_key: Option<String>,
pub debug: bool,
}
impl LoadArgs {
pub fn new(file: PathBuf) -> Self {
Self {
file,
key_remap: Vec::new(),
top_level_key: None,
debug: false,
}
}
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(pattern).expect("Invalid regex pattern provided to with_key_remap");
self.key_remap.push((regex, replacement.into()));
self
}
pub fn with_top_level_key(mut self, key: &str) -> Self {
self.top_level_key = Some(key.into());
self
}
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
}
impl From<PathBuf> for LoadArgs {
fn from(val: PathBuf) -> Self {
LoadArgs::new(val)
}
}
impl From<String> for LoadArgs {
fn from(val: String) -> Self {
LoadArgs::new(val.into())
}
}
impl From<&str> for LoadArgs {
fn from(val: &str) -> Self {
LoadArgs::new(val.into())
}
}