1use std::{path::Path, time::Instant};
4
5use anyhow::Context;
6use burn::prelude::*;
7
8use crate::{
9 config::ModelConfig,
10 data::{InputBatch, channel_wise_normalize},
11 model::reve::Reve,
12 weights::load_model,
13};
14
15pub struct ReveOutput {
17 pub output: Vec<f32>,
20 pub shape: Vec<usize>,
22 pub n_channels: usize,
23}
24
25pub struct EncodingResult {
27 pub outputs: Vec<ReveOutput>,
28 pub ms_load: f64,
29 pub ms_encode: f64,
30}
31
32pub struct ReveEncoder<B: Backend> {
34 model: Reve<B>,
35 pub model_cfg: ModelConfig,
36 device: B::Device,
37}
38
39impl<B: Backend> ReveEncoder<B> {
40 pub fn load(
42 config_path: &Path,
43 weights_path: &Path,
44 device: B::Device,
45 ) -> anyhow::Result<(Self, f64)> {
46 let cfg_str = std::fs::read_to_string(config_path)
47 .with_context(|| format!("config: {}", config_path.display()))?;
48 let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
49 let model_cfg: ModelConfig = serde_json::from_value(
50 hf_val.get("model").cloned().unwrap_or(hf_val.clone()),
51 )
52 .context("parsing model config")?;
53
54 let t = Instant::now();
55 let model = load_model::<B>(
56 &model_cfg,
57 weights_path.to_str().context("weights path not valid UTF-8")?,
58 &device,
59 )?;
60 let ms = t.elapsed().as_secs_f64() * 1000.0;
61
62 Ok((
63 Self {
64 model,
65 model_cfg,
66 device,
67 },
68 ms,
69 ))
70 }
71
72 pub fn describe(&self) -> String {
73 let c = &self.model_cfg;
74 format!(
75 "REVE embed_dim={} depth={} heads={} head_dim={} patch={} outputs={}",
76 c.embed_dim, c.depth, c.heads, c.head_dim, c.patch_size, c.n_outputs,
77 )
78 }
79
80 pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<ReveOutput> {
82 let signal = channel_wise_normalize(batch.signal.clone());
83
84 let output = self.model.forward(signal, batch.positions.clone());
85 let shape = output.dims().to_vec();
88 let output_vec = output
89 .into_data()
90 .to_vec::<f32>()
91 .map_err(|e| anyhow::anyhow!("output→vec: {e:?}"))?;
92
93 Ok(ReveOutput {
94 output: output_vec,
95 shape: shape[1..].to_vec(), n_channels: batch.n_channels,
97 })
98 }
99
100 pub fn run_batches(
102 &self,
103 batches: &[InputBatch<B>],
104 ) -> anyhow::Result<Vec<ReveOutput>> {
105 batches.iter().map(|b| self.run_batch(b)).collect()
106 }
107
108 pub fn device(&self) -> &B::Device {
109 &self.device
110 }
111}