1use burn::prelude::*;
6
7pub struct InputBatch<B: Backend> {
9 pub signal: Tensor<B, 3>,
11 pub positions: Tensor<B, 3>,
13 pub n_channels: usize,
15 pub n_samples: usize,
17}
18
19pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
23 let mean = x.clone().mean_dim(2); let diff = x.clone() - mean.clone();
25 let var = (diff.clone() * diff).mean_dim(2);
26 let std = (var + 1e-8).sqrt();
27 let normed = (x - mean) / std;
28 normed.clamp(-15.0, 15.0)
30}
31
32pub fn build_batch<B: Backend>(
34 signal: Vec<f32>, positions: Vec<f32>, n_channels: usize,
37 n_samples: usize,
38 device: &B::Device,
39) -> InputBatch<B> {
40 let signal = Tensor::<B, 2>::from_data(
41 TensorData::new(signal, vec![n_channels, n_samples]),
42 device,
43 )
44 .unsqueeze_dim::<3>(0); let positions = Tensor::<B, 2>::from_data(
47 TensorData::new(positions, vec![n_channels, 3]),
48 device,
49 )
50 .unsqueeze_dim::<3>(0); InputBatch {
53 signal,
54 positions,
55 n_channels,
56 n_samples,
57 }
58}
59
60pub fn build_batch_named<B: Backend>(
62 signal: Vec<f32>,
63 channel_names: &[&str],
64 n_samples: usize,
65 position_bank: &crate::position_bank::PositionBank,
66 device: &B::Device,
67) -> InputBatch<B> {
68 let n_channels = channel_names.len();
69 let positions = position_bank.get_positions(channel_names);
70 build_batch(signal, positions, n_channels, n_samples, device)
71}