eegdino_rs/lib.rs
1//! # eegdino-rs
2//!
3//! Rust inference crate for the
4//! [EEG-DINO](https://github.com/miraclefish/EEG-DINO) foundation model,
5//! built on the [Burn](https://burn.dev) ML framework.
6//!
7//! EEG-DINO learns robust EEG representations via hierarchical self-distillation
8//! on 9 000+ hours of EEG data. This crate provides a faithful port of the
9//! encoder architecture with verified numerical parity (NRMSE < 1e-6) against
10//! the original PyTorch implementation.
11//!
12//! ## Model sizes
13//!
14//! | Variant | Params | d_model | Heads | Layers | FFN dim |
15//! |---------|--------|---------|-------|--------|---------|
16//! | Small | 4.6 M | 200 | 8 | 12 | 512 |
17//! | Medium | 33 M | 512 | 16 | 16 | 1 024 |
18//! | Large | 201 M | 1 024 | 16 | 24 | 2 048 |
19//!
20//! ## Quick start (builder)
21//!
22//! ```rust,ignore
23//! use eegdino_rs::prelude::*;
24//! use burn::backend::NdArray;
25//!
26//! type B = NdArray;
27//!
28//! let encoder = EegDinoEncoder::<B>::builder()
29//! .weights("weights/eeg_dino_small.safetensors")
30//! .size(ModelSize::Small)
31//! .device(Default::default())
32//! .build()?;
33//!
34//! let signal = vec![0.0f32; 19 * 2000];
35//! let result = encoder.encode_raw(&signal, 1, 19, 2000)?;
36//! // result.shape == [1, 191, 200]
37//! ```
38//!
39//! ## Batch encoding
40//!
41//! ```rust,ignore
42//! let signals: Vec<Vec<f32>> = load_recordings();
43//! // Single batched forward pass (fastest):
44//! let result = encoder.encode_batch(&signals, 19, 2000)?;
45//! // Or one-by-one:
46//! let results = encoder.encode_many(&signals, 19, 2000);
47//! ```
48//!
49//! ## Backends
50//!
51//! | Feature | Backend | Notes |
52//! |---------|---------|-------|
53//! | `ndarray` (default) | CPU | Multi-threaded via Rayon + SIMD |
54//! | `blas-accelerate` | CPU + Accelerate | Recommended on Apple Silicon |
55//! | `wgpu` | GPU | Metal (macOS) / Vulkan (Linux) |
56//! | `wgpu-f16` | GPU f16 | Half-precision, 2x less memory |
57
58pub mod config;
59pub mod error;
60pub mod model;
61pub(crate) mod weights;
62pub mod inference;
63pub mod prelude;
64
65pub use config::{ModelConfig, ModelSize};
66pub use error::{EegDinoError, Result};
67pub use inference::{
68 EegDinoEncoder, EegDinoEncoderBuilder, EncodingResult,
69 EegDinoClassifier, ClassificationResult,
70 detect_model_size,
71};
72pub use model::encoder::EEGEncoder;
73pub use model::classifier::ClassificationModel;
74pub use model::embedding::{EmbeddingCache, PatchEmbedding};
75
76/// Configure the Rayon thread pool. Call once before model use.
77pub fn init_threads(n: Option<usize>) {
78 let mut builder = rayon::ThreadPoolBuilder::new();
79 if let Some(n) = n {
80 builder = builder.num_threads(n);
81 }
82 builder.build_global().ok();
83}