1use std::path::Path;
15use std::time::Instant;
16use burn::prelude::*;
17use clap::Parser;
18
19#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
21mod backend {
22 pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
23 pub fn device() -> Device { Device::DefaultDevice }
24 #[cfg(feature = "metal")]
25 pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
26 #[cfg(feature = "vulkan")]
27 pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
28 #[cfg(not(any(feature = "metal", feature = "vulkan")))]
29 pub const NAME: &str = "GPU (wgpu — WGSL)";
30}
31
32#[cfg(feature = "ndarray")]
33mod backend {
34 pub use burn::backend::NdArray as B;
35 pub type Device = burn::backend::ndarray::NdArrayDevice;
36 pub fn device() -> Device { Device::Cpu }
37 #[cfg(feature = "blas-accelerate")]
38 pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
39 #[cfg(feature = "openblas-system")]
40 pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
41 #[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
42 pub const NAME: &str = "CPU (NdArray + Rayon)";
43}
44
45use backend::{B, device};
46
47#[derive(Parser, Debug)]
49#[command(about = "OSF — inference latency benchmark")]
50struct Args {
51 #[arg(long)]
53 weights: String,
54 #[arg(long)]
56 config: Option<String>,
57 #[arg(long, default_value_t = 3)]
59 warmup: usize,
60 #[arg(long, default_value_t = 10)]
62 runs: usize,
63 #[arg(long, default_value_t = false)]
65 json: bool,
66}
67
68fn generate_psg(n_channels: usize, n_samples: usize, seed: u32) -> Vec<f32> {
70 let mut signal = vec![0.0f32; n_channels * n_samples];
71 for ch in 0..n_channels {
72 let freq = 1.0 + ch as f32 * 0.5 + seed as f32 * 0.01;
73 let mut noise_state: u32 = (ch as u32 + 1).wrapping_mul(0xDEAD_BEEF).wrapping_add(seed);
74 for t in 0..n_samples {
75 let time = t as f32 / 64.0;
76 let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
77 noise_state ^= noise_state << 13;
78 noise_state ^= noise_state >> 17;
79 noise_state ^= noise_state << 5;
80 let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
81 signal[ch * n_samples + t] = sine + noise;
82 }
83 }
84 signal
85}
86
87fn timed<F, R>(f: F) -> (R, f64)
89where F: FnOnce() -> R {
90 let t = Instant::now();
91 let r = f();
92 (r, t.elapsed().as_secs_f64() * 1000.0)
93}
94
95fn main() -> anyhow::Result<()> {
96 let args = Args::parse();
97 let dev = device();
98 let json_mode = args.json;
99
100 if !json_mode {
101 eprintln!("╔══════════════════════════════════════════════════════════════╗");
102 eprintln!("║ OSF-RS — Inference Benchmark ║");
103 eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104 eprintln!(" Backend: {}", backend::NAME);
105 }
106
107 let model_cfg = if let Some(ref cfg_path) = args.config {
109 let s = std::fs::read_to_string(cfg_path)?;
110 serde_json::from_str(&s)?
111 } else {
112 osf_rs::ModelConfig::default()
113 };
114
115 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117 model_cfg.clone(),
118 Path::new(&args.weights),
119 dev.clone(),
120 )?;
121
122 if !json_mode {
123 eprintln!(" Model: {}", encoder.describe());
124 eprintln!(" Load: {ms_load:.0} ms\n");
125 }
126
127 let n_ch = osf_rs::NUM_PSG_CHANNELS; let n_t = osf_rs::EPOCH_SAMPLES; if !json_mode {
132 eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133 }
134
135 let signal = generate_psg(n_ch, n_t, 42);
136 let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138 for _ in 0..args.warmup {
140 let _ = encoder.run_batch(&batch)?;
141 }
142
143 let mut infer_times = Vec::with_capacity(args.runs);
145 for _ in 0..args.runs {
146 let (_, ms) = timed(|| encoder.run_batch(&batch));
147 infer_times.push(ms);
148 }
149
150 let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151 let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152 let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153 let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154 / infer_times.len() as f64).sqrt();
155
156 if !json_mode {
157 eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
158 args.runs);
159 }
160
161 let batch_sizes = [1, 2, 4, 8, 16];
163 let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165 if !json_mode {
166 eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167 eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168 }
169
170 for &bs in &batch_sizes {
171 let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173 let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174 burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175 &dev,
176 ).reshape([bs, n_ch, n_t]);
177
178 let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181 let mut t_vec = Vec::new();
183 for _ in 0..5.max(args.runs / 2) {
184 let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185 t_vec.push(ms);
186 }
187 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188 let per_epoch = avg / bs as f64;
189 let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190 let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192 if !json_mode {
193 eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
194 }
195
196 batch_scaling.push(serde_json::json!({
197 "batch_size": bs,
198 "mean_ms": round2(avg),
199 "min_ms": round2(bmin),
200 "max_ms": round2(bmax),
201 "per_epoch_ms": round2(per_epoch),
202 "runs": t_vec,
203 }));
204 }
205
206 let throughput_epochs_sec = 1000.0 / infer_mean;
208
209 if !json_mode {
210 eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211 }
212
213 let channel_subsets = [2, 4, 6, 8, 10, 12];
215 let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217 if !json_mode {
218 eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219 eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
220 }
221
222 for &active_ch in &channel_subsets {
223 let mut sig = vec![0.0f32; n_ch * n_t];
225 let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226 for ch in 0..active_ch {
227 for t in 0..n_t {
228 sig[ch * n_t + t] = active[ch * n_t + t];
229 }
230 }
231 let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233 let _ = encoder.run_batch(&b)?; let mut t_vec = Vec::new();
235 for _ in 0..5 {
236 let (_, ms) = timed(|| encoder.run_batch(&b));
237 t_vec.push(ms);
238 }
239 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240 let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241 let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243 if !json_mode {
244 eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
245 }
246
247 channel_scaling.push(serde_json::json!({
248 "active_channels": active_ch,
249 "total_channels": n_ch,
250 "mean_ms": round2(avg),
251 "min_ms": round2(cmin),
252 "max_ms": round2(cmax),
253 "runs": t_vec,
254 }));
255 }
256
257 if !json_mode { eprintln!(); }
258
259 let result = serde_json::json!({
261 "backend": backend::NAME,
262 "model": {
263 "encoder_name": model_cfg.encoder_name,
264 "width": model_cfg.width,
265 "depth": model_cfg.depth,
266 "heads": model_cfg.heads,
267 "num_leads": model_cfg.num_leads,
268 "patch_size_time": model_cfg.patch_size_time,
269 "patch_size_ch": model_cfg.patch_size_ch,
270 "seq_len": model_cfg.seq_len,
271 },
272 "load_ms": round2(ms_load),
273 "inference": {
274 "channels": n_ch,
275 "samples": n_t,
276 "warmup": args.warmup,
277 "runs": args.runs,
278 "mean_ms": round2(infer_mean),
279 "min_ms": round2(infer_min),
280 "max_ms": round2(infer_max),
281 "std_ms": round2(infer_std),
282 "all_ms": infer_times,
283 },
284 "batch_scaling": batch_scaling,
285 "channel_scaling": channel_scaling,
286 "throughput_epochs_sec": round2(throughput_epochs_sec),
287 });
288
289 if json_mode {
290 println!("{}", serde_json::to_string_pretty(&result)?);
291 }
292
293 Ok(())
294}
295
296fn round2(v: f64) -> f64 {
297 (v * 100.0).round() / 100.0
298}