use burn::prelude::*;
pub struct InputBatch<B: Backend> {
pub signal: Tensor<B, 3>,
pub n_channels: usize,
pub n_samples: usize,
}
pub struct EpochEmbedding {
pub cls_emb: Vec<f32>,
pub patch_embs: Vec<f32>,
pub embed_dim: usize,
pub num_patches: usize,
}
pub struct EncodingResult {
pub epochs: Vec<EpochEmbedding>,
pub ms_load: f64,
pub ms_encode: f64,
}
impl EncodingResult {
pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
use safetensors::{Dtype, View};
use std::borrow::Cow;
struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
impl View for RawTensor {
fn dtype(&self) -> Dtype { self.dtype }
fn shape(&self) -> &[usize] { &self.shape }
fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
fn data_len(&self) -> usize { self.data.len() }
}
let f32_bytes = |v: &[f32]| -> Vec<u8> {
v.iter().flat_map(|f| f.to_le_bytes()).collect()
};
let mut keys: Vec<String> = Vec::new();
let mut tensors: Vec<RawTensor> = Vec::new();
for (i, ep) in self.epochs.iter().enumerate() {
keys.push(format!("cls_emb_{i}"));
tensors.push(RawTensor {
data: f32_bytes(&ep.cls_emb),
shape: vec![ep.embed_dim],
dtype: Dtype::F32,
});
keys.push(format!("patch_embs_{i}"));
tensors.push(RawTensor {
data: f32_bytes(&ep.patch_embs),
shape: vec![ep.num_patches, ep.embed_dim],
dtype: Dtype::F32,
});
}
let n = self.epochs.len() as f32;
keys.push("n_epochs".into());
tensors.push(RawTensor {
data: f32_bytes(&[n]),
shape: vec![1],
dtype: Dtype::F32,
});
let pairs: Vec<(&str, RawTensor)> = keys.iter()
.map(|s| s.as_str())
.zip(tensors)
.collect();
let bytes = safetensors::serialize(pairs, None)?;
std::fs::write(path, bytes)?;
Ok(())
}
}
pub fn build_batch<B: Backend>(
signal: Vec<f32>, n_channels: usize,
n_samples: usize,
device: &B::Device,
) -> InputBatch<B> {
let signal = Tensor::<B, 2>::from_data(
TensorData::new(signal, vec![n_channels, n_samples]), device,
).unsqueeze_dim::<3>(0);
InputBatch { signal, n_channels, n_samples }
}
pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let mean = x.clone().mean_dim(2); let diff = x.clone() - mean.clone();
let var = (diff.clone() * diff).mean_dim(2);
let std = (var + 1e-8).sqrt();
(x - mean) / std
}