Skip to main content

osf_rs/
data.rs

1/// Data preparation for OSF inference.
2///
3/// OSF input: [B, 12, 1920] — 12-channel PSG, 64 Hz × 30 seconds.
4
5use burn::prelude::*;
6
7/// A single prepared input batch for the OSF model.
8pub struct InputBatch<B: Backend> {
9    /// PSG signal: [1, C, T].
10    pub signal: Tensor<B, 3>,
11    /// Number of channels.
12    pub n_channels: usize,
13    /// Number of time samples.
14    pub n_samples: usize,
15}
16
17/// Per-epoch embedding produced by OSF.
18pub struct EpochEmbedding {
19    /// CLS embedding values (row-major f32): [D].
20    pub cls_emb: Vec<f32>,
21    /// Patch embedding values (row-major f32): [N, D].
22    pub patch_embs: Vec<f32>,
23    /// CLS embedding dimension.
24    pub embed_dim: usize,
25    /// Number of patches.
26    pub num_patches: usize,
27}
28
29/// Collection of per-epoch outputs.
30pub struct EncodingResult {
31    pub epochs: Vec<EpochEmbedding>,
32    pub ms_load: f64,
33    pub ms_encode: f64,
34}
35
36impl EncodingResult {
37    /// Save to safetensors file.
38    pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
39        use safetensors::{Dtype, View};
40        use std::borrow::Cow;
41
42        struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
43        impl View for RawTensor {
44            fn dtype(&self)    -> Dtype         { self.dtype }
45            fn shape(&self)    -> &[usize]      { &self.shape }
46            fn data(&self)     -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
47            fn data_len(&self) -> usize         { self.data.len() }
48        }
49
50        let f32_bytes = |v: &[f32]| -> Vec<u8> {
51            v.iter().flat_map(|f| f.to_le_bytes()).collect()
52        };
53
54        let mut keys: Vec<String> = Vec::new();
55        let mut tensors: Vec<RawTensor> = Vec::new();
56
57        for (i, ep) in self.epochs.iter().enumerate() {
58            keys.push(format!("cls_emb_{i}"));
59            tensors.push(RawTensor {
60                data: f32_bytes(&ep.cls_emb),
61                shape: vec![ep.embed_dim],
62                dtype: Dtype::F32,
63            });
64
65            keys.push(format!("patch_embs_{i}"));
66            tensors.push(RawTensor {
67                data: f32_bytes(&ep.patch_embs),
68                shape: vec![ep.num_patches, ep.embed_dim],
69                dtype: Dtype::F32,
70            });
71        }
72
73        let n = self.epochs.len() as f32;
74        keys.push("n_epochs".into());
75        tensors.push(RawTensor {
76            data: f32_bytes(&[n]),
77            shape: vec![1],
78            dtype: Dtype::F32,
79        });
80
81        let pairs: Vec<(&str, RawTensor)> = keys.iter()
82            .map(|s| s.as_str())
83            .zip(tensors)
84            .collect();
85        let bytes = safetensors::serialize(pairs, None)?;
86        std::fs::write(path, bytes)?;
87        Ok(())
88    }
89}
90
91/// Build an InputBatch from raw signal data.
92pub fn build_batch<B: Backend>(
93    signal: Vec<f32>,    // [C, T] row-major
94    n_channels: usize,
95    n_samples: usize,
96    device: &B::Device,
97) -> InputBatch<B> {
98    let signal = Tensor::<B, 2>::from_data(
99        TensorData::new(signal, vec![n_channels, n_samples]), device,
100    ).unsqueeze_dim::<3>(0); // [1, C, T]
101
102    InputBatch { signal, n_channels, n_samples }
103}
104
105/// Channel-wise z-score normalisation.
106///
107/// Python equivalent: (x - mean) / (std + eps) per channel.
108pub fn channel_wise_normalize<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
109    let mean = x.clone().mean_dim(2); // [B, C, 1]
110    let diff = x.clone() - mean.clone();
111    let var = (diff.clone() * diff).mean_dim(2);
112    let std = (var + 1e-8).sqrt();
113    (x - mean) / std
114}