use burn::backend::NdArray as B;
use burn::prelude::*;
use std::collections::HashMap;
fn load_parity_data(
device: &burn::backend::ndarray::NdArrayDevice,
) -> Option<HashMap<String, (Vec<f32>, Vec<usize>)>> {
let path = "/tmp/reve_parity.safetensors";
if !std::path::Path::new(path).exists() {
eprintln!("Skipping parity test: {path} not found");
eprintln!("Generate it by running the Python parity script first.");
return None;
}
let bytes = std::fs::read(path).unwrap();
let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
let mut tensors = HashMap::new();
for (key, view) in st.tensors() {
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let f32s: Vec<f32> = data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
tensors.insert(key.to_string(), (f32s, shape));
}
Some(tensors)
}
#[test]
fn test_python_parity() {
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
let data = match load_parity_data(&device) {
Some(d) => d,
None => return,
};
let (eeg_data, eeg_shape) = &data["_input_eeg"];
let (pos_data, pos_shape) = &data["_input_pos"];
let (out_data, out_shape) = &data["_output"];
eprintln!("EEG shape: {:?}", eeg_shape);
eprintln!("Pos shape: {:?}", pos_shape);
eprintln!("Expected output shape: {:?}", out_shape);
eprintln!("Expected output: {:?}", &out_data[..]);
let n_chans = eeg_shape[1];
let n_times = eeg_shape[2];
let n_outputs = out_shape[1];
let mut model = reve_rs::model::reve::Reve::<B>::new(
n_outputs,
n_chans,
n_times,
512, 2, 8, 64, 2.66, true, 4, 200, 20, false, &device,
);
load_parity_weights(&data, &mut model, &device);
let eeg = Tensor::<B, 3>::from_data(
TensorData::new(eeg_data.clone(), eeg_shape.clone()),
&device,
);
let pos = Tensor::<B, 3>::from_data(
TensorData::new(pos_data.clone(), pos_shape.clone()),
&device,
);
let output = model.forward(eeg, pos);
let output_vec = output.into_data().to_vec::<f32>().unwrap();
eprintln!("Rust output: {:?}", &output_vec[..]);
let max_diff: f32 = output_vec
.iter()
.zip(out_data.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
eprintln!("Max absolute difference: {:.6e}", max_diff);
assert!(
max_diff < 0.01,
"Output difference too large: {:.6e}",
max_diff
);
}
fn load_parity_weights(
data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
model: &mut reve_rs::model::reve::Reve<B>,
device: &burn::backend::ndarray::NdArrayDevice,
) {
fn t1(
data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
key: &str,
device: &burn::backend::ndarray::NdArrayDevice,
) -> Tensor<B, 1> {
let (d, s) = &data[key];
Tensor::from_data(TensorData::new(d.clone(), s.clone()), device)
}
fn t2(
data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
key: &str,
device: &burn::backend::ndarray::NdArrayDevice,
) -> Tensor<B, 2> {
let (d, s) = &data[key];
Tensor::from_data(TensorData::new(d.clone(), s.clone()), device)
}
let w = t2(data, "to_patch_embedding.0.weight", device);
let b = t1(data, "to_patch_embedding.0.bias", device);
model.patch_embed.weight = model.patch_embed.weight.clone().map(|_| w.transpose());
if let Some(ref bias) = model.patch_embed.bias {
model.patch_embed.bias = Some(bias.clone().map(|_| b));
}
let w = t2(data, "mlp4d.0.weight", device);
model.mlp4d_linear.weight = model.mlp4d_linear.weight.clone().map(|_| w.transpose());
let w = t1(data, "mlp4d.2.weight", device);
let b = t1(data, "mlp4d.2.bias", device);
model.mlp4d_ln.gamma = model.mlp4d_ln.gamma.clone().map(|_| w);
if let Some(ref beta) = model.mlp4d_ln.beta {
model.mlp4d_ln.beta = Some(beta.clone().map(|_| b));
}
let w = t1(data, "ln.weight", device);
let b = t1(data, "ln.bias", device);
model.pos_ln.gamma = model.pos_ln.gamma.clone().map(|_| w);
if let Some(ref beta) = model.pos_ln.beta {
model.pos_ln.beta = Some(beta.clone().map(|_| b));
}
for i in 0..2 {
let block = &mut model.transformer.layers[i];
let w = t1(data, &format!("transformer.layers.{i}.0.norm.weight"), device);
block.attn.norm.weight = block.attn.norm.weight.clone().map(|_| w);
let w = t2(
data,
&format!("transformer.layers.{i}.0.to_qkv.weight"),
device,
);
block.attn.to_qkv.weight = block.attn.to_qkv.weight.clone().map(|_| w.transpose());
let w = t2(
data,
&format!("transformer.layers.{i}.0.to_out.weight"),
device,
);
block.attn.to_out.weight = block.attn.to_out.weight.clone().map(|_| w.transpose());
let w = t1(
data,
&format!("transformer.layers.{i}.1.net.0.weight"),
device,
);
block.ff.norm.weight = block.ff.norm.weight.clone().map(|_| w);
let w = t2(
data,
&format!("transformer.layers.{i}.1.net.1.weight"),
device,
);
block.ff.linear1.weight = block.ff.linear1.weight.clone().map(|_| w.transpose());
let w = t2(
data,
&format!("transformer.layers.{i}.1.net.3.weight"),
device,
);
block.ff.linear2.weight = block.ff.linear2.weight.clone().map(|_| w.transpose());
}
let w = t1(data, "final_layer.1.weight", device);
let b = t1(data, "final_layer.1.bias", device);
model.final_ln.gamma = model.final_ln.gamma.clone().map(|_| w);
if let Some(ref beta) = model.final_ln.beta {
model.final_ln.beta = Some(beta.clone().map(|_| b));
}
let w = t2(data, "final_layer.2.weight", device);
let b = t1(data, "final_layer.2.bias", device);
model.final_linear.weight = model.final_linear.weight.clone().map(|_| w.transpose());
if let Some(ref bias) = model.final_linear.bias {
model.final_linear.bias = Some(bias.clone().map(|_| b));
}
}