use core::marker::PhantomData;
use std::path::PathBuf;
use burn::{
record::{PrecisionSettings, Record, Recorder, RecorderError},
tensor::backend::Backend,
};
use regex::Regex;
use serde::{Serialize, de::DeserializeOwned};
use super::reader::from_file;
#[derive(new, Debug, Default, Clone)]
pub struct SafetensorsFileRecorder<PS: PrecisionSettings> {
_settings: PhantomData<PS>,
}
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for SafetensorsFileRecorder<PS> {
type Settings = PS;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = LoadArgs;
fn save_item<I: Serialize>(
&self,
_item: I,
_file: Self::RecordArgs,
) -> Result<(), RecorderError> {
unimplemented!("save_item not implemented for SafetensorsFileRecorder")
}
fn load_item<I: DeserializeOwned>(
&self,
_file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
unimplemented!("load_item not implemented for SafetensorsFileRecorder")
}
fn load<R: Record<B>>(
&self,
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(
&args.file,
args.key_remap,
args.debug,
args.adapter_type,
)?;
Ok(R::from_item(item, device))
}
}
#[derive(Debug, Clone)]
pub struct LoadArgs {
pub file: PathBuf,
pub key_remap: Vec<(Regex, String)>,
pub debug: bool,
pub adapter_type: AdapterType,
}
#[derive(Debug, Clone, Default, Copy)]
pub enum AdapterType {
#[default]
PyTorch,
NoAdapter,
}
impl LoadArgs {
pub fn new(file: PathBuf) -> Self {
Self {
file,
key_remap: Vec::new(),
debug: false,
adapter_type: Default::default(),
}
}
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(pattern).expect("Invalid regex pattern provided");
self.key_remap.push((regex, replacement.to_string()));
self
}
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
pub fn with_adapter_type(mut self, adapter_type: AdapterType) -> Self {
self.adapter_type = adapter_type;
self
}
}
impl From<PathBuf> for LoadArgs {
fn from(val: PathBuf) -> Self {
LoadArgs::new(val)
}
}
impl From<String> for LoadArgs {
fn from(val: String) -> Self {
LoadArgs::new(val.into())
}
}
impl From<&str> for LoadArgs {
fn from(val: &str) -> Self {
LoadArgs::new(val.into())
}
}