reve-rs 0.0.1

REVE EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// Python parity test — load weights + inputs from Python, compare output.
///
/// Run: cargo test --test python_parity -- --nocapture
///
/// Requires /tmp/reve_parity.safetensors generated by the Python script.

use burn::backend::NdArray as B;
use burn::prelude::*;
use std::collections::HashMap;

fn load_parity_data(
    device: &burn::backend::ndarray::NdArrayDevice,
) -> Option<HashMap<String, (Vec<f32>, Vec<usize>)>> {
    let path = "/tmp/reve_parity.safetensors";
    if !std::path::Path::new(path).exists() {
        eprintln!("Skipping parity test: {path} not found");
        eprintln!("Generate it by running the Python parity script first.");
        return None;
    }

    let bytes = std::fs::read(path).unwrap();
    let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
    let mut tensors = HashMap::new();

    for (key, view) in st.tensors() {
        let shape: Vec<usize> = view.shape().to_vec();
        let data = view.data();
        let f32s: Vec<f32> = data
            .chunks_exact(4)
            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
            .collect();
        tensors.insert(key.to_string(), (f32s, shape));
    }

    Some(tensors)
}

#[test]
fn test_python_parity() {
    let device = burn::backend::ndarray::NdArrayDevice::Cpu;

    let data = match load_parity_data(&device) {
        Some(d) => d,
        None => return,
    };

    // Load input
    let (eeg_data, eeg_shape) = &data["_input_eeg"];
    let (pos_data, pos_shape) = &data["_input_pos"];
    let (out_data, out_shape) = &data["_output"];

    eprintln!("EEG shape: {:?}", eeg_shape);
    eprintln!("Pos shape: {:?}", pos_shape);
    eprintln!("Expected output shape: {:?}", out_shape);
    eprintln!("Expected output: {:?}", &out_data[..]);

    let n_chans = eeg_shape[1];
    let n_times = eeg_shape[2];
    let n_outputs = out_shape[1];

    // Build model with same config as Python
    let mut model = reve_rs::model::reve::Reve::<B>::new(
        n_outputs,
        n_chans,
        n_times,
        512,   // embed_dim
        2,     // depth
        8,     // heads
        64,    // head_dim
        2.66,  // mlp_dim_ratio
        true,  // geglu
        4,     // freqs
        200,   // patch_size
        20,    // patch_overlap
        false, // no attention pooling
        &device,
    );

    // Load weights from parity data
    load_parity_weights(&data, &mut model, &device);

    // Build input tensors
    let eeg = Tensor::<B, 3>::from_data(
        TensorData::new(eeg_data.clone(), eeg_shape.clone()),
        &device,
    );
    let pos = Tensor::<B, 3>::from_data(
        TensorData::new(pos_data.clone(), pos_shape.clone()),
        &device,
    );

    // Forward pass
    let output = model.forward(eeg, pos);
    let output_vec = output.into_data().to_vec::<f32>().unwrap();

    eprintln!("Rust output: {:?}", &output_vec[..]);

    // Compare
    let max_diff: f32 = output_vec
        .iter()
        .zip(out_data.iter())
        .map(|(a, b)| (a - b).abs())
        .fold(0.0f32, f32::max);

    eprintln!("Max absolute difference: {:.6e}", max_diff);

    // Allow some tolerance for float differences across implementations
    assert!(
        max_diff < 0.01,
        "Output difference too large: {:.6e}",
        max_diff
    );
}

fn load_parity_weights(
    data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
    model: &mut reve_rs::model::reve::Reve<B>,
    device: &burn::backend::ndarray::NdArrayDevice,
) {
    // Helper to load a tensor
    fn t1(
        data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
        key: &str,
        device: &burn::backend::ndarray::NdArrayDevice,
    ) -> Tensor<B, 1> {
        let (d, s) = &data[key];
        Tensor::from_data(TensorData::new(d.clone(), s.clone()), device)
    }
    fn t2(
        data: &HashMap<String, (Vec<f32>, Vec<usize>)>,
        key: &str,
        device: &burn::backend::ndarray::NdArrayDevice,
    ) -> Tensor<B, 2> {
        let (d, s) = &data[key];
        Tensor::from_data(TensorData::new(d.clone(), s.clone()), device)
    }

    // Patch embedding (PyTorch [out, in] → burn [in, out])
    let w = t2(data, "to_patch_embedding.0.weight", device);
    let b = t1(data, "to_patch_embedding.0.bias", device);
    model.patch_embed.weight = model.patch_embed.weight.clone().map(|_| w.transpose());
    if let Some(ref bias) = model.patch_embed.bias {
        model.patch_embed.bias = Some(bias.clone().map(|_| b));
    }

    // MLP4D linear (no bias)
    let w = t2(data, "mlp4d.0.weight", device);
    model.mlp4d_linear.weight = model.mlp4d_linear.weight.clone().map(|_| w.transpose());

    // MLP4D LayerNorm
    let w = t1(data, "mlp4d.2.weight", device);
    let b = t1(data, "mlp4d.2.bias", device);
    model.mlp4d_ln.gamma = model.mlp4d_ln.gamma.clone().map(|_| w);
    if let Some(ref beta) = model.mlp4d_ln.beta {
        model.mlp4d_ln.beta = Some(beta.clone().map(|_| b));
    }

    // 4DPE LayerNorm (ln)
    let w = t1(data, "ln.weight", device);
    let b = t1(data, "ln.bias", device);
    model.pos_ln.gamma = model.pos_ln.gamma.clone().map(|_| w);
    if let Some(ref beta) = model.pos_ln.beta {
        model.pos_ln.beta = Some(beta.clone().map(|_| b));
    }

    // Transformer layers
    for i in 0..2 {
        let block = &mut model.transformer.layers[i];

        // Attention norm (RMSNorm)
        let w = t1(data, &format!("transformer.layers.{i}.0.norm.weight"), device);
        block.attn.norm.weight = block.attn.norm.weight.clone().map(|_| w);

        // Attention to_qkv (no bias)
        let w = t2(
            data,
            &format!("transformer.layers.{i}.0.to_qkv.weight"),
            device,
        );
        block.attn.to_qkv.weight = block.attn.to_qkv.weight.clone().map(|_| w.transpose());

        // Attention to_out (no bias)
        let w = t2(
            data,
            &format!("transformer.layers.{i}.0.to_out.weight"),
            device,
        );
        block.attn.to_out.weight = block.attn.to_out.weight.clone().map(|_| w.transpose());

        // FF norm (RMSNorm)
        let w = t1(
            data,
            &format!("transformer.layers.{i}.1.net.0.weight"),
            device,
        );
        block.ff.norm.weight = block.ff.norm.weight.clone().map(|_| w);

        // FF linear1 (no bias)
        let w = t2(
            data,
            &format!("transformer.layers.{i}.1.net.1.weight"),
            device,
        );
        block.ff.linear1.weight = block.ff.linear1.weight.clone().map(|_| w.transpose());

        // FF linear2 (no bias)
        let w = t2(
            data,
            &format!("transformer.layers.{i}.1.net.3.weight"),
            device,
        );
        block.ff.linear2.weight = block.ff.linear2.weight.clone().map(|_| w.transpose());
    }

    // Final layer (Flatten → LayerNorm → Linear)
    let w = t1(data, "final_layer.1.weight", device);
    let b = t1(data, "final_layer.1.bias", device);
    model.final_ln.gamma = model.final_ln.gamma.clone().map(|_| w);
    if let Some(ref beta) = model.final_ln.beta {
        model.final_ln.beta = Some(beta.clone().map(|_| b));
    }

    let w = t2(data, "final_layer.2.weight", device);
    let b = t1(data, "final_layer.2.bias", device);
    model.final_linear.weight = model.final_linear.weight.clone().map(|_| w.transpose());
    if let Some(ref bias) = model.final_linear.bias {
        model.final_linear.bias = Some(bias.clone().map(|_| b));
    }
}