osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Example: Extract embeddings from synthetic PSG data.
///
/// Usage:
///   cargo run --example embed --release -- --weights data/osf_backbone.safetensors

use std::path::Path;
use std::time::Instant;
use clap::Parser;

// ── Backend ───────────────────────────────────────────────────────────────────
#[cfg(feature = "wgpu")]
mod backend {
    pub use burn::backend::Wgpu as B;
    pub fn device() -> burn::backend::wgpu::WgpuDevice {
        burn::backend::wgpu::WgpuDevice::default()
    }
}
#[cfg(not(feature = "wgpu"))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice {
        burn::backend::ndarray::NdArrayDevice::Cpu
    }
}
use backend::B;

#[derive(Parser, Debug)]
#[command(about = "OSF — PSG embedding extraction")]
struct Args {
    #[arg(long)]
    weights: String,
    #[arg(long)]
    config: Option<String>,
    #[arg(long, default_value = "data/embeddings.safetensors")]
    output: String,
    #[arg(long, short = 'v')]
    verbose: bool,
}

fn generate_synthetic_psg(n_channels: usize, n_samples: usize) -> Vec<f32> {
    let mut signal = vec![0.0f32; n_channels * n_samples];
    for ch in 0..n_channels {
        let freq = 1.0 + ch as f32 * 0.5;
        let mut noise_state: u32 = (ch as u32 + 1) * 0xDEAD_BEEF;
        for t in 0..n_samples {
            let time = t as f32 / 64.0; // 64 Hz
            let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
            noise_state ^= noise_state << 13;
            noise_state ^= noise_state >> 17;
            noise_state ^= noise_state << 5;
            let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
            signal[ch * n_samples + t] = sine + noise;
        }
    }
    signal
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    let t0 = Instant::now();
    let device = backend::device();

    println!("╔══════════════════════════════════════════════════════════════╗");
    println!("║  OSF — PSG Embedding Extraction                             ║");
    println!("╚══════════════════════════════════════════════════════════════╝\n");

    // 1. Load model
    let model_cfg = if let Some(ref cfg_path) = args.config {
        let cfg_str = std::fs::read_to_string(cfg_path)?;
        serde_json::from_str(&cfg_str)?
    } else {
        osf_rs::ModelConfig::default()
    };

    println!("▸ Loading OSF model …");
    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
        model_cfg,
        Path::new(&args.weights),
        device.clone(),
    )?;
    println!("  {}  ({ms_load:.0} ms)\n", encoder.describe());

    // 2. Generate synthetic PSG epochs
    let n_channels = osf_rs::NUM_PSG_CHANNELS;
    let n_samples = osf_rs::EPOCH_SAMPLES;
    let n_epochs = 3;

    println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
        n_epochs, n_channels, n_samples);

    let mut all_outputs = Vec::new();

    for epoch_idx in 0..n_epochs {
        let signal = generate_synthetic_psg(n_channels, n_samples);
        let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);

        let t = Instant::now();
        let result = encoder.run_batch(&batch)?;
        let ms = t.elapsed().as_secs_f64() * 1000.0;

        // Stats
        let cls = &result.cls_emb;
        let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
        let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
            / cls.len() as f64).sqrt();

        println!("  Epoch {epoch_idx}: cls=[{}]  patches=[{},{}]  mean={mean:+.4}  std={std:.4}  {ms:.1}ms",
            result.embed_dim, result.num_patches, result.embed_dim);

        if args.verbose && epoch_idx == 0 {
            println!("    First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
        }

        all_outputs.push(result);
    }

    // 3. Save
    let encoding = osf_rs::EncodingResult {
        epochs: all_outputs,
        ms_load,
        ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
    };

    if let Some(p) = Path::new(&args.output).parent() {
        std::fs::create_dir_all(p)?;
    }
    encoding.save_safetensors(&args.output)?;
    println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);

    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
    println!("\n  Total: {ms_total:.0} ms");
    Ok(())
}