Skip to main content

bids_eeg/
harmonize.rs

1//! Channel harmonization across datasets.
2//!
3//! When combining data from multiple datasets for cross-dataset ML, channels
4//! and sampling rates must be aligned. This module provides functions to
5//! find common channels, harmonize sampling rates, and standardize channel
6//! ordering.
7//!
8//! Inspired by MOABB's `BaseProcessing.match_all()`.
9
10use crate::data::EegData;
11use std::collections::{BTreeSet, HashSet};
12
13/// Strategy for combining channel sets across datasets.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ChannelStrategy {
16    /// Keep only channels present in ALL datasets (safe, no missing data).
17    Intersect,
18    /// Keep all channels from any dataset (requires interpolation for missing).
19    Union,
20}
21
22/// Result of channel harmonization analysis.
23#[derive(Debug, Clone)]
24pub struct HarmonizationPlan {
25    /// Final channel list to use, in sorted order.
26    pub channels: Vec<String>,
27    /// Minimum sampling rate across all inputs.
28    pub target_sr: f64,
29    /// Per-input: which channels are present (true) or missing (false).
30    pub channel_availability: Vec<Vec<bool>>,
31}
32
33impl std::fmt::Display for HarmonizationPlan {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(
36            f,
37            "HarmonizationPlan({} channels @ {:.1} Hz",
38            self.channels.len(),
39            self.target_sr
40        )?;
41        let n_inputs = self.channel_availability.len();
42        if n_inputs > 0 {
43            let all_present = self
44                .channel_availability
45                .iter()
46                .all(|avail| avail.iter().all(|&a| a));
47            if all_present {
48                write!(f, ", all channels present in all inputs")?;
49            } else {
50                let missing: usize = self
51                    .channel_availability
52                    .iter()
53                    .map(|avail| avail.iter().filter(|&&a| !a).count())
54                    .sum();
55                write!(f, ", {missing} missing channel-input pairs")?;
56            }
57        }
58        write!(f, ")")
59    }
60}
61
62/// Analyze multiple EegData recordings and produce a harmonization plan.
63///
64/// Examines all inputs to find the common channel set (or union) and
65/// the minimum sampling rate.
66///
67/// # Arguments
68/// - `inputs`: slice of `EegData` references to harmonize
69/// - `strategy`: how to combine channel sets
70/// - `ignore_channels`: channel names to exclude (e.g., `["STIM", "STATUS"]`)
71/// - `sr_shift`: small negative shift to avoid off-by-one sample counts
72///   (MOABB uses `-0.5` for this; pass `0.0` if not needed)
73#[must_use]
74pub fn plan_harmonization(
75    inputs: &[&EegData],
76    strategy: ChannelStrategy,
77    ignore_channels: &[&str],
78    sr_shift: f64,
79) -> HarmonizationPlan {
80    if inputs.is_empty() {
81        return HarmonizationPlan {
82            channels: Vec::new(),
83            target_sr: 0.0,
84            channel_availability: Vec::new(),
85        };
86    }
87
88    let ignore: HashSet<&str> = ignore_channels.iter().copied().collect();
89
90    // Collect channel sets
91    let channel_sets: Vec<BTreeSet<String>> = inputs
92        .iter()
93        .map(|d| {
94            d.channel_labels
95                .iter()
96                .filter(|ch| !ignore.contains(ch.as_str()))
97                .cloned()
98                .collect()
99        })
100        .collect();
101
102    let channels: BTreeSet<String> = match strategy {
103        ChannelStrategy::Intersect => {
104            let mut result = channel_sets[0].clone();
105            for set in &channel_sets[1..] {
106                result = result.intersection(set).cloned().collect();
107            }
108            result
109        }
110        ChannelStrategy::Union => {
111            let mut result = BTreeSet::new();
112            for set in &channel_sets {
113                result = result.union(set).cloned().collect();
114            }
115            result
116        }
117    };
118
119    let channels: Vec<String> = channels.into_iter().collect();
120
121    // Find minimum sampling rate
122    let target_sr = inputs
123        .iter()
124        .filter_map(|d| d.sampling_rates.first().copied())
125        .fold(f64::INFINITY, f64::min)
126        + sr_shift;
127
128    // Build availability map
129    let channel_availability = inputs
130        .iter()
131        .map(|d| {
132            let input_channels: HashSet<&str> =
133                d.channel_labels.iter().map(|s| s.as_str()).collect();
134            channels
135                .iter()
136                .map(|ch| input_channels.contains(ch.as_str()))
137                .collect()
138        })
139        .collect();
140
141    HarmonizationPlan {
142        channels,
143        target_sr,
144        channel_availability,
145    }
146}
147
148/// Apply a harmonization plan to a single `EegData`, producing a new one
149/// with the specified channels (in order) and sampling rate.
150///
151/// Missing channels are filled with zeros.
152#[must_use]
153pub fn apply_harmonization(data: &EegData, plan: &HarmonizationPlan) -> EegData {
154    let n_samples = data.data.first().map_or(0, |ch| ch.len());
155    let channel_map: std::collections::HashMap<&str, usize> = data
156        .channel_labels
157        .iter()
158        .enumerate()
159        .map(|(i, name)| (name.as_str(), i))
160        .collect();
161
162    let mut new_data = Vec::with_capacity(plan.channels.len());
163    let mut new_rates = Vec::with_capacity(plan.channels.len());
164    let sr = data
165        .sampling_rates
166        .first()
167        .copied()
168        .unwrap_or(plan.target_sr);
169
170    for ch_name in &plan.channels {
171        if let Some(&idx) = channel_map.get(ch_name.as_str()) {
172            new_data.push(data.data[idx].clone());
173            new_rates.push(data.sampling_rates.get(idx).copied().unwrap_or(sr));
174        } else {
175            // Missing channel — fill with zeros
176            new_data.push(vec![0.0; n_samples]);
177            new_rates.push(sr);
178        }
179    }
180
181    let mut result = EegData {
182        channel_labels: plan.channels.clone(),
183        data: new_data,
184        sampling_rates: new_rates,
185        duration: data.duration,
186        annotations: data.annotations.clone(),
187        stim_channel_indices: Vec::new(), // stim channels may not be in the plan
188        is_discontinuous: data.is_discontinuous,
189        record_onsets: data.record_onsets.clone(),
190    };
191
192    // Resample if needed
193    if (sr - plan.target_sr).abs() > 0.01 {
194        result = result.resample(plan.target_sr);
195    }
196
197    result
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::data::EegData;
204
205    fn make_data(channels: &[&str], sr: f64, n_samples: usize) -> EegData {
206        EegData {
207            channel_labels: channels.iter().map(|s| (*s).to_string()).collect(),
208            data: channels
209                .iter()
210                .enumerate()
211                .map(|(i, _)| (0..n_samples).map(|s| (s + i) as f64).collect())
212                .collect(),
213            sampling_rates: vec![sr; channels.len()],
214            duration: n_samples as f64 / sr,
215            annotations: Vec::new(),
216            stim_channel_indices: Vec::new(),
217            is_discontinuous: false,
218            record_onsets: Vec::new(),
219        }
220    }
221
222    #[test]
223    fn test_intersect_strategy() {
224        let d1 = make_data(&["Fz", "Cz", "Pz", "STIM"], 256.0, 256);
225        let d2 = make_data(&["Fz", "Cz", "Oz", "STIM"], 512.0, 512);
226
227        let plan = plan_harmonization(&[&d1, &d2], ChannelStrategy::Intersect, &["STIM"], 0.0);
228        assert_eq!(plan.channels, vec!["Cz", "Fz"]); // sorted, intersected, STIM excluded
229        assert!((plan.target_sr - 256.0).abs() < 1e-10); // min SR
230    }
231
232    #[test]
233    fn test_union_strategy() {
234        let d1 = make_data(&["Fz", "Cz"], 256.0, 256);
235        let d2 = make_data(&["Cz", "Pz"], 256.0, 256);
236
237        let plan = plan_harmonization(&[&d1, &d2], ChannelStrategy::Union, &[], 0.0);
238        assert_eq!(plan.channels, vec!["Cz", "Fz", "Pz"]);
239        // d1 has Fz, Cz but not Pz
240        assert_eq!(plan.channel_availability[0], vec![true, true, false]);
241        // d2 has Cz, Pz but not Fz
242        assert_eq!(plan.channel_availability[1], vec![true, false, true]);
243    }
244
245    #[test]
246    fn test_apply_harmonization() {
247        let d1 = make_data(&["Fz", "Cz", "Pz"], 256.0, 256);
248        let plan = HarmonizationPlan {
249            channels: vec!["Cz".into(), "Fz".into(), "Oz".into()],
250            target_sr: 256.0,
251            channel_availability: vec![vec![true, true, false]],
252        };
253
254        let result = apply_harmonization(&d1, &plan);
255        assert_eq!(result.channel_labels, vec!["Cz", "Fz", "Oz"]);
256        assert_eq!(result.data.len(), 3);
257        // Oz should be zeros (missing)
258        assert!(result.data[2].iter().all(|&v| v == 0.0));
259        // Cz and Fz should have data
260        assert!(result.data[0].iter().any(|&v| v != 0.0));
261    }
262
263    #[test]
264    fn test_sr_shift() {
265        let d1 = make_data(&["Fz"], 128.0, 128);
266        let plan = plan_harmonization(&[&d1], ChannelStrategy::Intersect, &[], -0.5);
267        assert!((plan.target_sr - 127.5).abs() < 1e-10);
268    }
269}