use std::{path::Path, time::Instant};
use anyhow::Context;
use burn::prelude::*;
use crate::{
config::ModelConfig,
data::{InputBatch, EpochEmbedding},
model::vit::OsfViT,
weights::load_model,
};
pub struct OsfEncoder<B: Backend> {
model: OsfViT<B>,
pub model_cfg: ModelConfig,
device: B::Device,
}
impl<B: Backend> OsfEncoder<B> {
pub fn load(
config_path: &Path,
weights_path: &Path,
device: B::Device,
) -> anyhow::Result<(Self, f64)> {
let cfg_str = std::fs::read_to_string(config_path)
.with_context(|| format!("config: {}", config_path.display()))?;
let model_cfg: ModelConfig = serde_json::from_str(&cfg_str)
.context("parsing model config")?;
let t = Instant::now();
let model = load_model::<B>(
&model_cfg,
weights_path.to_str().context("weights path not valid UTF-8")?,
&device,
)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
Ok((Self { model, model_cfg, device }, ms))
}
pub fn load_with_config(
model_cfg: ModelConfig,
weights_path: &Path,
device: B::Device,
) -> anyhow::Result<(Self, f64)> {
let t = Instant::now();
let model = load_model::<B>(
&model_cfg,
weights_path.to_str().context("weights path not valid UTF-8")?,
&device,
)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
Ok((Self { model, model_cfg, device }, ms))
}
pub fn describe(&self) -> String {
let c = &self.model_cfg;
format!(
"OSF encoder={} width={} depth={} heads={} leads={} patch_time={} patch_ch={}",
c.encoder_name, c.width, c.depth, c.heads, c.num_leads,
c.patch_size_time, c.patch_size_ch,
)
}
pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
let (cls, patches) = self.model.forward_encoding(batch.signal.clone());
let embed_dim = cls.dims()[2];
let num_patches = patches.dims()[1];
let cls_vec = cls.reshape([embed_dim])
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("cls→vec: {e:?}"))?;
let patches_vec = patches.reshape([num_patches, embed_dim])
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("patches→vec: {e:?}"))?;
Ok(EpochEmbedding {
cls_emb: cls_vec,
patch_embs: patches_vec,
embed_dim,
num_patches,
})
}
pub fn run_batches(&self, batches: &[InputBatch<B>]) -> anyhow::Result<Vec<EpochEmbedding>> {
batches.iter().map(|b| self.run_batch(b)).collect()
}
pub fn model(&self) -> &OsfViT<B> { &self.model }
pub fn device(&self) -> &B::Device { &self.device }
}