osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Data preparation for OSF inference.
///
/// OSF input: [B, 12, 1920] — 12-channel PSG, 64 Hz × 30 seconds.

use burn::prelude::*;

/// A single prepared input batch for the OSF model.
pub struct InputBatch<B: Backend> {
    /// PSG signal: [1, C, T].
    pub signal: Tensor<B, 3>,
    /// Number of channels.
    pub n_channels: usize,
    /// Number of time samples.
    pub n_samples: usize,
}

/// Per-epoch embedding produced by OSF.
pub struct EpochEmbedding {
    /// CLS embedding values (row-major f32): [D].
    pub cls_emb: Vec<f32>,
    /// Patch embedding values (row-major f32): [N, D].
    pub patch_embs: Vec<f32>,
    /// CLS embedding dimension.
    pub embed_dim: usize,
    /// Number of patches.
    pub num_patches: usize,
}

/// Collection of per-epoch outputs.
pub struct EncodingResult {
    pub epochs: Vec<EpochEmbedding>,
    pub ms_load: f64,
    pub ms_encode: f64,
}

impl EncodingResult {
    /// Save to safetensors file.
    pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
        use safetensors::{Dtype, View};
        use std::borrow::Cow;

        struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
        impl View for RawTensor {
            fn dtype(&self)    -> Dtype         { self.dtype }
            fn shape(&self)    -> &[usize]      { &self.shape }
            fn data(&self)     -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
            fn data_len(&self) -> usize         { self.data.len() }
        }

        let f32_bytes = |v: &[f32]| -> Vec<u8> {
            v.iter().flat_map(|f| f.to_le_bytes()).collect()
        };

        let mut keys: Vec<String> = Vec::new();
        let mut tensors: Vec<RawTensor> = Vec::new();

        for (i, ep) in self.epochs.iter().enumerate() {
            keys.push(format!("cls_emb_{i}"));
            tensors.push(RawTensor {
                data: f32_bytes(&ep.cls_emb),
                shape: vec![ep.embed_dim],
                dtype: Dtype::F32,
            });

            keys.push(format!("patch_embs_{i}"));
            tensors.push(RawTensor {
                data: f32_bytes(&ep.patch_embs),
                shape: vec![ep.num_patches, ep.embed_dim],
                dtype: Dtype::F32,
            });
        }

        let n = self.epochs.len() as f32;
        keys.push("n_epochs".into());
        tensors.push(RawTensor {
            data: f32_bytes(&[n]),
            shape: vec![1],
            dtype: Dtype::F32,
        });

        let pairs: Vec<(&str, RawTensor)> = keys.iter()
            .map(|s| s.as_str())
            .zip(tensors)
            .collect();
        let bytes = safetensors::serialize(pairs, None)?;
        std::fs::write(path, bytes)?;
        Ok(())
    }
}

/// Build an InputBatch from raw signal data.
pub fn build_batch<B: Backend>(
    signal: Vec<f32>,    // [C, T] row-major
    n_channels: usize,
    n_samples: usize,
    device: &B::Device,
) -> InputBatch<B> {
    let signal = Tensor::<B, 2>::from_data(
        TensorData::new(signal, vec![n_channels, n_samples]), device,
    ).unsqueeze_dim::<3>(0); // [1, C, T]

    InputBatch { signal, n_channels, n_samples }
}

/// Channel-wise z-score normalisation.
///
/// Python equivalent: (x - mean) / (std + eps) per channel.
pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
    let mean = x.clone().mean_dim(2); // [B, C, 1]
    let diff = x.clone() - mean.clone();
    let var = (diff.clone() * diff).mean_dim(2);
    let std = (var + 1e-8).sqrt();
    (x - mean) / std
}