osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Benchmark — measure inference latency, load time, and batch/epoch scaling.
///
/// Demonstrates:
///   - Weight loading time
///   - Forward pass latency (CLS + patch embeddings)
///   - Batch size scaling
///   - Epoch count throughput
///
/// Usage:
///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors
///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors --json
///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors --warmup 5 --runs 20

use std::path::Path;
use std::time::Instant;
use burn::prelude::*;
use clap::Parser;

// ── Backend ───────────────────────────────────────────────────────────────────
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
    pub fn device() -> Device { Device::DefaultDevice }
    #[cfg(feature = "metal")]
    pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
    #[cfg(feature = "vulkan")]
    pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
    #[cfg(not(any(feature = "metal", feature = "vulkan")))]
    pub const NAME: &str = "GPU (wgpu — WGSL)";
}

#[cfg(feature = "ndarray")]
mod backend {
    pub use burn::backend::NdArray as B;
    pub type Device = burn::backend::ndarray::NdArrayDevice;
    pub fn device() -> Device { Device::Cpu }
    #[cfg(feature = "blas-accelerate")]
    pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
    #[cfg(feature = "openblas-system")]
    pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
    #[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
    pub const NAME: &str = "CPU (NdArray + Rayon)";
}

use backend::{B, device};

// ── CLI ───────────────────────────────────────────────────────────────────────
#[derive(Parser, Debug)]
#[command(about = "OSF — inference latency benchmark")]
struct Args {
    /// Safetensors weights file.
    #[arg(long)]
    weights: String,
    /// Optional config JSON (uses default OSF-Base if omitted).
    #[arg(long)]
    config: Option<String>,
    /// Number of warmup runs.
    #[arg(long, default_value_t = 3)]
    warmup: usize,
    /// Number of timed runs.
    #[arg(long, default_value_t = 10)]
    runs: usize,
    /// Output results as JSON.
    #[arg(long, default_value_t = false)]
    json: bool,
}

/// Generate deterministic synthetic PSG signal.
fn generate_psg(n_channels: usize, n_samples: usize, seed: u32) -> Vec<f32> {
    let mut signal = vec![0.0f32; n_channels * n_samples];
    for ch in 0..n_channels {
        let freq = 1.0 + ch as f32 * 0.5 + seed as f32 * 0.01;
        let mut noise_state: u32 = (ch as u32 + 1).wrapping_mul(0xDEAD_BEEF).wrapping_add(seed);
        for t in 0..n_samples {
            let time = t as f32 / 64.0;
            let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
            noise_state ^= noise_state << 13;
            noise_state ^= noise_state >> 17;
            noise_state ^= noise_state << 5;
            let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
            signal[ch * n_samples + t] = sine + noise;
        }
    }
    signal
}

/// Time a closure, return (result, elapsed_ms).
fn timed<F, R>(f: F) -> (R, f64)
where F: FnOnce() -> R {
    let t = Instant::now();
    let r = f();
    (r, t.elapsed().as_secs_f64() * 1000.0)
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    let dev = device();
    let json_mode = args.json;

    if !json_mode {
        eprintln!("╔══════════════════════════════════════════════════════════════╗");
        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
        eprintln!("  Backend: {}", backend::NAME);
    }

    // Load config
    let model_cfg = if let Some(ref cfg_path) = args.config {
        let s = std::fs::read_to_string(cfg_path)?;
        serde_json::from_str(&s)?
    } else {
        osf_rs::ModelConfig::default()
    };

    // ── 1. Weight loading benchmark ─────────────────────────────────────────
    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
        model_cfg.clone(),
        Path::new(&args.weights),
        dev.clone(),
    )?;

    if !json_mode {
        eprintln!("  Model:   {}", encoder.describe());
        eprintln!("  Load:    {ms_load:.0} ms\n");
    }

    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920

    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
    if !json_mode {
        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
    }

    let signal = generate_psg(n_ch, n_t, 42);
    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);

    // Warmup
    for _ in 0..args.warmup {
        let _ = encoder.run_batch(&batch)?;
    }

    // Timed runs
    let mut infer_times = Vec::with_capacity(args.runs);
    for _ in 0..args.runs {
        let (_, ms) = timed(|| encoder.run_batch(&batch));
        infer_times.push(ms);
    }

    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
        / infer_times.len() as f64).sqrt();

    if !json_mode {
        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
            args.runs);
    }

    // ── 3. Batch size scaling ───────────────────────────────────────────────
    let batch_sizes = [1, 2, 4, 8, 16];
    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();

    if !json_mode {
        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
    }

    for &bs in &batch_sizes {
        // Build a batched input: [bs, 12, 1920]
        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
            &dev,
        ).reshape([bs, n_ch, n_t]);

        // Warmup
        let _ = encoder.model().forward_encoding(signal_tensor.clone());

        // Timed
        let mut t_vec = Vec::new();
        for _ in 0..5.max(args.runs / 2) {
            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
            t_vec.push(ms);
        }
        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
        let per_epoch = avg / bs as f64;
        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);

        if !json_mode {
            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
        }

        batch_scaling.push(serde_json::json!({
            "batch_size": bs,
            "mean_ms": round2(avg),
            "min_ms": round2(bmin),
            "max_ms": round2(bmax),
            "per_epoch_ms": round2(per_epoch),
            "runs": t_vec,
        }));
    }

    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
    let throughput_epochs_sec = 1000.0 / infer_mean;

    if !json_mode {
        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
    }

    // ── 5. Channel subset scaling (robustness test with padding) ────────────
    let channel_subsets = [2, 4, 6, 8, 10, 12];
    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();

    if !json_mode {
        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
    }

    for &active_ch in &channel_subsets {
        // Create signal with `active_ch` active channels, rest zero-padded to 12
        let mut sig = vec![0.0f32; n_ch * n_t];
        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
        for ch in 0..active_ch {
            for t in 0..n_t {
                sig[ch * n_t + t] = active[ch * n_t + t];
            }
        }
        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);

        let _ = encoder.run_batch(&b)?; // warmup
        let mut t_vec = Vec::new();
        for _ in 0..5 {
            let (_, ms) = timed(|| encoder.run_batch(&b));
            t_vec.push(ms);
        }
        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);

        if !json_mode {
            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
        }

        channel_scaling.push(serde_json::json!({
            "active_channels": active_ch,
            "total_channels": n_ch,
            "mean_ms": round2(avg),
            "min_ms": round2(cmin),
            "max_ms": round2(cmax),
            "runs": t_vec,
        }));
    }

    if !json_mode { eprintln!(); }

    // ── JSON output ─────────────────────────────────────────────────────────
    let result = serde_json::json!({
        "backend": backend::NAME,
        "model": {
            "encoder_name": model_cfg.encoder_name,
            "width": model_cfg.width,
            "depth": model_cfg.depth,
            "heads": model_cfg.heads,
            "num_leads": model_cfg.num_leads,
            "patch_size_time": model_cfg.patch_size_time,
            "patch_size_ch": model_cfg.patch_size_ch,
            "seq_len": model_cfg.seq_len,
        },
        "load_ms": round2(ms_load),
        "inference": {
            "channels": n_ch,
            "samples": n_t,
            "warmup": args.warmup,
            "runs": args.runs,
            "mean_ms": round2(infer_mean),
            "min_ms": round2(infer_min),
            "max_ms": round2(infer_max),
            "std_ms": round2(infer_std),
            "all_ms": infer_times,
        },
        "batch_scaling": batch_scaling,
        "channel_scaling": channel_scaling,
        "throughput_epochs_sec": round2(throughput_epochs_sec),
    });

    if json_mode {
        println!("{}", serde_json::to_string_pretty(&result)?);
    }

    Ok(())
}

fn round2(v: f64) -> f64 {
    (v * 100.0).round() / 100.0
}