Skip to main content

eegdino_rs/
lib.rs

1//! # eegdino-rs
2//!
3//! Rust inference crate for the
4//! [EEG-DINO](https://github.com/miraclefish/EEG-DINO) foundation model,
5//! built on the [Burn](https://burn.dev) ML framework.
6//!
7//! EEG-DINO learns robust EEG representations via hierarchical self-distillation
8//! on 9 000+ hours of EEG data.  This crate provides a faithful port of the
9//! encoder architecture with verified numerical parity (NRMSE < 1e-6) against
10//! the original PyTorch implementation.
11//!
12//! ## Model sizes
13//!
14//! | Variant | Params | d_model | Heads | Layers | FFN dim |
15//! |---------|--------|---------|-------|--------|---------|
16//! | Small   | 4.6 M  | 200     | 8     | 12     | 512     |
17//! | Medium  | 33 M   | 512     | 16    | 16     | 1 024   |
18//! | Large   | 201 M  | 1 024   | 16    | 24     | 2 048   |
19//!
20//! ## Quick start (builder)
21//!
22//! ```rust,ignore
23//! use eegdino_rs::prelude::*;
24//! use burn::backend::NdArray;
25//!
26//! type B = NdArray;
27//!
28//! let encoder = EegDinoEncoder::<B>::builder()
29//!     .weights("weights/eeg_dino_small.safetensors")
30//!     .size(ModelSize::Small)
31//!     .device(Default::default())
32//!     .build()?;
33//!
34//! let signal = vec![0.0f32; 19 * 2000];
35//! let result = encoder.encode_raw(&signal, 1, 19, 2000)?;
36//! // result.shape == [1, 191, 200]
37//! ```
38//!
39//! ## Batch encoding
40//!
41//! ```rust,ignore
42//! let signals: Vec<Vec<f32>> = load_recordings();
43//! // Single batched forward pass (fastest):
44//! let result = encoder.encode_batch(&signals, 19, 2000)?;
45//! // Or one-by-one:
46//! let results = encoder.encode_many(&signals, 19, 2000);
47//! ```
48//!
49//! ## Backends
50//!
51//! | Feature | Backend | Notes |
52//! |---------|---------|-------|
53//! | `ndarray` (default) | CPU | Multi-threaded via Rayon + SIMD |
54//! | `blas-accelerate` | CPU + Accelerate | Recommended on Apple Silicon |
55//! | `wgpu` | GPU | Metal (macOS) / Vulkan (Linux) |
56//! | `wgpu-f16` | GPU f16 | Half-precision, 2x less memory |
57
58pub mod config;
59pub mod error;
60pub mod model;
61pub(crate) mod weights;
62pub mod inference;
63pub mod prelude;
64
65pub use config::{ModelConfig, ModelSize};
66pub use error::{EegDinoError, Result};
67pub use inference::{
68    EegDinoEncoder, EegDinoEncoderBuilder, EncodingResult,
69    EegDinoClassifier, ClassificationResult,
70    detect_model_size,
71};
72pub use model::encoder::EEGEncoder;
73pub use model::classifier::ClassificationModel;
74pub use model::embedding::{EmbeddingCache, PatchEmbedding};
75
76/// Configure the Rayon thread pool.  Call once before model use.
77pub fn init_threads(n: Option<usize>) {
78    let mut builder = rayon::ThreadPoolBuilder::new();
79    if let Some(n) = n {
80        builder = builder.num_threads(n);
81    }
82    builder.build_global().ok();
83}