use std::path::{Path, PathBuf};
use std::time::Instant;
use burn::prelude::*;
use crate::config::{ModelConfig, ModelSize};
use crate::error::{EegDinoError, Result};
use crate::model::embedding::EmbeddingCache;
use crate::model::encoder::EEGEncoder;
use crate::model::classifier::ClassificationModel;
use crate::weights;
pub struct EncodingResult {
pub embeddings: Vec<f32>,
pub shape: Vec<usize>,
pub ms_encode: f64,
}
pub struct ClassificationResult {
pub logits: Vec<f32>,
pub shape: Vec<usize>,
pub ms_infer: f64,
}
pub struct EegDinoEncoderBuilder<B: Backend> {
weights_path: Option<PathBuf>,
config: Option<ModelConfig>,
normalization: f32,
device: Option<B::Device>,
}
impl<B: Backend> Default for EegDinoEncoderBuilder<B> {
fn default() -> Self {
Self { weights_path: None, config: None, normalization: 100.0, device: None }
}
}
impl<B: Backend> EegDinoEncoderBuilder<B> {
pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
self.weights_path = Some(path.into());
self
}
pub fn size(mut self, size: ModelSize) -> Self {
self.config = Some(ModelConfig::from_size(size));
self
}
pub fn config(mut self, cfg: ModelConfig) -> Self {
self.config = Some(cfg);
self
}
pub fn normalization(mut self, n: f32) -> Self {
self.normalization = n;
self
}
pub fn device(mut self, device: B::Device) -> Self {
self.device = Some(device);
self
}
pub fn build(self) -> Result<EegDinoEncoder<B>> {
let weights_path = self.weights_path
.ok_or_else(|| EegDinoError::Builder("weights path is required".into()))?;
let device = self.device
.ok_or_else(|| EegDinoError::Builder("device is required".into()))?;
let path_str = weights_path.to_str()
.ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
let cfg = match self.config {
Some(c) => c,
None => {
let w = weights::WeightMap::from_file(path_str)?;
ModelConfig::from_size(w.detect_model_size()?)
}
};
let encoder = weights::load_encoder::<B>(&cfg, path_str, &device)?;
let cache = EmbeddingCache::new(&cfg, &device);
Ok(EegDinoEncoder { encoder, cache, config: cfg, normalization: self.normalization, device })
}
}
pub struct EegDinoEncoder<B: Backend> {
pub encoder: EEGEncoder<B>,
pub cache: EmbeddingCache<B>,
pub config: ModelConfig,
pub normalization: f32,
device: B::Device,
}
impl<B: Backend> EegDinoEncoder<B> {
pub fn builder() -> EegDinoEncoderBuilder<B> {
EegDinoEncoderBuilder::default()
}
pub fn load(
weights_path: &Path,
config: Option<ModelConfig>,
device: B::Device,
) -> Result<(Self, f64)> {
let t0 = Instant::now();
let mut b = Self::builder().weights(weights_path).device(device);
if let Some(c) = config { b = b.config(c); }
let enc = b.build()?;
Ok((enc, t0.elapsed().as_secs_f64() * 1000.0))
}
pub fn encode(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
self.encoder.forward_cached(x, &self.cache)
}
pub fn encode_raw(
&self,
signal: &[f32],
batch_size: usize,
num_channels: usize,
num_samples: usize,
) -> Result<EncodingResult> {
let t0 = Instant::now();
let patch_size = self.config.patch_size;
if !num_samples.is_multiple_of(patch_size) {
return Err(EegDinoError::InvalidInput(format!(
"num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
)));
}
let expected = batch_size * num_channels * num_samples;
if signal.len() != expected {
return Err(EegDinoError::InvalidInput(format!(
"signal length {} != batch_size({batch_size}) * channels({num_channels}) * samples({num_samples}) = {expected}",
signal.len()
)));
}
let num_patches = num_samples / patch_size;
let x = Tensor::<B, 1>::from_floats(signal, &self.device)
.reshape([batch_size, num_channels, num_patches, patch_size]);
let x = x / self.normalization;
let output = self.encode(x);
let shape: Vec<usize> = output.dims().to_vec();
let data: Vec<f32> = output.to_data().convert::<f32>().to_vec().unwrap();
Ok(EncodingResult { embeddings: data, shape, ms_encode: t0.elapsed().as_secs_f64() * 1000.0 })
}
pub fn encode_batch(
&self,
signals: &[Vec<f32>],
num_channels: usize,
num_samples: usize,
) -> Result<EncodingResult> {
let expected_len = num_channels * num_samples;
let mut flat = Vec::with_capacity(signals.len() * expected_len);
for (i, s) in signals.iter().enumerate() {
if s.len() != expected_len {
return Err(EegDinoError::InvalidInput(format!(
"signal[{i}] length {} != {expected_len}", s.len()
)));
}
flat.extend_from_slice(s);
}
self.encode_raw(&flat, signals.len(), num_channels, num_samples)
}
pub fn encode_many(
&self,
signals: &[Vec<f32>],
num_channels: usize,
num_samples: usize,
) -> Vec<Result<EncodingResult>> {
signals.iter()
.map(|s| self.encode_raw(s, 1, num_channels, num_samples))
.collect()
}
pub fn device(&self) -> &B::Device { &self.device }
}
pub struct EegDinoClassifier<B: Backend> {
pub model: ClassificationModel<B>,
pub config: ModelConfig,
pub num_classes: usize,
pub normalization: f32,
device: B::Device,
}
impl<B: Backend> EegDinoClassifier<B> {
pub fn load(
weights_path: &Path,
config: Option<ModelConfig>,
num_classes: usize,
device: B::Device,
) -> Result<(Self, f64)> {
let t0 = Instant::now();
let path_str = weights_path.to_str()
.ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
let cfg = match config {
Some(c) => c,
None => {
let w = weights::WeightMap::from_file(path_str)?;
ModelConfig::from_size(w.detect_model_size()?)
}
};
let model = weights::load_classifier::<B>(&cfg, num_classes, path_str, &device)?;
let ms = t0.elapsed().as_secs_f64() * 1000.0;
Ok((Self { model, config: cfg, num_classes, normalization: 100.0, device }, ms))
}
pub fn classify_raw(
&self,
signal: &[f32],
batch_size: usize,
num_channels: usize,
num_samples: usize,
) -> Result<ClassificationResult> {
let t0 = Instant::now();
let patch_size = self.config.patch_size;
if !num_samples.is_multiple_of(patch_size) {
return Err(EegDinoError::InvalidInput(format!(
"num_samples ({num_samples}) must be divisible by patch_size ({patch_size})"
)));
}
let num_patches = num_samples / patch_size;
let x = Tensor::<B, 1>::from_floats(signal, &self.device)
.reshape([batch_size, num_channels, num_patches, patch_size]);
let x = x / self.normalization;
let logits = self.model.forward(x);
let shape: Vec<usize> = logits.dims().to_vec();
let data: Vec<f32> = logits.to_data().convert::<f32>().to_vec().unwrap();
Ok(ClassificationResult { logits: data, shape, ms_infer: t0.elapsed().as_secs_f64() * 1000.0 })
}
pub fn classify(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
self.model.forward(x)
}
}
pub fn detect_model_size(weights_path: &Path) -> Result<ModelSize> {
let path_str = weights_path.to_str()
.ok_or_else(|| EegDinoError::Builder("weights path is not valid UTF-8".into()))?;
let w = weights::WeightMap::from_file(path_str)?;
w.detect_model_size()
}