use burn::prelude::*;
pub struct InputBatch<B: Backend> {
pub signal: Tensor<B, 3>,
pub positions: Tensor<B, 3>,
pub n_channels: usize,
pub n_samples: usize,
}
pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let mean = x.clone().mean_dim(2); let diff = x.clone() - mean.clone();
let var = (diff.clone() * diff).mean_dim(2);
let std = (var + 1e-8).sqrt();
let normed = (x - mean) / std;
normed.clamp(-15.0, 15.0)
}
pub fn build_batch<B: Backend>(
signal: Vec<f32>, positions: Vec<f32>, n_channels: usize,
n_samples: usize,
device: &B::Device,
) -> InputBatch<B> {
let signal = Tensor::<B, 2>::from_data(
TensorData::new(signal, vec![n_channels, n_samples]),
device,
)
.unsqueeze_dim::<3>(0);
let positions = Tensor::<B, 2>::from_data(
TensorData::new(positions, vec![n_channels, 3]),
device,
)
.unsqueeze_dim::<3>(0);
InputBatch {
signal,
positions,
n_channels,
n_samples,
}
}
pub fn build_batch_named<B: Backend>(
signal: Vec<f32>,
channel_names: &[&str],
n_samples: usize,
position_bank: &crate::position_bank::PositionBank,
device: &B::Device,
) -> InputBatch<B> {
let n_channels = channel_names.len();
let positions = position_bank.get_positions(channel_names);
build_batch(signal, positions, n_channels, n_samples, device)
}