seizuretransformer 0.0.1

SeizureTransformer EEG model in Rust (Burn + wgpu)
Documentation
use std::{fs, path::PathBuf, time::Instant};

use clap::Parser;
use seizuretransformer::{load_model_from_file, SeizureTransformer, SeizureTransformerConfig};

use burn::prelude::*;

#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
    pub use burn::backend::{wgpu::WgpuDevice as Device, Wgpu as B};
    pub fn device() -> Device {
        Device::DefaultDevice
    }
    #[cfg(feature = "metal")]
    pub const NAME: &str = "GPU (wgpu Metal)";
    #[cfg(feature = "vulkan")]
    pub const NAME: &str = "GPU (wgpu Vulkan)";
    #[cfg(not(any(feature = "metal", feature = "vulkan")))]
    pub const NAME: &str = "GPU (wgpu WGSL)";
}

#[cfg(feature = "ndarray")]
mod backend {
    pub use burn::backend::NdArray as B;
    pub type Device = burn::backend::ndarray::NdArrayDevice;
    pub fn device() -> Device {
        Device::Cpu
    }
    pub const NAME: &str = "CPU (NdArray)";
}

use backend::{device, B};

#[derive(Parser, Debug)]
#[command(about = "SeizureTransformer inference (Burn 0.20)")]
struct Args {
    #[arg(long)]
    config: Option<PathBuf>,

    #[arg(long)]
    weights: Option<PathBuf>,

    #[arg(long, default_value_t = 1)]
    batch_size: usize,

    #[arg(long, default_value_t = 2)]
    warmup: usize,

    #[arg(long, default_value_t = 10)]
    iters: usize,
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    let dev = device();

    let cfg = if let Some(path) = args.config {
        let txt = fs::read_to_string(&path)?;
        serde_json::from_str::<SeizureTransformerConfig>(&txt)?
    } else {
        SeizureTransformerConfig::default()
    };

    println!("Backend: {}", backend::NAME);
    println!("Config : {:?}", cfg);

    let t0 = Instant::now();
    let model: SeizureTransformer<B> = if let Some(weights) = args.weights {
        load_model_from_file::<B>(&cfg, &weights, &dev)?
    } else {
        SeizureTransformer::new(&cfg, &dev)
    };
    let load_ms = t0.elapsed().as_secs_f64() * 1000.0;

    let x = Tensor::<B, 3>::zeros([args.batch_size, cfg.in_channels, cfg.in_samples], &dev);

    for _ in 0..args.warmup {
        let _ = model.forward(x.clone());
    }

    let t1 = Instant::now();
    let mut y = model.forward(x.clone());
    for _ in 1..args.iters {
        y = model.forward(x.clone());
    }
    let infer_total_ms = t1.elapsed().as_secs_f64() * 1000.0;
    let infer_avg_ms = infer_total_ms / args.iters as f64;

    println!("Output shape: {:?}", y.dims());
    println!(
        "Load: {:.1} ms | Infer(avg {} iters, {} warmup): {:.1} ms",
        load_ms, args.iters, args.warmup, infer_avg_ms
    );
    Ok(())
}