use burn::{
module::Param,
record::{PrecisionSettings, Record},
tensor::{Tensor, backend::Backend},
};
use burn::record::serde::{
adapter::{BurnModuleAdapter, DefaultAdapter},
data::NestedValue,
ser::Serializer,
};
use serde::Serialize;
pub struct PyTorchAdapter<PS: PrecisionSettings, B: Backend> {
_precision_settings: std::marker::PhantomData<(PS, B)>,
}
impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS, B> {
fn adapt_linear(data: NestedValue) -> NestedValue {
let mut map = data.as_map().expect("Failed to get map from NestedValue");
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
let weight: Param<Tensor<B, 2>> = weight
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
.expect("Failed to deserialize weight");
let weight = weight.set_require_grad(false);
let weight_transposed = Param::from_tensor(weight.val().transpose());
map.insert(
"weight".to_owned(),
serialize::<PS, _, 2>(weight_transposed),
);
NestedValue::Map(map)
}
fn adapt_group_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
}
fn serialize<PS, B, const D: usize>(val: Param<Tensor<B, D>>) -> NestedValue
where
B: Backend,
PS: PrecisionSettings,
{
let serializer = Serializer::new();
val.into_item::<PS>()
.serialize(serializer)
.expect("Failed to serialize the item")
}
fn rename_weight_bias(data: NestedValue) -> NestedValue {
let mut map = data.as_map().expect("Failed to get map from NestedValue");
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
map.insert("gamma".to_owned(), weight);
let bias = map
.remove("bias")
.expect("Failed to find 'bias' key in map");
map.insert("beta".to_owned(), bias);
NestedValue::Map(map)
}