1use std::{path::Path, time::Instant};
8
9use anyhow::Context;
10use burn::prelude::*;
11
12use crate::{
13 config::ModelConfig,
14 data::{InputBatch, EpochEmbedding},
15 model::vit::OsfViT,
16 weights::load_model,
17};
18
19pub struct OsfEncoder<B: Backend> {
21 model: OsfViT<B>,
22 pub model_cfg: ModelConfig,
23 device: B::Device,
24}
25
26impl<B: Backend> OsfEncoder<B> {
27 pub fn load(
29 config_path: &Path,
30 weights_path: &Path,
31 device: B::Device,
32 ) -> anyhow::Result<(Self, f64)> {
33 let cfg_str = std::fs::read_to_string(config_path)
34 .with_context(|| format!("config: {}", config_path.display()))?;
35 let model_cfg: ModelConfig = serde_json::from_str(&cfg_str)
36 .context("parsing model config")?;
37
38 let t = Instant::now();
39 let model = load_model::<B>(
40 &model_cfg,
41 weights_path.to_str().context("weights path not valid UTF-8")?,
42 &device,
43 )?;
44 let ms = t.elapsed().as_secs_f64() * 1000.0;
45
46 Ok((Self { model, model_cfg, device }, ms))
47 }
48
49 pub fn load_with_config(
51 model_cfg: ModelConfig,
52 weights_path: &Path,
53 device: B::Device,
54 ) -> anyhow::Result<(Self, f64)> {
55 let t = Instant::now();
56 let model = load_model::<B>(
57 &model_cfg,
58 weights_path.to_str().context("weights path not valid UTF-8")?,
59 &device,
60 )?;
61 let ms = t.elapsed().as_secs_f64() * 1000.0;
62
63 Ok((Self { model, model_cfg, device }, ms))
64 }
65
66 pub fn describe(&self) -> String {
67 let c = &self.model_cfg;
68 format!(
69 "OSF encoder={} width={} depth={} heads={} leads={} patch_time={} patch_ch={}",
70 c.encoder_name, c.width, c.depth, c.heads, c.num_leads,
71 c.patch_size_time, c.patch_size_ch,
72 )
73 }
74
75 pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
79 let (cls, patches) = self.model.forward_encoding(batch.signal.clone());
80
81 let embed_dim = cls.dims()[2];
83 let num_patches = patches.dims()[1];
84
85 let cls_vec = cls.reshape([embed_dim])
86 .into_data()
87 .to_vec::<f32>()
88 .map_err(|e| anyhow::anyhow!("cls→vec: {e:?}"))?;
89
90 let patches_vec = patches.reshape([num_patches, embed_dim])
91 .into_data()
92 .to_vec::<f32>()
93 .map_err(|e| anyhow::anyhow!("patches→vec: {e:?}"))?;
94
95 Ok(EpochEmbedding {
96 cls_emb: cls_vec,
97 patch_embs: patches_vec,
98 embed_dim,
99 num_patches,
100 })
101 }
102
103 pub fn run_batches(&self, batches: &[InputBatch<B>]) -> anyhow::Result<Vec<EpochEmbedding>> {
105 batches.iter().map(|b| self.run_batch(b)).collect()
106 }
107
108 pub fn model(&self) -> &OsfViT<B> { &self.model }
110
111 pub fn device(&self) -> &B::Device { &self.device }
112}