Skip to main content

benchmark/
benchmark.rs

1/// Benchmark — measure inference latency, load time, and batch/epoch scaling.
2///
3/// Demonstrates:
4///   - Weight loading time
5///   - Forward pass latency (CLS + patch embeddings)
6///   - Batch size scaling
7///   - Epoch count throughput
8///
9/// Usage:
10///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors
11///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors --json
12///   cargo run --example benchmark --release -- --weights data/osf_backbone.safetensors --warmup 5 --runs 20
13
14use std::path::Path;
15use std::time::Instant;
16use burn::prelude::*;
17use clap::Parser;
18
19// ── Backend ───────────────────────────────────────────────────────────────────
20#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
21mod backend {
22    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
23    pub fn device() -> Device { Device::DefaultDevice }
24    #[cfg(feature = "metal")]
25    pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
26    #[cfg(feature = "vulkan")]
27    pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
28    #[cfg(not(any(feature = "metal", feature = "vulkan")))]
29    pub const NAME: &str = "GPU (wgpu — WGSL)";
30}
31
32#[cfg(feature = "ndarray")]
33mod backend {
34    pub use burn::backend::NdArray as B;
35    pub type Device = burn::backend::ndarray::NdArrayDevice;
36    pub fn device() -> Device { Device::Cpu }
37    #[cfg(feature = "blas-accelerate")]
38    pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
39    #[cfg(feature = "openblas-system")]
40    pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
41    #[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
42    pub const NAME: &str = "CPU (NdArray + Rayon)";
43}
44
45use backend::{B, device};
46
47// ── CLI ───────────────────────────────────────────────────────────────────────
48#[derive(Parser, Debug)]
49#[command(about = "OSF — inference latency benchmark")]
50struct Args {
51    /// Safetensors weights file.
52    #[arg(long)]
53    weights: String,
54    /// Optional config JSON (uses default OSF-Base if omitted).
55    #[arg(long)]
56    config: Option<String>,
57    /// Number of warmup runs.
58    #[arg(long, default_value_t = 3)]
59    warmup: usize,
60    /// Number of timed runs.
61    #[arg(long, default_value_t = 10)]
62    runs: usize,
63    /// Output results as JSON.
64    #[arg(long, default_value_t = false)]
65    json: bool,
66}
67
68/// Generate deterministic synthetic PSG signal.
69fn generate_psg(n_channels: usize, n_samples: usize, seed: u32) -> Vec<f32> {
70    let mut signal = vec![0.0f32; n_channels * n_samples];
71    for ch in 0..n_channels {
72        let freq = 1.0 + ch as f32 * 0.5 + seed as f32 * 0.01;
73        let mut noise_state: u32 = (ch as u32 + 1).wrapping_mul(0xDEAD_BEEF).wrapping_add(seed);
74        for t in 0..n_samples {
75            let time = t as f32 / 64.0;
76            let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
77            noise_state ^= noise_state << 13;
78            noise_state ^= noise_state >> 17;
79            noise_state ^= noise_state << 5;
80            let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
81            signal[ch * n_samples + t] = sine + noise;
82        }
83    }
84    signal
85}
86
87/// Time a closure, return (result, elapsed_ms).
88fn timed<F, R>(f: F) -> (R, f64)
89where F: FnOnce() -> R {
90    let t = Instant::now();
91    let r = f();
92    (r, t.elapsed().as_secs_f64() * 1000.0)
93}
94
95fn main() -> anyhow::Result<()> {
96    let args = Args::parse();
97    let dev = device();
98    let json_mode = args.json;
99
100    if !json_mode {
101        eprintln!("╔══════════════════════════════════════════════════════════════╗");
102        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
103        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104        eprintln!("  Backend: {}", backend::NAME);
105    }
106
107    // Load config
108    let model_cfg = if let Some(ref cfg_path) = args.config {
109        let s = std::fs::read_to_string(cfg_path)?;
110        serde_json::from_str(&s)?
111    } else {
112        osf_rs::ModelConfig::default()
113    };
114
115    // ── 1. Weight loading benchmark ─────────────────────────────────────────
116    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117        model_cfg.clone(),
118        Path::new(&args.weights),
119        dev.clone(),
120    )?;
121
122    if !json_mode {
123        eprintln!("  Model:   {}", encoder.describe());
124        eprintln!("  Load:    {ms_load:.0} ms\n");
125    }
126
127    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
128    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920
129
130    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131    if !json_mode {
132        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133    }
134
135    let signal = generate_psg(n_ch, n_t, 42);
136    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138    // Warmup
139    for _ in 0..args.warmup {
140        let _ = encoder.run_batch(&batch)?;
141    }
142
143    // Timed runs
144    let mut infer_times = Vec::with_capacity(args.runs);
145    for _ in 0..args.runs {
146        let (_, ms) = timed(|| encoder.run_batch(&batch));
147        infer_times.push(ms);
148    }
149
150    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154        / infer_times.len() as f64).sqrt();
155
156    if !json_mode {
157        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
158            args.runs);
159    }
160
161    // ── 3. Batch size scaling ───────────────────────────────────────────────
162    let batch_sizes = [1, 2, 4, 8, 16];
163    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165    if !json_mode {
166        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168    }
169
170    for &bs in &batch_sizes {
171        // Build a batched input: [bs, 12, 1920]
172        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175            &dev,
176        ).reshape([bs, n_ch, n_t]);
177
178        // Warmup
179        let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181        // Timed
182        let mut t_vec = Vec::new();
183        for _ in 0..5.max(args.runs / 2) {
184            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185            t_vec.push(ms);
186        }
187        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188        let per_epoch = avg / bs as f64;
189        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192        if !json_mode {
193            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
194        }
195
196        batch_scaling.push(serde_json::json!({
197            "batch_size": bs,
198            "mean_ms": round2(avg),
199            "min_ms": round2(bmin),
200            "max_ms": round2(bmax),
201            "per_epoch_ms": round2(per_epoch),
202            "runs": t_vec,
203        }));
204    }
205
206    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207    let throughput_epochs_sec = 1000.0 / infer_mean;
208
209    if !json_mode {
210        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211    }
212
213    // ── 5. Channel subset scaling (robustness test with padding) ────────────
214    let channel_subsets = [2, 4, 6, 8, 10, 12];
215    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217    if !json_mode {
218        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
220    }
221
222    for &active_ch in &channel_subsets {
223        // Create signal with `active_ch` active channels, rest zero-padded to 12
224        let mut sig = vec![0.0f32; n_ch * n_t];
225        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226        for ch in 0..active_ch {
227            for t in 0..n_t {
228                sig[ch * n_t + t] = active[ch * n_t + t];
229            }
230        }
231        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233        let _ = encoder.run_batch(&b)?; // warmup
234        let mut t_vec = Vec::new();
235        for _ in 0..5 {
236            let (_, ms) = timed(|| encoder.run_batch(&b));
237            t_vec.push(ms);
238        }
239        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243        if !json_mode {
244            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
245        }
246
247        channel_scaling.push(serde_json::json!({
248            "active_channels": active_ch,
249            "total_channels": n_ch,
250            "mean_ms": round2(avg),
251            "min_ms": round2(cmin),
252            "max_ms": round2(cmax),
253            "runs": t_vec,
254        }));
255    }
256
257    if !json_mode { eprintln!(); }
258
259    // ── JSON output ─────────────────────────────────────────────────────────
260    let result = serde_json::json!({
261        "backend": backend::NAME,
262        "model": {
263            "encoder_name": model_cfg.encoder_name,
264            "width": model_cfg.width,
265            "depth": model_cfg.depth,
266            "heads": model_cfg.heads,
267            "num_leads": model_cfg.num_leads,
268            "patch_size_time": model_cfg.patch_size_time,
269            "patch_size_ch": model_cfg.patch_size_ch,
270            "seq_len": model_cfg.seq_len,
271        },
272        "load_ms": round2(ms_load),
273        "inference": {
274            "channels": n_ch,
275            "samples": n_t,
276            "warmup": args.warmup,
277            "runs": args.runs,
278            "mean_ms": round2(infer_mean),
279            "min_ms": round2(infer_min),
280            "max_ms": round2(infer_max),
281            "std_ms": round2(infer_std),
282            "all_ms": infer_times,
283        },
284        "batch_scaling": batch_scaling,
285        "channel_scaling": channel_scaling,
286        "throughput_epochs_sec": round2(throughput_epochs_sec),
287    });
288
289    if json_mode {
290        println!("{}", serde_json::to_string_pretty(&result)?);
291    }
292
293    Ok(())
294}
295
296fn round2(v: f64) -> f64 {
297    (v * 100.0).round() / 100.0
298}