eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// Python parity test for EEGPT.
/// Requires /tmp/eegpt_parity.safetensors from Python.

use burn::backend::NdArray as B;
use burn::prelude::*;
use std::collections::HashMap;

fn load_data() -> Option<HashMap<String, (Vec<f32>, Vec<usize>)>> {
    let path = "/tmp/eegpt_parity.safetensors";
    if !std::path::Path::new(path).exists() {
        eprintln!("Skipping: {path} not found"); return None;
    }
    let bytes = std::fs::read(path).unwrap();
    let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
    let mut m = HashMap::new();
    for (k, v) in st.tensors() {
        let shape: Vec<usize> = v.shape().to_vec();
        let f32s: Vec<f32> = v.data().chunks_exact(4)
            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])).collect();
        m.insert(k.to_string(), (f32s, shape));
    }
    Some(m)
}

fn t1(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 1> {
    let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t2(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 2> {
    let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t3(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 3> {
    let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t4(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 4> {
    let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}

macro_rules! set_linear_wb {
    ($d:expr, $l:expr, $wk:expr, $bk:expr, $dev:expr) => {
        let w = t2($d, $wk, $dev);
        let b = t1($d, $bk, $dev);
        $l.weight = $l.weight.clone().map(|_| w.transpose());
        if let Some(ref bias) = $l.bias { $l.bias = Some(bias.clone().map(|_| b)); }
    };
}

macro_rules! set_ln {
    ($d:expr, $n:expr, $wk:expr, $bk:expr, $dev:expr) => {
        let w = t1($d, $wk, $dev);
        let b = t1($d, $bk, $dev);
        $n.gamma = $n.gamma.clone().map(|_| w);
        if let Some(ref beta) = $n.beta { $n.beta = Some(beta.clone().map(|_| b)); }
    };
}

#[test]
fn test_python_parity() {
    let dev = burn::backend::ndarray::NdArrayDevice::Cpu;
    let data = match load_data() { Some(d) => d, None => return };

    let (inp_data, inp_shape) = &data["_input"];
    let (out_data, _) = &data["_output"];
    let n_chans = inp_shape[1];
    let n_times = inp_shape[2];

    let mut model = eegpt_rs::model::eegpt::EEGPT::<B>::new(
        4, n_chans, n_times, 64, 32, 4, 512, 2, 8, 4.0, true, 62, 16, 1e-6, &dev,
    );

    // Load weights
    let te = &mut model.target_encoder;

    // Summary token
    let st = t3(&data, "target_encoder.summary_token", &dev);
    te.summary_token = te.summary_token.clone().map(|_| st);

    // Patch embed conv
    let w = t4(&data, "target_encoder.patch_embed.proj.weight", &dev);
    let b = t1(&data, "target_encoder.patch_embed.proj.bias", &dev);
    te.patch_embed.proj.weight = te.patch_embed.proj.weight.clone().map(|_| w);
    if let Some(ref bias) = te.patch_embed.proj.bias {
        te.patch_embed.proj.bias = Some(bias.clone().map(|_| b));
    }

    // Channel embedding
    let w = t2(&data, "target_encoder.chan_embed.weight", &dev);
    te.chan_embed.weight = te.chan_embed.weight.clone().map(|_| w);

    // Blocks
    for i in 0..2 {
        let block = &mut te.blocks[i];
        let p = format!("target_encoder.blocks.{i}");
        set_ln!(&data, block.norm1, &format!("{p}.norm1.weight"), &format!("{p}.norm1.bias"), &dev);
        set_linear_wb!(&data, block.attn.qkv, &format!("{p}.attn.qkv.weight"), &format!("{p}.attn.qkv.bias"), &dev);
        set_linear_wb!(&data, block.attn.proj, &format!("{p}.attn.proj.weight"), &format!("{p}.attn.proj.bias"), &dev);
        set_ln!(&data, block.norm2, &format!("{p}.norm2.weight"), &format!("{p}.norm2.bias"), &dev);
        set_linear_wb!(&data, block.mlp_fc1, &format!("{p}.mlp.fc1.weight"), &format!("{p}.mlp.fc1.bias"), &dev);
        set_linear_wb!(&data, block.mlp_fc2, &format!("{p}.mlp.fc2.weight"), &format!("{p}.mlp.fc2.bias"), &dev);
    }

    // Norm
    set_ln!(&data, te.norm, "target_encoder.norm.weight", "target_encoder.norm.bias", &dev);

    // Probe
    set_linear_wb!(&data, model.probe1, "final_layer.probe1.weight", "final_layer.probe1.bias", &dev);
    set_linear_wb!(&data, model.probe2, "final_layer.probe2.weight", "final_layer.probe2.bias", &dev);

    // Load chans_id — stored as int64 in safetensors, need special handling
    let chan_ids = {
        let path = "/tmp/eegpt_parity.safetensors";
        let bytes = std::fs::read(path).unwrap();
        let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
        if let Ok(view) = st.tensor("chans_id") {
            let ids: Vec<i64> = view.data().chunks_exact(8)
                .map(|b| i64::from_le_bytes([b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7]]))
                .collect();
            let n = ids.len();
            Tensor::<B, 2, Int>::from_data(
                TensorData::new(ids, vec![1, n]), &dev,
            )
        } else {
            Tensor::<B, 2, Int>::from_data(
                TensorData::new((0..n_chans as i64).collect::<Vec<_>>(), vec![1, n_chans]), &dev,
            )
        }
    };

    // Run forward
    let input = Tensor::<B, 3>::from_data(TensorData::new(inp_data.clone(), inp_shape.clone()), &dev);
    let output = model.forward(input, chan_ids);
    let out_vec = output.into_data().to_vec::<f32>().unwrap();

    eprintln!("Expected: {:?}", out_data);
    eprintln!("Got:      {:?}", out_vec);

    let max_diff: f32 = out_vec.iter().zip(out_data.iter())
        .map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
    eprintln!("Max diff: {:.6e}", max_diff);
    assert!(max_diff < 0.01, "Parity failed: max_diff={:.6e}", max_diff);
}