osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// OSF Sleep Foundation Model inference — thin CLI.
///
/// Build — CPU (default):
///   cargo build --release
///
/// Build — GPU (macOS Metal):
///   cargo build --release --no-default-features --features metal
///
/// Usage:
///   infer --weights osf_backbone.safetensors [--config config.json]

use std::{path::Path, time::Instant};
use clap::Parser;

// ── Backend ───────────────────────────────────────────────────────────────────
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
    pub fn device() -> Device { Device::DefaultDevice }
    #[cfg(feature = "metal")]
    pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
    #[cfg(feature = "vulkan")]
    pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
    #[cfg(not(any(feature = "metal", feature = "vulkan")))]
    pub const NAME: &str = "GPU (wgpu — WGSL)";
}

#[cfg(feature = "ndarray")]
mod backend {
    pub use burn::backend::NdArray as B;
    pub type Device = burn::backend::ndarray::NdArrayDevice;
    pub fn device() -> Device { Device::Cpu }
    #[cfg(feature = "blas-accelerate")]
    pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
    #[cfg(feature = "openblas-system")]
    pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
    #[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
    pub const NAME: &str = "CPU (NdArray + Rayon)";
}

use backend::{B, device};

// ── CLI ───────────────────────────────────────────────────────────────────────
#[derive(Parser, Debug)]
#[command(about = "OSF Sleep Foundation Model inference (Burn 0.20.1)")]
struct Args {
    /// Safetensors weights file.
    #[arg(long)]
    weights: String,

    /// Optional config.json (uses default OSF-Base config if omitted).
    #[arg(long)]
    config: Option<String>,

    /// Output safetensors file for embeddings.
    #[arg(long)]
    output: Option<String>,

    /// Print details.
    #[arg(long, short = 'v')]
    verbose: bool,
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    let t0 = Instant::now();
    let dev = device();

    println!("Backend : {}", backend::NAME);

    // Load config
    let model_cfg = if let Some(ref cfg_path) = args.config {
        let cfg_str = std::fs::read_to_string(cfg_path)?;
        serde_json::from_str(&cfg_str)?
    } else {
        osf_rs::ModelConfig::default()
    };

    // Load model
    let (encoder, ms_weights) = osf_rs::OsfEncoder::<B>::load_with_config(
        model_cfg,
        Path::new(&args.weights),
        dev.clone(),
    )?;

    println!("Model   : {}  ({ms_weights:.0} ms)", encoder.describe());

    // Create dummy PSG input: 12 channels × 1920 samples (64 Hz × 30 s)
    let n_channels = osf_rs::NUM_PSG_CHANNELS;
    let n_samples = osf_rs::EPOCH_SAMPLES;
    let signal = vec![0.0f32; n_channels * n_samples];

    let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &dev);

    let t_inf = Instant::now();
    let result = encoder.run_batch(&batch)?;
    let ms_infer = t_inf.elapsed().as_secs_f64() * 1000.0;

    println!("Output  : cls=[{}]  patches=[{}, {}]  ({ms_infer:.1} ms)",
        result.embed_dim, result.num_patches, result.embed_dim);

    if args.verbose {
        let mean: f64 = result.cls_emb.iter().map(|&v| v as f64).sum::<f64>()
            / result.cls_emb.len() as f64;
        let std: f64 = (result.cls_emb.iter().map(|&v| {
            let d = v as f64 - mean; d * d
        }).sum::<f64>() / result.cls_emb.len() as f64).sqrt();
        println!("  CLS: mean={mean:+.6}  std={std:.6}");

        let p_mean: f64 = result.patch_embs.iter().map(|&v| v as f64).sum::<f64>()
            / result.patch_embs.len() as f64;
        let p_std: f64 = (result.patch_embs.iter().map(|&v| {
            let d = v as f64 - p_mean; d * d
        }).sum::<f64>() / result.patch_embs.len() as f64).sqrt();
        println!("  Patches: mean={p_mean:+.6}  std={p_std:.6}");
    }

    // Save if requested
    if let Some(ref out_path) = args.output {
        let encoding = osf_rs::EncodingResult {
            epochs: vec![result],
            ms_load: ms_weights,
            ms_encode: ms_infer,
        };
        if let Some(p) = Path::new(out_path).parent() {
            std::fs::create_dir_all(p)?;
        }
        encoding.save_safetensors(out_path)?;
        println!("Saved → {out_path}");
    }

    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
    println!("── Timing ───────────────────────────────────────────────────────");
    println!("  Weights  : {ms_weights:.0} ms");
    println!("  Infer    : {ms_infer:.0} ms");
    println!("  Total    : {ms_total:.0} ms");

    Ok(())
}