1use std::path::Path;
25
26use anyhow::{Context, Result};
27use burn::prelude::*;
28
29use crate::channel_positions::bipolar_channel_xyz;
30use crate::channel_vocab;
31use crate::data::InputBatch;
32
33pub struct PreprocInfo {
35 pub ch_names: Vec<String>,
37 pub n_channels: usize,
39 pub n_epochs: usize,
41 pub src_sfreq: f32,
43 pub target_sfreq: f32,
45 pub epoch_dur: f32,
47}
48
49fn epochs_to_batches<B: Backend>(
51 epochs: Vec<(ndarray::Array2<f32>, Vec<String>)>,
52 device: &B::Device,
53) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
54 if epochs.is_empty() {
55 anyhow::bail!("preprocessing produced zero epochs");
56 }
57
58 let ch_names = epochs[0].1.clone();
59 let n_channels = ch_names.len();
60 let n_samples = epochs[0].0.ncols();
61
62 let positions: Vec<f32> = ch_names.iter()
64 .flat_map(|name| {
65 bipolar_channel_xyz(name)
66 .unwrap_or([0.0, 0.0, 0.0])
67 .to_vec()
68 })
69 .collect();
70
71 let vocab_indices: Option<Vec<i64>> = {
72 let indices: Vec<Option<usize>> = ch_names.iter()
73 .map(|n| channel_vocab::channel_index(n))
74 .collect();
75 if indices.iter().all(|i| i.is_some()) {
76 Some(indices.iter().map(|i| i.unwrap() as i64).collect())
77 } else {
78 None
79 }
80 };
81
82 let n_epochs = epochs.len();
83 let mut batches = Vec::with_capacity(n_epochs);
84
85 for (epoch_data, _names) in &epochs {
86 let signal: Vec<f32> = epoch_data.iter().copied().collect();
87 batches.push(crate::data::build_batch::<B>(
88 signal,
89 positions.clone(),
90 vocab_indices.clone(),
91 n_channels,
92 n_samples,
93 device,
94 ));
95 }
96
97 let info = PreprocInfo {
98 ch_names,
99 n_channels,
100 n_epochs,
101 src_sfreq: 256.0, target_sfreq: 256.0,
103 epoch_dur: 5.0,
104 };
105
106 Ok((batches, info))
107}
108
109pub fn load_edf<B: Backend>(
123 path: &Path,
124 device: &B::Device,
125) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
126 load_edf_with_config(path, &exg_luna::LunaPipelineConfig::default(), device)
127}
128
129pub fn load_edf_with_config<B: Backend>(
131 path: &Path,
132 cfg: &exg_luna::LunaPipelineConfig,
133 device: &B::Device,
134) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
135 let raw = exg::edf::open_raw_edf(path)
137 .with_context(|| format!("opening EDF: {}", path.display()))?;
138 let data = raw.read_all_data()
139 .with_context(|| format!("reading EDF data: {}", path.display()))?;
140 let ch_names: Vec<String> = raw.channel_names();
141 let sfreq = raw.header.sample_rate;
142
143 let epochs = exg_luna::preprocess_luna(data, &ch_names, sfreq, cfg)
145 .with_context(|| "LUNA preprocessing failed")?;
146
147 let mut info_result = epochs_to_batches(epochs, device)?;
148 info_result.1.src_sfreq = sfreq;
149 Ok(info_result)
150}
151
152pub fn load_fif<B: Backend>(
156 path: &Path,
157 device: &B::Device,
158) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
159 load_fif_with_config(path, &exg_luna::LunaPipelineConfig::default(), device)
160}
161
162pub fn load_fif_with_config<B: Backend>(
164 path: &Path,
165 cfg: &exg_luna::LunaPipelineConfig,
166 device: &B::Device,
167) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
168 let raw = exg::fiff::raw::open_raw(path)
169 .with_context(|| format!("opening FIF: {}", path.display()))?;
170 let data = raw.read_all_data()
171 .with_context(|| format!("reading FIF data: {}", path.display()))?;
172 let ch_names: Vec<String> = raw.info.chs.iter().map(|ch| ch.name.clone()).collect();
173 let sfreq = raw.info.sfreq as f32;
174 let data_f32 = data.mapv(|v| v as f32);
175
176 let epochs = exg_luna::preprocess_luna(data_f32, &ch_names, sfreq, cfg)
177 .with_context(|| "LUNA preprocessing failed")?;
178
179 let mut info_result = epochs_to_batches(epochs, device)?;
180 info_result.1.src_sfreq = sfreq;
181 Ok(info_result)
182}
183
184pub fn load_luna_epochs<B: Backend>(
188 path: &Path,
189 device: &B::Device,
190) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
191 let epochs_data = exg_luna::load_luna_epochs(path)
192 .with_context(|| format!("loading LUNA epochs: {}", path.display()))?;
193
194 let epochs: Vec<(ndarray::Array2<f32>, Vec<String>)> = epochs_data.into_iter()
196 .map(|e| (e.signal, e.channel_names))
197 .collect();
198
199 epochs_to_batches(epochs, device)
200}
201
202pub fn load_csv_and_preprocess<B: Backend>(
208 path: &Path,
209 sample_rate: f32,
210 device: &B::Device,
211) -> Result<(Vec<InputBatch<B>>, PreprocInfo)> {
212 let (data, ch_names, detected_sfreq) = exg::csv::read_eeg(path)
213 .with_context(|| format!("reading CSV: {}", path.display()))?;
214
215 let sfreq = if detected_sfreq > 0.0 { detected_sfreq } else { sample_rate };
216 let ch_strings: Vec<String> = ch_names;
217
218 let cfg = exg_luna::LunaPipelineConfig::default();
219 let epochs = exg_luna::preprocess_luna(data, &ch_strings, sfreq, &cfg)
220 .with_context(|| "LUNA preprocessing of CSV failed")?;
221
222 let mut info_result = epochs_to_batches(epochs, device)?;
223 info_result.1.src_sfreq = sfreq;
224 Ok(info_result)
225}