use burn::backend::NdArray as B;
use burn::prelude::*;
use std::collections::HashMap;
fn load_data() -> Option<HashMap<String, (Vec<f32>, Vec<usize>)>> {
let path = "/tmp/eegpt_parity.safetensors";
if !std::path::Path::new(path).exists() {
eprintln!("Skipping: {path} not found"); return None;
}
let bytes = std::fs::read(path).unwrap();
let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
let mut m = HashMap::new();
for (k, v) in st.tensors() {
let shape: Vec<usize> = v.shape().to_vec();
let f32s: Vec<f32> = v.data().chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])).collect();
m.insert(k.to_string(), (f32s, shape));
}
Some(m)
}
fn t1(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 1> {
let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t2(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 2> {
let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t3(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 3> {
let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
fn t4(d: &HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, dev: &burn::backend::ndarray::NdArrayDevice) -> Tensor<B, 4> {
let (v, s) = &d[k]; Tensor::from_data(TensorData::new(v.clone(), s.clone()), dev)
}
macro_rules! set_linear_wb {
($d:expr, $l:expr, $wk:expr, $bk:expr, $dev:expr) => {
let w = t2($d, $wk, $dev);
let b = t1($d, $bk, $dev);
$l.weight = $l.weight.clone().map(|_| w.transpose());
if let Some(ref bias) = $l.bias { $l.bias = Some(bias.clone().map(|_| b)); }
};
}
macro_rules! set_ln {
($d:expr, $n:expr, $wk:expr, $bk:expr, $dev:expr) => {
let w = t1($d, $wk, $dev);
let b = t1($d, $bk, $dev);
$n.gamma = $n.gamma.clone().map(|_| w);
if let Some(ref beta) = $n.beta { $n.beta = Some(beta.clone().map(|_| b)); }
};
}
#[test]
fn test_python_parity() {
let dev = burn::backend::ndarray::NdArrayDevice::Cpu;
let data = match load_data() { Some(d) => d, None => return };
let (inp_data, inp_shape) = &data["_input"];
let (out_data, _) = &data["_output"];
let n_chans = inp_shape[1];
let n_times = inp_shape[2];
let mut model = eegpt_rs::model::eegpt::EEGPT::<B>::new(
4, n_chans, n_times, 64, 32, 4, 512, 2, 8, 4.0, true, 62, 16, 1e-6, &dev,
);
let te = &mut model.target_encoder;
let st = t3(&data, "target_encoder.summary_token", &dev);
te.summary_token = te.summary_token.clone().map(|_| st);
let w = t4(&data, "target_encoder.patch_embed.proj.weight", &dev);
let b = t1(&data, "target_encoder.patch_embed.proj.bias", &dev);
te.patch_embed.proj.weight = te.patch_embed.proj.weight.clone().map(|_| w);
if let Some(ref bias) = te.patch_embed.proj.bias {
te.patch_embed.proj.bias = Some(bias.clone().map(|_| b));
}
let w = t2(&data, "target_encoder.chan_embed.weight", &dev);
te.chan_embed.weight = te.chan_embed.weight.clone().map(|_| w);
for i in 0..2 {
let block = &mut te.blocks[i];
let p = format!("target_encoder.blocks.{i}");
set_ln!(&data, block.norm1, &format!("{p}.norm1.weight"), &format!("{p}.norm1.bias"), &dev);
set_linear_wb!(&data, block.attn.qkv, &format!("{p}.attn.qkv.weight"), &format!("{p}.attn.qkv.bias"), &dev);
set_linear_wb!(&data, block.attn.proj, &format!("{p}.attn.proj.weight"), &format!("{p}.attn.proj.bias"), &dev);
set_ln!(&data, block.norm2, &format!("{p}.norm2.weight"), &format!("{p}.norm2.bias"), &dev);
set_linear_wb!(&data, block.mlp_fc1, &format!("{p}.mlp.fc1.weight"), &format!("{p}.mlp.fc1.bias"), &dev);
set_linear_wb!(&data, block.mlp_fc2, &format!("{p}.mlp.fc2.weight"), &format!("{p}.mlp.fc2.bias"), &dev);
}
set_ln!(&data, te.norm, "target_encoder.norm.weight", "target_encoder.norm.bias", &dev);
set_linear_wb!(&data, model.probe1, "final_layer.probe1.weight", "final_layer.probe1.bias", &dev);
set_linear_wb!(&data, model.probe2, "final_layer.probe2.weight", "final_layer.probe2.bias", &dev);
let chan_ids = {
let path = "/tmp/eegpt_parity.safetensors";
let bytes = std::fs::read(path).unwrap();
let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
if let Ok(view) = st.tensor("chans_id") {
let ids: Vec<i64> = view.data().chunks_exact(8)
.map(|b| i64::from_le_bytes([b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7]]))
.collect();
let n = ids.len();
Tensor::<B, 2, Int>::from_data(
TensorData::new(ids, vec![1, n]), &dev,
)
} else {
Tensor::<B, 2, Int>::from_data(
TensorData::new((0..n_chans as i64).collect::<Vec<_>>(), vec![1, n_chans]), &dev,
)
}
};
let input = Tensor::<B, 3>::from_data(TensorData::new(inp_data.clone(), inp_shape.clone()), &dev);
let output = model.forward(input, chan_ids);
let out_vec = output.into_data().to_vec::<f32>().unwrap();
eprintln!("Expected: {:?}", out_data);
eprintln!("Got: {:?}", out_vec);
let max_diff: f32 = out_vec.iter().zip(out_data.iter())
.map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
eprintln!("Max diff: {:.6e}", max_diff);
assert!(max_diff < 0.01, "Parity failed: max_diff={:.6e}", max_diff);
}