use std::collections::HashMap;
use std::fs;
use std::path::Path;
use anyhow::Context;
use burn::module::RunningState;
use burn::prelude::*;
use half::{bf16, f16};
use safetensors::{Dtype, SafeTensors};
use crate::config::SeizureTransformerConfig;
use crate::model::{FusedMultiheadAttention, SeizureTransformer, TransformerEncoderLayer};
pub struct WeightMap {
pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
impl WeightMap {
pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let bytes = fs::read(path)?;
let st = SafeTensors::deserialize(&bytes)?;
let mut tensors = HashMap::new();
for name in st.names() {
let t = st.tensor(name)?;
let shape = t.shape().to_vec();
let data = match t.dtype() {
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(t.data()).to_vec(),
Dtype::F16 => {
let v = bytemuck::cast_slice::<u8, f16>(t.data());
v.iter().map(|x| x.to_f32()).collect()
}
Dtype::BF16 => {
let v = bytemuck::cast_slice::<u8, bf16>(t.data());
v.iter().map(|x| x.to_f32()).collect()
}
dt => anyhow::bail!("unsupported dtype for {name}: {dt:?}"),
};
tensors.insert(name.to_string(), (data, shape));
}
Ok(Self { tensors })
}
pub fn take<B: Backend, const N: usize>(
&mut self,
key: &str,
device: &B::Device,
) -> anyhow::Result<Tensor<B, N>> {
let (data, shape) = self
.tensors
.remove(key)
.ok_or_else(|| anyhow::anyhow!("missing key: {key}"))?;
if shape.len() != N {
anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
}
Ok(Tensor::from_data(TensorData::new(data, shape), device))
}
pub fn maybe_take<B: Backend, const N: usize>(
&mut self,
key: &str,
device: &B::Device,
) -> Option<Tensor<B, N>> {
self.take::<B, N>(key, device).ok()
}
}
fn set_linear_wb<B: Backend>(
linear: &mut burn::nn::Linear<B>,
w_torch: Tensor<B, 2>,
b_torch: Tensor<B, 1>,
) {
linear.weight = linear.weight.clone().map(|_| w_torch.transpose());
if let Some(ref bias) = linear.bias {
linear.bias = Some(bias.clone().map(|_| b_torch));
}
}
fn set_conv1d_wb<B: Backend>(
conv: &mut burn::nn::conv::Conv1d<B>,
w: Tensor<B, 3>,
b: Tensor<B, 1>,
) {
conv.weight = conv.weight.clone().map(|_| w);
if let Some(ref bias) = conv.bias {
conv.bias = Some(bias.clone().map(|_| b));
}
}
fn set_layernorm<B: Backend>(ln: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
ln.gamma = ln.gamma.clone().map(|_| w);
if let Some(ref beta) = ln.beta {
ln.beta = Some(beta.clone().map(|_| b));
}
}
fn set_batchnorm<B: Backend>(
bn: &mut burn::nn::BatchNorm<B>,
gamma: Tensor<B, 1>,
beta: Tensor<B, 1>,
running_mean: Tensor<B, 1>,
running_var: Tensor<B, 1>,
) {
bn.gamma = bn.gamma.clone().map(|_| gamma);
bn.beta = bn.beta.clone().map(|_| beta);
bn.running_mean = RunningState::new(running_mean);
bn.running_var = RunningState::new(running_var);
}
fn load_mha<B: Backend>(
wm: &mut WeightMap,
mha: &mut FusedMultiheadAttention<B>,
prefix: &str,
device: &B::Device,
) -> anyhow::Result<()> {
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 2>(&format!("{prefix}.in_proj_weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.in_proj_bias"), device),
) {
set_linear_wb(&mut mha.in_proj, w, b);
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 2>(&format!("{prefix}.out_proj.weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.out_proj.bias"), device),
) {
set_linear_wb(&mut mha.out_proj, w, b);
}
Ok(())
}
fn load_transformer_layer<B: Backend>(
wm: &mut WeightMap,
layer: &mut TransformerEncoderLayer<B>,
prefix: &str,
device: &B::Device,
) -> anyhow::Result<()> {
load_mha(
wm,
&mut layer.self_attn,
&format!("{prefix}.self_attn"),
device,
)?;
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 2>(&format!("{prefix}.linear1.weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.linear1.bias"), device),
) {
set_linear_wb(&mut layer.linear1, w, b);
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 2>(&format!("{prefix}.linear2.weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.linear2.bias"), device),
) {
set_linear_wb(&mut layer.linear2, w, b);
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 1>(&format!("{prefix}.norm1.weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.norm1.bias"), device),
) {
set_layernorm(&mut layer.norm1, w, b);
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 1>(&format!("{prefix}.norm2.weight"), device),
wm.maybe_take::<B, 1>(&format!("{prefix}.norm2.bias"), device),
) {
set_layernorm(&mut layer.norm2, w, b);
}
Ok(())
}
pub fn load_model<B: Backend>(
cfg: &SeizureTransformerConfig,
wm: &mut WeightMap,
device: &B::Device,
) -> anyhow::Result<SeizureTransformer<B>> {
let mut model = SeizureTransformer::new(cfg, device);
for (i, conv) in model.encoder.convs.iter_mut().enumerate() {
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 3>(&format!("encoder.convs.{i}.weight"), device),
wm.maybe_take::<B, 1>(&format!("encoder.convs.{i}.bias"), device),
) {
set_conv1d_wb(conv, w, b);
}
}
for (i, block) in model.res_cnn_stack.blocks.iter_mut().enumerate() {
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 3>(&format!("res_cnn_stack.members.{i}.conv1.weight"), device),
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.conv1.bias"), device),
) {
set_conv1d_wb(&mut block.conv1, w, b);
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 3>(&format!("res_cnn_stack.members.{i}.conv2.weight"), device),
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.conv2.bias"), device),
) {
set_conv1d_wb(&mut block.conv2, w, b);
}
if let (Some(g), Some(be), Some(rm), Some(rv)) = (
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm1.weight"), device),
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm1.bias"), device),
wm.maybe_take::<B, 1>(
&format!("res_cnn_stack.members.{i}.norm1.running_mean"),
device,
),
wm.maybe_take::<B, 1>(
&format!("res_cnn_stack.members.{i}.norm1.running_var"),
device,
),
) {
set_batchnorm(&mut block.norm1, g, be, rm, rv);
}
if let (Some(g), Some(be), Some(rm), Some(rv)) = (
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm2.weight"), device),
wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm2.bias"), device),
wm.maybe_take::<B, 1>(
&format!("res_cnn_stack.members.{i}.norm2.running_mean"),
device,
),
wm.maybe_take::<B, 1>(
&format!("res_cnn_stack.members.{i}.norm2.running_var"),
device,
),
) {
set_batchnorm(&mut block.norm2, g, be, rm, rv);
}
}
for (i, layer) in model.transformer_encoder.iter_mut().enumerate() {
load_transformer_layer(
wm,
layer,
&format!("transformer_encoder.layers.{i}"),
device,
)?;
}
for (i, conv) in model.decoder_d.convs.iter_mut().enumerate() {
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 3>(&format!("decoder_d.convs.{i}.weight"), device),
wm.maybe_take::<B, 1>(&format!("decoder_d.convs.{i}.bias"), device),
) {
set_conv1d_wb(conv, w, b);
}
}
if let (Some(w), Some(b)) = (
wm.maybe_take::<B, 3>("conv_d.weight", device),
wm.maybe_take::<B, 1>("conv_d.bias", device),
) {
set_conv1d_wb(&mut model.conv_d, w, b);
}
Ok(model)
}
pub fn load_model_from_file<B: Backend>(
cfg: &SeizureTransformerConfig,
path: impl AsRef<Path>,
device: &B::Device,
) -> anyhow::Result<SeizureTransformer<B>> {
let mut wm = WeightMap::from_file(path.as_ref()).with_context(|| {
format!(
"failed loading safetensors from {}",
path.as_ref().display()
)
})?;
load_model(cfg, &mut wm, device)
}