Skip to main content

osf_rs/
encoder.rs

1/// Standalone OSF encoder — produce sleep PSG embeddings.
2///
3/// The encoder produces:
4/// - CLS embedding: [B, 768] — global epoch-level representation
5/// - Patch embeddings: [B, 90, 768] — local patch representations
6
7use std::{path::Path, time::Instant};
8
9use anyhow::Context;
10use burn::prelude::*;
11
12use crate::{
13    config::ModelConfig,
14    data::{InputBatch, EpochEmbedding},
15    model::vit::OsfViT,
16    weights::load_model,
17};
18
19/// High-level OSF encoder for PSG signal processing.
20pub struct OsfEncoder<B: Backend> {
21    model: OsfViT<B>,
22    pub model_cfg: ModelConfig,
23    device: B::Device,
24}
25
26impl<B: Backend> OsfEncoder<B> {
27    /// Load model from config JSON and safetensors weights.
28    pub fn load(
29        config_path: &Path,
30        weights_path: &Path,
31        device: B::Device,
32    ) -> anyhow::Result<(Self, f64)> {
33        let cfg_str = std::fs::read_to_string(config_path)
34            .with_context(|| format!("config: {}", config_path.display()))?;
35        let model_cfg: ModelConfig = serde_json::from_str(&cfg_str)
36            .context("parsing model config")?;
37
38        let t = Instant::now();
39        let model = load_model::<B>(
40            &model_cfg,
41            weights_path.to_str().context("weights path not valid UTF-8")?,
42            &device,
43        )?;
44        let ms = t.elapsed().as_secs_f64() * 1000.0;
45
46        Ok((Self { model, model_cfg, device }, ms))
47    }
48
49    /// Load model from a ModelConfig and safetensors path directly.
50    pub fn load_with_config(
51        model_cfg: ModelConfig,
52        weights_path: &Path,
53        device: B::Device,
54    ) -> anyhow::Result<(Self, f64)> {
55        let t = Instant::now();
56        let model = load_model::<B>(
57            &model_cfg,
58            weights_path.to_str().context("weights path not valid UTF-8")?,
59            &device,
60        )?;
61        let ms = t.elapsed().as_secs_f64() * 1000.0;
62
63        Ok((Self { model, model_cfg, device }, ms))
64    }
65
66    pub fn describe(&self) -> String {
67        let c = &self.model_cfg;
68        format!(
69            "OSF  encoder={}  width={}  depth={}  heads={}  leads={}  patch_time={}  patch_ch={}",
70            c.encoder_name, c.width, c.depth, c.heads, c.num_leads,
71            c.patch_size_time, c.patch_size_ch,
72        )
73    }
74
75    /// Run inference on a prepared InputBatch.
76    ///
77    /// Returns an EpochEmbedding with CLS and patch embeddings.
78    pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
79        let (cls, patches) = self.model.forward_encoding(batch.signal.clone());
80
81        // cls: [B, 1, D], patches: [B, N, D]
82        let embed_dim = cls.dims()[2];
83        let num_patches = patches.dims()[1];
84
85        let cls_vec = cls.reshape([embed_dim])
86            .into_data()
87            .to_vec::<f32>()
88            .map_err(|e| anyhow::anyhow!("cls→vec: {e:?}"))?;
89
90        let patches_vec = patches.reshape([num_patches, embed_dim])
91            .into_data()
92            .to_vec::<f32>()
93            .map_err(|e| anyhow::anyhow!("patches→vec: {e:?}"))?;
94
95        Ok(EpochEmbedding {
96            cls_emb: cls_vec,
97            patch_embs: patches_vec,
98            embed_dim,
99            num_patches,
100        })
101    }
102
103    /// Run on multiple batches.
104    pub fn run_batches(&self, batches: &[InputBatch<B>]) -> anyhow::Result<Vec<EpochEmbedding>> {
105        batches.iter().map(|b| self.run_batch(b)).collect()
106    }
107
108    /// Get the raw ViT model reference.
109    pub fn model(&self) -> &OsfViT<B> { &self.model }
110
111    pub fn device(&self) -> &B::Device { &self.device }
112}