seizuretransformer 0.0.1

SeizureTransformer EEG model in Rust (Burn + wgpu)
Documentation
use std::collections::HashMap;
use std::fs;
use std::path::Path;

use anyhow::Context;
use burn::module::RunningState;
use burn::prelude::*;
use half::{bf16, f16};
use safetensors::{Dtype, SafeTensors};

use crate::config::SeizureTransformerConfig;
use crate::model::{FusedMultiheadAttention, SeizureTransformer, TransformerEncoderLayer};

pub struct WeightMap {
    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}

impl WeightMap {
    pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
        let bytes = fs::read(path)?;
        let st = SafeTensors::deserialize(&bytes)?;

        let mut tensors = HashMap::new();
        for name in st.names() {
            let t = st.tensor(name)?;
            let shape = t.shape().to_vec();
            let data = match t.dtype() {
                Dtype::F32 => bytemuck::cast_slice::<u8, f32>(t.data()).to_vec(),
                Dtype::F16 => {
                    let v = bytemuck::cast_slice::<u8, f16>(t.data());
                    v.iter().map(|x| x.to_f32()).collect()
                }
                Dtype::BF16 => {
                    let v = bytemuck::cast_slice::<u8, bf16>(t.data());
                    v.iter().map(|x| x.to_f32()).collect()
                }
                dt => anyhow::bail!("unsupported dtype for {name}: {dt:?}"),
            };
            tensors.insert(name.to_string(), (data, shape));
        }

        Ok(Self { tensors })
    }

    pub fn take<B: Backend, const N: usize>(
        &mut self,
        key: &str,
        device: &B::Device,
    ) -> anyhow::Result<Tensor<B, N>> {
        let (data, shape) = self
            .tensors
            .remove(key)
            .ok_or_else(|| anyhow::anyhow!("missing key: {key}"))?;
        if shape.len() != N {
            anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
        }
        Ok(Tensor::from_data(TensorData::new(data, shape), device))
    }

    pub fn maybe_take<B: Backend, const N: usize>(
        &mut self,
        key: &str,
        device: &B::Device,
    ) -> Option<Tensor<B, N>> {
        self.take::<B, N>(key, device).ok()
    }
}

fn set_linear_wb<B: Backend>(
    linear: &mut burn::nn::Linear<B>,
    w_torch: Tensor<B, 2>,
    b_torch: Tensor<B, 1>,
) {
    linear.weight = linear.weight.clone().map(|_| w_torch.transpose());
    if let Some(ref bias) = linear.bias {
        linear.bias = Some(bias.clone().map(|_| b_torch));
    }
}

fn set_conv1d_wb<B: Backend>(
    conv: &mut burn::nn::conv::Conv1d<B>,
    w: Tensor<B, 3>,
    b: Tensor<B, 1>,
) {
    conv.weight = conv.weight.clone().map(|_| w);
    if let Some(ref bias) = conv.bias {
        conv.bias = Some(bias.clone().map(|_| b));
    }
}

fn set_layernorm<B: Backend>(ln: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
    ln.gamma = ln.gamma.clone().map(|_| w);
    if let Some(ref beta) = ln.beta {
        ln.beta = Some(beta.clone().map(|_| b));
    }
}

fn set_batchnorm<B: Backend>(
    bn: &mut burn::nn::BatchNorm<B>,
    gamma: Tensor<B, 1>,
    beta: Tensor<B, 1>,
    running_mean: Tensor<B, 1>,
    running_var: Tensor<B, 1>,
) {
    bn.gamma = bn.gamma.clone().map(|_| gamma);
    bn.beta = bn.beta.clone().map(|_| beta);
    bn.running_mean = RunningState::new(running_mean);
    bn.running_var = RunningState::new(running_var);
}

fn load_mha<B: Backend>(
    wm: &mut WeightMap,
    mha: &mut FusedMultiheadAttention<B>,
    prefix: &str,
    device: &B::Device,
) -> anyhow::Result<()> {
    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 2>(&format!("{prefix}.in_proj_weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.in_proj_bias"), device),
    ) {
        set_linear_wb(&mut mha.in_proj, w, b);
    }

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 2>(&format!("{prefix}.out_proj.weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.out_proj.bias"), device),
    ) {
        set_linear_wb(&mut mha.out_proj, w, b);
    }

    Ok(())
}

fn load_transformer_layer<B: Backend>(
    wm: &mut WeightMap,
    layer: &mut TransformerEncoderLayer<B>,
    prefix: &str,
    device: &B::Device,
) -> anyhow::Result<()> {
    load_mha(
        wm,
        &mut layer.self_attn,
        &format!("{prefix}.self_attn"),
        device,
    )?;

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 2>(&format!("{prefix}.linear1.weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.linear1.bias"), device),
    ) {
        set_linear_wb(&mut layer.linear1, w, b);
    }

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 2>(&format!("{prefix}.linear2.weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.linear2.bias"), device),
    ) {
        set_linear_wb(&mut layer.linear2, w, b);
    }

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 1>(&format!("{prefix}.norm1.weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.norm1.bias"), device),
    ) {
        set_layernorm(&mut layer.norm1, w, b);
    }

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 1>(&format!("{prefix}.norm2.weight"), device),
        wm.maybe_take::<B, 1>(&format!("{prefix}.norm2.bias"), device),
    ) {
        set_layernorm(&mut layer.norm2, w, b);
    }

    Ok(())
}

pub fn load_model<B: Backend>(
    cfg: &SeizureTransformerConfig,
    wm: &mut WeightMap,
    device: &B::Device,
) -> anyhow::Result<SeizureTransformer<B>> {
    let mut model = SeizureTransformer::new(cfg, device);

    for (i, conv) in model.encoder.convs.iter_mut().enumerate() {
        if let (Some(w), Some(b)) = (
            wm.maybe_take::<B, 3>(&format!("encoder.convs.{i}.weight"), device),
            wm.maybe_take::<B, 1>(&format!("encoder.convs.{i}.bias"), device),
        ) {
            set_conv1d_wb(conv, w, b);
        }
    }

    for (i, block) in model.res_cnn_stack.blocks.iter_mut().enumerate() {
        if let (Some(w), Some(b)) = (
            wm.maybe_take::<B, 3>(&format!("res_cnn_stack.members.{i}.conv1.weight"), device),
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.conv1.bias"), device),
        ) {
            set_conv1d_wb(&mut block.conv1, w, b);
        }
        if let (Some(w), Some(b)) = (
            wm.maybe_take::<B, 3>(&format!("res_cnn_stack.members.{i}.conv2.weight"), device),
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.conv2.bias"), device),
        ) {
            set_conv1d_wb(&mut block.conv2, w, b);
        }

        if let (Some(g), Some(be), Some(rm), Some(rv)) = (
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm1.weight"), device),
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm1.bias"), device),
            wm.maybe_take::<B, 1>(
                &format!("res_cnn_stack.members.{i}.norm1.running_mean"),
                device,
            ),
            wm.maybe_take::<B, 1>(
                &format!("res_cnn_stack.members.{i}.norm1.running_var"),
                device,
            ),
        ) {
            set_batchnorm(&mut block.norm1, g, be, rm, rv);
        }

        if let (Some(g), Some(be), Some(rm), Some(rv)) = (
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm2.weight"), device),
            wm.maybe_take::<B, 1>(&format!("res_cnn_stack.members.{i}.norm2.bias"), device),
            wm.maybe_take::<B, 1>(
                &format!("res_cnn_stack.members.{i}.norm2.running_mean"),
                device,
            ),
            wm.maybe_take::<B, 1>(
                &format!("res_cnn_stack.members.{i}.norm2.running_var"),
                device,
            ),
        ) {
            set_batchnorm(&mut block.norm2, g, be, rm, rv);
        }
    }

    for (i, layer) in model.transformer_encoder.iter_mut().enumerate() {
        load_transformer_layer(
            wm,
            layer,
            &format!("transformer_encoder.layers.{i}"),
            device,
        )?;
    }

    for (i, conv) in model.decoder_d.convs.iter_mut().enumerate() {
        if let (Some(w), Some(b)) = (
            wm.maybe_take::<B, 3>(&format!("decoder_d.convs.{i}.weight"), device),
            wm.maybe_take::<B, 1>(&format!("decoder_d.convs.{i}.bias"), device),
        ) {
            set_conv1d_wb(conv, w, b);
        }
    }

    if let (Some(w), Some(b)) = (
        wm.maybe_take::<B, 3>("conv_d.weight", device),
        wm.maybe_take::<B, 1>("conv_d.bias", device),
    ) {
        set_conv1d_wb(&mut model.conv_d, w, b);
    }

    Ok(model)
}

pub fn load_model_from_file<B: Backend>(
    cfg: &SeizureTransformerConfig,
    path: impl AsRef<Path>,
    device: &B::Device,
) -> anyhow::Result<SeizureTransformer<B>> {
    let mut wm = WeightMap::from_file(path.as_ref()).with_context(|| {
        format!(
            "failed loading safetensors from {}",
            path.as_ref().display()
        )
    })?;
    load_model(cfg, &mut wm, device)
}