use std::path::Path;
use burn::backend::NdArray as B;
use burn::prelude::*;
fn load_parity_vectors(
path: &str,
device: &burn::backend::ndarray::NdArrayDevice,
) -> (Tensor<B, 2>, Tensor<B, 1>, Tensor<B, 2>) {
let bytes = std::fs::read(path).expect("read parity vectors");
let st = safetensors::SafeTensors::deserialize(&bytes).expect("deserialize");
let to_f32 = |name: &str| -> Vec<f32> {
let view = st.tensor(name).unwrap_or_else(|_| panic!("key {name}"));
view.data().chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
};
let signal_view = st.tensor("input_signal").unwrap();
let signal_shape: Vec<usize> = signal_view.shape().to_vec();
let signal_data = to_f32("input_signal");
let signal = Tensor::<B, 2>::from_data(
TensorData::new(signal_data, signal_shape), device,
);
let cls_view = st.tensor("cls_emb").unwrap();
let cls_shape: Vec<usize> = cls_view.shape().to_vec();
let cls_data = to_f32("cls_emb");
let cls = Tensor::<B, 1>::from_data(
TensorData::new(cls_data, cls_shape), device,
);
let patch_view = st.tensor("patch_embs").unwrap();
let patch_shape: Vec<usize> = patch_view.shape().to_vec();
let patch_data = to_f32("patch_embs");
let patches = Tensor::<B, 2>::from_data(
TensorData::new(patch_data, patch_shape), device,
);
(signal, cls, patches)
}
#[test]
fn test_forward_parity_with_python() {
let vectors_path = "tests/vectors/parity.safetensors";
let weights_path = "data/osf_backbone.safetensors";
if !Path::new(vectors_path).exists() || !Path::new(weights_path).exists() {
eprintln!("Skipping parity test: missing {vectors_path} or {weights_path}");
eprintln!("Run: python scripts/export_parity_vectors.py");
eprintln!("Run: python scripts/export_safetensors.py --input ../OSF-Base/osf_backbone.pth --output {weights_path}");
return;
}
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
let (signal, expected_cls, expected_patches) = load_parity_vectors(vectors_path, &device);
let cfg = osf_rs::ModelConfig::default();
let (encoder, _ms) = osf_rs::OsfEncoder::<B>::load_with_config(
cfg,
Path::new(weights_path),
device.clone(),
).expect("load model");
let signal_3d = signal.unsqueeze_dim::<3>(0); let (cls_out, patches_out) = encoder.model().forward_encoding(signal_3d);
let cls_out_1d = cls_out.reshape([768]);
let cls_diff = (cls_out_1d.clone() - expected_cls.clone()).abs();
let cls_max_err = cls_diff.clone().max().into_data().to_vec::<f32>().unwrap()[0];
let cls_mean_err = cls_diff.mean().into_data().to_vec::<f32>().unwrap()[0];
eprintln!("CLS max error: {cls_max_err:.6e}");
eprintln!("CLS mean error: {cls_mean_err:.6e}");
let patches_out_2d = patches_out.reshape([90, 768]);
let patch_diff = (patches_out_2d.clone() - expected_patches.clone()).abs();
let patch_max_err = patch_diff.clone().max().into_data().to_vec::<f32>().unwrap()[0];
let patch_mean_err = patch_diff.mean().into_data().to_vec::<f32>().unwrap()[0];
eprintln!("Patch max error: {patch_max_err:.6e}");
eprintln!("Patch mean error: {patch_mean_err:.6e}");
let tol = 5e-4;
assert!(cls_max_err < tol,
"CLS max error {cls_max_err:.6e} exceeds tolerance {tol}");
assert!(patch_max_err < tol,
"Patch max error {patch_max_err:.6e} exceeds tolerance {tol}");
}