seizuretransformer 0.0.1

SeizureTransformer EEG model in Rust (Burn + wgpu)
Documentation
use std::{fs, path::PathBuf};

use burn::prelude::*;
use clap::Parser;
use seizuretransformer::{load_model_from_file, SeizureTransformer, SeizureTransformerConfig};

#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
    pub use burn::backend::{wgpu::WgpuDevice as Device, Wgpu as B};
    pub fn device() -> Device {
        Device::DefaultDevice
    }
}

#[cfg(feature = "ndarray")]
mod backend {
    pub use burn::backend::NdArray as B;
    pub type Device = burn::backend::ndarray::NdArrayDevice;
    pub fn device() -> Device {
        Device::Cpu
    }
}

use backend::{device, B};

#[derive(Parser, Debug)]
struct Args {
    #[arg(long)]
    config: PathBuf,
    #[arg(long)]
    input: PathBuf,
    #[arg(long)]
    output: PathBuf,
    #[arg(long)]
    batch: usize,
    #[arg(long)]
    channels: usize,
    #[arg(long)]
    samples: usize,
    #[arg(long)]
    weights: Option<PathBuf>,
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    let dev = device();

    let cfg_txt = fs::read_to_string(&args.config)?;
    let cfg: SeizureTransformerConfig = serde_json::from_str(&cfg_txt)?;

    let model: SeizureTransformer<B> = if let Some(weights) = args.weights.as_ref() {
        load_model_from_file::<B>(&cfg, weights, &dev)?
    } else {
        SeizureTransformer::new(&cfg, &dev)
    };

    let raw = fs::read(&args.input)?;
    let floats: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&raw).to_vec();
    let expected = args.batch * args.channels * args.samples;
    anyhow::ensure!(
        floats.len() == expected,
        "input length mismatch: got {}, expected {}",
        floats.len(),
        expected
    );

    let x = Tensor::<B, 3>::from_data(
        TensorData::new(floats, [args.batch, args.channels, args.samples]),
        &dev,
    );
    let y = model.forward(x);
    let out = y.into_data().to_vec::<f32>()?;
    fs::write(&args.output, bytemuck::cast_slice(&out))?;
    Ok(())
}