1use crate::data::EegData;
28use crate::events::EegEvent;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub enum Step {
35 SelectChannels(Vec<String>),
36 ExcludeChannels(Vec<String>),
37 AverageReference,
38 ChannelReference(String),
39 Bandpass(f64, f64, usize),
40 Highpass(f64, usize),
41 Lowpass(f64, usize),
42 Notch(f64),
43 Resample(f64),
44 Epoch(String, f64, f64),
46 Baseline(f64, f64),
48 ZScore,
49 MinMaxNormalize,
50}
51
52#[derive(Debug, Clone)]
54pub struct PipelineResult {
55 pub x: Vec<Vec<Vec<f64>>>,
57 pub y: Vec<String>,
59 pub metadata: Vec<HashMap<String, String>>,
61 pub channel_names: Vec<String>,
63 pub sampling_rate: f64,
65}
66
67impl PipelineResult {
68 #[must_use]
70 pub fn n_epochs(&self) -> usize {
71 self.x.len()
72 }
73
74 #[must_use]
76 pub fn n_channels(&self) -> usize {
77 self.x.first().map_or(0, |e| e.len())
78 }
79
80 #[must_use]
82 pub fn n_samples(&self) -> usize {
83 self.x
84 .first()
85 .and_then(|e| e.first())
86 .map_or(0, |c| c.len())
87 }
88
89 #[must_use]
91 pub fn shape(&self) -> (usize, usize, usize) {
92 (self.n_epochs(), self.n_channels(), self.n_samples())
93 }
94
95 #[must_use]
100 pub fn to_flat_features(&self) -> Vec<Vec<f64>> {
101 self.x
102 .iter()
103 .map(|epoch| epoch.iter().flat_map(|ch| ch.iter().copied()).collect())
104 .collect()
105 }
106
107 #[must_use]
109 pub fn to_contiguous(&self) -> Vec<f64> {
110 let mut out = Vec::with_capacity(self.n_epochs() * self.n_channels() * self.n_samples());
111 for epoch in &self.x {
112 for ch in epoch {
113 out.extend_from_slice(ch);
114 }
115 }
116 out
117 }
118
119 #[must_use]
121 pub fn classes(&self) -> Vec<String> {
122 let mut c: Vec<String> = self
123 .y
124 .iter()
125 .collect::<std::collections::HashSet<_>>()
126 .into_iter()
127 .cloned()
128 .collect();
129 c.sort();
130 c
131 }
132
133 #[must_use]
135 pub fn y_encoded(&self) -> Vec<usize> {
136 let classes = self.classes();
137 self.y
138 .iter()
139 .map(|label| classes.iter().position(|c| c == label).unwrap_or(0))
140 .collect()
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct Pipeline {
167 #[serde(default)]
169 pub name: String,
170 #[serde(default)]
172 pub description: String,
173 steps: Vec<Step>,
174}
175
176impl Default for Pipeline {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl Pipeline {
183 #[must_use]
184 pub fn new() -> Self {
185 Self {
186 name: String::new(),
187 description: String::new(),
188 steps: Vec::new(),
189 }
190 }
191
192 #[must_use]
194 pub fn with_name(mut self, name: &str) -> Self {
195 self.name = name.into();
196 self
197 }
198
199 #[must_use]
201 pub fn with_description(mut self, desc: &str) -> Self {
202 self.description = desc.into();
203 self
204 }
205
206 pub fn save_json(&self, path: &str) -> std::io::Result<()> {
210 let json = serde_json::to_string_pretty(self)
211 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
212 std::fs::write(path, json)
213 }
214
215 pub fn load_json(path: &str) -> std::io::Result<Self> {
217 let json = std::fs::read_to_string(path)?;
218 serde_json::from_str(&json)
219 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
220 }
221
222 #[must_use]
224 pub fn to_json(&self) -> String {
225 serde_json::to_string_pretty(self).unwrap_or_default()
226 }
227
228 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
230 serde_json::from_str(json)
231 }
232
233 #[must_use]
235 pub fn steps(&self) -> &[Step] {
236 &self.steps
237 }
238
239 #[must_use]
242 pub fn select_channels(mut self, names: &[&str]) -> Self {
243 self.steps.push(Step::SelectChannels(
244 names.iter().map(|s| (*s).to_string()).collect(),
245 ));
246 self
247 }
248
249 #[must_use]
250 pub fn exclude_channels(mut self, names: &[&str]) -> Self {
251 self.steps.push(Step::ExcludeChannels(
252 names.iter().map(|s| (*s).to_string()).collect(),
253 ));
254 self
255 }
256
257 #[must_use]
260 pub fn average_reference(mut self) -> Self {
261 self.steps.push(Step::AverageReference);
262 self
263 }
264
265 #[must_use]
266 pub fn reference_channel(mut self, name: &str) -> Self {
267 self.steps.push(Step::ChannelReference(name.into()));
268 self
269 }
270
271 #[must_use]
274 pub fn bandpass(mut self, l_freq: f64, h_freq: f64, order: usize) -> Self {
275 self.steps.push(Step::Bandpass(l_freq, h_freq, order));
276 self
277 }
278
279 #[must_use]
280 pub fn highpass(mut self, freq: f64, order: usize) -> Self {
281 self.steps.push(Step::Highpass(freq, order));
282 self
283 }
284
285 #[must_use]
286 pub fn lowpass(mut self, freq: f64, order: usize) -> Self {
287 self.steps.push(Step::Lowpass(freq, order));
288 self
289 }
290
291 #[must_use]
292 pub fn notch(mut self, freq: f64) -> Self {
293 self.steps.push(Step::Notch(freq));
294 self
295 }
296
297 #[must_use]
300 pub fn resample(mut self, target_hz: f64) -> Self {
301 self.steps.push(Step::Resample(target_hz));
302 self
303 }
304
305 #[must_use]
310 pub fn epoch(mut self, trial_type: &str, tmin: f64, tmax: f64) -> Self {
311 self.steps.push(Step::Epoch(trial_type.into(), tmin, tmax));
312 self
313 }
314
315 #[must_use]
319 pub fn baseline(mut self, bmin: f64, bmax: f64) -> Self {
320 self.steps.push(Step::Baseline(bmin, bmax));
321 self
322 }
323
324 #[must_use]
325 pub fn z_score(mut self) -> Self {
326 self.steps.push(Step::ZScore);
327 self
328 }
329
330 #[must_use]
331 pub fn min_max_normalize(mut self) -> Self {
332 self.steps.push(Step::MinMaxNormalize);
333 self
334 }
335
336 pub fn transform(&self, data: &EegData, events: &[EegEvent]) -> PipelineResult {
345 let mut pre_steps = Vec::new();
347 let mut epoch_specs: Vec<(String, f64, f64)> = Vec::new();
348 let mut post_steps = Vec::new();
349 let mut past_epoch = false;
350
351 for step in &self.steps {
352 match step {
353 Step::Epoch(tt, tmin, tmax) => {
354 epoch_specs.push((tt.clone(), *tmin, *tmax));
355 past_epoch = true;
356 }
357 other => {
358 if past_epoch {
359 post_steps.push(other.clone());
360 } else {
361 pre_steps.push(other.clone());
362 }
363 }
364 }
365 }
366
367 let mut processed = data.clone();
369 for step in &pre_steps {
370 processed = apply_step_to_data(processed, step);
371 }
372
373 let sr = processed.sampling_rates.first().copied().unwrap_or(1.0);
374 let channel_names = processed.channel_labels.clone();
375
376 let (mut epochs, labels, metas) = if epoch_specs.is_empty() {
378 (
380 vec![processed.data.clone()],
381 vec!["_whole_".into()],
382 vec![HashMap::new()],
383 )
384 } else {
385 extract_epochs(&processed, events, &epoch_specs)
386 };
387
388 let mut current_sr = sr;
390 for step in &post_steps {
391 for epoch in &mut epochs {
392 apply_step_to_epoch(epoch, step, current_sr);
393 }
394 if let Step::Resample(target) = step {
395 current_sr = *target;
396 }
397 }
398
399 let final_sr = current_sr;
400
401 PipelineResult {
402 x: epochs,
403 y: labels,
404 metadata: metas,
405 channel_names,
406 sampling_rate: final_sr,
407 }
408 }
409}
410
411fn apply_step_to_data(data: EegData, step: &Step) -> EegData {
414 match step {
415 Step::SelectChannels(names) => {
416 let refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
417 data.select_channels(&refs)
418 }
419 Step::ExcludeChannels(names) => {
420 let refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
421 data.exclude_channels(&refs)
422 }
423 Step::AverageReference => data.set_average_reference(),
424 Step::ChannelReference(ch) => data.set_reference(ch),
425 Step::Bandpass(lo, hi, order) => data.filter(Some(*lo), Some(*hi), *order),
426 Step::Highpass(freq, order) => data.filter(Some(*freq), None, *order),
427 Step::Lowpass(freq, order) => data.filter(None, Some(*freq), *order),
428 Step::Notch(freq) => data.notch_filter(*freq, 30.0),
429 Step::Resample(hz) => data.resample(*hz),
430 Step::ZScore | Step::MinMaxNormalize | Step::Baseline(..) | Step::Epoch(..) => data,
431 }
432}
433
434fn apply_step_to_epoch(epoch: &mut [Vec<f64>], step: &Step, sr: f64) {
435 match step {
436 Step::Baseline(bmin, bmax) => {
437 for ch in epoch.iter_mut() {
438 let start = (bmin.max(0.0) * sr).round() as usize;
439 let end = (bmax.max(0.0) * sr).round() as usize;
440 let end = end.min(ch.len());
441 let start = start.min(end);
442 if start < end {
443 let mean: f64 = ch[start..end].iter().sum::<f64>() / (end - start) as f64;
444 for v in ch.iter_mut() {
445 *v -= mean;
446 }
447 }
448 }
449 }
450 Step::ZScore => {
451 for ch in epoch.iter_mut() {
452 let n = ch.len() as f64;
453 if n < 2.0 {
454 continue;
455 }
456 let mean = ch.iter().sum::<f64>() / n;
457 let std = (ch.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n).sqrt();
458 let std = if std > f64::EPSILON { std } else { 1.0 };
459 for v in ch.iter_mut() {
460 *v = (*v - mean) / std;
461 }
462 }
463 }
464 Step::MinMaxNormalize => {
465 for ch in epoch.iter_mut() {
466 let min = ch.iter().copied().fold(f64::INFINITY, f64::min);
467 let max = ch.iter().copied().fold(f64::NEG_INFINITY, f64::max);
468 let range = max - min;
469 let range = if range > f64::EPSILON { range } else { 1.0 };
470 for v in ch.iter_mut() {
471 *v = (*v - min) / range;
472 }
473 }
474 }
475 Step::Resample(target_hz) => {
476 for ch in epoch.iter_mut() {
477 *ch = bids_filter::resample(ch, sr, *target_hz);
478 }
479 }
480 _ => {} }
482}
483
484type ExtractedEpochs = (
486 Vec<Vec<Vec<f64>>>,
487 Vec<String>,
488 Vec<HashMap<String, String>>,
489);
490
491fn extract_epochs(
493 data: &EegData,
494 events: &[EegEvent],
495 epoch_specs: &[(String, f64, f64)],
496) -> ExtractedEpochs {
497 let sr = data.sampling_rates.first().copied().unwrap_or(1.0);
498 let n_total = data.data.first().map_or(0, |ch| ch.len()) as isize;
499
500 let mut epochs = Vec::new();
501 let mut labels = Vec::new();
502 let mut metas = Vec::new();
503
504 for (trial_type, tmin, tmax) in epoch_specs {
505 let n_before = ((-tmin) * sr).round() as usize;
506 let n_after = (tmax * sr).round() as usize;
507 let epoch_len = n_before + n_after;
508
509 for event in events {
510 let tt = event.trial_type.as_deref().unwrap_or("");
511 if tt != trial_type {
512 continue;
513 }
514
515 let center = (event.onset * sr).round() as isize;
516 let start = center - n_before as isize;
517
518 if start < 0 || start + epoch_len as isize > n_total {
519 continue;
520 }
521 let start = start as usize;
522
523 let epoch: Vec<Vec<f64>> = data
524 .data
525 .iter()
526 .map(|ch| ch[start..start + epoch_len].to_vec())
527 .collect();
528
529 let mut meta = HashMap::new();
530 meta.insert("trial_type".into(), trial_type.clone());
531 meta.insert("onset".into(), event.onset.to_string());
532 if let Some(ref v) = event.value {
533 meta.insert("value".into(), v.clone());
534 }
535 for (k, v) in &event.extra {
536 meta.insert(k.clone(), v.clone());
537 }
538
539 epochs.push(epoch);
540 labels.push(trial_type.clone());
541 metas.push(meta);
542 }
543 }
544
545 (epochs, labels, metas)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use crate::data::EegData;
552 use crate::events::EegEvent;
553
554 fn make_test_data() -> EegData {
555 let sr = 256.0;
556 let dur = 10.0;
557 let n = (sr * dur) as usize;
558 EegData {
559 channel_labels: vec!["Fz".into(), "Cz".into(), "Pz".into()],
560 data: vec![
561 (0..n)
562 .map(|i| (2.0 * std::f64::consts::PI * 10.0 * i as f64 / sr).sin())
563 .collect(),
564 (0..n)
565 .map(|i| (2.0 * std::f64::consts::PI * 20.0 * i as f64 / sr).sin())
566 .collect(),
567 (0..n)
568 .map(|i| (2.0 * std::f64::consts::PI * 5.0 * i as f64 / sr).sin())
569 .collect(),
570 ],
571 sampling_rates: vec![sr; 3],
572 duration: dur,
573 annotations: Vec::new(),
574 stim_channel_indices: Vec::new(),
575 is_discontinuous: false,
576 record_onsets: Vec::new(),
577 }
578 }
579
580 fn make_test_events() -> Vec<EegEvent> {
581 vec![
582 EegEvent {
583 onset: 1.0,
584 duration: 0.0,
585 trial_type: Some("left_hand".into()),
586 value: None,
587 sample: None,
588 response_time: None,
589 extra: HashMap::new(),
590 },
591 EegEvent {
592 onset: 3.0,
593 duration: 0.0,
594 trial_type: Some("right_hand".into()),
595 value: None,
596 sample: None,
597 response_time: None,
598 extra: HashMap::new(),
599 },
600 EegEvent {
601 onset: 5.0,
602 duration: 0.0,
603 trial_type: Some("left_hand".into()),
604 value: None,
605 sample: None,
606 response_time: None,
607 extra: HashMap::new(),
608 },
609 EegEvent {
610 onset: 7.0,
611 duration: 0.0,
612 trial_type: Some("right_hand".into()),
613 value: None,
614 sample: None,
615 response_time: None,
616 extra: HashMap::new(),
617 },
618 ]
619 }
620
621 #[test]
622 fn test_basic_pipeline() {
623 let data = make_test_data();
624 let events = make_test_events();
625
626 let pipeline = Pipeline::new()
627 .select_channels(&["Fz", "Cz"])
628 .bandpass(1.0, 40.0, 4)
629 .epoch("left_hand", 0.0, 2.0)
630 .epoch("right_hand", 0.0, 2.0)
631 .baseline(0.0, 0.5)
632 .z_score();
633
634 let result = pipeline.transform(&data, &events);
635 assert_eq!(result.n_epochs(), 4);
636 assert_eq!(result.n_channels(), 2);
637 assert_eq!(result.y.iter().filter(|l| *l == "left_hand").count(), 2);
638 assert_eq!(result.y.iter().filter(|l| *l == "right_hand").count(), 2);
639 }
640
641 #[test]
642 fn test_pipeline_with_resample() {
643 let data = make_test_data();
644 let events = make_test_events();
645
646 let pipeline = Pipeline::new().epoch("left_hand", 0.0, 2.0).resample(128.0);
647
648 let result = pipeline.transform(&data, &events);
649 assert_eq!(result.n_epochs(), 2);
650 assert_eq!(result.sampling_rate, 128.0);
651 assert!((result.n_samples() as i32 - 256).abs() <= 2);
653 }
654
655 #[test]
656 fn test_pipeline_result_helpers() {
657 let data = make_test_data();
658 let events = make_test_events();
659
660 let pipeline = Pipeline::new()
661 .epoch("left_hand", 0.0, 1.0)
662 .epoch("right_hand", 0.0, 1.0);
663
664 let result = pipeline.transform(&data, &events);
665 let (ne, nc, ns) = result.shape();
666 assert_eq!(ne, 4);
667 assert_eq!(nc, 3);
668
669 let flat = result.to_flat_features();
670 assert_eq!(flat.len(), 4);
671 assert_eq!(flat[0].len(), nc * ns);
672
673 let classes = result.classes();
674 assert_eq!(classes, vec!["left_hand", "right_hand"]);
675
676 let encoded = result.y_encoded();
677 assert_eq!(encoded.len(), 4);
678 }
679
680 #[test]
681 fn test_no_epoch_pipeline() {
682 let data = make_test_data();
683 let pipeline = Pipeline::new().select_channels(&["Fz"]).z_score();
684
685 let result = pipeline.transform(&data, &[]);
686 assert_eq!(result.n_epochs(), 1);
687 assert_eq!(result.n_channels(), 1);
688 }
689
690 #[test]
691 fn test_json_roundtrip() {
692 let pipeline = Pipeline::new()
693 .with_name("motor_imagery_baseline")
694 .with_description("Standard MI preprocessing")
695 .select_channels(&["Fz", "Cz", "Pz", "C3", "C4"])
696 .average_reference()
697 .bandpass(8.0, 30.0, 5)
698 .notch(50.0)
699 .epoch("left_hand", -0.5, 3.5)
700 .epoch("right_hand", -0.5, 3.5)
701 .baseline(-0.5, 0.0)
702 .resample(128.0)
703 .z_score();
704
705 let json = pipeline.to_json();
706 assert!(json.contains("motor_imagery_baseline"));
707 assert!(json.contains("Bandpass"));
708 assert!(json.contains("Notch"));
709 assert!(json.contains("left_hand"));
710
711 let loaded = Pipeline::from_json(&json).unwrap();
712 assert_eq!(loaded.name, "motor_imagery_baseline");
713 assert_eq!(loaded.steps().len(), pipeline.steps().len());
714
715 let data = make_test_data();
717 let events = make_test_events();
718 let r1 = pipeline.transform(&data, &events);
719 let r2 = loaded.transform(&data, &events);
720 assert_eq!(r1.shape(), r2.shape());
721 assert_eq!(r1.y, r2.y);
722 }
723
724 #[test]
725 fn test_save_load_json_file() {
726 let dir = std::env::temp_dir().join("bids_pipeline_test");
727 std::fs::create_dir_all(&dir).unwrap();
728 let path = dir.join("pipeline.json");
729
730 let pipeline = Pipeline::new()
731 .with_name("test_save")
732 .bandpass(1.0, 40.0, 4)
733 .epoch("stimulus", 0.0, 1.0)
734 .z_score();
735
736 pipeline.save_json(path.to_str().unwrap()).unwrap();
737 assert!(path.exists());
738
739 let loaded = Pipeline::load_json(path.to_str().unwrap()).unwrap();
740 assert_eq!(loaded.name, "test_save");
741 assert_eq!(loaded.steps().len(), 3);
742
743 let content = std::fs::read_to_string(&path).unwrap();
745 assert!(content.contains("\"name\""));
746 assert!(content.contains("test_save"));
747 assert!(content.contains("Bandpass"));
748
749 std::fs::remove_dir_all(&dir).unwrap();
750 }
751
752 #[test]
753 fn test_step_visibility() {
754 let pipeline = Pipeline::new()
755 .highpass(1.0, 4)
756 .lowpass(40.0, 4)
757 .resample(256.0);
758
759 let steps = pipeline.steps();
760 assert_eq!(steps.len(), 3);
761 assert!(matches!(steps[0], Step::Highpass(..)));
762 assert!(matches!(steps[1], Step::Lowpass(..)));
763 assert!(matches!(steps[2], Step::Resample(..)));
764 }
765}