osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Python parity test: compare Rust forward pass against Python-exported vectors.
///
/// Prerequisite: run the following to generate test vectors:
///   python scripts/export_parity_vectors.py
///   python scripts/export_safetensors.py --input ../OSF-Base/osf_backbone.pth --output data/osf_backbone.safetensors

use std::path::Path;

use burn::backend::NdArray as B;
use burn::prelude::*;

fn load_parity_vectors(
    path: &str,
    device: &burn::backend::ndarray::NdArrayDevice,
) -> (Tensor<B, 2>, Tensor<B, 1>, Tensor<B, 2>) {
    let bytes = std::fs::read(path).expect("read parity vectors");
    let st = safetensors::SafeTensors::deserialize(&bytes).expect("deserialize");

    let to_f32 = |name: &str| -> Vec<f32> {
        let view = st.tensor(name).unwrap_or_else(|_| panic!("key {name}"));
        view.data().chunks_exact(4)
            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
            .collect()
    };

    let signal_view = st.tensor("input_signal").unwrap();
    let signal_shape: Vec<usize> = signal_view.shape().to_vec();
    let signal_data = to_f32("input_signal");
    let signal = Tensor::<B, 2>::from_data(
        TensorData::new(signal_data, signal_shape), device,
    );

    let cls_view = st.tensor("cls_emb").unwrap();
    let cls_shape: Vec<usize> = cls_view.shape().to_vec();
    let cls_data = to_f32("cls_emb");
    let cls = Tensor::<B, 1>::from_data(
        TensorData::new(cls_data, cls_shape), device,
    );

    let patch_view = st.tensor("patch_embs").unwrap();
    let patch_shape: Vec<usize> = patch_view.shape().to_vec();
    let patch_data = to_f32("patch_embs");
    let patches = Tensor::<B, 2>::from_data(
        TensorData::new(patch_data, patch_shape), device,
    );

    (signal, cls, patches)
}

#[test]
fn test_forward_parity_with_python() {
    let vectors_path = "tests/vectors/parity.safetensors";
    let weights_path = "data/osf_backbone.safetensors";

    if !Path::new(vectors_path).exists() || !Path::new(weights_path).exists() {
        eprintln!("Skipping parity test: missing {vectors_path} or {weights_path}");
        eprintln!("Run: python scripts/export_parity_vectors.py");
        eprintln!("Run: python scripts/export_safetensors.py --input ../OSF-Base/osf_backbone.pth --output {weights_path}");
        return;
    }

    let device = burn::backend::ndarray::NdArrayDevice::Cpu;

    // Load parity vectors
    let (signal, expected_cls, expected_patches) = load_parity_vectors(vectors_path, &device);

    // Load model
    let cfg = osf_rs::ModelConfig::default();
    let (encoder, _ms) = osf_rs::OsfEncoder::<B>::load_with_config(
        cfg,
        Path::new(weights_path),
        device.clone(),
    ).expect("load model");

    // Run inference
    let signal_3d = signal.unsqueeze_dim::<3>(0); // [1, 12, 1920]
    let (cls_out, patches_out) = encoder.model().forward_encoding(signal_3d);

    // Compare CLS embedding: cls_out is [1, 1, 768]
    let cls_out_1d = cls_out.reshape([768]);
    let cls_diff = (cls_out_1d.clone() - expected_cls.clone()).abs();
    let cls_max_err = cls_diff.clone().max().into_data().to_vec::<f32>().unwrap()[0];
    let cls_mean_err = cls_diff.mean().into_data().to_vec::<f32>().unwrap()[0];

    eprintln!("CLS max error:  {cls_max_err:.6e}");
    eprintln!("CLS mean error: {cls_mean_err:.6e}");

    // Compare patch embeddings
    let patches_out_2d = patches_out.reshape([90, 768]);
    let patch_diff = (patches_out_2d.clone() - expected_patches.clone()).abs();
    let patch_max_err = patch_diff.clone().max().into_data().to_vec::<f32>().unwrap()[0];
    let patch_mean_err = patch_diff.mean().into_data().to_vec::<f32>().unwrap()[0];

    eprintln!("Patch max error:  {patch_max_err:.6e}");
    eprintln!("Patch mean error: {patch_mean_err:.6e}");

    // Tolerance: f32 accumulation through 12 transformer blocks
    let tol = 5e-4;
    assert!(cls_max_err < tol,
        "CLS max error {cls_max_err:.6e} exceeds tolerance {tol}");
    assert!(patch_max_err < tol,
        "Patch max error {patch_max_err:.6e} exceeds tolerance {tol}");
}