Skip to main content

luna_rs/
data.rs

1/// Data preparation for LUNA inference (burn 0.20.1)
2///
3/// LUNA input: (B, C, T) signal + (B, C, 3) channel locations + optional channel name indices.
4
5use burn::prelude::*;
6
7/// A single prepared input for the LUNA model.
8pub struct InputBatch<B: Backend> {
9    /// EEG signal: [1, C, T] — z-scored and normalised.
10    pub signal: Tensor<B, 3>,
11    /// Channel 3D positions in metres: [1, C, 3].
12    pub channel_locations: Tensor<B, 3>,
13    /// Channel name indices into the global vocabulary: [1, C].
14    pub channel_names: Option<Tensor<B, 2, Int>>,
15    /// Number of channels.
16    pub n_channels: usize,
17    /// Number of time samples.
18    pub n_samples: usize,
19}
20
21/// Metadata from a FIF file.
22pub struct FifInfo {
23    pub ch_names: Vec<String>,
24    pub ch_pos_mm: Vec<[f32; 3]>,
25    pub sfreq: f32,
26    pub n_times_raw: usize,
27    pub duration_s: f32,
28    pub n_epochs: usize,
29    pub target_sfreq: f32,
30    pub epoch_dur_s: f32,
31}
32
33/// Channel-wise z-score normalisation (matching Python `ChannelWiseNormalize`).
34pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
35    let mean = x.clone().mean_dim(2);  // [B, C, 1]
36    let diff = x.clone() - mean.clone();
37    let var = (diff.clone() * diff.clone()).mean_dim(2);
38    let std = (var + 1e-8).sqrt();
39    (x - mean) / std
40}
41
42/// Build InputBatch from raw arrays with explicit channel name indices.
43pub fn build_batch<B: Backend>(
44    signal: Vec<f32>,       // [C, T] row-major
45    positions: Vec<f32>,    // [C, 3] row-major
46    channel_indices: Option<Vec<i64>>,  // [C] indices into CHANNEL_VOCAB
47    n_channels: usize,
48    n_samples: usize,
49    device: &B::Device,
50) -> InputBatch<B> {
51    let signal = Tensor::<B, 2>::from_data(
52        TensorData::new(signal, vec![n_channels, n_samples]), device,
53    ).unsqueeze_dim::<3>(0);  // [1, C, T]
54
55    let channel_locations = Tensor::<B, 2>::from_data(
56        TensorData::new(positions, vec![n_channels, 3]), device,
57    ).unsqueeze_dim::<3>(0);  // [1, C, 3]
58
59    let channel_names = channel_indices.map(|idx| {
60        Tensor::<B, 1, Int>::from_data(
61            TensorData::new(idx, vec![n_channels]), device,
62        ).unsqueeze_dim::<2>(0)  // [1, C]
63    });
64
65    InputBatch {
66        signal,
67        channel_locations,
68        channel_names,
69        n_channels,
70        n_samples,
71    }
72}
73
74/// Build InputBatch from channel name strings.
75///
76/// Automatically looks up:
77/// - Channel vocabulary indices from `CHANNEL_VOCAB`
78/// - 3D electrode positions (bipolar midpoints for names like "FP1-F7")
79///
80/// This is the recommended way to build batches for LUNA inference.
81pub fn build_batch_named<B: Backend>(
82    signal: Vec<f32>,         // [C, T] row-major
83    channel_names: &[&str],   // e.g. ["FP1-F7", "F7-T3", ...]
84    n_samples: usize,
85    device: &B::Device,
86) -> InputBatch<B> {
87    let n_channels = channel_names.len();
88
89    // Look up channel vocabulary indices
90    let indices = crate::channel_vocab::channel_indices_unwrap(channel_names);
91
92    // Look up 3D positions (bipolar midpoint or unipolar)
93    let positions: Vec<f32> = channel_names.iter()
94        .flat_map(|name| {
95            crate::channel_positions::bipolar_channel_xyz(name)
96                .unwrap_or([0.0, 0.0, 0.0])
97                .to_vec()
98        })
99        .collect();
100
101    build_batch(signal, positions, Some(indices), n_channels, n_samples, device)
102}