Skip to main content

reve_rs/
encoder.rs

1/// Standalone REVE encoder — produce EEG embeddings or classification outputs.
2
3use std::{path::Path, time::Instant};
4
5use anyhow::Context;
6use burn::prelude::*;
7
8use crate::{
9    config::ModelConfig,
10    data::{InputBatch, channel_wise_normalize},
11    model::reve::Reve,
12    weights::load_model,
13};
14
15/// Per-sample output from REVE.
16pub struct ReveOutput {
17    /// Output values (row-major f32).
18    /// Classification mode: [n_outputs]
19    pub output: Vec<f32>,
20    /// Shape of the output.
21    pub shape: Vec<usize>,
22    pub n_channels: usize,
23}
24
25/// Collection of outputs.
26pub struct EncodingResult {
27    pub outputs: Vec<ReveOutput>,
28    pub ms_load: f64,
29    pub ms_encode: f64,
30}
31
32/// REVE encoder for EEG signal processing.
33pub struct ReveEncoder<B: Backend> {
34    model: Reve<B>,
35    pub model_cfg: ModelConfig,
36    device: B::Device,
37}
38
39impl<B: Backend> ReveEncoder<B> {
40    /// Load model from config.json and weights safetensors.
41    pub fn load(
42        config_path: &Path,
43        weights_path: &Path,
44        device: B::Device,
45    ) -> anyhow::Result<(Self, f64)> {
46        let cfg_str = std::fs::read_to_string(config_path)
47            .with_context(|| format!("config: {}", config_path.display()))?;
48        let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
49        let model_cfg: ModelConfig = serde_json::from_value(
50            hf_val.get("model").cloned().unwrap_or(hf_val.clone()),
51        )
52        .context("parsing model config")?;
53
54        let t = Instant::now();
55        let model = load_model::<B>(
56            &model_cfg,
57            weights_path.to_str().context("weights path not valid UTF-8")?,
58            &device,
59        )?;
60        let ms = t.elapsed().as_secs_f64() * 1000.0;
61
62        Ok((
63            Self {
64                model,
65                model_cfg,
66                device,
67            },
68            ms,
69        ))
70    }
71
72    pub fn describe(&self) -> String {
73        let c = &self.model_cfg;
74        format!(
75            "REVE  embed_dim={}  depth={}  heads={}  head_dim={}  patch={}  outputs={}",
76            c.embed_dim, c.depth, c.heads, c.head_dim, c.patch_size, c.n_outputs,
77        )
78    }
79
80    /// Run inference on a prepared InputBatch.
81    pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<ReveOutput> {
82        let signal = channel_wise_normalize(batch.signal.clone());
83
84        let output = self.model.forward(signal, batch.positions.clone());
85        // [B, n_outputs] or [B, n_outputs]
86
87        let shape = output.dims().to_vec();
88        let output_vec = output
89            .into_data()
90            .to_vec::<f32>()
91            .map_err(|e| anyhow::anyhow!("output→vec: {e:?}"))?;
92
93        Ok(ReveOutput {
94            output: output_vec,
95            shape: shape[1..].to_vec(), // remove batch dim
96            n_channels: batch.n_channels,
97        })
98    }
99
100    /// Run on multiple batches.
101    pub fn run_batches(
102        &self,
103        batches: &[InputBatch<B>],
104    ) -> anyhow::Result<Vec<ReveOutput>> {
105        batches.iter().map(|b| self.run_batch(b)).collect()
106    }
107
108    pub fn device(&self) -> &B::Device {
109        &self.device
110    }
111}