Skip to main content

bids_eeg/
data.rs

1//! EEG signal data reading: EDF, BDF, and BrainVision formats.
2//!
3//! Reads raw signal data from all BIDS-EEG formats into [`EegData`], a
4//! channel × samples matrix with physical-unit values, annotations, and
5//! stimulus channel detection. Implements the [`TimeSeries`](bids_core::timeseries::TimeSeries)
6//! trait for modality-agnostic processing.
7
8use bids_core::error::{BidsError, Result};
9use std::io::{BufReader, Read, Seek, SeekFrom};
10use std::path::Path;
11
12// ─── Annotation ────────────────────────────────────────────────────────────────
13
14/// A time-stamped annotation from EDF+, BDF+, or BrainVision marker files.
15///
16/// Corresponds to MNE's `raw.annotations`. EDF+ TAL (Time-stamped Annotation
17/// Lists) entries and BrainVision `.vmrk` markers are both parsed into this type.
18#[derive(Debug, Clone, PartialEq)]
19pub struct Annotation {
20    /// Onset time in seconds from the start of the recording.
21    pub onset: f64,
22    /// Duration in seconds (0.0 if instantaneous).
23    pub duration: f64,
24    /// Description / label of the annotation.
25    pub description: String,
26}
27
28// ─── EegData ───────────────────────────────────────────────────────────────────
29
30/// Raw EEG signal data read from a data file.
31///
32/// Stores multichannel time-series data as a channel × samples matrix,
33/// where each inner `Vec<f64>` represents one channel's signal in physical units.
34/// Also carries annotations parsed from the file (EDF+ TAL, BDF+ status, or
35/// BrainVision markers).
36#[derive(Clone)]
37pub struct EegData {
38    /// Channel labels in order.
39    pub channel_labels: Vec<String>,
40    /// Signal data: one `Vec<f64>` per channel, all in physical units (e.g., µV).
41    pub data: Vec<Vec<f64>>,
42    /// Sampling rate per channel in Hz.
43    pub sampling_rates: Vec<f64>,
44    /// Total duration in seconds.
45    pub duration: f64,
46    /// Annotations parsed from the data file (EDF+ TAL, BDF+ status, .vmrk markers).
47    pub annotations: Vec<Annotation>,
48    /// Indices of channels detected as stimulus/trigger channels.
49    pub stim_channel_indices: Vec<usize>,
50    /// Whether this recording is from an EDF+D/BDF+D discontinuous file.
51    /// If true, there may be gaps in the time axis — use `record_onsets`
52    /// to reconstruct the true timeline.
53    pub is_discontinuous: bool,
54    /// Actual onset time of each data record in seconds (from EDF+ TAL).
55    /// For continuous recordings this is empty or `[0, dur, 2*dur, ...]`.
56    /// For discontinuous (EDF+D), these give the real timestamps of each record,
57    /// which may have gaps.
58    pub record_onsets: Vec<f64>,
59}
60
61impl std::fmt::Debug for EegData {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("EegData")
64            .field("n_channels", &self.data.len())
65            .field(
66                "n_samples",
67                &self.data.first().map_or(0, std::vec::Vec::len),
68            )
69            .field("channel_labels", &self.channel_labels)
70            .field("sampling_rates", &self.sampling_rates)
71            .field("duration", &self.duration)
72            .field("annotations", &self.annotations.len())
73            .field("is_discontinuous", &self.is_discontinuous)
74            .finish()
75    }
76}
77
78impl std::fmt::Display for EegData {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        let sr = self.sampling_rates.first().copied().unwrap_or(0.0);
81        write!(
82            f,
83            "EegData({} ch × {} samples @ {:.0} Hz, {:.1}s)",
84            self.data.len(),
85            self.data.first().map_or(0, std::vec::Vec::len),
86            sr,
87            self.duration
88        )
89    }
90}
91
92impl EegData {
93    /// Number of channels.
94    #[inline]
95    pub fn n_channels(&self) -> usize {
96        self.data.len()
97    }
98
99    /// Number of samples for the given channel index.
100    #[inline]
101    pub fn n_samples(&self, channel: usize) -> usize {
102        self.data.get(channel).map_or(0, std::vec::Vec::len)
103    }
104
105    /// Get a single channel's data by index.
106    #[inline]
107    pub fn channel(&self, index: usize) -> Option<&[f64]> {
108        self.data.get(index).map(std::vec::Vec::as_slice)
109    }
110
111    /// Get a single channel's data by label.
112    pub fn channel_by_name(&self, name: &str) -> Option<&[f64]> {
113        let idx = self.channel_labels.iter().position(|l| l == name)?;
114        self.channel(idx)
115    }
116
117    /// Generate a times array in seconds for the given channel (like MNE's `raw.times`).
118    ///
119    /// Returns one f64 per sample: `[0.0, 1/sr, 2/sr, ...]`.
120    pub fn times(&self, channel: usize) -> Option<Vec<f64>> {
121        let n = self.n_samples(channel);
122        if n == 0 {
123            return None;
124        }
125        let sr = self.sampling_rates.get(channel).copied().unwrap_or(1.0);
126        Some((0..n).map(|i| i as f64 / sr).collect())
127    }
128
129    /// Get the data together with the times array (like MNE's `get_data(return_times=True)`).
130    pub fn get_data_with_times(&self) -> (Vec<Vec<f64>>, Vec<f64>) {
131        let times = self.times(0).unwrap_or_default();
132        (self.data.clone(), times)
133    }
134
135    /// Select a subset of channels by name (include), returning a new `EegData`.
136    #[must_use]
137    pub fn select_channels(&self, names: &[&str]) -> EegData {
138        let mut labels = Vec::with_capacity(names.len());
139        let mut data = Vec::with_capacity(names.len());
140        let mut rates = Vec::with_capacity(names.len());
141        let mut stim = Vec::new();
142        for name in names {
143            if let Some(idx) = self.channel_labels.iter().position(|l| l == *name) {
144                if self.stim_channel_indices.contains(&idx) {
145                    stim.push(labels.len());
146                }
147                labels.push(self.channel_labels[idx].clone());
148                data.push(self.data[idx].clone());
149                rates.push(self.sampling_rates[idx]);
150            }
151        }
152        EegData {
153            channel_labels: labels,
154            data,
155            sampling_rates: rates,
156            duration: self.duration,
157            annotations: self.annotations.clone(),
158            stim_channel_indices: stim,
159            is_discontinuous: false,
160            record_onsets: Vec::new(),
161        }
162    }
163
164    /// Exclude channels by name, returning a new `EegData` with all other channels.
165    #[must_use]
166    pub fn exclude_channels(&self, names: &[&str]) -> EegData {
167        let keep: Vec<&str> = self
168            .channel_labels
169            .iter()
170            .filter(|l| !names.contains(&l.as_str()))
171            .map(std::string::String::as_str)
172            .collect();
173        self.select_channels(&keep)
174    }
175
176    /// Extract a time window (in seconds) from all channels.
177    #[must_use]
178    pub fn time_slice(&self, start_sec: f64, end_sec: f64) -> EegData {
179        let mut data = Vec::with_capacity(self.data.len());
180        let mut rates = Vec::with_capacity(self.data.len());
181        for (i, ch_data) in self.data.iter().enumerate() {
182            let sr = self.sampling_rates[i];
183            let start_sample = ((start_sec * sr).round() as usize).min(ch_data.len());
184            let end_sample = ((end_sec * sr).round() as usize).min(ch_data.len());
185            data.push(ch_data[start_sample..end_sample].to_vec());
186            rates.push(sr);
187        }
188        // Filter annotations to the time window
189        let anns = self
190            .annotations
191            .iter()
192            .filter(|a| a.onset + a.duration >= start_sec && a.onset < end_sec)
193            .map(|a| Annotation {
194                onset: (a.onset - start_sec).max(0.0),
195                duration: a.duration,
196                description: a.description.clone(),
197            })
198            .collect();
199        EegData {
200            channel_labels: self.channel_labels.clone(),
201            data,
202            sampling_rates: rates,
203            duration: end_sec - start_sec,
204            annotations: anns,
205            stim_channel_indices: self.stim_channel_indices.clone(),
206            is_discontinuous: self.is_discontinuous,
207            record_onsets: self.record_onsets.clone(),
208        }
209    }
210
211    /// Convert channel data to different units by applying a scale factor.
212    ///
213    /// `unit_map` maps channel name → scale factor. For example, to convert
214    /// from µV to V: `{"EEG1": 1e-6}`.
215    pub fn convert_units(&mut self, unit_map: &std::collections::HashMap<String, f64>) {
216        for (i, label) in self.channel_labels.iter().enumerate() {
217            if let Some(&scale) = unit_map.get(label) {
218                for v in &mut self.data[i] {
219                    *v *= scale;
220                }
221            }
222        }
223    }
224
225    /// Select channels by type (like MNE's `pick_types`).
226    ///
227    /// `types` should be a list of `ChannelType` variants to keep. Requires
228    /// that `channel_types` is available (from channels.tsv).
229    #[must_use]
230    pub fn pick_types(
231        &self,
232        types: &[crate::ChannelType],
233        channel_types: &[crate::ChannelType],
234    ) -> EegData {
235        let names: Vec<&str> = self
236            .channel_labels
237            .iter()
238            .enumerate()
239            .filter(|(i, _)| channel_types.get(*i).is_some_and(|ct| types.contains(ct)))
240            .map(|(_, name)| name.as_str())
241            .collect();
242        self.select_channels(&names)
243    }
244
245    /// Concatenate another `EegData` in time (appending samples).
246    ///
247    /// Both must have the same channels in the same order. Annotations from
248    /// `other` are time-shifted by `self.duration`.
249    pub fn concatenate(&mut self, other: &EegData) -> std::result::Result<(), String> {
250        if self.channel_labels != other.channel_labels {
251            return Err("Channel labels must match for concatenation".into());
252        }
253        if self.data.len() != other.data.len() {
254            return Err("Channel count must match for concatenation".into());
255        }
256        let time_offset = self.duration;
257        for (i, ch) in self.data.iter_mut().enumerate() {
258            ch.extend_from_slice(&other.data[i]);
259        }
260        for ann in &other.annotations {
261            self.annotations.push(Annotation {
262                onset: ann.onset + time_offset,
263                duration: ann.duration,
264                description: ann.description.clone(),
265            });
266        }
267        self.duration += other.duration;
268        Ok(())
269    }
270
271    /// Remove (zero-out) data segments that overlap with annotations matching
272    /// a description pattern (like MNE's `reject_by_annotation`).
273    ///
274    /// Returns a new `EegData` where samples overlapping "BAD" annotations
275    /// (or annotations matching `pattern`) are replaced with `f64::NAN`.
276    #[must_use]
277    pub fn reject_by_annotation(&self, pattern: &str) -> EegData {
278        let mut new_data = self.data.clone();
279        let pattern_upper = pattern.to_uppercase();
280
281        for ann in &self.annotations {
282            if !ann.description.to_uppercase().contains(&pattern_upper) {
283                continue;
284            }
285            for (ch, ch_data) in new_data.iter_mut().enumerate() {
286                let sr = self.sampling_rates[ch];
287                let start = (ann.onset * sr).round() as usize;
288                let end = ((ann.onset + ann.duration) * sr).round() as usize;
289                let start = start.min(ch_data.len());
290                let end = end.min(ch_data.len());
291                for v in &mut ch_data[start..end] {
292                    *v = f64::NAN;
293                }
294            }
295        }
296
297        EegData {
298            channel_labels: self.channel_labels.clone(),
299            data: new_data,
300            sampling_rates: self.sampling_rates.clone(),
301            duration: self.duration,
302            annotations: self.annotations.clone(),
303            stim_channel_indices: self.stim_channel_indices.clone(),
304            is_discontinuous: self.is_discontinuous,
305            record_onsets: self.record_onsets.clone(),
306        }
307    }
308}
309
310// ─── MNE-inspired signal processing methods ────────────────────────────────────
311
312impl EegData {
313    /// Apply a bandpass filter to all channels (like MNE's `raw.filter(l_freq, h_freq)`).
314    ///
315    /// Uses a Butterworth IIR filter with zero-phase `filtfilt` application.
316    /// `l_freq` and `h_freq` are in Hz. If `l_freq` is `None`, applies lowpass only.
317    /// If `h_freq` is `None`, applies highpass only.
318    #[must_use]
319    pub fn filter(&self, l_freq: Option<f64>, h_freq: Option<f64>, order: usize) -> EegData {
320        let sr = self.sampling_rates.first().copied().unwrap_or(1.0);
321        let nyquist = sr / 2.0;
322        let mut new_data = self.data.clone();
323
324        for ch_data in &mut new_data {
325            let filtered = match (l_freq, h_freq) {
326                (Some(lo), Some(hi)) => {
327                    let (b, a) = bids_filter::butter_bandpass(
328                        order,
329                        (lo / nyquist).clamp(0.001, 0.999),
330                        (hi / nyquist).clamp(0.001, 0.999),
331                    );
332                    bids_filter::filtfilt(&b, &a, ch_data)
333                }
334                (Some(lo), None) => {
335                    let (b, a) =
336                        bids_filter::butter_highpass(order, (lo / nyquist).clamp(0.001, 0.999));
337                    bids_filter::filtfilt(&b, &a, ch_data)
338                }
339                (None, Some(hi)) => {
340                    let (b, a) =
341                        bids_filter::butter_lowpass(order, (hi / nyquist).clamp(0.001, 0.999));
342                    bids_filter::filtfilt(&b, &a, ch_data)
343                }
344                (None, None) => continue,
345            };
346            *ch_data = filtered;
347        }
348
349        EegData {
350            data: new_data,
351            ..self.clone()
352        }
353    }
354
355    /// Remove power line noise at `freq` Hz and its harmonics (like MNE's `raw.notch_filter()`).
356    #[must_use]
357    pub fn notch_filter(&self, freq: f64, quality: f64) -> EegData {
358        let sr = self.sampling_rates.first().copied().unwrap_or(1.0);
359        let mut new_data = self.data.clone();
360        let nyquist = sr / 2.0;
361
362        // Apply notch at fundamental and harmonics up to Nyquist
363        let mut f = freq;
364        while f < nyquist {
365            for ch in &mut new_data {
366                *ch = bids_filter::notch_filter(ch, f, sr, quality);
367            }
368            f += freq;
369        }
370
371        EegData {
372            data: new_data,
373            ..self.clone()
374        }
375    }
376
377    /// Resample all channels to a new sampling rate (like MNE's `raw.resample()`).
378    ///
379    /// Applies an anti-aliasing lowpass filter before downsampling.
380    #[must_use]
381    pub fn resample(&self, new_sr: f64) -> EegData {
382        let old_sr = self.sampling_rates.first().copied().unwrap_or(1.0);
383        let new_data: Vec<Vec<f64>> = self
384            .data
385            .iter()
386            .map(|ch| bids_filter::resample(ch, old_sr, new_sr))
387            .collect();
388        let new_duration = new_data.first().map_or(0.0, |ch| ch.len() as f64 / new_sr);
389
390        EegData {
391            data: new_data,
392            sampling_rates: vec![new_sr; self.channel_labels.len()],
393            duration: new_duration,
394            channel_labels: self.channel_labels.clone(),
395            annotations: self.annotations.clone(),
396            stim_channel_indices: self.stim_channel_indices.clone(),
397            is_discontinuous: self.is_discontinuous,
398            record_onsets: Vec::new(),
399        }
400    }
401
402    /// Re-reference to the average of all channels (like MNE's `raw.set_eeg_reference('average')`).
403    ///
404    /// Subtracts the mean across channels at each time point.
405    #[must_use]
406    pub fn set_average_reference(&self) -> EegData {
407        let n_ch = self.data.len();
408        let n_s = self.data.first().map_or(0, std::vec::Vec::len);
409        if n_ch == 0 || n_s == 0 {
410            return self.clone();
411        }
412
413        let mut new_data = self.data.clone();
414
415        for s in 0..n_s {
416            let mean: f64 = self.data.iter().map(|ch| ch[s]).sum::<f64>() / n_ch as f64;
417            for ch in &mut new_data {
418                ch[s] -= mean;
419            }
420        }
421
422        EegData {
423            data: new_data,
424            ..self.clone()
425        }
426    }
427
428    /// Re-reference to a specific channel (like MNE's `raw.set_eeg_reference([ch_name])`).
429    #[must_use]
430    pub fn set_reference(&self, ref_channel: &str) -> EegData {
431        let ref_idx = match self.channel_labels.iter().position(|l| l == ref_channel) {
432            Some(i) => i,
433            None => return self.clone(),
434        };
435        let ref_data = self.data[ref_idx].clone();
436        let mut new_data = self.data.clone();
437
438        for (i, ch) in new_data.iter_mut().enumerate() {
439            if i != ref_idx {
440                for (s, v) in ch.iter_mut().enumerate() {
441                    *v -= ref_data[s];
442                }
443            }
444        }
445
446        EegData {
447            data: new_data,
448            ..self.clone()
449        }
450    }
451
452    /// Extract epochs around events (like MNE's `mne.Epochs(raw, events, tmin, tmax)`).
453    ///
454    /// Returns a Vec of `EegData`, one per event matching `event_desc`.
455    /// `tmin` and `tmax` are relative to event onset in seconds.
456    /// If `event_desc` is `None`, epochs around all annotations.
457    pub fn epoch(&self, tmin: f64, tmax: f64, event_desc: Option<&str>) -> Vec<EegData> {
458        let sr = self.sampling_rates.first().copied().unwrap_or(1.0);
459        let events: Vec<&Annotation> = self
460            .annotations
461            .iter()
462            .filter(|a| event_desc.is_none_or(|d| a.description == d))
463            .collect();
464
465        let n_before = ((-tmin) * sr).round() as usize;
466        let n_after = (tmax * sr).round() as usize;
467        let epoch_len = n_before + n_after;
468
469        let mut epochs = Vec::with_capacity(events.len());
470        for event in &events {
471            let center = (event.onset * sr).round() as isize;
472            let start = center - n_before as isize;
473            let end = center + n_after as isize;
474
475            // Skip if epoch would go out of bounds
476            let n_samples = self.data.first().map_or(0, std::vec::Vec::len) as isize;
477            if start < 0 || end > n_samples {
478                continue;
479            }
480
481            let start = start as usize;
482            let data: Vec<Vec<f64>> = self
483                .data
484                .iter()
485                .map(|ch| ch[start..start + epoch_len].to_vec())
486                .collect();
487
488            epochs.push(EegData {
489                channel_labels: self.channel_labels.clone(),
490                data,
491                sampling_rates: self.sampling_rates.clone(),
492                duration: epoch_len as f64 / sr,
493                annotations: vec![Annotation {
494                    onset: -tmin,
495                    duration: event.duration,
496                    description: event.description.clone(),
497                }],
498                stim_channel_indices: self.stim_channel_indices.clone(),
499                is_discontinuous: false,
500                record_onsets: Vec::new(),
501            });
502        }
503        epochs
504    }
505
506    /// Average a list of epochs to compute an ERP (Event-Related Potential).
507    ///
508    /// All epochs must have the same shape. Like MNE's `epochs.average()`.
509    pub fn average_epochs(epochs: &[EegData]) -> Option<EegData> {
510        if epochs.is_empty() {
511            return None;
512        }
513        let n_ch = epochs[0].data.len();
514        let n_s = epochs[0].data.first().map_or(0, std::vec::Vec::len);
515        let n_epochs = epochs.len() as f64;
516
517        let mut avg_data = vec![vec![0.0; n_s]; n_ch];
518        for epoch in epochs {
519            for (ch, ch_data) in epoch.data.iter().enumerate() {
520                for (s, &v) in ch_data.iter().enumerate() {
521                    if ch < n_ch && s < n_s {
522                        avg_data[ch][s] += v;
523                    }
524                }
525            }
526        }
527        for ch in &mut avg_data {
528            for v in ch.iter_mut() {
529                *v /= n_epochs;
530            }
531        }
532
533        Some(EegData {
534            channel_labels: epochs[0].channel_labels.clone(),
535            data: avg_data,
536            sampling_rates: epochs[0].sampling_rates.clone(),
537            duration: epochs[0].duration,
538            annotations: Vec::new(),
539            stim_channel_indices: epochs[0].stim_channel_indices.clone(),
540            is_discontinuous: false,
541            record_onsets: Vec::new(),
542        })
543    }
544
545    /// Compute power spectral density using Welch's method (like MNE's `raw.compute_psd()`).
546    ///
547    /// Returns `(frequencies, psd)` where `psd[ch]` is the power spectrum for each channel.
548    /// `n_fft` is the FFT window size (default: sampling rate = 1 Hz resolution).
549    pub fn compute_psd(&self, n_fft: Option<usize>) -> (Vec<f64>, Vec<Vec<f64>>) {
550        let sr = self.sampling_rates.first().copied().unwrap_or(1.0);
551        let n_fft = n_fft.unwrap_or(sr as usize);
552        let n_freqs = n_fft / 2 + 1;
553
554        let freqs: Vec<f64> = (0..n_freqs).map(|i| i as f64 * sr / n_fft as f64).collect();
555
556        let psd: Vec<Vec<f64>> = self
557            .data
558            .iter()
559            .map(|ch| welch_psd(ch, n_fft, sr))
560            .collect();
561
562        (freqs, psd)
563    }
564}
565
566/// Simple Welch PSD estimate using the periodogram method.
567///
568/// Splits the signal into overlapping segments, computes the squared magnitude
569/// of the DFT for each, and averages. Uses a Hann window.
570fn welch_psd(x: &[f64], n_fft: usize, _fs: f64) -> Vec<f64> {
571    let n_freqs = n_fft / 2 + 1;
572    if x.len() < n_fft {
573        return vec![0.0; n_freqs];
574    }
575
576    let hop = n_fft / 2; // 50% overlap
577    let n_segments = (x.len() - n_fft) / hop + 1;
578    let mut psd = vec![0.0; n_freqs];
579
580    // Hann window
581    let window: Vec<f64> = (0..n_fft)
582        .map(|i| 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n_fft - 1) as f64).cos()))
583        .collect();
584    let window_power: f64 = window.iter().map(|w| w * w).sum::<f64>();
585
586    for seg in 0..n_segments {
587        let start = seg * hop;
588        // Apply window and compute DFT magnitudes via Goertzel / direct DFT
589        for (k, psd_bin) in psd.iter_mut().enumerate() {
590            let freq = 2.0 * std::f64::consts::PI * k as f64 / n_fft as f64;
591            let mut re = 0.0;
592            let mut im = 0.0;
593            for n in 0..n_fft {
594                let windowed = x[start + n] * window[n];
595                re += windowed * (freq * n as f64).cos();
596                im -= windowed * (freq * n as f64).sin();
597            }
598            *psd_bin += (re * re + im * im) / window_power;
599        }
600    }
601
602    // Average over segments and scale
603    let scale = 1.0 / n_segments as f64;
604    for v in &mut psd {
605        *v *= scale;
606    }
607
608    psd
609}
610
611#[cfg(feature = "safetensors")]
612impl EegData {
613    /// Save signal data as a safetensors file.
614    ///
615    /// The file contains one tensor "data" with shape [n_channels, n_samples]
616    /// and dtype f64, plus metadata with channel names and sampling rate.
617    pub fn save_safetensors(&self, path: &std::path::Path) -> std::io::Result<()> {
618        use safetensors::tensor::{Dtype, TensorView};
619        use std::collections::HashMap;
620
621        let n_ch = self.n_channels();
622        let n_s = self.n_samples(0);
623
624        // Flatten channels × samples into a contiguous buffer
625        let mut flat = Vec::with_capacity(n_ch * n_s);
626        for ch in &self.data {
627            flat.extend_from_slice(ch);
628        }
629
630        let flat_bytes: &[u8] = bytemuck::cast_slice(&flat);
631
632        let tensor = TensorView::new(Dtype::F64, vec![n_ch, n_s], flat_bytes)
633            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
634
635        let mut tensors = HashMap::new();
636        tensors.insert("data".to_string(), tensor);
637
638        // Store metadata as JSON in the safetensors header
639        let mut metadata = HashMap::new();
640        metadata.insert(
641            "channel_names".to_string(),
642            serde_json::to_string(&self.channel_labels).unwrap_or_default(),
643        );
644        metadata.insert(
645            "sampling_rate".to_string(),
646            self.sampling_rates
647                .first()
648                .map(|r| r.to_string())
649                .unwrap_or_default(),
650        );
651        metadata.insert("duration".to_string(), self.duration.to_string());
652        metadata.insert("n_channels".to_string(), n_ch.to_string());
653        metadata.insert("n_samples".to_string(), n_s.to_string());
654
655        let bytes = safetensors::tensor::serialize(&tensors, &Some(metadata))
656            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
657
658        std::fs::write(path, bytes)
659    }
660}
661
662#[cfg(feature = "ndarray")]
663impl EegData {
664    /// Convert to an ndarray Array2<f64> with shape (n_channels, n_samples).
665    pub fn to_ndarray(&self) -> ndarray::Array2<f64> {
666        let n_ch = self.n_channels();
667        let n_s = self.n_samples(0);
668        let mut arr = ndarray::Array2::zeros((n_ch, n_s));
669        for (i, ch) in self.data.iter().enumerate() {
670            for (j, &v) in ch.iter().enumerate() {
671                arr[(i, j)] = v;
672            }
673        }
674        arr
675    }
676
677    /// Create from an ndarray Array2<f64> (n_channels × n_samples).
678    pub fn from_ndarray(
679        arr: &ndarray::Array2<f64>,
680        channel_labels: Vec<String>,
681        sampling_rate: f64,
682    ) -> Self {
683        let n_ch = arr.nrows();
684        let n_s = arr.ncols();
685        let data: Vec<Vec<f64>> = (0..n_ch).map(|i| arr.row(i).to_vec()).collect();
686        Self {
687            channel_labels,
688            data,
689            sampling_rates: vec![sampling_rate; n_ch],
690            duration: n_s as f64 / sampling_rate,
691            annotations: Vec::new(),
692            stim_channel_indices: Vec::new(),
693            is_discontinuous: false,
694            record_onsets: Vec::new(),
695        }
696    }
697}
698
699impl bids_core::timeseries::TimeSeries for EegData {
700    fn n_channels(&self) -> usize {
701        self.data.len()
702    }
703    fn n_samples(&self) -> usize {
704        self.data.first().map_or(0, std::vec::Vec::len)
705    }
706    fn channel_names(&self) -> &[String] {
707        &self.channel_labels
708    }
709    fn sampling_rate(&self) -> f64 {
710        self.sampling_rates.first().copied().unwrap_or(1.0)
711    }
712    fn channel_data(&self, index: usize) -> Option<&[f64]> {
713        self.data.get(index).map(std::vec::Vec::as_slice)
714    }
715    fn duration(&self) -> f64 {
716        self.duration
717    }
718}
719
720// ─── ReadOptions ───────────────────────────────────────────────────────────────
721
722/// Options for reading EEG data files.
723#[derive(Debug, Clone, Default)]
724pub struct ReadOptions {
725    /// If set, only read these channels (by name). Corresponds to MNE's `include`.
726    pub channels: Option<Vec<String>>,
727    /// If set, exclude these channels by name. Applied after `channels`.
728    /// Corresponds to MNE's `exclude` parameter.
729    pub exclude: Option<Vec<String>>,
730    /// If set, only read starting from this time in seconds.
731    pub start_time: Option<f64>,
732    /// If set, only read up to this time in seconds.
733    pub end_time: Option<f64>,
734    /// Override stim channel detection. If `Some`, these channel names are
735    /// treated as stimulus channels. If `None`, auto-detection is used
736    /// (channels named "Status", "Trigger", "STI", case-insensitive).
737    pub stim_channel: Option<Vec<String>>,
738}
739
740impl ReadOptions {
741    pub fn new() -> Self {
742        Self::default()
743    }
744
745    /// Include only these channels (like MNE's `include` / `picks`).
746    pub fn with_channels(mut self, channels: Vec<String>) -> Self {
747        self.channels = Some(channels);
748        self
749    }
750
751    /// Exclude these channels (like MNE's `exclude`).
752    pub fn with_exclude(mut self, exclude: Vec<String>) -> Self {
753        self.exclude = Some(exclude);
754        self
755    }
756
757    pub fn with_time_range(mut self, start: f64, end: f64) -> Self {
758        self.start_time = Some(start);
759        self.end_time = Some(end);
760        self
761    }
762
763    /// Override stim channel names. Pass empty vec to disable auto-detection.
764    pub fn with_stim_channel(mut self, names: Vec<String>) -> Self {
765        self.stim_channel = Some(names);
766        self
767    }
768}
769
770// ─── EDF / BDF Reader ─────────────────────────────────────────────────────────
771
772/// Read EEG data from an EDF (European Data Format) or BDF (BioSemi) file.
773///
774/// Parses the header to determine channel layout, then reads the raw data records
775/// and converts digital values to physical units using the calibration parameters.
776///
777/// Handles EDF files where `n_records` is `-1` (unknown) by computing the
778/// record count from the file size.
779///
780/// Performance notes:
781/// - Uses buffered I/O and bulk reads (single `read_exact` for all data)
782/// - Pre-computes a channel→output lookup table (no hash or linear scan in hot loop)
783/// - Channel-major decode loop for sequential output writes (cache-friendly)
784/// - Branchless format dispatch: EDF (i16) and BDF (i24) use separate decode paths
785pub fn read_edf(path: &Path, opts: &ReadOptions) -> Result<EegData> {
786    let file = std::fs::File::open(path)?;
787    let mut reader = BufReader::with_capacity(256 * 1024, file);
788
789    let mut hdr_buf = [0u8; 256];
790    reader.read_exact(&mut hdr_buf)?;
791
792    let is_bdf = hdr_buf[0] == 0xFF;
793    let bytes_per_sample: usize = if is_bdf { 3 } else { 2 };
794
795    let n_channels: usize = parse_header_int(&hdr_buf[252..256])?;
796    let n_records: i64 = parse_header_int(&hdr_buf[236..244])?;
797    let record_duration: f64 = parse_header_f64(&hdr_buf[244..252])?;
798
799    if n_channels == 0 {
800        return Err(BidsError::DataFormat("EDF/BDF file has 0 channels".into()));
801    }
802    if n_channels > 10_000 {
803        return Err(BidsError::DataFormat(format!(
804            "EDF/BDF file claims {n_channels} channels — likely corrupt header"
805        )));
806    }
807
808    // Detect EDF+ / BDF+ from the reserved field
809    let reserved = String::from_utf8_lossy(&hdr_buf[192..236])
810        .trim()
811        .to_string();
812    let is_edf_plus = reserved.starts_with("EDF+");
813    let is_edf_plus_discontinuous = reserved.contains("EDF+D") || reserved.contains("BDF+D");
814
815    // Read extended header
816    let ext_size = n_channels * 256;
817    let mut ext = vec![0u8; ext_size];
818    reader.read_exact(&mut ext)?;
819
820    // Parse per-channel fields
821    let mut labels = Vec::with_capacity(n_channels);
822    let mut phys_min = Vec::with_capacity(n_channels);
823    let mut phys_max = Vec::with_capacity(n_channels);
824    let mut dig_min = Vec::with_capacity(n_channels);
825    let mut dig_max = Vec::with_capacity(n_channels);
826    let mut samples_per_record = Vec::with_capacity(n_channels);
827
828    for i in 0..n_channels {
829        labels.push(read_field(&ext, i, 16, 0));
830        phys_min.push(read_field_f64(&ext, i, 8, n_channels * 104)?);
831        phys_max.push(read_field_f64(&ext, i, 8, n_channels * 112)?);
832        dig_min.push(read_field_f64(&ext, i, 8, n_channels * 120)?);
833        dig_max.push(read_field_f64(&ext, i, 8, n_channels * 128)?);
834        samples_per_record.push(read_field_int(&ext, i, 8, n_channels * 216)?);
835    }
836
837    let sampling_rates: Vec<f64> = samples_per_record
838        .iter()
839        .map(|&s| {
840            if record_duration > 0.0 {
841                s as f64 / record_duration
842            } else {
843                s as f64
844            }
845        })
846        .collect();
847
848    // Identify EDF+ annotation channels and BDF status channels
849    let annotation_channel_indices: Vec<usize> = labels
850        .iter()
851        .enumerate()
852        .filter(|(_, l)| l.as_str() == "EDF Annotations" || l.as_str() == "BDF Status")
853        .map(|(i, _)| i)
854        .collect();
855
856    // Detect stim channels: user-specified, or auto-detect by name
857    let stim_names: Vec<String> = if let Some(ref names) = opts.stim_channel {
858        names.clone()
859    } else {
860        // Auto-detect: channels named "status", "trigger", "sti *" (case insensitive)
861        labels
862            .iter()
863            .filter(|l| {
864                let lower = l.to_lowercase();
865                lower == "status"
866                    || lower == "trigger"
867                    || lower.starts_with("sti ")
868                    || lower.starts_with("sti\t")
869            })
870            .cloned()
871            .collect()
872    };
873
874    // Determine which channels to read (include then exclude)
875    let mut channel_indices: Vec<usize> = if let Some(ref wanted) = opts.channels {
876        wanted
877            .iter()
878            .filter_map(|name| labels.iter().position(|l| l == name))
879            .collect()
880    } else {
881        // By default, exclude annotation channels from signal data
882        (0..n_channels)
883            .filter(|i| !annotation_channel_indices.contains(i))
884            .collect()
885    };
886
887    // Apply exclude
888    if let Some(ref excl) = opts.exclude {
889        let excl_indices: Vec<usize> = excl
890            .iter()
891            .filter_map(|name| labels.iter().position(|l| l == name))
892            .collect();
893        channel_indices.retain(|i| !excl_indices.contains(i));
894    }
895
896    // Map stim channel names to output indices
897    let stim_channel_indices: Vec<usize> = channel_indices
898        .iter()
899        .enumerate()
900        .filter(|(_, ch)| stim_names.iter().any(|n| n == &labels[**ch]))
901        .map(|(out_idx, _)| out_idx)
902        .collect();
903
904    // Pre-compute channel→output-index lookup table (O(1) per channel in hot loop)
905    let mut ch_to_out: Vec<usize> = vec![usize::MAX; n_channels];
906    for (out_idx, &ch) in channel_indices.iter().enumerate() {
907        ch_to_out[ch] = out_idx;
908    }
909
910    // Also need to read annotation channels even if excluded from output
911    let need_annotation_read = is_edf_plus
912        || is_bdf
913        || annotation_channel_indices
914            .iter()
915            .any(|i| labels[*i] == "BDF Status");
916    let mut ann_ch_to_out: Vec<usize> = vec![usize::MAX; n_channels];
917    if need_annotation_read {
918        for (ann_idx, &ch) in annotation_channel_indices.iter().enumerate() {
919            ann_ch_to_out[ch] = ann_idx;
920        }
921    }
922
923    // EDF spec: n_records == -1 means "unknown" — compute from file size.
924    let n_records = if n_records < 0 {
925        let record_bytes: usize = samples_per_record
926            .iter()
927            .map(|&s| s * bytes_per_sample)
928            .sum();
929        let file_len = reader.seek(SeekFrom::End(0))? as usize;
930        let header_size = 256 + ext_size;
931        if record_bytes > 0 && file_len > header_size {
932            ((file_len - header_size) / record_bytes) as i64
933        } else {
934            0
935        }
936    } else {
937        n_records
938    };
939    let total_duration = n_records as f64 * record_duration;
940
941    let start_record = opts
942        .start_time
943        .map(|t| ((t / record_duration).floor() as i64).clamp(0, n_records) as usize)
944        .unwrap_or(0);
945    let end_record = opts
946        .end_time
947        .map(|t| ((t / record_duration).ceil() as i64).clamp(0, n_records) as usize)
948        .unwrap_or(n_records as usize);
949
950    if start_record >= end_record {
951        return Ok(EegData {
952            channel_labels: channel_indices.iter().map(|&i| labels[i].clone()).collect(),
953            data: vec![Vec::new(); channel_indices.len()],
954            sampling_rates: channel_indices.iter().map(|&i| sampling_rates[i]).collect(),
955            duration: 0.0,
956            annotations: Vec::new(),
957            stim_channel_indices: stim_channel_indices.clone(),
958            is_discontinuous: false,
959            record_onsets: Vec::new(),
960        });
961    }
962
963    let records_to_read = end_record - start_record;
964
965    // Calculate record size in bytes and per-channel byte offsets within a record
966    let mut ch_byte_offsets = Vec::with_capacity(n_channels);
967    let mut offset = 0usize;
968    for (ch, &spr) in samples_per_record.iter().enumerate() {
969        ch_byte_offsets.push(offset);
970        let ch_bytes = spr.checked_mul(bytes_per_sample).ok_or_else(|| {
971            BidsError::DataFormat(format!(
972                "Channel {ch} samples_per_record ({spr}) overflows byte calculation"
973            ))
974        })?;
975        offset = offset
976            .checked_add(ch_bytes)
977            .ok_or_else(|| BidsError::DataFormat("Record size overflow".into()))?;
978    }
979    let record_byte_size = offset;
980
981    // Pre-allocate output with exact sizes
982    let mut out_data: Vec<Vec<f64>> = channel_indices
983        .iter()
984        .map(|&i| vec![0.0f64; samples_per_record[i] as usize * records_to_read])
985        .collect();
986
987    // Pre-compute gain and offset for each channel: physical = digital * gain + cal_offset
988    // where gain = (phys_max - phys_min) / (dig_max - dig_min)
989    //       cal_offset = phys_min - dig_min * gain
990    let mut gains = vec![0.0f64; n_channels];
991    let mut cal_offsets = vec![0.0f64; n_channels];
992    for i in 0..n_channels {
993        let dd = dig_max[i] - dig_min[i];
994        let pd = phys_max[i] - phys_min[i];
995        let g = if dd.abs() > f64::EPSILON {
996            pd / dd
997        } else {
998            1.0
999        };
1000        gains[i] = g;
1001        cal_offsets[i] = phys_min[i] - dig_min[i] * g;
1002    }
1003
1004    // Seek to start record — bulk read all needed data at once
1005    let data_start = 256 + ext_size + start_record * record_byte_size;
1006    reader.seek(SeekFrom::Start(data_start as u64))?;
1007
1008    let total_data_bytes = records_to_read * record_byte_size;
1009    let mut all_data = vec![0u8; total_data_bytes];
1010    reader.read_exact(&mut all_data)?;
1011
1012    // Decode — split by format to avoid branch in inner loop
1013    let params = DecodeParams {
1014        all_data: &all_data,
1015        records_to_read,
1016        n_channels,
1017        samples_per_record: &samples_per_record,
1018        ch_byte_offsets: &ch_byte_offsets,
1019        ch_to_out: &ch_to_out,
1020        gains: &gains,
1021        cal_offsets: &cal_offsets,
1022        record_byte_size,
1023    };
1024    if is_bdf {
1025        decode_records_bdf(&params, &mut out_data);
1026    } else {
1027        decode_records_edf(&params, &mut out_data);
1028    }
1029
1030    // Extract annotation channel raw bytes (if EDF+/BDF+)
1031    let mut annotations = Vec::new();
1032    let mut record_onsets: Vec<f64> = Vec::new(); // actual onset time of each record (for EDF+D)
1033    if need_annotation_read && !annotation_channel_indices.is_empty() {
1034        for &ann_ch in &annotation_channel_indices {
1035            let label = &labels[ann_ch];
1036            if label == "EDF Annotations" {
1037                // Parse TAL (Time-stamped Annotation Lists) from each record
1038                let spr = samples_per_record[ann_ch];
1039                let ch_bytes = spr * bytes_per_sample;
1040                for rec in 0..records_to_read {
1041                    let rec_base = rec * record_byte_size;
1042                    let start = rec_base + ch_byte_offsets[ann_ch];
1043                    let end = start + ch_bytes;
1044                    if end <= all_data.len() {
1045                        let tal_bytes = &all_data[start..end];
1046                        if let Some(onset) = parse_edf_tal(tal_bytes, &mut annotations) {
1047                            record_onsets.push(onset);
1048                        }
1049                    }
1050                }
1051            } else if label == "BDF Status" {
1052                // BDF status channel: extract trigger events from the 24-bit status word.
1053                // prev_val must persist across record boundaries to avoid duplicate
1054                // events when a trigger is held high across the boundary.
1055                let spr = samples_per_record[ann_ch];
1056                let mut prev_val: i32 = 0;
1057                for rec in 0..records_to_read {
1058                    let rec_base = rec * record_byte_size;
1059                    let src = &all_data[rec_base + ch_byte_offsets[ann_ch]..];
1060                    let t_base = (start_record + rec) as f64 * record_duration;
1061                    for s in 0..spr {
1062                        let off = s * 3;
1063                        let b0 = src[off] as u32;
1064                        let b1 = src[off + 1] as u32;
1065                        let b2 = src[off + 2] as u32;
1066                        let raw_val = (b0 | (b1 << 8) | (b2 << 16)) as i32;
1067                        // Lower 16 bits are the trigger value in BDF
1068                        let trigger = raw_val & 0xFFFF;
1069                        if trigger != 0 && trigger != prev_val {
1070                            let onset = t_base + s as f64 / (spr as f64 / record_duration);
1071                            annotations.push(Annotation {
1072                                onset,
1073                                duration: 0.0,
1074                                description: format!("{trigger}"),
1075                            });
1076                        }
1077                        prev_val = trigger;
1078                    }
1079                }
1080            }
1081        }
1082    }
1083
1084    // Sub-record precision trimming
1085    trim_time_range(
1086        &mut out_data,
1087        &channel_indices,
1088        &sampling_rates,
1089        record_duration,
1090        start_record,
1091        opts,
1092    );
1093
1094    let actual_duration = (opts.end_time.unwrap_or(total_duration)
1095        - opts.start_time.unwrap_or(0.0))
1096    .min(total_duration);
1097
1098    Ok(EegData {
1099        channel_labels: channel_indices.iter().map(|&i| labels[i].clone()).collect(),
1100        data: out_data,
1101        sampling_rates: channel_indices.iter().map(|&i| sampling_rates[i]).collect(),
1102        duration: actual_duration,
1103        annotations,
1104        stim_channel_indices,
1105        is_discontinuous: is_edf_plus_discontinuous,
1106        record_onsets,
1107    })
1108}
1109
1110/// Bulk decode EDF (16-bit) records.
1111/// Shared parameters for EDF/BDF record decoding.
1112struct DecodeParams<'a> {
1113    all_data: &'a [u8],
1114    records_to_read: usize,
1115    n_channels: usize,
1116    samples_per_record: &'a [usize],
1117    ch_byte_offsets: &'a [usize],
1118    ch_to_out: &'a [usize],
1119    gains: &'a [f64],
1120    cal_offsets: &'a [f64],
1121    record_byte_size: usize,
1122}
1123
1124/// Bulk decode EDF (16-bit) records.
1125///
1126/// Iterates channels in the outer loop so that writes to each output
1127/// channel's `Vec<f64>` are sequential, giving much better cache locality
1128/// on the write side. The input reads stride by `record_byte_size` which
1129/// is typically in L2 cache for common channel counts.
1130#[inline(never)]
1131fn decode_records_edf(p: &DecodeParams, out_data: &mut [Vec<f64>]) {
1132    for ch in 0..p.n_channels {
1133        let out_idx = p.ch_to_out[ch];
1134        if out_idx == usize::MAX {
1135            continue;
1136        }
1137
1138        let n_samples = p.samples_per_record[ch];
1139        let gain = p.gains[ch];
1140        let cal_offset = p.cal_offsets[ch];
1141        let ch_off = p.ch_byte_offsets[ch];
1142        let dst = &mut out_data[out_idx];
1143
1144        for rec in 0..p.records_to_read {
1145            let src = &p.all_data[rec * p.record_byte_size + ch_off..];
1146            let dst_start = rec * n_samples;
1147
1148            for s in 0..n_samples {
1149                let off = s * 2;
1150                let digital = i16::from_le_bytes([src[off], src[off + 1]]) as f64;
1151                dst[dst_start + s] = digital * gain + cal_offset;
1152            }
1153        }
1154    }
1155}
1156
1157/// Bulk decode BDF (24-bit) records.
1158///
1159/// Channel-major iteration for sequential output writes.
1160#[inline(never)]
1161fn decode_records_bdf(p: &DecodeParams, out_data: &mut [Vec<f64>]) {
1162    for ch in 0..p.n_channels {
1163        let out_idx = p.ch_to_out[ch];
1164        if out_idx == usize::MAX {
1165            continue;
1166        }
1167
1168        let n_samples = p.samples_per_record[ch];
1169        let gain = p.gains[ch];
1170        let cal_offset = p.cal_offsets[ch];
1171        let ch_off = p.ch_byte_offsets[ch];
1172        let dst = &mut out_data[out_idx];
1173
1174        for rec in 0..p.records_to_read {
1175            let src = &p.all_data[rec * p.record_byte_size + ch_off..];
1176            let dst_start = rec * n_samples;
1177
1178            for s in 0..n_samples {
1179                let off = s * 3;
1180                let b0 = src[off] as u32;
1181                let b1 = src[off + 1] as u32;
1182                let b2 = src[off + 2] as u32;
1183                let val = b0 | (b1 << 8) | (b2 << 16);
1184                // Sign extend from 24-bit
1185                let digital = if val & 0x800000 != 0 {
1186                    (val | 0xFF000000) as i32
1187                } else {
1188                    val as i32
1189                } as f64;
1190                dst[dst_start + s] = digital * gain + cal_offset;
1191            }
1192        }
1193    }
1194}
1195
1196/// Trim samples at sub-record precision for time range requests.
1197fn trim_time_range(
1198    out_data: &mut [Vec<f64>],
1199    channel_indices: &[usize],
1200    sampling_rates: &[f64],
1201    record_duration: f64,
1202    start_record: usize,
1203    opts: &ReadOptions,
1204) {
1205    // Trim from the start
1206    if let Some(start_t) = opts.start_time {
1207        let record_start_t = start_record as f64 * record_duration;
1208        if start_t > record_start_t {
1209            for (out_idx, &ch) in channel_indices.iter().enumerate() {
1210                let skip = ((start_t - record_start_t) * sampling_rates[ch]).round() as usize;
1211                if skip > 0 && skip < out_data[out_idx].len() {
1212                    // In-place shift via drain — avoids reallocation
1213                    out_data[out_idx].drain(..skip);
1214                }
1215            }
1216        }
1217    }
1218    // Trim from the end
1219    if let Some(end_t) = opts.end_time {
1220        let actual_start = opts.start_time.unwrap_or(0.0);
1221        let desired_dur = end_t - actual_start;
1222        for (out_idx, &ch) in channel_indices.iter().enumerate() {
1223            let max_samples = (desired_dur * sampling_rates[ch]).round() as usize;
1224            out_data[out_idx].truncate(max_samples);
1225        }
1226    }
1227}
1228
1229// ─── EDF+ TAL Parser ───────────────────────────────────────────────────────────
1230
1231/// Parse EDF+ Time-stamped Annotation Lists (TAL) from raw annotation channel bytes.
1232///
1233/// TAL format per record:
1234/// `+T\x14\x14\x00` — time-keeping annotation (onset only, gives record start time)
1235/// `+T\x15D\x14description\x14\x00` — annotation with onset T, duration D, description
1236///
1237/// Returns the record onset time (from the first TAL entry without a description).
1238/// This is critical for EDF+D (discontinuous) files where records may not be contiguous.
1239fn parse_edf_tal(data: &[u8], annotations: &mut Vec<Annotation>) -> Option<f64> {
1240    let mut record_onset = None;
1241    // Split on \x00 to get individual TAL entries
1242    for entry in data.split(|&b| b == 0) {
1243        if entry.is_empty() {
1244            continue;
1245        }
1246
1247        let s = String::from_utf8_lossy(entry);
1248        let s = s.trim_matches(|c: char| c == '\x14' || c == '\x00' || c == '\x15');
1249        if s.is_empty() {
1250            continue;
1251        }
1252
1253        // Split on \x14 (annotation separator)
1254        let parts: Vec<&str> = s.split('\x14').collect();
1255        if parts.is_empty() {
1256            continue;
1257        }
1258
1259        // First part: onset and optional duration
1260        let onset_dur = parts[0];
1261        let (onset, duration) = if let Some(dur_sep) = onset_dur.find('\x15') {
1262            let onset_str = &onset_dur[..dur_sep];
1263            let dur_str = &onset_dur[dur_sep + 1..];
1264            (
1265                onset_str
1266                    .trim_start_matches('+')
1267                    .parse::<f64>()
1268                    .unwrap_or(0.0),
1269                dur_str.parse::<f64>().unwrap_or(0.0),
1270            )
1271        } else {
1272            (
1273                onset_dur
1274                    .trim_start_matches('+')
1275                    .parse::<f64>()
1276                    .unwrap_or(0.0),
1277                0.0,
1278            )
1279        };
1280
1281        // Remaining parts are descriptions
1282        let has_description = parts[1..].iter().any(|d| !d.trim().is_empty());
1283        if !has_description && record_onset.is_none() {
1284            // First TAL entry without description = record onset time
1285            record_onset = Some(onset);
1286        }
1287        for desc in &parts[1..] {
1288            let desc = desc.trim();
1289            if desc.is_empty() {
1290                continue;
1291            }
1292            annotations.push(Annotation {
1293                onset,
1294                duration,
1295                description: desc.to_string(),
1296            });
1297        }
1298    }
1299    record_onset
1300}
1301
1302// ─── BrainVision Reader ────────────────────────────────────────────────────────
1303
1304/// Read EEG data from BrainVision format (.vhdr + .eeg/.dat).
1305///
1306/// Parses the `.vhdr` header to determine data layout, then reads the binary
1307/// data file and applies channel-specific resolution scaling.
1308///
1309/// Performance notes:
1310/// - Single bulk read of entire binary file
1311/// - For INT_16/INT_32: batch decode with pre-computed resolution
1312/// - For IEEE_FLOAT_32: safe reinterpret via `from_le_bytes` batched over slices
1313/// - Vectorized layout gets direct contiguous slice access per channel
1314pub fn read_brainvision(vhdr_path: &Path, opts: &ReadOptions) -> Result<EegData> {
1315    let header_text = std::fs::read_to_string(vhdr_path)?;
1316    let bv = parse_vhdr(&header_text)?;
1317
1318    let parent = vhdr_path.parent().unwrap_or(Path::new("."));
1319    let data_path = parent.join(&bv.data_file);
1320    if !data_path.exists() {
1321        return Err(BidsError::Io(std::io::Error::new(
1322            std::io::ErrorKind::NotFound,
1323            format!("BrainVision data file not found: {}", data_path.display()),
1324        )));
1325    }
1326
1327    let n_channels = bv.channels.len();
1328    let bps = bv.bytes_per_sample();
1329    let raw_data = std::fs::read(&data_path)?;
1330    let total_samples = raw_data.len() / bps / n_channels;
1331    let sampling_rate = bv
1332        .sampling_interval_us
1333        .map(|us| 1_000_000.0 / us)
1334        .unwrap_or(1.0);
1335
1336    // Channel selection: include then exclude
1337    let mut channel_indices: Vec<usize> = if let Some(ref wanted) = opts.channels {
1338        wanted
1339            .iter()
1340            .filter_map(|name| bv.channels.iter().position(|c| c.name == *name))
1341            .collect()
1342    } else {
1343        (0..n_channels).collect()
1344    };
1345
1346    if let Some(ref excl) = opts.exclude {
1347        let excl_indices: Vec<usize> = excl
1348            .iter()
1349            .filter_map(|name| bv.channels.iter().position(|c| c.name == *name))
1350            .collect();
1351        channel_indices.retain(|i| !excl_indices.contains(i));
1352    }
1353
1354    // Stim channel detection
1355    let stim_names: Vec<String> = if let Some(ref names) = opts.stim_channel {
1356        names.clone()
1357    } else {
1358        bv.channels
1359            .iter()
1360            .filter(|c| {
1361                let lower = c.name.to_lowercase();
1362                lower == "status" || lower == "trigger" || lower.starts_with("sti ")
1363            })
1364            .map(|c| c.name.clone())
1365            .collect()
1366    };
1367    let stim_channel_indices: Vec<usize> = channel_indices
1368        .iter()
1369        .enumerate()
1370        .filter(|(_, ch)| stim_names.iter().any(|n| n == &bv.channels[**ch].name))
1371        .map(|(out_idx, _)| out_idx)
1372        .collect();
1373
1374    let start_sample = opts
1375        .start_time
1376        .map(|t| (t * sampling_rate).round() as usize)
1377        .unwrap_or(0)
1378        .min(total_samples);
1379    let end_sample = opts
1380        .end_time
1381        .map(|t| (t * sampling_rate).round() as usize)
1382        .unwrap_or(total_samples)
1383        .min(total_samples);
1384
1385    let n_out = end_sample.saturating_sub(start_sample);
1386
1387    let out_data = if bv.data_orientation == BvOrientation::Multiplexed {
1388        decode_bv_multiplexed(
1389            &raw_data,
1390            &bv,
1391            &channel_indices,
1392            start_sample,
1393            n_out,
1394            n_channels,
1395            bps,
1396        )
1397    } else {
1398        decode_bv_vectorized(
1399            &raw_data,
1400            &bv,
1401            &channel_indices,
1402            start_sample,
1403            n_out,
1404            total_samples,
1405            n_channels,
1406            bps,
1407        )
1408    };
1409
1410    // Read .vmrk marker file if it exists
1411    let annotations = if let Some(ref marker_file) = bv.marker_file {
1412        let vmrk_path = parent.join(marker_file);
1413        if vmrk_path.exists() {
1414            let vmrk_text = std::fs::read_to_string(&vmrk_path)?;
1415            parse_vmrk(&vmrk_text, sampling_rate)
1416        } else {
1417            Vec::new()
1418        }
1419    } else {
1420        // Try conventional name: same stem as .vhdr but .vmrk
1421        let vmrk_path = vhdr_path.with_extension("vmrk");
1422        if vmrk_path.exists() {
1423            let vmrk_text = std::fs::read_to_string(&vmrk_path)?;
1424            parse_vmrk(&vmrk_text, sampling_rate)
1425        } else {
1426            Vec::new()
1427        }
1428    };
1429
1430    Ok(EegData {
1431        channel_labels: channel_indices
1432            .iter()
1433            .map(|&i| bv.channels[i].name.clone())
1434            .collect(),
1435        data: out_data,
1436        sampling_rates: vec![sampling_rate; channel_indices.len()],
1437        duration: n_out as f64 / sampling_rate,
1438        annotations,
1439        stim_channel_indices,
1440        is_discontinuous: false,
1441        record_onsets: Vec::new(),
1442    })
1443}
1444
1445/// Decode multiplexed BrainVision data using cache-friendly tiled decoding.
1446/// Layout: [ch0_s0, ch1_s0, ..., chN_s0, ch0_s1, ch1_s1, ...]
1447///
1448/// For multiplexed data, per-channel iteration re-reads the entire file from
1449/// memory for each channel (N×file_size memory traffic), while per-sample
1450/// iteration scatters writes across N output buffers.
1451///
1452/// Tiled approach: process in blocks of TILE samples. Within each tile,
1453/// both the input block and output tile fit in L2 cache, giving good locality
1454/// for both reads and writes.
1455#[inline(never)]
1456fn decode_bv_multiplexed(
1457    raw: &[u8],
1458    bv: &BvHeader,
1459    indices: &[usize],
1460    start: usize,
1461    count: usize,
1462    n_ch: usize,
1463    bps: usize,
1464) -> Vec<Vec<f64>> {
1465    let mut out: Vec<Vec<f64>> = indices.iter().map(|_| vec![0.0f64; count]).collect();
1466
1467    // Build ch→(out_idx, resolution) lookup
1468    let mut ch_map: Vec<(usize, f64)> = vec![(usize::MAX, 0.0); n_ch];
1469    for (out_idx, &ch) in indices.iter().enumerate() {
1470        ch_map[ch] = (out_idx, bv.channels[ch].resolution);
1471    }
1472
1473    let frame_bytes = n_ch * bps;
1474
1475    // Tile size: chosen so tile_input + tile_output fits in L2 cache (~256KB).
1476    // tile_input = TILE * n_ch * bps, tile_output = TILE * n_out * 8
1477    // For 64ch × 2B: input = TILE*128, output = TILE*64*8 = TILE*512
1478    // Total = TILE * 640 → TILE = 256K/640 ≈ 400. Use 512 for power-of-2.
1479    const TILE: usize = 512;
1480    let n_out = indices.len();
1481
1482    match bv.data_format {
1483        BvDataFormat::Int16 => {
1484            let mut s = 0;
1485            while s < count {
1486                let tile_end = (s + TILE).min(count);
1487                for oi in 0..n_out {
1488                    let ch = indices[oi];
1489                    let res = ch_map[ch].1;
1490                    let dst = &mut out[oi][s..tile_end];
1491                    let mut base = (start + s) * frame_bytes + ch * 2;
1492                    for d in dst.iter_mut() {
1493                        *d = i16::from_le_bytes([raw[base], raw[base + 1]]) as f64 * res;
1494                        base += frame_bytes;
1495                    }
1496                }
1497                s = tile_end;
1498            }
1499        }
1500        BvDataFormat::Float32 => {
1501            let mut s = 0;
1502            while s < count {
1503                let tile_end = (s + TILE).min(count);
1504                for oi in 0..n_out {
1505                    let ch = indices[oi];
1506                    let res = ch_map[ch].1;
1507                    let dst = &mut out[oi][s..tile_end];
1508                    let mut base = (start + s) * frame_bytes + ch * 4;
1509                    for d in dst.iter_mut() {
1510                        *d = f32::from_le_bytes([
1511                            raw[base],
1512                            raw[base + 1],
1513                            raw[base + 2],
1514                            raw[base + 3],
1515                        ]) as f64
1516                            * res;
1517                        base += frame_bytes;
1518                    }
1519                }
1520                s = tile_end;
1521            }
1522        }
1523        BvDataFormat::Int32 => {
1524            let mut s = 0;
1525            while s < count {
1526                let tile_end = (s + TILE).min(count);
1527                for oi in 0..n_out {
1528                    let ch = indices[oi];
1529                    let res = ch_map[ch].1;
1530                    let dst = &mut out[oi][s..tile_end];
1531                    let mut base = (start + s) * frame_bytes + ch * 4;
1532                    for d in dst.iter_mut() {
1533                        *d = i32::from_le_bytes([
1534                            raw[base],
1535                            raw[base + 1],
1536                            raw[base + 2],
1537                            raw[base + 3],
1538                        ]) as f64
1539                            * res;
1540                        base += frame_bytes;
1541                    }
1542                }
1543                s = tile_end;
1544            }
1545        }
1546    }
1547    out
1548}
1549
1550/// Decode vectorized BrainVision data.
1551/// Layout: [ch0_s0, ch0_s1, ..., ch0_sN, ch1_s0, ch1_s1, ...]
1552#[inline(never)]
1553fn decode_bv_vectorized(
1554    raw: &[u8],
1555    bv: &BvHeader,
1556    indices: &[usize],
1557    start: usize,
1558    count: usize,
1559    total: usize,
1560    _n_ch: usize,
1561    bps: usize,
1562) -> Vec<Vec<f64>> {
1563    let mut out: Vec<Vec<f64>> = indices.iter().map(|_| vec![0.0f64; count]).collect();
1564
1565    let ch_stride = total * bps; // bytes per channel's contiguous block
1566
1567    match bv.data_format {
1568        BvDataFormat::Int16 => {
1569            for (out_idx, &ch) in indices.iter().enumerate() {
1570                let res = bv.channels[ch].resolution;
1571                let ch_base = ch * ch_stride + start * 2;
1572                let src = &raw[ch_base..];
1573                for (s, d) in out[out_idx].iter_mut().enumerate() {
1574                    let off = s * 2;
1575                    *d = i16::from_le_bytes([src[off], src[off + 1]]) as f64 * res;
1576                }
1577            }
1578        }
1579        BvDataFormat::Float32 => {
1580            for (out_idx, &ch) in indices.iter().enumerate() {
1581                let res = bv.channels[ch].resolution;
1582                let ch_base = ch * ch_stride + start * 4;
1583                let src = &raw[ch_base..];
1584                for (s, d) in out[out_idx].iter_mut().enumerate() {
1585                    let off = s * 4;
1586                    *d = f32::from_le_bytes([src[off], src[off + 1], src[off + 2], src[off + 3]])
1587                        as f64
1588                        * res;
1589                }
1590            }
1591        }
1592        BvDataFormat::Int32 => {
1593            for (out_idx, &ch) in indices.iter().enumerate() {
1594                let res = bv.channels[ch].resolution;
1595                let ch_base = ch * ch_stride + start * 4;
1596                let src = &raw[ch_base..];
1597                for (s, d) in out[out_idx].iter_mut().enumerate() {
1598                    let off = s * 4;
1599                    *d = i32::from_le_bytes([src[off], src[off + 1], src[off + 2], src[off + 3]])
1600                        as f64
1601                        * res;
1602                }
1603            }
1604        }
1605    }
1606    out
1607}
1608
1609// ─── Unified reader ────────────────────────────────────────────────────────────
1610
1611/// Detect format from file extension and read EEG data.
1612///
1613/// Supported formats:
1614/// - `.edf` — European Data Format
1615/// - `.bdf` — BioSemi Data Format
1616/// - `.vhdr` — BrainVision (reads companion `.eeg`/`.dat` file)
1617pub fn read_eeg_data(path: &Path, opts: &ReadOptions) -> Result<EegData> {
1618    let ext = path
1619        .extension()
1620        .and_then(|e| e.to_str())
1621        .unwrap_or("")
1622        .to_lowercase();
1623
1624    match ext.as_str() {
1625        "edf" | "bdf" => read_edf(path, opts),
1626        "vhdr" => read_brainvision(path, opts),
1627        "set" => read_eeglab_set(path, opts),
1628        _ => Err(BidsError::FileType(format!(
1629            "Unsupported EEG data format: .{ext}. Supported: .edf, .bdf, .vhdr, .set"
1630        ))),
1631    }
1632}
1633
1634// ─── Internal helpers ──────────────────────────────────────────────────────────
1635
1636fn parse_header_int<T: std::str::FromStr>(bytes: &[u8]) -> Result<T> {
1637    String::from_utf8_lossy(bytes)
1638        .trim()
1639        .parse::<T>()
1640        .map_err(|_| {
1641            BidsError::Csv(format!(
1642                "Failed to parse header field: '{}'",
1643                String::from_utf8_lossy(bytes).trim()
1644            ))
1645        })
1646}
1647
1648fn parse_header_f64(bytes: &[u8]) -> Result<f64> {
1649    parse_header_int(bytes)
1650}
1651
1652fn read_field(ext: &[u8], ch: usize, width: usize, base_offset: usize) -> String {
1653    let offset = base_offset + ch * width;
1654    if offset + width <= ext.len() {
1655        String::from_utf8_lossy(&ext[offset..offset + width])
1656            .trim()
1657            .to_string()
1658    } else {
1659        String::new()
1660    }
1661}
1662
1663fn read_field_f64(ext: &[u8], ch: usize, width: usize, base_offset: usize) -> Result<f64> {
1664    let s = read_field(ext, ch, width, base_offset);
1665    s.parse::<f64>().map_err(|_| {
1666        BidsError::Csv(format!(
1667            "Failed to parse channel {ch} field at offset {base_offset}: '{s}'"
1668        ))
1669    })
1670}
1671
1672fn read_field_int(ext: &[u8], ch: usize, width: usize, base_offset: usize) -> Result<usize> {
1673    let s = read_field(ext, ch, width, base_offset);
1674    s.parse::<usize>().map_err(|_| {
1675        BidsError::Csv(format!(
1676            "Failed to parse channel {ch} field at offset {base_offset}: '{s}'"
1677        ))
1678    })
1679}
1680
1681// ─── BrainVision header parsing ────────────────────────────────────────────────
1682
1683#[derive(Debug, Clone, PartialEq)]
1684enum BvOrientation {
1685    Multiplexed,
1686    Vectorized,
1687}
1688
1689#[derive(Debug, Clone, PartialEq)]
1690enum BvDataFormat {
1691    Int16,
1692    Float32,
1693    Int32,
1694}
1695
1696#[derive(Debug, Clone)]
1697struct BvChannel {
1698    name: String,
1699    resolution: f64,
1700}
1701
1702#[derive(Debug, Clone)]
1703struct BvHeader {
1704    data_file: String,
1705    marker_file: Option<String>,
1706    data_format: BvDataFormat,
1707    data_orientation: BvOrientation,
1708    channels: Vec<BvChannel>,
1709    sampling_interval_us: Option<f64>,
1710}
1711
1712impl BvHeader {
1713    fn bytes_per_sample(&self) -> usize {
1714        match self.data_format {
1715            BvDataFormat::Int16 => 2,
1716            BvDataFormat::Float32 | BvDataFormat::Int32 => 4,
1717        }
1718    }
1719}
1720
1721fn parse_vhdr(text: &str) -> Result<BvHeader> {
1722    let mut data_file = String::new();
1723    let mut marker_file: Option<String> = None;
1724    let mut data_format = BvDataFormat::Int16;
1725    let mut orientation = BvOrientation::Multiplexed;
1726    let mut sampling_interval: Option<f64> = None;
1727    let mut channels = Vec::new();
1728    let mut section = String::new();
1729
1730    for line in text.lines() {
1731        let line = line.trim();
1732        if line.is_empty() || line.starts_with(';') {
1733            continue;
1734        }
1735        if line.starts_with('[') && line.ends_with(']') {
1736            section = line[1..line.len() - 1].to_lowercase();
1737            continue;
1738        }
1739
1740        if let Some((key, value)) = line.split_once('=') {
1741            let key = key.trim();
1742            let value = value.trim();
1743            match section.as_str() {
1744                "common infos" => match key {
1745                    "DataFile" => data_file = value.to_string(),
1746                    "MarkerFile" => marker_file = Some(value.to_string()),
1747                    "DataOrientation" => {
1748                        orientation = if value.to_uppercase().contains("VECTORIZED") {
1749                            BvOrientation::Vectorized
1750                        } else {
1751                            BvOrientation::Multiplexed
1752                        };
1753                    }
1754                    "SamplingInterval" => {
1755                        sampling_interval = value.parse().ok();
1756                    }
1757                    _ => {}
1758                },
1759                "binary infos" => {
1760                    if key == "BinaryFormat" {
1761                        data_format = match value.to_uppercase().as_str() {
1762                            "IEEE_FLOAT_32" => BvDataFormat::Float32,
1763                            "INT_32" => BvDataFormat::Int32,
1764                            _ => BvDataFormat::Int16,
1765                        };
1766                    }
1767                }
1768                "channel infos" => {
1769                    if key.starts_with("Ch") || key.starts_with("ch") {
1770                        let parts: Vec<&str> = value.splitn(4, ',').collect();
1771                        let name = parts
1772                            .first()
1773                            .map(|s| s.trim().to_string())
1774                            .unwrap_or_default();
1775                        let resolution = parts
1776                            .get(2)
1777                            .and_then(|s| s.trim().parse::<f64>().ok())
1778                            .unwrap_or(1.0);
1779                        channels.push(BvChannel { name, resolution });
1780                    }
1781                }
1782                _ => {}
1783            }
1784        }
1785    }
1786
1787    if data_file.is_empty() {
1788        return Err(BidsError::Csv("BrainVision header missing DataFile".into()));
1789    }
1790    if channels.is_empty() {
1791        return Err(BidsError::Csv("BrainVision header has no channels".into()));
1792    }
1793
1794    Ok(BvHeader {
1795        data_file,
1796        marker_file,
1797        data_format,
1798        data_orientation: orientation,
1799        channels,
1800        sampling_interval_us: sampling_interval,
1801    })
1802}
1803
1804// ─── BrainVision .vmrk marker parser ───────────────────────────────────────────
1805
1806/// Parse BrainVision marker file (.vmrk) into annotations.
1807///
1808/// Marker format: `Mk<n>=<type>,<description>,<position>,<size>,<channel>,<date>`
1809/// where position is 1-indexed sample number.
1810fn parse_vmrk(text: &str, sampling_rate: f64) -> Vec<Annotation> {
1811    let mut annotations = Vec::new();
1812    let mut section = String::new();
1813
1814    for line in text.lines() {
1815        let line = line.trim();
1816        if line.is_empty() || line.starts_with(';') {
1817            continue;
1818        }
1819        if line.starts_with('[') && line.ends_with(']') {
1820            section = line[1..line.len() - 1].to_lowercase();
1821            continue;
1822        }
1823
1824        if section == "marker infos" {
1825            if let Some((key, value)) = line.split_once('=') {
1826                let key = key.trim();
1827                if key.starts_with("Mk") || key.starts_with("mk") {
1828                    // Mk1=Stimulus,S  1,26214,1,0
1829                    // type, description, position, size, channel[, date]
1830                    let parts: Vec<&str> = value.splitn(6, ',').collect();
1831                    if parts.len() >= 3 {
1832                        let marker_type = parts[0].trim();
1833                        let description = parts[1].trim();
1834                        let position: usize = parts[2].trim().parse().unwrap_or(1);
1835                        let size: usize = parts
1836                            .get(3)
1837                            .and_then(|s| s.trim().parse().ok())
1838                            .unwrap_or(1);
1839
1840                        // Position is 1-indexed sample number
1841                        let onset = (position.saturating_sub(1)) as f64 / sampling_rate;
1842                        let duration = if size > 1 {
1843                            size as f64 / sampling_rate
1844                        } else {
1845                            0.0
1846                        };
1847
1848                        // Build description like MNE: "type/description" or just description
1849                        let desc = if marker_type.is_empty()
1850                            || marker_type == "Stimulus"
1851                            || marker_type == "Response"
1852                            || marker_type == "Comment"
1853                        {
1854                            description.to_string()
1855                        } else {
1856                            format!("{marker_type}/{description}")
1857                        };
1858
1859                        if !desc.is_empty() {
1860                            annotations.push(Annotation {
1861                                onset,
1862                                duration,
1863                                description: desc,
1864                            });
1865                        }
1866                    }
1867                }
1868            }
1869        }
1870    }
1871
1872    annotations
1873}
1874
1875/// Read BrainVision markers from a .vmrk file directly.
1876///
1877/// This is a standalone function for reading markers without reading the
1878/// signal data. Useful for event-only analysis.
1879pub fn read_brainvision_markers(vmrk_path: &Path, sampling_rate: f64) -> Result<Vec<Annotation>> {
1880    let text = std::fs::read_to_string(vmrk_path)?;
1881    Ok(parse_vmrk(&text, sampling_rate))
1882}
1883
1884// ─── EEGLAB .set/.fdt reader ───────────────────────────────────────────────────
1885
1886/// Read EEG data from an EEGLAB `.set` file and its companion `.fdt` data file.
1887///
1888/// The `.set` file is a MATLAB MAT v5 file containing metadata (channel names,
1889/// sampling rate, etc.). The actual signal data is stored in a companion `.fdt`
1890/// file as a flat binary array of `f32` values in channels × samples order.
1891///
1892/// # Limitations
1893///
1894/// This reader handles the common case where `.set` metadata is paired with
1895/// a binary `.fdt` file. Complex `.set` files that embed data directly in
1896/// the MAT structure (MATLAB v7.3 / HDF5) are not supported — convert to
1897/// EDF or BrainVision first.
1898///
1899/// # Errors
1900///
1901/// Returns an error if the `.fdt` companion doesn't exist, or the `.set` file
1902/// can't be parsed for the required metadata fields.
1903pub fn read_eeglab_set(path: &Path, opts: &ReadOptions) -> Result<EegData> {
1904    // Read the MAT v5 file to extract basic metadata
1905    let set_bytes = std::fs::read(path)?;
1906
1907    // Parse minimal metadata from MAT v5 header + data elements
1908    let (n_channels, n_samples, srate, channel_labels) = parse_set_metadata(&set_bytes, path)?;
1909
1910    // Find companion .fdt file
1911    let fdt_path = path.with_extension("fdt");
1912    if !fdt_path.exists() {
1913        return Err(BidsError::DataFormat(format!(
1914            "Companion .fdt file not found for {}. \
1915             If the data is embedded in the .set file (MATLAB v7.3/HDF5), \
1916             convert to EDF or BrainVision format first.",
1917            path.display()
1918        )));
1919    }
1920
1921    // Read binary .fdt: float32, channels × samples, little-endian
1922    let fdt_bytes = std::fs::read(&fdt_path)?;
1923    let expected_size = n_channels * n_samples * 4;
1924    if fdt_bytes.len() < expected_size {
1925        return Err(BidsError::DataFormat(format!(
1926            ".fdt file too small: expected {} bytes ({} ch × {} samp × 4), got {}",
1927            expected_size,
1928            n_channels,
1929            n_samples,
1930            fdt_bytes.len()
1931        )));
1932    }
1933
1934    let mut data = vec![Vec::with_capacity(n_samples); n_channels];
1935    #[allow(clippy::needless_range_loop)]
1936    for s in 0..n_samples {
1937        for ch in 0..n_channels {
1938            let offset = (s * n_channels + ch) * 4;
1939            let val = f32::from_le_bytes([
1940                fdt_bytes[offset],
1941                fdt_bytes[offset + 1],
1942                fdt_bytes[offset + 2],
1943                fdt_bytes[offset + 3],
1944            ]);
1945            data[ch].push(val as f64);
1946        }
1947    }
1948
1949    // Apply channel selection from opts
1950    let (data, channel_labels) = if let Some(ref include) = opts.channels {
1951        let mut new_data = Vec::new();
1952        let mut new_labels = Vec::new();
1953        for label in include {
1954            if let Some(idx) = channel_labels.iter().position(|l| l == label) {
1955                new_data.push(data[idx].clone());
1956                new_labels.push(label.clone());
1957            }
1958        }
1959        (new_data, new_labels)
1960    } else if let Some(ref exclude) = opts.exclude {
1961        let mut new_data = Vec::new();
1962        let mut new_labels = Vec::new();
1963        for (idx, label) in channel_labels.iter().enumerate() {
1964            if !exclude.contains(label) {
1965                new_data.push(data[idx].clone());
1966                new_labels.push(label.clone());
1967            }
1968        }
1969        (new_data, new_labels)
1970    } else {
1971        (data, channel_labels)
1972    };
1973
1974    let n_ch = data.len();
1975    let duration = if srate > 0.0 {
1976        n_samples as f64 / srate
1977    } else {
1978        0.0
1979    };
1980
1981    Ok(EegData {
1982        channel_labels,
1983        data,
1984        sampling_rates: vec![srate; n_ch],
1985        duration,
1986        annotations: Vec::new(),
1987        stim_channel_indices: Vec::new(),
1988        is_discontinuous: false,
1989        record_onsets: Vec::new(),
1990    })
1991}
1992
1993/// Parse minimal metadata from a MAT v5 `.set` file.
1994///
1995/// Extracts: nbchan (number of channels), pnts (number of samples),
1996/// srate (sampling rate), and channel labels.
1997///
1998/// This is a minimal parser for the MATLAB Level 5 MAT-file format,
1999/// reading just enough to get the EEG struct's scalar fields and chanlocs.
2000fn parse_set_metadata(bytes: &[u8], path: &Path) -> Result<(usize, usize, f64, Vec<String>)> {
2001    // MAT v5 files start with a 128-byte header: 116 bytes text + 8 reserved + 4 version + 2 endian
2002    if bytes.len() < 128 {
2003        return Err(BidsError::DataFormat(format!(
2004            "{}: File too small to be a valid MAT v5 file",
2005            path.display()
2006        )));
2007    }
2008
2009    let header_text = String::from_utf8_lossy(&bytes[..116]);
2010    if !header_text.contains("MATLAB") {
2011        return Err(BidsError::DataFormat(format!(
2012            "{}: Not a MATLAB MAT v5 file (header doesn't contain 'MATLAB'). \
2013             If this is a MATLAB v7.3 (HDF5) file, convert with: \
2014             pop_saveset(EEG, 'filename', 'output.set', 'savemode', 'onefile', 'version', '7')",
2015            path.display()
2016        )));
2017    }
2018
2019    // Scan the binary for known field name patterns
2020    // Look for common EEGLAB field values as ASCII strings
2021    let text = String::from_utf8_lossy(bytes);
2022
2023    // Try to find nbchan, pnts, srate by scanning for field names
2024    // In MAT v5, struct field names are stored as arrays of fixed-width strings
2025    let mut n_channels = 0usize;
2026    let mut n_samples = 0usize;
2027    let mut srate = 0.0f64;
2028    let mut channel_labels = Vec::new();
2029
2030    // Heuristic: scan for ASCII patterns that encode the metadata.
2031    // This is simplified — a full MAT parser would decode the tag/data structure.
2032    // We look for the pattern: field_name followed by a numeric value.
2033    for window in bytes.windows(6) {
2034        if window == b"nbchan" {
2035            // Look for a double value in the next ~50 bytes
2036            if let Some(v) = find_next_double(
2037                bytes,
2038                bytes.len().min(offset_of(bytes, window) + 100),
2039                offset_of(bytes, window),
2040            ) {
2041                n_channels = v as usize;
2042            }
2043        }
2044        if window[..4] == *b"pnts" {
2045            if let Some(v) = find_next_double(
2046                bytes,
2047                bytes.len().min(offset_of(bytes, window) + 100),
2048                offset_of(bytes, window),
2049            ) {
2050                n_samples = v as usize;
2051            }
2052        }
2053        if window[..5] == *b"srate" {
2054            if let Some(v) = find_next_double(
2055                bytes,
2056                bytes.len().min(offset_of(bytes, window) + 100),
2057                offset_of(bytes, window),
2058            ) {
2059                srate = v;
2060            }
2061        }
2062    }
2063
2064    // Generate default channel labels if we couldn't parse them from chanlocs
2065    if channel_labels.is_empty() && n_channels > 0 {
2066        channel_labels = (0..n_channels)
2067            .map(|i| format!("EEG{:03}", i + 1))
2068            .collect();
2069    }
2070
2071    // Try to extract channel labels from chanlocs.labels
2072    // Look for sequences of short ASCII strings after "labels"
2073    if let Some(pos) = text.find("labels") {
2074        let search_region = &bytes[pos..bytes.len().min(pos + n_channels * 20 + 200)];
2075        let mut labels = Vec::new();
2076        let mut i = 0;
2077        while i < search_region.len() && labels.len() < n_channels {
2078            // Look for runs of printable ASCII that could be channel names
2079            if search_region[i].is_ascii_alphanumeric() {
2080                let start = i;
2081                while i < search_region.len()
2082                    && search_region[i].is_ascii_graphic()
2083                    && search_region[i] != 0
2084                {
2085                    i += 1;
2086                }
2087                let candidate = String::from_utf8_lossy(&search_region[start..i]).to_string();
2088                if candidate.len() >= 2 && candidate.len() <= 10 && candidate != "labels" {
2089                    labels.push(candidate);
2090                }
2091            } else {
2092                i += 1;
2093            }
2094        }
2095        if labels.len() == n_channels {
2096            channel_labels = labels;
2097        }
2098    }
2099
2100    if n_channels == 0 || n_samples == 0 || srate == 0.0 {
2101        return Err(BidsError::DataFormat(format!(
2102            "{}: Could not extract EEG metadata from .set file \
2103             (nbchan={}, pnts={}, srate={}). The file may use an unsupported \
2104             MAT format. Convert with EEGLAB: pop_saveset(EEG, 'savemode', 'onefile')",
2105            path.display(),
2106            n_channels,
2107            n_samples,
2108            srate
2109        )));
2110    }
2111
2112    Ok((n_channels, n_samples, srate, channel_labels))
2113}
2114
2115fn offset_of(haystack: &[u8], needle: &[u8]) -> usize {
2116    needle.as_ptr() as usize - haystack.as_ptr() as usize
2117}
2118
2119fn find_next_double(bytes: &[u8], end: usize, start: usize) -> Option<f64> {
2120    // MAT v5 stores doubles as 8-byte little-endian IEEE 754
2121    // Look for a double that makes sense as a positive integer or frequency
2122    let search = &bytes[start..end.min(bytes.len())];
2123    for offset in (0..search.len().saturating_sub(7)).step_by(8) {
2124        let val = f64::from_le_bytes([
2125            search[offset],
2126            search[offset + 1],
2127            search[offset + 2],
2128            search[offset + 3],
2129            search[offset + 4],
2130            search[offset + 5],
2131            search[offset + 6],
2132            search[offset + 7],
2133        ]);
2134        if val.is_finite() && val > 0.0 && val < 1e9 {
2135            return Some(val);
2136        }
2137    }
2138    None
2139}
2140
2141#[cfg(test)]
2142mod tests {
2143    use super::*;
2144    use std::io::Write;
2145
2146    fn create_test_edf(
2147        path: &Path,
2148        n_channels: usize,
2149        n_records: usize,
2150        samples_per_record: usize,
2151    ) {
2152        let mut file = std::fs::File::create(path).unwrap();
2153        let mut hdr = [b' '; 256];
2154        hdr[0..1].copy_from_slice(b"0");
2155        hdr[168..176].copy_from_slice(b"01.01.01");
2156        hdr[176..184].copy_from_slice(b"00.00.00");
2157        let hs = format!("{:<8}", 256 + n_channels * 256);
2158        hdr[184..192].copy_from_slice(hs.as_bytes());
2159        let nr = format!("{:<8}", n_records);
2160        hdr[236..244].copy_from_slice(nr.as_bytes());
2161        hdr[244..252].copy_from_slice(b"1       ");
2162        let nc = format!("{:<4}", n_channels);
2163        hdr[252..256].copy_from_slice(nc.as_bytes());
2164        file.write_all(&hdr).unwrap();
2165
2166        let mut ext = vec![b' '; n_channels * 256];
2167        for i in 0..n_channels {
2168            let label = format!("{:<16}", format!("EEG{}", i + 1));
2169            ext[i * 16..i * 16 + 16].copy_from_slice(label.as_bytes());
2170            let o = n_channels * 96 + i * 8;
2171            ext[o..o + 2].copy_from_slice(b"uV");
2172            let s = format!("{:<8}", "-3200");
2173            ext[n_channels * 104 + i * 8..n_channels * 104 + i * 8 + 8]
2174                .copy_from_slice(s.as_bytes());
2175            let s = format!("{:<8}", "3200");
2176            ext[n_channels * 112 + i * 8..n_channels * 112 + i * 8 + 8]
2177                .copy_from_slice(s.as_bytes());
2178            let s = format!("{:<8}", "-32768");
2179            ext[n_channels * 120 + i * 8..n_channels * 120 + i * 8 + 8]
2180                .copy_from_slice(s.as_bytes());
2181            let s = format!("{:<8}", "32767");
2182            ext[n_channels * 128 + i * 8..n_channels * 128 + i * 8 + 8]
2183                .copy_from_slice(s.as_bytes());
2184            let s = format!("{:<8}", samples_per_record);
2185            ext[n_channels * 216 + i * 8..n_channels * 216 + i * 8 + 8]
2186                .copy_from_slice(s.as_bytes());
2187        }
2188        file.write_all(&ext).unwrap();
2189
2190        // Write data records in bulk (one buffer per record)
2191        let rec_bytes = n_channels * samples_per_record * 2;
2192        let mut buf = vec![0u8; rec_bytes];
2193        for rec in 0..n_records {
2194            for ch in 0..n_channels {
2195                for s in 0..samples_per_record {
2196                    let t = rec as f64 + s as f64 / samples_per_record as f64;
2197                    let value = (1000.0
2198                        * (2.0 * std::f64::consts::PI * (ch as f64 + 1.0) * t).sin())
2199                        as i16;
2200                    let off = (ch * samples_per_record + s) * 2;
2201                    buf[off..off + 2].copy_from_slice(&value.to_le_bytes());
2202                }
2203            }
2204            file.write_all(&buf).unwrap();
2205        }
2206    }
2207
2208    #[test]
2209    fn test_read_edf_basic() {
2210        let dir = std::env::temp_dir().join("bids_eeg_data_test_edf");
2211        std::fs::create_dir_all(&dir).unwrap();
2212        let path = dir.join("test.edf");
2213        create_test_edf(&path, 3, 2, 256);
2214
2215        let data = read_edf(&path, &ReadOptions::default()).unwrap();
2216        assert_eq!(data.n_channels(), 3);
2217        assert_eq!(data.n_samples(0), 512);
2218        assert_eq!(data.channel_labels, vec!["EEG1", "EEG2", "EEG3"]);
2219        assert!((data.sampling_rates[0] - 256.0).abs() < 0.01);
2220        assert!((data.duration - 2.0).abs() < 0.01);
2221        for ch_data in &data.data {
2222            for &v in ch_data {
2223                assert!(v >= -3200.1 && v <= 3200.1, "Value {} out of range", v);
2224            }
2225        }
2226        std::fs::remove_dir_all(&dir).unwrap();
2227    }
2228
2229    #[test]
2230    fn test_read_edf_unknown_n_records() {
2231        // EDF spec allows n_records == -1 meaning "unknown".
2232        // The reader should compute the count from the file size.
2233        let dir = std::env::temp_dir().join("bids_eeg_data_test_nrec");
2234        std::fs::create_dir_all(&dir).unwrap();
2235        let path = dir.join("test.edf");
2236
2237        // Create a normal EDF with 2 channels, 3 records, 128 spr
2238        create_test_edf(&path, 2, 3, 128);
2239
2240        // Patch the header to set n_records = -1
2241        let mut bytes = std::fs::read(&path).unwrap();
2242        let neg1 = format!("{:<8}", "-1");
2243        bytes[236..244].copy_from_slice(neg1.as_bytes());
2244        std::fs::write(&path, &bytes).unwrap();
2245
2246        let data = read_edf(&path, &ReadOptions::default()).unwrap();
2247        assert_eq!(data.n_channels(), 2);
2248        assert_eq!(data.n_samples(0), 3 * 128); // should infer 3 records from file size
2249        assert!((data.duration - 3.0).abs() < 0.01);
2250
2251        std::fs::remove_dir_all(&dir).unwrap();
2252    }
2253
2254    #[test]
2255    fn test_read_edf_channel_select() {
2256        let dir = std::env::temp_dir().join("bids_eeg_data_test_chsel");
2257        std::fs::create_dir_all(&dir).unwrap();
2258        let path = dir.join("test.edf");
2259        create_test_edf(&path, 4, 1, 128);
2260
2261        let opts = ReadOptions::new().with_channels(vec!["EEG1".into(), "EEG3".into()]);
2262        let data = read_edf(&path, &opts).unwrap();
2263        assert_eq!(data.n_channels(), 2);
2264        assert_eq!(data.channel_labels, vec!["EEG1", "EEG3"]);
2265        assert_eq!(data.n_samples(0), 128);
2266        std::fs::remove_dir_all(&dir).unwrap();
2267    }
2268
2269    #[test]
2270    fn test_read_edf_time_range() {
2271        let dir = std::env::temp_dir().join("bids_eeg_data_test_time");
2272        std::fs::create_dir_all(&dir).unwrap();
2273        let path = dir.join("test.edf");
2274        create_test_edf(&path, 2, 4, 256);
2275
2276        let opts = ReadOptions::new().with_time_range(1.0, 3.0);
2277        let data = read_edf(&path, &opts).unwrap();
2278        assert_eq!(data.n_channels(), 2);
2279        assert_eq!(data.n_samples(0), 512);
2280        std::fs::remove_dir_all(&dir).unwrap();
2281    }
2282
2283    #[test]
2284    fn test_eeg_data_select_channels() {
2285        let data = EegData {
2286            channel_labels: vec!["Fp1".into(), "Fp2".into(), "Cz".into()],
2287            data: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]],
2288            sampling_rates: vec![256.0; 3],
2289            duration: 0.0078125,
2290            annotations: Vec::new(),
2291            stim_channel_indices: Vec::new(),
2292            is_discontinuous: false,
2293            record_onsets: Vec::new(),
2294        };
2295        let subset = data.select_channels(&["Fp1", "Cz"]);
2296        assert_eq!(subset.n_channels(), 2);
2297        assert_eq!(subset.channel_labels, vec!["Fp1", "Cz"]);
2298        assert_eq!(subset.channel(0), Some(&[1.0, 2.0][..]));
2299        assert_eq!(subset.channel(1), Some(&[5.0, 6.0][..]));
2300    }
2301
2302    #[test]
2303    fn test_eeg_data_time_slice() {
2304        let data = EegData {
2305            channel_labels: vec!["Fp1".into()],
2306            data: vec![(0..256).map(|i| i as f64).collect()],
2307            sampling_rates: vec![256.0],
2308            duration: 1.0,
2309            annotations: Vec::new(),
2310            stim_channel_indices: Vec::new(),
2311            is_discontinuous: false,
2312            record_onsets: Vec::new(),
2313        };
2314        let slice = data.time_slice(0.0, 0.5);
2315        assert_eq!(slice.n_samples(0), 128);
2316        assert_eq!(slice.channel(0).unwrap()[0], 0.0);
2317        assert_eq!(slice.channel(0).unwrap()[127], 127.0);
2318    }
2319
2320    #[test]
2321    fn test_read_eeg_data_dispatch() {
2322        let dir = std::env::temp_dir().join("bids_eeg_data_test_dispatch");
2323        std::fs::create_dir_all(&dir).unwrap();
2324        let path = dir.join("test.edf");
2325        create_test_edf(&path, 2, 1, 128);
2326        let data = read_eeg_data(&path, &ReadOptions::default()).unwrap();
2327        assert_eq!(data.n_channels(), 2);
2328        let bad = dir.join("test.xyz");
2329        std::fs::write(&bad, b"").unwrap();
2330        assert!(read_eeg_data(&bad, &ReadOptions::default()).is_err());
2331        std::fs::remove_dir_all(&dir).unwrap();
2332    }
2333
2334    #[test]
2335    fn test_brainvision_header_parse() {
2336        let vhdr = r#"
2337Brain Vision Data Exchange Header File Version 1.0
2338; comment line
2339
2340[Common Infos]
2341DataFile=test.eeg
2342DataOrientation=MULTIPLEXED
2343SamplingInterval=3906.25
2344
2345[Binary Infos]
2346BinaryFormat=INT_16
2347
2348[Channel Infos]
2349Ch1=Fp1,,0.1
2350Ch2=Fp2,,0.1
2351Ch3=Cz,,0.1
2352"#;
2353        let hdr = parse_vhdr(vhdr).unwrap();
2354        assert_eq!(hdr.data_file, "test.eeg");
2355        assert_eq!(hdr.data_format, BvDataFormat::Int16);
2356        assert_eq!(hdr.data_orientation, BvOrientation::Multiplexed);
2357        assert_eq!(hdr.channels.len(), 3);
2358        assert_eq!(hdr.channels[0].name, "Fp1");
2359        assert!((hdr.channels[0].resolution - 0.1).abs() < 0.001);
2360        let sr = 1_000_000.0 / hdr.sampling_interval_us.unwrap();
2361        assert!((sr - 256.0).abs() < 0.01);
2362    }
2363
2364    #[test]
2365    fn test_read_brainvision() {
2366        let dir = std::env::temp_dir().join("bids_eeg_data_test_bv");
2367        std::fs::create_dir_all(&dir).unwrap();
2368
2369        let vhdr_path = dir.join("test.vhdr");
2370        std::fs::write(
2371            &vhdr_path,
2372            r#"Brain Vision Data Exchange Header File Version 1.0
2373
2374[Common Infos]
2375DataFile=test.eeg
2376DataOrientation=MULTIPLEXED
2377SamplingInterval=3906.25
2378
2379[Binary Infos]
2380BinaryFormat=INT_16
2381
2382[Channel Infos]
2383Ch1=Fp1,,0.1
2384Ch2=Fp2,,0.1
2385"#,
2386        )
2387        .unwrap();
2388
2389        let eeg_path = dir.join("test.eeg");
2390        let mut eeg_data = Vec::with_capacity(256 * 4);
2391        for s in 0..256 {
2392            let v1 = (1000.0 * (2.0 * std::f64::consts::PI * s as f64 / 256.0).sin()) as i16;
2393            let v2 = (500.0 * (2.0 * std::f64::consts::PI * 2.0 * s as f64 / 256.0).sin()) as i16;
2394            eeg_data.extend_from_slice(&v1.to_le_bytes());
2395            eeg_data.extend_from_slice(&v2.to_le_bytes());
2396        }
2397        std::fs::write(&eeg_path, &eeg_data).unwrap();
2398
2399        let data = read_brainvision(&vhdr_path, &ReadOptions::default()).unwrap();
2400        assert_eq!(data.n_channels(), 2);
2401        assert_eq!(data.n_samples(0), 256);
2402        assert_eq!(data.channel_labels, vec!["Fp1", "Fp2"]);
2403        assert!((data.sampling_rates[0] - 256.0).abs() < 0.01);
2404        assert!(data.data[0].iter().all(|v| v.abs() <= 3276.8));
2405        std::fs::remove_dir_all(&dir).unwrap();
2406    }
2407
2408    #[test]
2409    fn test_channel_by_name() {
2410        let data = EegData {
2411            channel_labels: vec!["Fp1".into(), "Fp2".into()],
2412            data: vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
2413            sampling_rates: vec![256.0; 2],
2414            duration: 0.01171875,
2415            annotations: Vec::new(),
2416            stim_channel_indices: Vec::new(),
2417            is_discontinuous: false,
2418            record_onsets: Vec::new(),
2419        };
2420        assert_eq!(data.channel_by_name("Fp1"), Some(&[1.0, 2.0, 3.0][..]));
2421        assert_eq!(data.channel_by_name("Fp2"), Some(&[4.0, 5.0, 6.0][..]));
2422        assert_eq!(data.channel_by_name("Cz"), None);
2423    }
2424
2425    /// Benchmark-style test: 64 channels, 600 seconds @ 2048 Hz (typical clinical EEG).
2426    /// ~150 MB of data. Ensures we can handle realistic sizes.
2427    #[test]
2428    fn test_read_edf_large() {
2429        let dir = std::env::temp_dir().join("bids_eeg_data_test_large");
2430        std::fs::create_dir_all(&dir).unwrap();
2431        let path = dir.join("large.edf");
2432
2433        let n_ch = 64;
2434        let n_rec = 60; // 60 seconds (keep test fast)
2435        let spr = 2048;
2436        create_test_edf(&path, n_ch, n_rec, spr);
2437
2438        let start = std::time::Instant::now();
2439        let data = read_edf(&path, &ReadOptions::default()).unwrap();
2440        let elapsed = start.elapsed();
2441
2442        assert_eq!(data.n_channels(), n_ch);
2443        assert_eq!(data.n_samples(0), n_rec * spr);
2444
2445        // Should be well under 1 second for 60s × 64ch × 2048Hz (~15MB)
2446        assert!(
2447            elapsed.as_millis() < 1000,
2448            "Reading took {}ms, expected < 1000ms",
2449            elapsed.as_millis()
2450        );
2451
2452        // Channel selection should be faster
2453        let start = std::time::Instant::now();
2454        let _data = read_edf(
2455            &path,
2456            &ReadOptions::new().with_channels(vec!["EEG1".into(), "EEG32".into()]),
2457        )
2458        .unwrap();
2459        let elapsed2 = start.elapsed();
2460        assert!(
2461            elapsed2 <= elapsed || elapsed2.as_millis() < 500,
2462            "Channel-select took {}ms",
2463            elapsed2.as_millis()
2464        );
2465
2466        std::fs::remove_dir_all(&dir).unwrap();
2467    }
2468
2469    #[test]
2470    fn test_times() {
2471        let data = EegData {
2472            channel_labels: vec!["Fp1".into()],
2473            data: vec![vec![0.0; 512]],
2474            sampling_rates: vec![256.0],
2475            duration: 2.0,
2476            annotations: Vec::new(),
2477            stim_channel_indices: Vec::new(),
2478            is_discontinuous: false,
2479            record_onsets: Vec::new(),
2480        };
2481        let times = data.times(0).unwrap();
2482        assert_eq!(times.len(), 512);
2483        assert!((times[0] - 0.0).abs() < 1e-10);
2484        assert!((times[1] - 1.0 / 256.0).abs() < 1e-10);
2485        assert!((times[511] - 511.0 / 256.0).abs() < 1e-10);
2486    }
2487
2488    #[test]
2489    fn test_exclude_channels() {
2490        let data = EegData {
2491            channel_labels: vec!["Fp1".into(), "Fp2".into(), "Cz".into()],
2492            data: vec![vec![1.0], vec![2.0], vec![3.0]],
2493            sampling_rates: vec![256.0; 3],
2494            duration: 0.00390625,
2495            annotations: Vec::new(),
2496            stim_channel_indices: Vec::new(),
2497            is_discontinuous: false,
2498            record_onsets: Vec::new(),
2499        };
2500        let excl = data.exclude_channels(&["Fp2"]);
2501        assert_eq!(excl.n_channels(), 2);
2502        assert_eq!(excl.channel_labels, vec!["Fp1", "Cz"]);
2503    }
2504
2505    #[test]
2506    fn test_convert_units() {
2507        let mut data = EegData {
2508            channel_labels: vec!["Fp1".into()],
2509            data: vec![vec![100.0, 200.0]],
2510            sampling_rates: vec![256.0],
2511            duration: 0.0078125,
2512            annotations: Vec::new(),
2513            stim_channel_indices: Vec::new(),
2514            is_discontinuous: false,
2515            record_onsets: Vec::new(),
2516        };
2517        let mut map = std::collections::HashMap::new();
2518        map.insert("Fp1".into(), 1e-6);
2519        data.convert_units(&map);
2520        assert!((data.data[0][0] - 100e-6).abs() < 1e-15);
2521        assert!((data.data[0][1] - 200e-6).abs() < 1e-15);
2522    }
2523
2524    #[test]
2525    fn test_reject_by_annotation() {
2526        let data = EegData {
2527            channel_labels: vec!["Fp1".into()],
2528            data: vec![(0..256).map(|i| i as f64).collect()],
2529            sampling_rates: vec![256.0],
2530            duration: 1.0,
2531            annotations: vec![Annotation {
2532                onset: 0.25,
2533                duration: 0.25,
2534                description: "BAD_segment".into(),
2535            }],
2536            stim_channel_indices: Vec::new(),
2537            is_discontinuous: false,
2538            record_onsets: Vec::new(),
2539        };
2540        let rejected = data.reject_by_annotation("BAD");
2541        // Samples from 0.25s to 0.5s (64..128) should be NAN
2542        assert!(!rejected.data[0][63].is_nan());
2543        assert!(rejected.data[0][64].is_nan());
2544        assert!(rejected.data[0][127].is_nan());
2545        assert!(!rejected.data[0][128].is_nan());
2546    }
2547
2548    #[test]
2549    fn test_edf_tal_parse() {
2550        let mut annotations = Vec::new();
2551        // TAL format: +onset\x14\x14\x00 (record onset) +onset\x14description\x14\x00
2552        let tal = b"+0.0\x14\x14\x00+0.5\x14stimulus\x14\x00+1.5\x150.5\x14response\x14\x00";
2553        let record_onset = parse_edf_tal(tal, &mut annotations);
2554        assert_eq!(record_onset, Some(0.0)); // first entry without description = record onset
2555        assert_eq!(annotations.len(), 2);
2556        assert!((annotations[0].onset - 0.5).abs() < 1e-10);
2557        assert_eq!(annotations[0].description, "stimulus");
2558        assert!((annotations[0].duration - 0.0).abs() < 1e-10);
2559        assert!((annotations[1].onset - 1.5).abs() < 1e-10);
2560        assert_eq!(annotations[1].description, "response");
2561        assert!((annotations[1].duration - 0.5).abs() < 1e-10);
2562    }
2563
2564    #[test]
2565    fn test_vmrk_parse() {
2566        let vmrk = r#"Brain Vision Data Exchange Marker File Version 1.0
2567
2568[Common Infos]
2569Codepage=UTF-8
2570DataFile=test.eeg
2571
2572[Marker Infos]
2573Mk1=Stimulus,S  1,512,1,0
2574Mk2=Stimulus,S  2,1024,1,0
2575Mk3=Response,R  1,2048,1,0
2576Mk4=Comment,hello world,3072,1,0
2577"#;
2578        let anns = parse_vmrk(vmrk, 256.0);
2579        assert_eq!(anns.len(), 4);
2580        // Mk1: position 512 → onset (512-1)/256 = 1.99609375
2581        assert!((anns[0].onset - 511.0 / 256.0).abs() < 1e-10);
2582        assert_eq!(anns[0].description, "S  1");
2583        assert!((anns[1].onset - 1023.0 / 256.0).abs() < 1e-10);
2584        assert_eq!(anns[1].description, "S  2");
2585        assert_eq!(anns[2].description, "R  1");
2586        assert_eq!(anns[3].description, "hello world");
2587    }
2588
2589    #[test]
2590    fn test_brainvision_with_markers() {
2591        let dir = std::env::temp_dir().join("bids_eeg_data_test_bv_vmrk");
2592        std::fs::create_dir_all(&dir).unwrap();
2593
2594        std::fs::write(
2595            dir.join("test.vhdr"),
2596            r#"Brain Vision Data Exchange Header File Version 1.0
2597
2598[Common Infos]
2599DataFile=test.eeg
2600MarkerFile=test.vmrk
2601DataOrientation=MULTIPLEXED
2602SamplingInterval=3906.25
2603
2604[Binary Infos]
2605BinaryFormat=INT_16
2606
2607[Channel Infos]
2608Ch1=Fp1,,0.1
2609Ch2=Fp2,,0.1
2610"#,
2611        )
2612        .unwrap();
2613
2614        std::fs::write(
2615            dir.join("test.vmrk"),
2616            r#"Brain Vision Data Exchange Marker File Version 1.0
2617
2618[Marker Infos]
2619Mk1=Stimulus,S1,50,1,0
2620Mk2=Stimulus,S2,150,1,0
2621"#,
2622        )
2623        .unwrap();
2624
2625        // Create binary data (2 ch × 256 samples × INT_16)
2626        let mut buf = Vec::with_capacity(256 * 2 * 2);
2627        for s in 0..256 {
2628            let v = (100.0 * (s as f64)).round() as i16;
2629            buf.extend_from_slice(&v.to_le_bytes());
2630            buf.extend_from_slice(&v.to_le_bytes());
2631        }
2632        std::fs::write(dir.join("test.eeg"), &buf).unwrap();
2633
2634        let data = read_brainvision(&dir.join("test.vhdr"), &ReadOptions::default()).unwrap();
2635        assert_eq!(data.annotations.len(), 2);
2636        assert_eq!(data.annotations[0].description, "S1");
2637        assert_eq!(data.annotations[1].description, "S2");
2638        // onset = (position - 1) / sampling_rate
2639        assert!((data.annotations[0].onset - 49.0 / 256.0).abs() < 1e-6);
2640
2641        std::fs::remove_dir_all(&dir).unwrap();
2642    }
2643
2644    #[test]
2645    fn test_read_edf_with_exclude() {
2646        let dir = std::env::temp_dir().join("bids_eeg_data_test_excl");
2647        std::fs::create_dir_all(&dir).unwrap();
2648        let path = dir.join("test.edf");
2649        create_test_edf(&path, 4, 1, 128);
2650
2651        let opts = ReadOptions::new().with_exclude(vec!["EEG2".into(), "EEG4".into()]);
2652        let data = read_edf(&path, &opts).unwrap();
2653        assert_eq!(data.n_channels(), 2);
2654        assert_eq!(data.channel_labels, vec!["EEG1", "EEG3"]);
2655
2656        std::fs::remove_dir_all(&dir).unwrap();
2657    }
2658
2659    #[test]
2660    fn test_stim_channel_detection() {
2661        let dir = std::env::temp_dir().join("bids_eeg_data_test_stim");
2662        std::fs::create_dir_all(&dir).unwrap();
2663        let path = dir.join("stim.edf");
2664
2665        // Create EDF with a "Status" channel
2666        let n_ch = 3;
2667        let spr = 128;
2668        let mut file = std::fs::File::create(&path).unwrap();
2669        let mut hdr = [b' '; 256];
2670        hdr[0..1].copy_from_slice(b"0");
2671        hdr[168..176].copy_from_slice(b"01.01.01");
2672        hdr[176..184].copy_from_slice(b"00.00.00");
2673        let hs = format!("{:<8}", 256 + n_ch * 256);
2674        hdr[184..192].copy_from_slice(hs.as_bytes());
2675        hdr[236..244].copy_from_slice(b"1       ");
2676        hdr[244..252].copy_from_slice(b"1       ");
2677        let nc = format!("{:<4}", n_ch);
2678        hdr[252..256].copy_from_slice(nc.as_bytes());
2679        file.write_all(&hdr).unwrap();
2680
2681        let mut ext = vec![b' '; n_ch * 256];
2682        let ch_labels = ["EEG1", "EEG2", "Status"];
2683        for i in 0..n_ch {
2684            let label = format!("{:<16}", ch_labels[i]);
2685            ext[i * 16..i * 16 + 16].copy_from_slice(label.as_bytes());
2686            ext[n_ch * 96 + i * 8..n_ch * 96 + i * 8 + 2].copy_from_slice(b"uV");
2687            let s = format!("{:<8}", "-3200");
2688            ext[n_ch * 104 + i * 8..n_ch * 104 + i * 8 + 8].copy_from_slice(s.as_bytes());
2689            let s = format!("{:<8}", "3200");
2690            ext[n_ch * 112 + i * 8..n_ch * 112 + i * 8 + 8].copy_from_slice(s.as_bytes());
2691            let s = format!("{:<8}", "-32768");
2692            ext[n_ch * 120 + i * 8..n_ch * 120 + i * 8 + 8].copy_from_slice(s.as_bytes());
2693            let s = format!("{:<8}", "32767");
2694            ext[n_ch * 128 + i * 8..n_ch * 128 + i * 8 + 8].copy_from_slice(s.as_bytes());
2695            let s = format!("{:<8}", spr);
2696            ext[n_ch * 216 + i * 8..n_ch * 216 + i * 8 + 8].copy_from_slice(s.as_bytes());
2697        }
2698        file.write_all(&ext).unwrap();
2699
2700        let rec_bytes = n_ch * spr * 2;
2701        let buf = vec![0u8; rec_bytes];
2702        file.write_all(&buf).unwrap();
2703        drop(file);
2704
2705        let data = read_edf(&path, &ReadOptions::default()).unwrap();
2706        assert_eq!(data.n_channels(), 3);
2707        // "Status" channel should be detected as stim
2708        assert_eq!(data.stim_channel_indices, vec![2]);
2709
2710        std::fs::remove_dir_all(&dir).unwrap();
2711    }
2712
2713    fn make_sine_data(freq: f64, sr: f64, duration: f64, n_ch: usize) -> EegData {
2714        let n = (duration * sr) as usize;
2715        let data: Vec<Vec<f64>> = (0..n_ch)
2716            .map(|_| {
2717                (0..n)
2718                    .map(|i| {
2719                        let t = i as f64 / sr;
2720                        (2.0 * std::f64::consts::PI * freq * t).sin()
2721                    })
2722                    .collect()
2723            })
2724            .collect();
2725        EegData {
2726            channel_labels: (0..n_ch).map(|i| format!("Ch{}", i + 1)).collect(),
2727            data,
2728            sampling_rates: vec![sr; n_ch],
2729            duration,
2730            annotations: Vec::new(),
2731            stim_channel_indices: Vec::new(),
2732            is_discontinuous: false,
2733            record_onsets: Vec::new(),
2734        }
2735    }
2736
2737    #[test]
2738    fn test_filter_lowpass() {
2739        // 10 Hz signal + 100 Hz noise at 500 Hz sampling
2740        let sr = 500.0;
2741        let n = 1000;
2742        let data = EegData {
2743            channel_labels: vec!["Ch1".into()],
2744            data: vec![
2745                (0..n)
2746                    .map(|i| {
2747                        let t = i as f64 / sr;
2748                        (2.0 * std::f64::consts::PI * 10.0 * t).sin()
2749                            + (2.0 * std::f64::consts::PI * 100.0 * t).sin()
2750                    })
2751                    .collect(),
2752            ],
2753            sampling_rates: vec![sr],
2754            duration: n as f64 / sr,
2755            annotations: Vec::new(),
2756            stim_channel_indices: Vec::new(),
2757            is_discontinuous: false,
2758            record_onsets: Vec::new(),
2759        };
2760
2761        let filtered = data.filter(None, Some(30.0), 5);
2762        assert_eq!(filtered.data[0].len(), n);
2763        // High-freq noise should be greatly reduced
2764        let orig_energy: f64 = data.data[0][n / 2..].iter().map(|v| v * v).sum::<f64>();
2765        let filt_energy: f64 = filtered.data[0][n / 2..].iter().map(|v| v * v).sum::<f64>();
2766        assert!(filt_energy < orig_energy * 0.7);
2767    }
2768
2769    #[test]
2770    fn test_notch_filter() {
2771        let sr = 500.0;
2772        let n = 2000;
2773        let data = EegData {
2774            channel_labels: vec!["Ch1".into()],
2775            data: vec![
2776                (0..n)
2777                    .map(|i| {
2778                        let t = i as f64 / sr;
2779                        (2.0 * std::f64::consts::PI * 10.0 * t).sin()
2780                            + 0.5 * (2.0 * std::f64::consts::PI * 50.0 * t).sin()
2781                    })
2782                    .collect(),
2783            ],
2784            sampling_rates: vec![sr],
2785            duration: n as f64 / sr,
2786            annotations: Vec::new(),
2787            stim_channel_indices: Vec::new(),
2788            is_discontinuous: false,
2789            record_onsets: Vec::new(),
2790        };
2791
2792        let filtered = data.notch_filter(50.0, 30.0);
2793        assert_eq!(filtered.data[0].len(), n);
2794    }
2795
2796    #[test]
2797    fn test_resample() {
2798        let data = make_sine_data(5.0, 1000.0, 1.0, 2);
2799        assert_eq!(data.data[0].len(), 1000);
2800
2801        let resampled = data.resample(250.0);
2802        assert_eq!(resampled.data[0].len(), 250);
2803        assert_eq!(resampled.data.len(), 2);
2804        assert!((resampled.sampling_rates[0] - 250.0).abs() < 1e-10);
2805    }
2806
2807    #[test]
2808    fn test_set_average_reference() {
2809        let data = EegData {
2810            channel_labels: vec!["Ch1".into(), "Ch2".into(), "Ch3".into()],
2811            data: vec![vec![3.0, 6.0], vec![1.0, 2.0], vec![2.0, 4.0]],
2812            sampling_rates: vec![256.0; 3],
2813            duration: 2.0 / 256.0,
2814            annotations: Vec::new(),
2815            stim_channel_indices: Vec::new(),
2816            is_discontinuous: false,
2817            record_onsets: Vec::new(),
2818        };
2819        let reref = data.set_average_reference();
2820        // Mean at t=0: (3+1+2)/3 = 2.0
2821        assert!((reref.data[0][0] - 1.0).abs() < 1e-10); // 3 - 2
2822        assert!((reref.data[1][0] - (-1.0)).abs() < 1e-10); // 1 - 2
2823        assert!((reref.data[2][0] - 0.0).abs() < 1e-10); // 2 - 2
2824    }
2825
2826    #[test]
2827    fn test_set_reference() {
2828        let data = EegData {
2829            channel_labels: vec!["Fp1".into(), "Cz".into(), "Pz".into()],
2830            data: vec![vec![10.0, 20.0], vec![5.0, 10.0], vec![8.0, 16.0]],
2831            sampling_rates: vec![256.0; 3],
2832            duration: 2.0 / 256.0,
2833            annotations: Vec::new(),
2834            stim_channel_indices: Vec::new(),
2835            is_discontinuous: false,
2836            record_onsets: Vec::new(),
2837        };
2838        let reref = data.set_reference("Cz");
2839        assert!((reref.data[0][0] - 5.0).abs() < 1e-10); // 10 - 5
2840        assert!((reref.data[1][0] - 5.0).abs() < 1e-10); // Cz unchanged
2841        assert!((reref.data[2][0] - 3.0).abs() < 1e-10); // 8 - 5
2842    }
2843
2844    #[test]
2845    fn test_epoch_and_average() {
2846        let sr = 100.0;
2847        let n = 500;
2848        let data = EegData {
2849            channel_labels: vec!["Ch1".into()],
2850            data: vec![
2851                (0..n)
2852                    .map(|i| (i as f64 / sr * 2.0 * std::f64::consts::PI).sin())
2853                    .collect(),
2854            ],
2855            sampling_rates: vec![sr],
2856            duration: n as f64 / sr,
2857            annotations: vec![
2858                Annotation {
2859                    onset: 1.0,
2860                    duration: 0.0,
2861                    description: "stim".into(),
2862                },
2863                Annotation {
2864                    onset: 2.0,
2865                    duration: 0.0,
2866                    description: "stim".into(),
2867                },
2868                Annotation {
2869                    onset: 3.0,
2870                    duration: 0.0,
2871                    description: "stim".into(),
2872                },
2873            ],
2874            stim_channel_indices: Vec::new(),
2875            is_discontinuous: false,
2876            record_onsets: Vec::new(),
2877        };
2878
2879        let epochs = data.epoch(-0.2, 0.5, Some("stim"));
2880        assert_eq!(epochs.len(), 3);
2881        assert_eq!(epochs[0].data[0].len(), 70); // 0.2 + 0.5 = 0.7s * 100 Hz = 70 samples
2882
2883        // Average
2884        let avg = EegData::average_epochs(&epochs).unwrap();
2885        assert_eq!(avg.data[0].len(), 70);
2886        assert_eq!(avg.data.len(), 1);
2887    }
2888
2889    #[test]
2890    fn test_compute_psd() {
2891        let data = make_sine_data(10.0, 256.0, 2.0, 1);
2892        let (freqs, psd) = data.compute_psd(Some(256));
2893        assert_eq!(freqs.len(), 129); // 256/2 + 1
2894        assert_eq!(psd.len(), 1);
2895        assert_eq!(psd[0].len(), 129);
2896        // Peak should be at ~10 Hz
2897        let peak_idx = psd[0]
2898            .iter()
2899            .enumerate()
2900            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
2901            .unwrap()
2902            .0;
2903        let peak_freq = freqs[peak_idx];
2904        assert!(
2905            (peak_freq - 10.0).abs() < 2.0,
2906            "PSD peak at {:.1} Hz, expected ~10 Hz",
2907            peak_freq
2908        );
2909    }
2910
2911    #[test]
2912    fn test_display_debug() {
2913        let data = make_sine_data(10.0, 256.0, 1.0, 4);
2914        let display = format!("{}", data);
2915        assert!(display.contains("4 ch"), "Display: {}", display);
2916        assert!(display.contains("256"), "Display: {}", display);
2917        let debug = format!("{:?}", data);
2918        assert!(debug.contains("n_channels: 4"), "Debug: {}", debug);
2919        // Should NOT contain raw sample data
2920        assert!(
2921            !debug.contains("0."),
2922            "Debug should not dump samples: {}",
2923            &debug[..100.min(debug.len())]
2924        );
2925    }
2926
2927    #[test]
2928    fn test_get_data_with_times() {
2929        let data = EegData {
2930            channel_labels: vec!["Fp1".into()],
2931            data: vec![vec![1.0, 2.0, 3.0, 4.0]],
2932            sampling_rates: vec![256.0],
2933            duration: 4.0 / 256.0,
2934            annotations: Vec::new(),
2935            stim_channel_indices: Vec::new(),
2936            is_discontinuous: false,
2937            record_onsets: Vec::new(),
2938        };
2939        let (d, t) = data.get_data_with_times();
2940        assert_eq!(d.len(), 1);
2941        assert_eq!(t.len(), 4);
2942        assert!((t[0] - 0.0).abs() < 1e-10);
2943        assert!((t[3] - 3.0 / 256.0).abs() < 1e-10);
2944    }
2945
2946    #[test]
2947    fn test_pick_types() {
2948        use crate::ChannelType;
2949        let data = EegData {
2950            channel_labels: vec!["Fp1".into(), "EOG1".into(), "Cz".into(), "ECG1".into()],
2951            data: vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
2952            sampling_rates: vec![256.0; 4],
2953            duration: 1.0 / 256.0,
2954            annotations: Vec::new(),
2955            stim_channel_indices: Vec::new(),
2956            is_discontinuous: false,
2957            record_onsets: Vec::new(),
2958        };
2959        let types = vec![
2960            ChannelType::EEG,
2961            ChannelType::EEG,
2962            ChannelType::EEG,
2963            ChannelType::ECG,
2964        ];
2965        let picked = data.pick_types(&[ChannelType::EEG], &types);
2966        assert_eq!(picked.n_channels(), 3);
2967        assert_eq!(picked.channel_labels, vec!["Fp1", "EOG1", "Cz"]);
2968    }
2969
2970    #[test]
2971    fn test_concatenate() {
2972        let mut data1 = EegData {
2973            channel_labels: vec!["Fp1".into(), "Fp2".into()],
2974            data: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
2975            sampling_rates: vec![256.0; 2],
2976            duration: 2.0 / 256.0,
2977            annotations: vec![Annotation {
2978                onset: 0.0,
2979                duration: 0.0,
2980                description: "A".into(),
2981            }],
2982            stim_channel_indices: Vec::new(),
2983            is_discontinuous: false,
2984            record_onsets: Vec::new(),
2985        };
2986        let data2 = EegData {
2987            channel_labels: vec!["Fp1".into(), "Fp2".into()],
2988            data: vec![vec![5.0, 6.0], vec![7.0, 8.0]],
2989            sampling_rates: vec![256.0; 2],
2990            duration: 2.0 / 256.0,
2991            annotations: vec![Annotation {
2992                onset: 0.0,
2993                duration: 0.0,
2994                description: "B".into(),
2995            }],
2996            stim_channel_indices: Vec::new(),
2997            is_discontinuous: false,
2998            record_onsets: Vec::new(),
2999        };
3000        data1.concatenate(&data2).unwrap();
3001        assert_eq!(data1.n_samples(0), 4);
3002        assert_eq!(data1.data[0], vec![1.0, 2.0, 5.0, 6.0]);
3003        assert_eq!(data1.data[1], vec![3.0, 4.0, 7.0, 8.0]);
3004        assert_eq!(data1.annotations.len(), 2);
3005        assert_eq!(data1.annotations[1].description, "B");
3006        assert!((data1.annotations[1].onset - 2.0 / 256.0).abs() < 1e-10);
3007
3008        // Mismatched channels should fail
3009        let data3 = EegData {
3010            channel_labels: vec!["Cz".into()],
3011            data: vec![vec![9.0]],
3012            sampling_rates: vec![256.0],
3013            duration: 1.0 / 256.0,
3014            annotations: Vec::new(),
3015            stim_channel_indices: Vec::new(),
3016            is_discontinuous: false,
3017            record_onsets: Vec::new(),
3018        };
3019        assert!(data1.concatenate(&data3).is_err());
3020    }
3021}