use std::path::Path;
use std::time::Instant;
use burn::prelude::*;
use clap::Parser;
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
pub fn device() -> Device { Device::DefaultDevice }
#[cfg(feature = "metal")]
pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
#[cfg(feature = "vulkan")]
pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
#[cfg(not(any(feature = "metal", feature = "vulkan")))]
pub const NAME: &str = "GPU (wgpu — WGSL)";
}
#[cfg(feature = "ndarray")]
mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device { Device::Cpu }
#[cfg(feature = "blas-accelerate")]
pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
#[cfg(feature = "openblas-system")]
pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
#[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
pub const NAME: &str = "CPU (NdArray + Rayon)";
}
use backend::{B, device};
#[derive(Parser, Debug)]
#[command(about = "OSF — inference latency benchmark")]
struct Args {
#[arg(long)]
weights: String,
#[arg(long)]
config: Option<String>,
#[arg(long, default_value_t = 3)]
warmup: usize,
#[arg(long, default_value_t = 10)]
runs: usize,
#[arg(long, default_value_t = false)]
json: bool,
}
fn generate_psg(n_channels: usize, n_samples: usize, seed: u32) -> Vec<f32> {
let mut signal = vec![0.0f32; n_channels * n_samples];
for ch in 0..n_channels {
let freq = 1.0 + ch as f32 * 0.5 + seed as f32 * 0.01;
let mut noise_state: u32 = (ch as u32 + 1).wrapping_mul(0xDEAD_BEEF).wrapping_add(seed);
for t in 0..n_samples {
let time = t as f32 / 64.0;
let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
noise_state ^= noise_state << 13;
noise_state ^= noise_state >> 17;
noise_state ^= noise_state << 5;
let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
signal[ch * n_samples + t] = sine + noise;
}
}
signal
}
fn timed<F, R>(f: F) -> (R, f64)
where F: FnOnce() -> R {
let t = Instant::now();
let r = f();
(r, t.elapsed().as_secs_f64() * 1000.0)
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let dev = device();
let json_mode = args.json;
if !json_mode {
eprintln!("╔══════════════════════════════════════════════════════════════╗");
eprintln!("║ OSF-RS — Inference Benchmark ║");
eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
eprintln!(" Backend: {}", backend::NAME);
}
let model_cfg = if let Some(ref cfg_path) = args.config {
let s = std::fs::read_to_string(cfg_path)?;
serde_json::from_str(&s)?
} else {
osf_rs::ModelConfig::default()
};
let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
model_cfg.clone(),
Path::new(&args.weights),
dev.clone(),
)?;
if !json_mode {
eprintln!(" Model: {}", encoder.describe());
eprintln!(" Load: {ms_load:.0} ms\n");
}
let n_ch = osf_rs::NUM_PSG_CHANNELS; let n_t = osf_rs::EPOCH_SAMPLES;
if !json_mode {
eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
}
let signal = generate_psg(n_ch, n_t, 42);
let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
for _ in 0..args.warmup {
let _ = encoder.run_batch(&batch)?;
}
let mut infer_times = Vec::with_capacity(args.runs);
for _ in 0..args.runs {
let (_, ms) = timed(|| encoder.run_batch(&batch));
infer_times.push(ms);
}
let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
/ infer_times.len() as f64).sqrt();
if !json_mode {
eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
args.runs);
}
let batch_sizes = [1, 2, 4, 8, 16];
let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
if !json_mode {
eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
}
for &bs in &batch_sizes {
let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
&dev,
).reshape([bs, n_ch, n_t]);
let _ = encoder.model().forward_encoding(signal_tensor.clone());
let mut t_vec = Vec::new();
for _ in 0..5.max(args.runs / 2) {
let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
t_vec.push(ms);
}
let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
let per_epoch = avg / bs as f64;
let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
if !json_mode {
eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
}
batch_scaling.push(serde_json::json!({
"batch_size": bs,
"mean_ms": round2(avg),
"min_ms": round2(bmin),
"max_ms": round2(bmax),
"per_epoch_ms": round2(per_epoch),
"runs": t_vec,
}));
}
let throughput_epochs_sec = 1000.0 / infer_mean;
if !json_mode {
eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
}
let channel_subsets = [2, 4, 6, 8, 10, 12];
let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
if !json_mode {
eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
}
for &active_ch in &channel_subsets {
let mut sig = vec![0.0f32; n_ch * n_t];
let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
for ch in 0..active_ch {
for t in 0..n_t {
sig[ch * n_t + t] = active[ch * n_t + t];
}
}
let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
let _ = encoder.run_batch(&b)?; let mut t_vec = Vec::new();
for _ in 0..5 {
let (_, ms) = timed(|| encoder.run_batch(&b));
t_vec.push(ms);
}
let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
if !json_mode {
eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
}
channel_scaling.push(serde_json::json!({
"active_channels": active_ch,
"total_channels": n_ch,
"mean_ms": round2(avg),
"min_ms": round2(cmin),
"max_ms": round2(cmax),
"runs": t_vec,
}));
}
if !json_mode { eprintln!(); }
let result = serde_json::json!({
"backend": backend::NAME,
"model": {
"encoder_name": model_cfg.encoder_name,
"width": model_cfg.width,
"depth": model_cfg.depth,
"heads": model_cfg.heads,
"num_leads": model_cfg.num_leads,
"patch_size_time": model_cfg.patch_size_time,
"patch_size_ch": model_cfg.patch_size_ch,
"seq_len": model_cfg.seq_len,
},
"load_ms": round2(ms_load),
"inference": {
"channels": n_ch,
"samples": n_t,
"warmup": args.warmup,
"runs": args.runs,
"mean_ms": round2(infer_mean),
"min_ms": round2(infer_min),
"max_ms": round2(infer_max),
"std_ms": round2(infer_std),
"all_ms": infer_times,
},
"batch_scaling": batch_scaling,
"channel_scaling": channel_scaling,
"throughput_epochs_sec": round2(throughput_epochs_sec),
});
if json_mode {
println!("{}", serde_json::to_string_pretty(&result)?);
}
Ok(())
}
fn round2(v: f64) -> f64 {
(v * 100.0).round() / 100.0
}