eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
use burn::prelude::*;
use std::time::Instant;

#[cfg(feature = "ndarray")]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
    pub const NAME: &str = "CPU (NdArray)";
}
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
    pub use burn::backend::Wgpu as B;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
    pub const NAME: &str = "GPU (wgpu)";
}
use backend::{B, device};

fn main() -> anyhow::Result<()> {
    let dev = device();
    println!("Backend: {}", backend::NAME);
    let n_chans = 22; let n_times = 1000;
    let model = eegpt_rs::model::eegpt::EEGPT::<B>::new(
        4, n_chans, n_times, 64, 32, 4, 512, 2, 8, 4.0, true, 62, 16, 1e-6, &dev,
    );
    let x = Tensor::<B, 3>::ones([1, n_chans, n_times], &dev).mul_scalar(0.1f32);
    let chan_ids = Tensor::<B, 2, Int>::from_data(
        TensorData::new((0..n_chans as i64).collect::<Vec<_>>(), vec![1, n_chans]), &dev,
    );
    let t0 = Instant::now();
    let out = model.forward(x, chan_ids);
    let ms = t0.elapsed().as_secs_f64() * 1000.0;
    println!("Output shape: {:?}  ({ms:.1} ms)", out.dims());
    Ok(())
}