Skip to main content

luna_rs/
preprocessing.rs

1//! End-to-end EEG preprocessing for LUNA inference.
2//!
3//! Bridges `exg` / `exg-luna` preprocessing with `luna-rs` `InputBatch`:
4//!
5//! ```text
6//! .edf / .fif / .csv
7//!   │
8//!   ├─ exg: read raw data
9//!   ├─ exg-luna: bandpass → notch → resample → bipolar → epoch
10//!   └─ luna-rs: channel vocab lookup → InputBatch
11//! ```
12//!
13//! # Quick start
14//!
15//! ```rust,ignore
16//! use luna_rs::preprocessing::load_edf;
17//!
18//! let (batches, info) = load_edf::<B>("recording.edf", &device)?;
19//! for batch in &batches {
20//!     let result = encoder.run_batch(batch)?;
21//! }
22//! ```
23
24use std::path::Path;
25
26use anyhow::{Context, Result};
27use burn::prelude::*;
28
29use crate::channel_positions::bipolar_channel_xyz;
30use crate::channel_vocab;
31use crate::data::InputBatch;
32
33/// Preprocessing metadata returned alongside batches.
34pub struct PreprocInfo {
35    /// Channel names after bipolar montage (e.g. "FP1-F7", ...).
36    pub ch_names: Vec<String>,
37    /// Number of channels after montage.
38    pub n_channels: usize,
39    /// Number of epochs produced.
40    pub n_epochs: usize,
41    /// Source sampling rate (Hz).
42    pub src_sfreq: f32,
43    /// Target sampling rate (Hz) — always 256.
44    pub target_sfreq: f32,
45    /// Epoch duration (seconds) — always 5.0.
46    pub epoch_dur: f32,
47}
48
49/// Build `InputBatch`es from exg-luna preprocessed epochs.
50fn epochs_to_batches<B: Backend>(
51    epochs: Vec<(ndarray::Array2<f32>, Vec<String>)>,
52    device: &B::Device,
53) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
54    if epochs.is_empty() {
55        anyhow::bail!("preprocessing produced zero epochs");
56    }
57
58    let ch_names = epochs[0].1.clone();
59    let n_channels = ch_names.len();
60    let n_samples = epochs[0].0.ncols();
61
62    // Look up channel positions and vocab indices
63    let positions: Vec<f32> = ch_names.iter()
64        .flat_map(|name| {
65            bipolar_channel_xyz(name)
66                .unwrap_or([0.0, 0.0, 0.0])
67                .to_vec()
68        })
69        .collect();
70
71    let vocab_indices: Option<Vec<i64>> = {
72        let indices: Vec<Option<usize>> = ch_names.iter()
73            .map(|n| channel_vocab::channel_index(n))
74            .collect();
75        if indices.iter().all(|i| i.is_some()) {
76            Some(indices.iter().map(|i| i.unwrap() as i64).collect())
77        } else {
78            None
79        }
80    };
81
82    let n_epochs = epochs.len();
83    let mut batches = Vec::with_capacity(n_epochs);
84
85    for (epoch_data, _names) in &epochs {
86        let signal: Vec<f32> = epoch_data.iter().copied().collect();
87        batches.push(crate::data::build_batch::<B>(
88            signal,
89            positions.clone(),
90            vocab_indices.clone(),
91            n_channels,
92            n_samples,
93            device,
94        ));
95    }
96
97    let info = PreprocInfo {
98        ch_names,
99        n_channels,
100        n_epochs,
101        src_sfreq: 256.0, // after preprocessing
102        target_sfreq: 256.0,
103        epoch_dur: 5.0,
104    };
105
106    Ok((batches, info))
107}
108
109/// Load and preprocess an EDF file for LUNA inference.
110///
111/// Applies the full LUNA pipeline:
112/// 1. Read EDF → raw signal + channel names
113/// 2. Channel rename (strip "EEG ", "-REF", "-LE")
114/// 3. Pick standard 10-20 channels
115/// 4. Bandpass filter 0.1–75 Hz
116/// 5. Notch filter 60 Hz
117/// 6. Resample to 256 Hz
118/// 7. TCP bipolar montage (22 channels)
119/// 8. Epoch into 5s windows (1280 samples)
120///
121/// Returns `InputBatch`es ready for `LunaEncoder::run_batch()`.
122pub fn load_edf<B: Backend>(
123    path: &Path,
124    device: &B::Device,
125) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
126    load_edf_with_config(path, &exg_luna::LunaPipelineConfig::default(), device)
127}
128
129/// Load and preprocess an EDF file with custom pipeline config.
130pub fn load_edf_with_config<B: Backend>(
131    path: &Path,
132    cfg: &exg_luna::LunaPipelineConfig,
133    device: &B::Device,
134) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
135    // Read EDF
136    let raw = exg::edf::open_raw_edf(path)
137        .with_context(|| format!("opening EDF: {}", path.display()))?;
138    let data = raw.read_all_data()
139        .with_context(|| format!("reading EDF data: {}", path.display()))?;
140    let ch_names: Vec<String> = raw.channel_names();
141    let sfreq = raw.header.sample_rate;
142
143    // Run LUNA preprocessing pipeline
144    let epochs = exg_luna::preprocess_luna(data, &ch_names, sfreq, cfg)
145        .with_context(|| "LUNA preprocessing failed")?;
146
147    let mut info_result = epochs_to_batches(epochs, device)?;
148    info_result.1.src_sfreq = sfreq;
149    Ok(info_result)
150}
151
152/// Load and preprocess a FIF file for LUNA inference.
153///
154/// Same pipeline as [`load_edf`] but reads FIF format.
155pub fn load_fif<B: Backend>(
156    path: &Path,
157    device: &B::Device,
158) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
159    load_fif_with_config(path, &exg_luna::LunaPipelineConfig::default(), device)
160}
161
162/// Load and preprocess a FIF file with custom pipeline config.
163pub fn load_fif_with_config<B: Backend>(
164    path: &Path,
165    cfg: &exg_luna::LunaPipelineConfig,
166    device: &B::Device,
167) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
168    let raw = exg::fiff::raw::open_raw(path)
169        .with_context(|| format!("opening FIF: {}", path.display()))?;
170    let data = raw.read_all_data()
171        .with_context(|| format!("reading FIF data: {}", path.display()))?;
172    let ch_names: Vec<String> = raw.info.chs.iter().map(|ch| ch.name.clone()).collect();
173    let sfreq = raw.info.sfreq as f32;
174    let data_f32 = data.mapv(|v| v as f32);
175
176    let epochs = exg_luna::preprocess_luna(data_f32, &ch_names, sfreq, cfg)
177        .with_context(|| "LUNA preprocessing failed")?;
178
179    let mut info_result = epochs_to_batches(epochs, device)?;
180    info_result.1.src_sfreq = sfreq;
181    Ok(info_result)
182}
183
184/// Load preprocessed LUNA epochs from exg-luna safetensors format.
185///
186/// This reads files exported by `exg_luna::export_luna_epochs`.
187pub fn load_luna_epochs<B: Backend>(
188    path: &Path,
189    device: &B::Device,
190) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
191    let epochs_data = exg_luna::load_luna_epochs(path)
192        .with_context(|| format!("loading LUNA epochs: {}", path.display()))?;
193
194    // Convert exg_luna::LunaEpoch → (Array2, Vec<String>)
195    let epochs: Vec<(ndarray::Array2<f32>, Vec<String>)> = epochs_data.into_iter()
196        .map(|e| (e.signal, e.channel_names))
197        .collect();
198
199    epochs_to_batches(epochs, device)
200}
201
202/// Load EEG from a CSV file, preprocess, and produce InputBatches.
203///
204/// The CSV must have channel names in the header that match the
205/// standard 10-20 electrode names (possibly with "EEG " prefix and "-REF" suffix).
206/// Applies the full LUNA pipeline (bandpass, notch, resample, bipolar montage, epoch).
207pub fn load_csv_and_preprocess<B: Backend>(
208    path: &Path,
209    sample_rate: f32,
210    device: &B::Device,
211) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
212    let (data, ch_names, detected_sfreq) = exg::csv::read_eeg(path)
213        .with_context(|| format!("reading CSV: {}", path.display()))?;
214
215    let sfreq = if detected_sfreq > 0.0 { detected_sfreq } else { sample_rate };
216    let ch_strings: Vec<String> = ch_names;
217
218    let cfg = exg_luna::LunaPipelineConfig::default();
219    let epochs = exg_luna::preprocess_luna(data, &ch_strings, sfreq, &cfg)
220        .with_context(|| "LUNA preprocessing of CSV failed")?;
221
222    let mut info_result = epochs_to_batches(epochs, device)?;
223    info_result.1.src_sfreq = sfreq;
224    Ok(info_result)
225}