Skip to main content

reve_rs/
data.rs

1/// Data preparation for REVE inference.
2///
3/// REVE input: (B, C, T) signal + (B, C, 3) channel positions.
4
5use burn::prelude::*;
6
7/// A single prepared input for the REVE model.
8pub struct InputBatch<B: Backend> {
9    /// EEG signal: [1, C, T].
10    pub signal: Tensor<B, 3>,
11    /// Channel 3D positions: [1, C, 3].
12    pub positions: Tensor<B, 3>,
13    /// Number of channels.
14    pub n_channels: usize,
15    /// Number of time samples.
16    pub n_samples: usize,
17}
18
19/// Channel-wise z-score normalisation (matching REVE preprocessing).
20///
21/// REVE expects z-scored input clipped at 15 standard deviations.
22pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
23    let mean = x.clone().mean_dim(2); // [B, C, 1]
24    let diff = x.clone() - mean.clone();
25    let var = (diff.clone() * diff).mean_dim(2);
26    let std = (var + 1e-8).sqrt();
27    let normed = (x - mean) / std;
28    // Clip at ±15 std
29    normed.clamp(-15.0, 15.0)
30}
31
32/// Build InputBatch from raw arrays.
33pub fn build_batch<B: Backend>(
34    signal: Vec<f32>,      // [C, T] row-major
35    positions: Vec<f32>,   // [C, 3] row-major
36    n_channels: usize,
37    n_samples: usize,
38    device: &B::Device,
39) -> InputBatch<B> {
40    let signal = Tensor::<B, 2>::from_data(
41        TensorData::new(signal, vec![n_channels, n_samples]),
42        device,
43    )
44    .unsqueeze_dim::<3>(0); // [1, C, T]
45
46    let positions = Tensor::<B, 2>::from_data(
47        TensorData::new(positions, vec![n_channels, 3]),
48        device,
49    )
50    .unsqueeze_dim::<3>(0); // [1, C, 3]
51
52    InputBatch {
53        signal,
54        positions,
55        n_channels,
56        n_samples,
57    }
58}
59
60/// Build InputBatch from channel names using a position bank.
61pub fn build_batch_named<B: Backend>(
62    signal: Vec<f32>,
63    channel_names: &[&str],
64    n_samples: usize,
65    position_bank: &crate::position_bank::PositionBank,
66    device: &B::Device,
67) -> InputBatch<B> {
68    let n_channels = channel_names.len();
69    let positions = position_bank.get_positions(channel_names);
70    build_batch(signal, positions, n_channels, n_samples, device)
71}