use std::{path::Path, time::Instant};
use anyhow::Context;
use burn::prelude::*;
use crate::{
config::{DataConfig, ModelConfig},
data::{load_batch, load_from_fif, FifInfo, InputBatch},
model::{encoder::EncoderTransformer, rope::RotaryEmbedding},
weights::load_encoder_weights,
};
pub struct EpochEmbedding {
pub embeddings: Vec<f32>,
pub shape: Vec<usize>,
pub tok_idx: Vec<i64>,
pub chan_pos: Vec<f32>,
pub n_channels: usize,
pub tc: usize,
}
impl EpochEmbedding {
#[inline] pub fn n_tokens(&self) -> usize { self.n_channels * self.tc }
#[inline] pub fn output_dim(&self) -> usize { self.shape.get(1).copied().unwrap_or(0) }
}
pub struct EncodingResult {
pub epochs: Vec<EpochEmbedding>,
pub fif_info: Option<FifInfo>,
pub ms_preproc: f64,
pub ms_encode: f64,
}
impl EncodingResult {
pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
use safetensors::{Dtype, View};
use std::borrow::Cow;
struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
impl View for RawTensor {
fn dtype(&self) -> Dtype { self.dtype }
fn shape(&self) -> &[usize] { &self.shape }
fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
fn data_len(&self) -> usize { self.data.len() }
}
let f32_bytes = |v: &[f32]| -> Vec<u8> { v.iter().flat_map(|f| f.to_le_bytes()).collect() };
let i64_bytes = |v: &[i64]| -> Vec<u8> { v.iter().flat_map(|i| i.to_le_bytes()).collect() };
let mut keys: Vec<String> = Vec::new();
let mut tensors: Vec<RawTensor> = Vec::new();
for (i, ep) in self.epochs.iter().enumerate() {
let n_tok = ep.n_tokens();
keys.push(format!("embeddings_{i}"));
tensors.push(RawTensor {
data: f32_bytes(&ep.embeddings),
shape: ep.shape.clone(),
dtype: Dtype::F32,
});
keys.push(format!("tok_idx_{i}"));
tensors.push(RawTensor {
data: i64_bytes(&ep.tok_idx),
shape: vec![n_tok, 4],
dtype: Dtype::I64,
});
keys.push(format!("chan_pos_{i}"));
tensors.push(RawTensor {
data: f32_bytes(&ep.chan_pos),
shape: vec![ep.n_channels, 3],
dtype: Dtype::F32,
});
}
let n = self.epochs.len() as f32;
keys.push("n_samples".into());
tensors.push(RawTensor { data: f32_bytes(&[n]), shape: vec![1], dtype: Dtype::F32 });
let pairs: Vec<(&str, RawTensor)> = keys.iter().map(|s| s.as_str()).zip(tensors).collect();
let bytes = safetensors::serialize(pairs, None)?;
std::fs::write(path, bytes)?;
Ok(())
}
}
pub struct ZunaEncoder<B: Backend> {
encoder: EncoderTransformer<B>,
rope: RotaryEmbedding<B>,
pub model_cfg: ModelConfig,
pub data_cfg: DataConfig,
device: B::Device,
}
impl<B: Backend> ZunaEncoder<B> {
pub fn load(
config_path: &Path,
weights_path: &Path,
device: B::Device,
) -> anyhow::Result<(Self, f64)> {
let cfg_str = std::fs::read_to_string(config_path)
.with_context(|| format!("config: {}", config_path.display()))?;
let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
.context("parsing model config")?;
let rope = RotaryEmbedding::<B>::new(
model_cfg.head_dim, model_cfg.rope_dim,
model_cfg.max_seqlen, model_cfg.rope_theta, &device,
);
let t = Instant::now();
let (encoder, n_heads) = load_encoder_weights::<B>(
&model_cfg,
weights_path.to_str().context("weights path not valid UTF-8")?,
&device,
)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
println!("Detected n_heads = {n_heads}");
Ok((Self { encoder, rope, model_cfg, data_cfg: DataConfig::default(), device }, ms))
}
pub fn describe(&self) -> String {
let c = &self.model_cfg;
format!(
"ZUNA encoder dim={} layers={} head_dim={} out_dim={}",
c.dim, c.n_layers, c.head_dim, c.encoder_output_dim,
)
}
pub fn encode_fif(
&self,
fif_path: &Path,
data_norm: f32,
) -> anyhow::Result<EncodingResult> {
let t_pp = Instant::now();
let (batches, fif_info) = load_from_fif::<B>(
fif_path, &self.data_cfg, data_norm, &self.device,
).with_context(|| format!("exg on {}", fif_path.display()))?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_enc = Instant::now();
let epochs = self.encode_inputs(batches)?;
let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
Ok(EncodingResult { epochs, fif_info: Some(fif_info), ms_preproc, ms_encode })
}
pub fn encode_batch(
&self,
batch_path: &Path,
) -> anyhow::Result<EncodingResult> {
let t_pp = Instant::now();
let batches = load_batch::<B>(
batch_path.to_str().context("batch path not valid UTF-8")?,
&self.data_cfg,
&self.device,
)?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_enc = Instant::now();
let epochs = self.encode_inputs(batches)?;
let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
Ok(EncodingResult { epochs, fif_info: None, ms_preproc, ms_encode })
}
pub fn encode_tensor(&self, batch: &InputBatch<B>) -> Tensor<B, 3> {
self.encoder.forward(
batch.encoder_input.clone(),
batch.tok_idx.clone(),
&self.rope,
)
}
pub fn preprocess_fif(
&self,
fif_path: &Path,
data_norm: f32,
) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
load_from_fif(fif_path, &self.data_cfg, data_norm, &self.device)
}
pub fn encode_batches(
&self,
batches: Vec<InputBatch<B>>,
) -> anyhow::Result<Vec<EpochEmbedding>> {
self.encode_inputs(batches)
}
pub fn device(&self) -> &B::Device { &self.device }
pub(crate) fn encode_inputs(
&self,
batches: Vec<InputBatch<B>>,
) -> anyhow::Result<Vec<EpochEmbedding>> {
batches.into_iter().map(|b| self.encode_one(b)).collect()
}
fn encode_one(&self, batch: InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
let n_channels = batch.n_channels;
let tc = batch.tc;
let tok_idx_saved = batch.tok_idx.clone();
let chan_pos_saved = batch.chan_pos.clone();
let enc_out = self.encoder.forward(
batch.encoder_input,
batch.tok_idx,
&self.rope,
);
let [_, s, output_dim] = enc_out.dims();
let embeddings = enc_out
.squeeze::<2>()
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("embedding→vec: {e:?}"))?;
let tok_idx_data = tok_idx_saved.into_data();
let tok_idx: Vec<i64> = tok_idx_data
.to_vec::<i64>()
.or_else(|_| tok_idx_data.to_vec::<i32>()
.map(|v| v.into_iter().map(|x| x as i64).collect()))
.map_err(|e| anyhow::anyhow!("tok_idx→vec: {e:?}"))?;
let chan_pos = chan_pos_saved
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("chan_pos→vec: {e:?}"))?;
Ok(EpochEmbedding {
embeddings,
shape: vec![s, output_dim],
tok_idx,
chan_pos,
n_channels,
tc,
})
}
}