osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Standalone OSF encoder — produce sleep PSG embeddings.
///
/// The encoder produces:
/// - CLS embedding: [B, 768] — global epoch-level representation
/// - Patch embeddings: [B, 90, 768] — local patch representations

use std::{path::Path, time::Instant};

use anyhow::Context;
use burn::prelude::*;

use crate::{
    config::ModelConfig,
    data::{InputBatch, EpochEmbedding},
    model::vit::OsfViT,
    weights::load_model,
};

/// High-level OSF encoder for PSG signal processing.
pub struct OsfEncoder<B: Backend> {
    model: OsfViT<B>,
    pub model_cfg: ModelConfig,
    device: B::Device,
}

impl<B: Backend> OsfEncoder<B> {
    /// Load model from config JSON and safetensors weights.
    pub fn load(
        config_path: &Path,
        weights_path: &Path,
        device: B::Device,
    ) -> anyhow::Result<(Self, f64)> {
        let cfg_str = std::fs::read_to_string(config_path)
            .with_context(|| format!("config: {}", config_path.display()))?;
        let model_cfg: ModelConfig = serde_json::from_str(&cfg_str)
            .context("parsing model config")?;

        let t = Instant::now();
        let model = load_model::<B>(
            &model_cfg,
            weights_path.to_str().context("weights path not valid UTF-8")?,
            &device,
        )?;
        let ms = t.elapsed().as_secs_f64() * 1000.0;

        Ok((Self { model, model_cfg, device }, ms))
    }

    /// Load model from a ModelConfig and safetensors path directly.
    pub fn load_with_config(
        model_cfg: ModelConfig,
        weights_path: &Path,
        device: B::Device,
    ) -> anyhow::Result<(Self, f64)> {
        let t = Instant::now();
        let model = load_model::<B>(
            &model_cfg,
            weights_path.to_str().context("weights path not valid UTF-8")?,
            &device,
        )?;
        let ms = t.elapsed().as_secs_f64() * 1000.0;

        Ok((Self { model, model_cfg, device }, ms))
    }

    pub fn describe(&self) -> String {
        let c = &self.model_cfg;
        format!(
            "OSF  encoder={}  width={}  depth={}  heads={}  leads={}  patch_time={}  patch_ch={}",
            c.encoder_name, c.width, c.depth, c.heads, c.num_leads,
            c.patch_size_time, c.patch_size_ch,
        )
    }

    /// Run inference on a prepared InputBatch.
    ///
    /// Returns an EpochEmbedding with CLS and patch embeddings.
    pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
        let (cls, patches) = self.model.forward_encoding(batch.signal.clone());

        // cls: [B, 1, D], patches: [B, N, D]
        let embed_dim = cls.dims()[2];
        let num_patches = patches.dims()[1];

        let cls_vec = cls.reshape([embed_dim])
            .into_data()
            .to_vec::<f32>()
            .map_err(|e| anyhow::anyhow!("cls→vec: {e:?}"))?;

        let patches_vec = patches.reshape([num_patches, embed_dim])
            .into_data()
            .to_vec::<f32>()
            .map_err(|e| anyhow::anyhow!("patches→vec: {e:?}"))?;

        Ok(EpochEmbedding {
            cls_emb: cls_vec,
            patch_embs: patches_vec,
            embed_dim,
            num_patches,
        })
    }

    /// Run on multiple batches.
    pub fn run_batches(&self, batches: &[InputBatch<B>]) -> anyhow::Result<Vec<EpochEmbedding>> {
        batches.iter().map(|b| self.run_batch(b)).collect()
    }

    /// Get the raw ViT model reference.
    pub fn model(&self) -> &OsfViT<B> { &self.model }

    pub fn device(&self) -> &B::Device { &self.device }
}