1use std::{path::Path, time::Instant};
10
11use anyhow::Context;
12use burn::prelude::*;
13
14use crate::{
15 config::{DataConfig, ModelConfig},
16 data::{InputBatch, FifInfo, channel_wise_normalize},
17 model::luna::Luna,
18 model::rope::RotaryEmbedding,
19 weights::load_model,
20};
21
22pub struct EpochEmbedding {
24 pub output: Vec<f32>,
28 pub shape: Vec<usize>,
30 pub chan_pos: Vec<f32>,
32 pub n_channels: usize,
33}
34
35pub struct EncodingResult {
37 pub epochs: Vec<EpochEmbedding>,
38 pub fif_info: Option<FifInfo>,
39 pub ms_preproc: f64,
40 pub ms_encode: f64,
41}
42
43impl EncodingResult {
44 pub fn load_safetensors(path: &str) -> anyhow::Result<Self> {
46 let bytes = std::fs::read(path)?;
47 let st = safetensors::SafeTensors::deserialize(&bytes)?;
48
49 let n_samples = {
50 let v = st.tensor("n_samples")?;
51 f32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize
52 };
53
54 let mut epochs = Vec::with_capacity(n_samples);
55 for i in 0..n_samples {
56 let out_view = st.tensor(&format!("output_{i}"))?;
57 let shape: Vec<usize> = out_view.shape().to_vec();
58 let output: Vec<f32> = out_view.data().chunks_exact(4)
59 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
60 .collect();
61
62 let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
63 let n_channels = pos_view.shape()[0];
64 let chan_pos: Vec<f32> = pos_view.data().chunks_exact(4)
65 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
66 .collect();
67
68 epochs.push(EpochEmbedding { output, shape, chan_pos, n_channels });
69 }
70
71 Ok(Self { epochs, fif_info: None, ms_preproc: 0.0, ms_encode: 0.0 })
72 }
73
74 pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
76 use safetensors::{Dtype, View};
77 use std::borrow::Cow;
78
79 struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
80 impl View for RawTensor {
81 fn dtype(&self) -> Dtype { self.dtype }
82 fn shape(&self) -> &[usize] { &self.shape }
83 fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
84 fn data_len(&self) -> usize { self.data.len() }
85 }
86
87 let f32_bytes = |v: &[f32]| -> Vec<u8> {
88 v.iter().flat_map(|f| f.to_le_bytes()).collect()
89 };
90
91 let mut keys: Vec<String> = Vec::new();
92 let mut tensors: Vec<RawTensor> = Vec::new();
93
94 for (i, ep) in self.epochs.iter().enumerate() {
95 keys.push(format!("output_{i}"));
96 tensors.push(RawTensor {
97 data: f32_bytes(&ep.output),
98 shape: ep.shape.clone(),
99 dtype: Dtype::F32,
100 });
101
102 keys.push(format!("chan_pos_{i}"));
103 tensors.push(RawTensor {
104 data: f32_bytes(&ep.chan_pos),
105 shape: vec![ep.n_channels, 3],
106 dtype: Dtype::F32,
107 });
108 }
109
110 let n = self.epochs.len() as f32;
111 keys.push("n_samples".into());
112 tensors.push(RawTensor {
113 data: f32_bytes(&[n]),
114 shape: vec![1],
115 dtype: Dtype::F32,
116 });
117
118 let pairs: Vec<(&str, RawTensor)> = keys.iter()
119 .map(|s| s.as_str())
120 .zip(tensors)
121 .collect();
122 let bytes = safetensors::serialize(pairs, None)?;
123 std::fs::write(path, bytes)?;
124 Ok(())
125 }
126}
127
128pub struct LunaEncoder<B: Backend> {
132 model: Luna<B>,
133 rope: RotaryEmbedding<B>,
134 pub model_cfg: ModelConfig,
135 pub data_cfg: DataConfig,
136 device: B::Device,
137}
138
139impl<B: Backend> LunaEncoder<B> {
140 pub fn load(
142 config_path: &Path,
143 weights_path: &Path,
144 device: B::Device,
145 ) -> anyhow::Result<(Self, f64)> {
146 let cfg_str = std::fs::read_to_string(config_path)
147 .with_context(|| format!("config: {}", config_path.display()))?;
148 let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
149 let model_cfg: ModelConfig = serde_json::from_value(
150 hf_val.get("model").cloned().unwrap_or(hf_val.clone())
151 ).context("parsing model config")?;
152
153 let max_seqlen = 1024; let head_dim = model_cfg.hidden_dim() / model_cfg.total_heads();
156 let rope = RotaryEmbedding::new(head_dim, max_seqlen, 10_000.0, &device);
157
158 let t = Instant::now();
159 let n_channel_names = 90; let model = load_model::<B>(
161 &model_cfg,
162 weights_path.to_str().context("weights path not valid UTF-8")?,
163 n_channel_names,
164 &device,
165 )?;
166 let ms = t.elapsed().as_secs_f64() * 1000.0;
167
168 Ok((Self { model, rope, model_cfg, data_cfg: DataConfig::default(), device }, ms))
169 }
170
171 pub fn describe(&self) -> String {
172 let c = &self.model_cfg;
173 format!(
174 "LUNA embed_dim={} queries={} depth={} heads={} patch={} classes={}",
175 c.embed_dim, c.num_queries, c.depth, c.num_heads, c.patch_size, c.num_classes,
176 )
177 }
178
179 pub fn run_batch(&self, batch: &InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
181 use crate::model::luna::LunaOutput;
182
183 let signal = channel_wise_normalize(batch.signal.clone());
185
186 let luna_output = self.model.forward(
187 signal,
188 batch.channel_locations.clone(),
189 None, batch.channel_names.clone(),
191 &self.rope,
192 );
193
194 let output = match luna_output {
196 LunaOutput::Classification { logits, .. } => logits,
197 LunaOutput::Reconstruction { x_reconstructed, .. } => x_reconstructed,
198 };
199
200 let shape = output.dims().to_vec();
201 let output_vec = output.squeeze::<2>()
202 .into_data()
203 .to_vec::<f32>()
204 .map_err(|e| anyhow::anyhow!("output→vec: {e:?}"))?;
205
206 let chan_pos = batch.channel_locations.clone()
207 .squeeze::<2>()
208 .into_data()
209 .to_vec::<f32>()
210 .map_err(|e| anyhow::anyhow!("chan_pos→vec: {e:?}"))?;
211
212 Ok(EpochEmbedding {
213 output: output_vec,
214 shape: shape[1..].to_vec(), chan_pos,
216 n_channels: batch.n_channels,
217 })
218 }
219
220 pub fn run_batches(&self, batches: &[InputBatch<B>]) -> anyhow::Result<Vec<EpochEmbedding>> {
222 batches.iter().map(|b| self.run_batch(b)).collect()
223 }
224
225 pub fn device(&self) -> &B::Device { &self.device }
226}