1use burn::prelude::*;
6
7pub struct InputBatch<B: Backend> {
9 pub signal: Tensor<B, 3>,
11 pub channel_locations: Tensor<B, 3>,
13 pub channel_names: Option<Tensor<B, 2, Int>>,
15 pub n_channels: usize,
17 pub n_samples: usize,
19}
20
21pub struct FifInfo {
23 pub ch_names: Vec<String>,
24 pub ch_pos_mm: Vec<[f32; 3]>,
25 pub sfreq: f32,
26 pub n_times_raw: usize,
27 pub duration_s: f32,
28 pub n_epochs: usize,
29 pub target_sfreq: f32,
30 pub epoch_dur_s: f32,
31}
32
33pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
35 let mean = x.clone().mean_dim(2); let diff = x.clone() - mean.clone();
37 let var = (diff.clone() * diff.clone()).mean_dim(2);
38 let std = (var + 1e-8).sqrt();
39 (x - mean) / std
40}
41
42pub fn build_batch<B: Backend>(
44 signal: Vec<f32>, positions: Vec<f32>, channel_indices: Option<Vec<i64>>, n_channels: usize,
48 n_samples: usize,
49 device: &B::Device,
50) -> InputBatch<B> {
51 let signal = Tensor::<B, 2>::from_data(
52 TensorData::new(signal, vec![n_channels, n_samples]), device,
53 ).unsqueeze_dim::<3>(0); let channel_locations = Tensor::<B, 2>::from_data(
56 TensorData::new(positions, vec![n_channels, 3]), device,
57 ).unsqueeze_dim::<3>(0); let channel_names = channel_indices.map(|idx| {
60 Tensor::<B, 1, Int>::from_data(
61 TensorData::new(idx, vec![n_channels]), device,
62 ).unsqueeze_dim::<2>(0) });
64
65 InputBatch {
66 signal,
67 channel_locations,
68 channel_names,
69 n_channels,
70 n_samples,
71 }
72}
73
74pub fn build_batch_named<B: Backend>(
82 signal: Vec<f32>, channel_names: &[&str], n_samples: usize,
85 device: &B::Device,
86) -> InputBatch<B> {
87 let n_channels = channel_names.len();
88
89 let indices = crate::channel_vocab::channel_indices_unwrap(channel_names);
91
92 let positions: Vec<f32> = channel_names.iter()
94 .flat_map(|name| {
95 crate::channel_positions::bipolar_channel_xyz(name)
96 .unwrap_or([0.0, 0.0, 0.0])
97 .to_vec()
98 })
99 .collect();
100
101 build_batch(signal, positions, Some(indices), n_channels, n_samples, device)
102}