use std::{fs, path::PathBuf};
use burn::prelude::*;
use clap::Parser;
use seizuretransformer::{load_model_from_file, SeizureTransformer, SeizureTransformerConfig};
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
pub use burn::backend::{wgpu::WgpuDevice as Device, Wgpu as B};
pub fn device() -> Device {
Device::DefaultDevice
}
}
#[cfg(feature = "ndarray")]
mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device {
Device::Cpu
}
}
use backend::{device, B};
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
config: PathBuf,
#[arg(long)]
input: PathBuf,
#[arg(long)]
output: PathBuf,
#[arg(long)]
batch: usize,
#[arg(long)]
channels: usize,
#[arg(long)]
samples: usize,
#[arg(long)]
weights: Option<PathBuf>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let dev = device();
let cfg_txt = fs::read_to_string(&args.config)?;
let cfg: SeizureTransformerConfig = serde_json::from_str(&cfg_txt)?;
let model: SeizureTransformer<B> = if let Some(weights) = args.weights.as_ref() {
load_model_from_file::<B>(&cfg, weights, &dev)?
} else {
SeizureTransformer::new(&cfg, &dev)
};
let raw = fs::read(&args.input)?;
let floats: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&raw).to_vec();
let expected = args.batch * args.channels * args.samples;
anyhow::ensure!(
floats.len() == expected,
"input length mismatch: got {}, expected {}",
floats.len(),
expected
);
let x = Tensor::<B, 3>::from_data(
TensorData::new(floats, [args.batch, args.channels, args.samples]),
&dev,
);
let y = model.forward(x);
let out = y.into_data().to_vec::<f32>()?;
fs::write(&args.output, bytemuck::cast_slice(&out))?;
Ok(())
}