1use std::path::{Path, PathBuf};
9use std::time::Instant;
10
11use burn::prelude::*;
12
13use crate::config::{ModelConfig, ModelSize};
14use crate::error::{EegDinoError, Result};
15use crate::model::embedding::EmbeddingCache;
16use crate::model::encoder::EEGEncoder;
17use crate::model::classifier::ClassificationModel;
18use crate::weights;
19
20pub struct EncodingResult {
24 pub embeddings: Vec<f32>,
26 pub shape: Vec<usize>,
28 pub ms_encode: f64,
30}
31
32pub struct ClassificationResult {
34 pub logits: Vec<f32>,
36 pub shape: Vec<usize>,
38 pub ms_infer: f64,
40}
41
42pub struct EegDinoEncoderBuilder<B: Backend> {
57 weights_path: Option<PathBuf>,
58 config: Option<ModelConfig>,
59 normalization: f32,
60 device: Option<B::Device>,
61}
62
63impl<B: Backend> Default for EegDinoEncoderBuilder<B> {
64 fn default() -> Self {
65 Self { weights_path: None, config: None, normalization: 100.0, device: None }
66 }
67}
68
69impl<B: Backend> EegDinoEncoderBuilder<B> {
70 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
72 self.weights_path = Some(path.into());
73 self
74 }
75
76 pub fn size(mut self, size: ModelSize) -> Self {
78 self.config = Some(ModelConfig::from_size(size));
79 self
80 }
81
82 pub fn config(mut self, cfg: ModelConfig) -> Self {
84 self.config = Some(cfg);
85 self
86 }
87
88 pub fn normalization(mut self, n: f32) -> Self {
91 self.normalization = n;
92 self
93 }
94
95 pub fn device(mut self, device: B::Device) -> Self {
97 self.device = Some(device);
98 self
99 }
100
101 pub fn build(self) -> Result<EegDinoEncoder<B>> {
103 let weights_path = self.weights_path
104 .ok_or_else(|| EegDinoError::Builder("weights path is required".into()))?;
105 let device = self.device
106 .ok_or_else(|| EegDinoError::Builder("device is required".into()))?;
107
108 let path_str = weights_path.to_str()
109 .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
110
111 let cfg = match self.config {
112 Some(c) => c,
113 None => {
114 let w = weights::WeightMap::from_file(path_str)?;
115 ModelConfig::from_size(w.detect_model_size()?)
116 }
117 };
118
119 let encoder = weights::load_encoder::<B>(&cfg, path_str, &device)?;
120 let cache = EmbeddingCache::new(&cfg, &device);
121
122 Ok(EegDinoEncoder { encoder, cache, config: cfg, normalization: self.normalization, device })
123 }
124}
125
126pub struct EegDinoEncoder<B: Backend> {
132 pub encoder: EEGEncoder<B>,
134 pub cache: EmbeddingCache<B>,
136 pub config: ModelConfig,
138 pub normalization: f32,
140 device: B::Device,
141}
142
143impl<B: Backend> EegDinoEncoder<B> {
144 pub fn builder() -> EegDinoEncoderBuilder<B> {
146 EegDinoEncoderBuilder::default()
147 }
148
149 pub fn load(
153 weights_path: &Path,
154 config: Option<ModelConfig>,
155 device: B::Device,
156 ) -> Result<(Self, f64)> {
157 let t0 = Instant::now();
158 let mut b = Self::builder().weights(weights_path).device(device);
159 if let Some(c) = config { b = b.config(c); }
160 let enc = b.build()?;
161 Ok((enc, t0.elapsed().as_secs_f64() * 1000.0))
162 }
163
164 pub fn encode(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
166 self.encoder.forward_cached(x, &self.cache)
167 }
168
169 pub fn encode_raw(
175 &self,
176 signal: &[f32],
177 batch_size: usize,
178 num_channels: usize,
179 num_samples: usize,
180 ) -> Result<EncodingResult> {
181 let t0 = Instant::now();
182 let patch_size = self.config.patch_size;
183
184 if !num_samples.is_multiple_of(patch_size) {
185 return Err(EegDinoError::InvalidInput(format!(
186 "num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
187 )));
188 }
189 let expected = batch_size * num_channels * num_samples;
190 if signal.len() != expected {
191 return Err(EegDinoError::InvalidInput(format!(
192 "signal length {} != batch_size({batch_size}) * channels({num_channels}) * samples({num_samples}) = {expected}",
193 signal.len()
194 )));
195 }
196
197 let num_patches = num_samples / patch_size;
198 let x = Tensor::<B, 1>::from_floats(signal, &self.device)
199 .reshape([batch_size, num_channels, num_patches, patch_size]);
200 let x = x / self.normalization;
201
202 let output = self.encode(x);
203 let shape: Vec<usize> = output.dims().to_vec();
204 let data: Vec<f32> = output.to_data().convert::<f32>().to_vec().unwrap();
205
206 Ok(EncodingResult { embeddings: data, shape, ms_encode: t0.elapsed().as_secs_f64() * 1000.0 })
207 }
208
209 pub fn encode_batch(
213 &self,
214 signals: &[Vec<f32>],
215 num_channels: usize,
216 num_samples: usize,
217 ) -> Result<EncodingResult> {
218 let expected_len = num_channels * num_samples;
219 let mut flat = Vec::with_capacity(signals.len() * expected_len);
220 for (i, s) in signals.iter().enumerate() {
221 if s.len() != expected_len {
222 return Err(EegDinoError::InvalidInput(format!(
223 "signal[{i}] length {} != {expected_len}", s.len()
224 )));
225 }
226 flat.extend_from_slice(s);
227 }
228 self.encode_raw(&flat, signals.len(), num_channels, num_samples)
229 }
230
231 pub fn encode_many(
233 &self,
234 signals: &[Vec<f32>],
235 num_channels: usize,
236 num_samples: usize,
237 ) -> Vec<Result<EncodingResult>> {
238 signals.iter()
239 .map(|s| self.encode_raw(s, 1, num_channels, num_samples))
240 .collect()
241 }
242
243 pub fn device(&self) -> &B::Device { &self.device }
245}
246
247pub struct EegDinoClassifier<B: Backend> {
251 pub model: ClassificationModel<B>,
253 pub config: ModelConfig,
255 pub num_classes: usize,
257 pub normalization: f32,
259 device: B::Device,
260}
261
262impl<B: Backend> EegDinoClassifier<B> {
263 pub fn load(
265 weights_path: &Path,
266 config: Option<ModelConfig>,
267 num_classes: usize,
268 device: B::Device,
269 ) -> Result<(Self, f64)> {
270 let t0 = Instant::now();
271
272 let path_str = weights_path.to_str()
273 .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
274
275 let cfg = match config {
276 Some(c) => c,
277 None => {
278 let w = weights::WeightMap::from_file(path_str)?;
279 ModelConfig::from_size(w.detect_model_size()?)
280 }
281 };
282
283 let model = weights::load_classifier::<B>(&cfg, num_classes, path_str, &device)?;
284 let ms = t0.elapsed().as_secs_f64() * 1000.0;
285 Ok((Self { model, config: cfg, num_classes, normalization: 100.0, device }, ms))
286 }
287
288 pub fn classify_raw(
290 &self,
291 signal: &[f32],
292 batch_size: usize,
293 num_channels: usize,
294 num_samples: usize,
295 ) -> Result<ClassificationResult> {
296 let t0 = Instant::now();
297 let patch_size = self.config.patch_size;
298
299 if !num_samples.is_multiple_of(patch_size) {
300 return Err(EegDinoError::InvalidInput(format!(
301 "num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
302 )));
303 }
304 let num_patches = num_samples / patch_size;
305
306 let x = Tensor::<B, 1>::from_floats(signal, &self.device)
307 .reshape([batch_size, num_channels, num_patches, patch_size]);
308 let x = x / self.normalization;
309
310 let logits = self.model.forward(x);
311 let shape: Vec<usize> = logits.dims().to_vec();
312 let data: Vec<f32> = logits.to_data().convert::<f32>().to_vec().unwrap();
313
314 Ok(ClassificationResult { logits: data, shape, ms_infer: t0.elapsed().as_secs_f64() * 1000.0 })
315 }
316
317 pub fn classify(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
319 self.model.forward(x)
320 }
321}
322
323pub fn detect_model_size(weights_path: &Path) -> Result<ModelSize> {
327 let path_str = weights_path.to_str()
328 .ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
329 let w = weights::WeightMap::from_file(path_str)?;
330 w.detect_model_size()
331}