use std::{fs, path::PathBuf, time::Instant};
use clap::Parser;
use seizuretransformer::{load_model_from_file, SeizureTransformer, SeizureTransformerConfig};
use burn::prelude::*;
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
pub use burn::backend::{wgpu::WgpuDevice as Device, Wgpu as B};
pub fn device() -> Device {
Device::DefaultDevice
}
#[cfg(feature = "metal")]
pub const NAME: &str = "GPU (wgpu Metal)";
#[cfg(feature = "vulkan")]
pub const NAME: &str = "GPU (wgpu Vulkan)";
#[cfg(not(any(feature = "metal", feature = "vulkan")))]
pub const NAME: &str = "GPU (wgpu WGSL)";
}
#[cfg(feature = "ndarray")]
mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device {
Device::Cpu
}
pub const NAME: &str = "CPU (NdArray)";
}
use backend::{device, B};
#[derive(Parser, Debug)]
#[command(about = "SeizureTransformer inference (Burn 0.20)")]
struct Args {
#[arg(long)]
config: Option<PathBuf>,
#[arg(long)]
weights: Option<PathBuf>,
#[arg(long, default_value_t = 1)]
batch_size: usize,
#[arg(long, default_value_t = 2)]
warmup: usize,
#[arg(long, default_value_t = 10)]
iters: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let dev = device();
let cfg = if let Some(path) = args.config {
let txt = fs::read_to_string(&path)?;
serde_json::from_str::<SeizureTransformerConfig>(&txt)?
} else {
SeizureTransformerConfig::default()
};
println!("Backend: {}", backend::NAME);
println!("Config : {:?}", cfg);
let t0 = Instant::now();
let model: SeizureTransformer<B> = if let Some(weights) = args.weights {
load_model_from_file::<B>(&cfg, &weights, &dev)?
} else {
SeizureTransformer::new(&cfg, &dev)
};
let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
let x = Tensor::<B, 3>::zeros([args.batch_size, cfg.in_channels, cfg.in_samples], &dev);
for _ in 0..args.warmup {
let _ = model.forward(x.clone());
}
let t1 = Instant::now();
let mut y = model.forward(x.clone());
for _ in 1..args.iters {
y = model.forward(x.clone());
}
let infer_total_ms = t1.elapsed().as_secs_f64() * 1000.0;
let infer_avg_ms = infer_total_ms / args.iters as f64;
println!("Output shape: {:?}", y.dims());
println!(
"Load: {:.1} ms | Infer(avg {} iters, {} warmup): {:.1} ms",
load_ms, args.iters, args.warmup, infer_avg_ms
);
Ok(())
}