1use crate::data::EegData;
11use std::collections::{BTreeSet, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ChannelStrategy {
16 Intersect,
18 Union,
20}
21
22#[derive(Debug, Clone)]
24pub struct HarmonizationPlan {
25 pub channels: Vec<String>,
27 pub target_sr: f64,
29 pub channel_availability: Vec<Vec<bool>>,
31}
32
33impl std::fmt::Display for HarmonizationPlan {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(
36 f,
37 "HarmonizationPlan({} channels @ {:.1} Hz",
38 self.channels.len(),
39 self.target_sr
40 )?;
41 let n_inputs = self.channel_availability.len();
42 if n_inputs > 0 {
43 let all_present = self
44 .channel_availability
45 .iter()
46 .all(|avail| avail.iter().all(|&a| a));
47 if all_present {
48 write!(f, ", all channels present in all inputs")?;
49 } else {
50 let missing: usize = self
51 .channel_availability
52 .iter()
53 .map(|avail| avail.iter().filter(|&&a| !a).count())
54 .sum();
55 write!(f, ", {missing} missing channel-input pairs")?;
56 }
57 }
58 write!(f, ")")
59 }
60}
61
62#[must_use]
74pub fn plan_harmonization(
75 inputs: &[&EegData],
76 strategy: ChannelStrategy,
77 ignore_channels: &[&str],
78 sr_shift: f64,
79) -> HarmonizationPlan {
80 if inputs.is_empty() {
81 return HarmonizationPlan {
82 channels: Vec::new(),
83 target_sr: 0.0,
84 channel_availability: Vec::new(),
85 };
86 }
87
88 let ignore: HashSet<&str> = ignore_channels.iter().copied().collect();
89
90 let channel_sets: Vec<BTreeSet<String>> = inputs
92 .iter()
93 .map(|d| {
94 d.channel_labels
95 .iter()
96 .filter(|ch| !ignore.contains(ch.as_str()))
97 .cloned()
98 .collect()
99 })
100 .collect();
101
102 let channels: BTreeSet<String> = match strategy {
103 ChannelStrategy::Intersect => {
104 let mut result = channel_sets[0].clone();
105 for set in &channel_sets[1..] {
106 result = result.intersection(set).cloned().collect();
107 }
108 result
109 }
110 ChannelStrategy::Union => {
111 let mut result = BTreeSet::new();
112 for set in &channel_sets {
113 result = result.union(set).cloned().collect();
114 }
115 result
116 }
117 };
118
119 let channels: Vec<String> = channels.into_iter().collect();
120
121 let target_sr = inputs
123 .iter()
124 .filter_map(|d| d.sampling_rates.first().copied())
125 .fold(f64::INFINITY, f64::min)
126 + sr_shift;
127
128 let channel_availability = inputs
130 .iter()
131 .map(|d| {
132 let input_channels: HashSet<&str> =
133 d.channel_labels.iter().map(|s| s.as_str()).collect();
134 channels
135 .iter()
136 .map(|ch| input_channels.contains(ch.as_str()))
137 .collect()
138 })
139 .collect();
140
141 HarmonizationPlan {
142 channels,
143 target_sr,
144 channel_availability,
145 }
146}
147
148#[must_use]
153pub fn apply_harmonization(data: &EegData, plan: &HarmonizationPlan) -> EegData {
154 let n_samples = data.data.first().map_or(0, |ch| ch.len());
155 let channel_map: std::collections::HashMap<&str, usize> = data
156 .channel_labels
157 .iter()
158 .enumerate()
159 .map(|(i, name)| (name.as_str(), i))
160 .collect();
161
162 let mut new_data = Vec::with_capacity(plan.channels.len());
163 let mut new_rates = Vec::with_capacity(plan.channels.len());
164 let sr = data
165 .sampling_rates
166 .first()
167 .copied()
168 .unwrap_or(plan.target_sr);
169
170 for ch_name in &plan.channels {
171 if let Some(&idx) = channel_map.get(ch_name.as_str()) {
172 new_data.push(data.data[idx].clone());
173 new_rates.push(data.sampling_rates.get(idx).copied().unwrap_or(sr));
174 } else {
175 new_data.push(vec![0.0; n_samples]);
177 new_rates.push(sr);
178 }
179 }
180
181 let mut result = EegData {
182 channel_labels: plan.channels.clone(),
183 data: new_data,
184 sampling_rates: new_rates,
185 duration: data.duration,
186 annotations: data.annotations.clone(),
187 stim_channel_indices: Vec::new(), is_discontinuous: data.is_discontinuous,
189 record_onsets: data.record_onsets.clone(),
190 };
191
192 if (sr - plan.target_sr).abs() > 0.01 {
194 result = result.resample(plan.target_sr);
195 }
196
197 result
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::data::EegData;
204
205 fn make_data(channels: &[&str], sr: f64, n_samples: usize) -> EegData {
206 EegData {
207 channel_labels: channels.iter().map(|s| (*s).to_string()).collect(),
208 data: channels
209 .iter()
210 .enumerate()
211 .map(|(i, _)| (0..n_samples).map(|s| (s + i) as f64).collect())
212 .collect(),
213 sampling_rates: vec![sr; channels.len()],
214 duration: n_samples as f64 / sr,
215 annotations: Vec::new(),
216 stim_channel_indices: Vec::new(),
217 is_discontinuous: false,
218 record_onsets: Vec::new(),
219 }
220 }
221
222 #[test]
223 fn test_intersect_strategy() {
224 let d1 = make_data(&["Fz", "Cz", "Pz", "STIM"], 256.0, 256);
225 let d2 = make_data(&["Fz", "Cz", "Oz", "STIM"], 512.0, 512);
226
227 let plan = plan_harmonization(&[&d1, &d2], ChannelStrategy::Intersect, &["STIM"], 0.0);
228 assert_eq!(plan.channels, vec!["Cz", "Fz"]); assert!((plan.target_sr - 256.0).abs() < 1e-10); }
231
232 #[test]
233 fn test_union_strategy() {
234 let d1 = make_data(&["Fz", "Cz"], 256.0, 256);
235 let d2 = make_data(&["Cz", "Pz"], 256.0, 256);
236
237 let plan = plan_harmonization(&[&d1, &d2], ChannelStrategy::Union, &[], 0.0);
238 assert_eq!(plan.channels, vec!["Cz", "Fz", "Pz"]);
239 assert_eq!(plan.channel_availability[0], vec![true, true, false]);
241 assert_eq!(plan.channel_availability[1], vec![true, false, true]);
243 }
244
245 #[test]
246 fn test_apply_harmonization() {
247 let d1 = make_data(&["Fz", "Cz", "Pz"], 256.0, 256);
248 let plan = HarmonizationPlan {
249 channels: vec!["Cz".into(), "Fz".into(), "Oz".into()],
250 target_sr: 256.0,
251 channel_availability: vec![vec![true, true, false]],
252 };
253
254 let result = apply_harmonization(&d1, &plan);
255 assert_eq!(result.channel_labels, vec!["Cz", "Fz", "Oz"]);
256 assert_eq!(result.data.len(), 3);
257 assert!(result.data[2].iter().all(|&v| v == 0.0));
259 assert!(result.data[0].iter().any(|&v| v != 0.0));
261 }
262
263 #[test]
264 fn test_sr_shift() {
265 let d1 = make_data(&["Fz"], 128.0, 128);
266 let plan = plan_harmonization(&[&d1], ChannelStrategy::Intersect, &[], -0.5);
267 assert!((plan.target_sr - 127.5).abs() < 1e-10);
268 }
269}