1use burn::prelude::*;
6
7pub struct InputBatch<B: Backend> {
9 pub signal: Tensor<B, 3>,
11 pub n_channels: usize,
13 pub n_samples: usize,
15}
16
17pub struct EpochEmbedding {
19 pub cls_emb: Vec<f32>,
21 pub patch_embs: Vec<f32>,
23 pub embed_dim: usize,
25 pub num_patches: usize,
27}
28
29pub struct EncodingResult {
31 pub epochs: Vec<EpochEmbedding>,
32 pub ms_load: f64,
33 pub ms_encode: f64,
34}
35
36impl EncodingResult {
37 pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
39 use safetensors::{Dtype, View};
40 use std::borrow::Cow;
41
42 struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
43 impl View for RawTensor {
44 fn dtype(&self) -> Dtype { self.dtype }
45 fn shape(&self) -> &[usize] { &self.shape }
46 fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
47 fn data_len(&self) -> usize { self.data.len() }
48 }
49
50 let f32_bytes = |v: &[f32]| -> Vec<u8> {
51 v.iter().flat_map(|f| f.to_le_bytes()).collect()
52 };
53
54 let mut keys: Vec<String> = Vec::new();
55 let mut tensors: Vec<RawTensor> = Vec::new();
56
57 for (i, ep) in self.epochs.iter().enumerate() {
58 keys.push(format!("cls_emb_{i}"));
59 tensors.push(RawTensor {
60 data: f32_bytes(&ep.cls_emb),
61 shape: vec![ep.embed_dim],
62 dtype: Dtype::F32,
63 });
64
65 keys.push(format!("patch_embs_{i}"));
66 tensors.push(RawTensor {
67 data: f32_bytes(&ep.patch_embs),
68 shape: vec![ep.num_patches, ep.embed_dim],
69 dtype: Dtype::F32,
70 });
71 }
72
73 let n = self.epochs.len() as f32;
74 keys.push("n_epochs".into());
75 tensors.push(RawTensor {
76 data: f32_bytes(&[n]),
77 shape: vec![1],
78 dtype: Dtype::F32,
79 });
80
81 let pairs: Vec<(&str, RawTensor)> = keys.iter()
82 .map(|s| s.as_str())
83 .zip(tensors)
84 .collect();
85 let bytes = safetensors::serialize(pairs, None)?;
86 std::fs::write(path, bytes)?;
87 Ok(())
88 }
89}
90
91pub fn build_batch<B: Backend>(
93 signal: Vec<f32>, n_channels: usize,
95 n_samples: usize,
96 device: &B::Device,
97) -> InputBatch<B> {
98 let signal = Tensor::<B, 2>::from_data(
99 TensorData::new(signal, vec![n_channels, n_samples]), device,
100 ).unsqueeze_dim::<3>(0); InputBatch { signal, n_channels, n_samples }
103}
104
105pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
109 let mean = x.clone().mean_dim(2); let diff = x.clone() - mean.clone();
111 let var = (diff.clone() * diff).mean_dim(2);
112 let std = (var + 1e-8).sqrt();
113 (x - mean) / std
114}