Skip to main content

eegdino_rs/
inference.rs

1/// High-level inference APIs for EEG-DINO.
2///
3/// - [`EegDinoEncoder`] --- encoder-only embeddings
4/// - [`EegDinoClassifier`] --- full classification pipeline
5/// - [`EegDinoEncoderBuilder`] --- ergonomic construction via builder pattern
6///
7/// All public methods return [`Result<T, EegDinoError>`](crate::EegDinoError).
8use std::path::{Path, PathBuf};
9use std::time::Instant;
10
11use burn::prelude::*;
12
13use crate::config::{ModelConfig, ModelSize};
14use crate::error::{EegDinoError, Result};
15use crate::model::embedding::EmbeddingCache;
16use crate::model::encoder::EEGEncoder;
17use crate::model::classifier::ClassificationModel;
18use crate::weights;
19
20// ── Result types ────────────────────────────────────────────────────────────
21
22/// Result of encoding: per-sample embeddings.
23pub struct EncodingResult {
24    /// Raw embeddings `[B, 1+C*P, D]` flattened to `Vec<f32>`.
25    pub embeddings: Vec<f32>,
26    /// Shape of the embeddings tensor.
27    pub shape: Vec<usize>,
28    /// Encode time in milliseconds.
29    pub ms_encode: f64,
30}
31
32/// Classification result.
33pub struct ClassificationResult {
34    /// Logits `[B, num_classes]` flattened to `Vec<f32>`.
35    pub logits: Vec<f32>,
36    /// Shape of the logits tensor.
37    pub shape: Vec<usize>,
38    /// Inference time in milliseconds.
39    pub ms_infer: f64,
40}
41
42// ── Builder ─────────────────────────────────────────────────────────────────
43
44/// Builder for [`EegDinoEncoder`].
45///
46/// # Example
47///
48/// ```rust,ignore
49/// let encoder = EegDinoEncoder::<B>::builder()
50///     .weights("weights/eeg_dino_small.safetensors")
51///     .size(ModelSize::Small)       // optional --- auto-detected from weights
52///     .normalization(100.0)         // optional --- default 100.0
53///     .device(device)
54///     .build()?;
55/// ```
56pub struct EegDinoEncoderBuilder<B: Backend> {
57    weights_path: Option<PathBuf>,
58    config: Option<ModelConfig>,
59    normalization: f32,
60    device: Option<B::Device>,
61}
62
63impl<B: Backend> Default for EegDinoEncoderBuilder<B> {
64    fn default() -> Self {
65        Self { weights_path: None, config: None, normalization: 100.0, device: None }
66    }
67}
68
69impl<B: Backend> EegDinoEncoderBuilder<B> {
70    /// Path to the safetensors weight file (required).
71    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
72        self.weights_path = Some(path.into());
73        self
74    }
75
76    /// Model size.  If omitted, auto-detected from the weight file.
77    pub fn size(mut self, size: ModelSize) -> Self {
78        self.config = Some(ModelConfig::from_size(size));
79        self
80    }
81
82    /// Full model config.  Overrides [`size`](Self::size).
83    pub fn config(mut self, cfg: ModelConfig) -> Self {
84        self.config = Some(cfg);
85        self
86    }
87
88    /// Signal normalization divisor applied in [`EegDinoEncoder::encode_raw`].
89    /// Default: `100.0`.
90    pub fn normalization(mut self, n: f32) -> Self {
91        self.normalization = n;
92        self
93    }
94
95    /// Device to place the model on (required).
96    pub fn device(mut self, device: B::Device) -> Self {
97        self.device = Some(device);
98        self
99    }
100
101    /// Build the encoder, loading weights and creating the on-device cache.
102    pub fn build(self) -> Result<EegDinoEncoder<B>> {
103        let weights_path = self.weights_path
104            .ok_or_else(|| EegDinoError::Builder("weights path is required".into()))?;
105        let device = self.device
106            .ok_or_else(|| EegDinoError::Builder("device is required".into()))?;
107
108        let path_str = weights_path.to_str()
109            .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
110
111        let cfg = match self.config {
112            Some(c) => c,
113            None => {
114                let w = weights::WeightMap::from_file(path_str)?;
115                ModelConfig::from_size(w.detect_model_size()?)
116            }
117        };
118
119        let encoder = weights::load_encoder::<B>(&cfg, path_str, &device)?;
120        let cache = EmbeddingCache::new(&cfg, &device);
121
122        Ok(EegDinoEncoder { encoder, cache, config: cfg, normalization: self.normalization, device })
123    }
124}
125
126// ── Encoder ─────────────────────────────────────────────────────────────────
127
128/// Encoder-only wrapper with on-device cache for fast repeated inference.
129///
130/// Construct via [`EegDinoEncoder::builder`] or [`EegDinoEncoder::load`].
131pub struct EegDinoEncoder<B: Backend> {
132    /// The underlying encoder module.
133    pub encoder: EEGEncoder<B>,
134    /// On-device cached DFT basis and channel one-hot tensors.
135    pub cache: EmbeddingCache<B>,
136    /// Model configuration.
137    pub config: ModelConfig,
138    /// Divisor applied to raw signals in [`encode_raw`](Self::encode_raw).
139    pub normalization: f32,
140    device: B::Device,
141}
142
143impl<B: Backend> EegDinoEncoder<B> {
144    /// Create a builder.
145    pub fn builder() -> EegDinoEncoderBuilder<B> {
146        EegDinoEncoderBuilder::default()
147    }
148
149    /// Load encoder from a safetensors file (convenience shorthand).
150    ///
151    /// Returns `(encoder, load_time_ms)`.
152    pub fn load(
153        weights_path: &Path,
154        config: Option<ModelConfig>,
155        device: B::Device,
156    ) -> Result<(Self, f64)> {
157        let t0 = Instant::now();
158        let mut b = Self::builder().weights(weights_path).device(device);
159        if let Some(c) = config { b = b.config(c); }
160        let enc = b.build()?;
161        Ok((enc, t0.elapsed().as_secs_f64() * 1000.0))
162    }
163
164    /// Encode a pre-shaped tensor `[B, C, P, L]` -> `[B, 1+C*P, D]`.
165    pub fn encode(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
166        self.encoder.forward_cached(x, &self.cache)
167    }
168
169    /// Encode from a flat `&[f32]` signal.
170    ///
171    /// The signal is interpreted as `[batch_size, num_channels, num_samples]`
172    /// in row-major order, divided by [`normalization`](Self::normalization),
173    /// and reshaped into patches.
174    pub fn encode_raw(
175        &self,
176        signal: &[f32],
177        batch_size: usize,
178        num_channels: usize,
179        num_samples: usize,
180    ) -> Result<EncodingResult> {
181        let t0 = Instant::now();
182        let patch_size = self.config.patch_size;
183
184        if !num_samples.is_multiple_of(patch_size) {
185            return Err(EegDinoError::InvalidInput(format!(
186                "num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
187            )));
188        }
189        let expected = batch_size * num_channels * num_samples;
190        if signal.len() != expected {
191            return Err(EegDinoError::InvalidInput(format!(
192                "signal length {} != batch_size({batch_size}) * channels({num_channels}) * samples({num_samples}) = {expected}",
193                signal.len()
194            )));
195        }
196
197        let num_patches = num_samples / patch_size;
198        let x = Tensor::<B, 1>::from_floats(signal, &self.device)
199            .reshape([batch_size, num_channels, num_patches, patch_size]);
200        let x = x / self.normalization;
201
202        let output = self.encode(x);
203        let shape: Vec<usize> = output.dims().to_vec();
204        let data: Vec<f32> = output.to_data().convert::<f32>().to_vec().unwrap();
205
206        Ok(EncodingResult { embeddings: data, shape, ms_encode: t0.elapsed().as_secs_f64() * 1000.0 })
207    }
208
209    /// Encode multiple signals as a single batched tensor (fastest path).
210    ///
211    /// All signals must have length `num_channels * num_samples`.
212    pub fn encode_batch(
213        &self,
214        signals: &[Vec<f32>],
215        num_channels: usize,
216        num_samples: usize,
217    ) -> Result<EncodingResult> {
218        let expected_len = num_channels * num_samples;
219        let mut flat = Vec::with_capacity(signals.len() * expected_len);
220        for (i, s) in signals.iter().enumerate() {
221            if s.len() != expected_len {
222                return Err(EegDinoError::InvalidInput(format!(
223                    "signal[{i}] length {} != {expected_len}", s.len()
224                )));
225            }
226            flat.extend_from_slice(s);
227        }
228        self.encode_raw(&flat, signals.len(), num_channels, num_samples)
229    }
230
231    /// Encode multiple signals sequentially.
232    pub fn encode_many(
233        &self,
234        signals: &[Vec<f32>],
235        num_channels: usize,
236        num_samples: usize,
237    ) -> Vec<Result<EncodingResult>> {
238        signals.iter()
239            .map(|s| self.encode_raw(s, 1, num_channels, num_samples))
240            .collect()
241    }
242
243    /// Reference to the underlying device.
244    pub fn device(&self) -> &B::Device { &self.device }
245}
246
247// ── Classifier ──────────────────────────────────────────────────────────────
248
249/// Full classification model: encoder + pooling + MLP head.
250pub struct EegDinoClassifier<B: Backend> {
251    /// The underlying classification module.
252    pub model: ClassificationModel<B>,
253    /// Model configuration.
254    pub config: ModelConfig,
255    /// Number of output classes.
256    pub num_classes: usize,
257    /// Divisor applied to raw signals.
258    pub normalization: f32,
259    device: B::Device,
260}
261
262impl<B: Backend> EegDinoClassifier<B> {
263    /// Load a finetuned classification model.
264    pub fn load(
265        weights_path: &Path,
266        config: Option<ModelConfig>,
267        num_classes: usize,
268        device: B::Device,
269    ) -> Result<(Self, f64)> {
270        let t0 = Instant::now();
271
272        let path_str = weights_path.to_str()
273            .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
274
275        let cfg = match config {
276            Some(c) => c,
277            None => {
278                let w = weights::WeightMap::from_file(path_str)?;
279                ModelConfig::from_size(w.detect_model_size()?)
280            }
281        };
282
283        let model = weights::load_classifier::<B>(&cfg, num_classes, path_str, &device)?;
284        let ms = t0.elapsed().as_secs_f64() * 1000.0;
285        Ok((Self { model, config: cfg, num_classes, normalization: 100.0, device }, ms))
286    }
287
288    /// Classify raw EEG signals.
289    pub fn classify_raw(
290        &self,
291        signal: &[f32],
292        batch_size: usize,
293        num_channels: usize,
294        num_samples: usize,
295    ) -> Result<ClassificationResult> {
296        let t0 = Instant::now();
297        let patch_size = self.config.patch_size;
298
299        if !num_samples.is_multiple_of(patch_size) {
300            return Err(EegDinoError::InvalidInput(format!(
301                "num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
302            )));
303        }
304        let num_patches = num_samples / patch_size;
305
306        let x = Tensor::<B, 1>::from_floats(signal, &self.device)
307            .reshape([batch_size, num_channels, num_patches, patch_size]);
308        let x = x / self.normalization;
309
310        let logits = self.model.forward(x);
311        let shape: Vec<usize> = logits.dims().to_vec();
312        let data: Vec<f32> = logits.to_data().convert::<f32>().to_vec().unwrap();
313
314        Ok(ClassificationResult { logits: data, shape, ms_infer: t0.elapsed().as_secs_f64() * 1000.0 })
315    }
316
317    /// Classify a pre-shaped tensor `[B, C, P, L]`.
318    pub fn classify(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
319        self.model.forward(x)
320    }
321}
322
323// ── Convenience ─────────────────────────────────────────────────────────────
324
325/// Detect the model size from a safetensors file without loading all weights.
326pub fn detect_model_size(weights_path: &Path) -> Result<ModelSize> {
327    let path_str = weights_path.to_str()
328        .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
329    let w = weights::WeightMap::from_file(path_str)?;
330    w.detect_model_size()
331}