use std::{path::Path, time::Instant};
use anyhow::Context;
use burn::{prelude::*, tensor::Distribution};
use crate::{
config::{DataConfig, ModelConfig},
data::invert_reshape,
encoder::{EncodingResult, EpochEmbedding},
inference::{EpochOutput, InferenceResult},
model::{decoder::DecoderTransformer, rope::RotaryEmbedding},
weights::load_decoder_weights,
};
pub struct ZunaDecoder<B: Backend> {
decoder: DecoderTransformer<B>,
rope: RotaryEmbedding<B>,
pub model_cfg: ModelConfig,
pub data_cfg: DataConfig,
pub global_sigma: f32,
device: B::Device,
}
impl<B: Backend> ZunaDecoder<B> {
pub fn load(
config_path: &Path,
weights_path: &Path,
device: B::Device,
) -> anyhow::Result<(Self, f64)> {
let cfg_str = std::fs::read_to_string(config_path)
.with_context(|| format!("config: {}", config_path.display()))?;
let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
.context("parsing model config")?;
let rope = RotaryEmbedding::<B>::new(
model_cfg.head_dim, model_cfg.rope_dim,
model_cfg.max_seqlen, model_cfg.rope_theta, &device,
);
let t = Instant::now();
let (decoder, n_heads) = load_decoder_weights::<B>(
&model_cfg,
weights_path.to_str().context("weights path not valid UTF-8")?,
&device,
)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
println!("Detected n_heads = {n_heads}");
let global_sigma = model_cfg.stft_global_sigma as f32;
Ok((Self { decoder, rope, model_cfg, data_cfg: DataConfig::default(), global_sigma, device }, ms))
}
pub fn describe(&self) -> String {
let c = &self.model_cfg;
format!(
"ZUNA decoder dim={} layers={} head_dim={} t_dim={} σ={}",
c.dim, c.n_layers, c.head_dim, c.t_dim, self.global_sigma,
)
}
pub fn decode_embeddings(
&self,
embeddings: &EncodingResult,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<InferenceResult> {
let t_dec = Instant::now();
let epochs = embeddings.epochs
.iter()
.map(|ep| self.decode_one(ep, steps, cfg, data_norm))
.collect::<anyhow::Result<Vec<_>>>()?;
let ms_infer = t_dec.elapsed().as_secs_f64() * 1000.0;
Ok(InferenceResult {
epochs,
fif_info: None,
ms_preproc: 0.0,
ms_infer,
})
}
pub fn decode_tensor(
&self,
enc_out: Tensor<B, 3>,
tok_idx: Tensor<B, 2, Int>,
steps: usize,
cfg: f32,
) -> Tensor<B, 3> {
let device = enc_out.device();
let [b, s, d] = enc_out.dims();
let dt = 1.0_f32 / steps as f32;
let sigma = self.global_sigma as f64;
let mut z = Tensor::<B, 3>::random(
[b, s, d],
Distribution::Normal(0.0, sigma),
&device,
);
for i in (1..=steps).rev() {
let t_val = dt * i as f32;
let time_t = Tensor::<B, 3>::full([b, 1, 1], t_val, &device);
let vc = self.decoder.forward(
z.clone(), enc_out.clone(), time_t.clone(), tok_idx.clone(), &self.rope,
);
let vc = if (cfg - 1.0).abs() > 1e-4 {
let enc_zeros = Tensor::zeros([b, s, d], &device);
let vc_uncond = self.decoder.forward(
z.clone(), enc_zeros, time_t, tok_idx.clone(), &self.rope,
);
vc_uncond.clone() + (vc - vc_uncond).mul_scalar(cfg)
} else {
vc
};
z = z - vc.mul_scalar(dt);
}
z
}
fn decode_one(
&self,
ep: &EpochEmbedding,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<EpochOutput> {
let n_tokens = ep.n_tokens();
let dc = &self.data_cfg;
let enc_out = Tensor::<B, 2>::from_data(
TensorData::new(ep.embeddings.clone(), ep.shape.clone()),
&self.device,
)
.unsqueeze_dim::<3>(0);
let tok_idx = Tensor::<B, 2, Int>::from_data(
TensorData::new(ep.tok_idx.clone(), vec![n_tokens, 4]),
&self.device,
);
let z = self.decode_tensor(enc_out, tok_idx, steps, cfg);
let [_, s, tf] = z.dims();
let recon = invert_reshape(
z.reshape([s, tf]),
ep.n_channels,
ep.tc,
dc.num_fine_time_pts,
);
let recon = recon.mul_scalar(data_norm);
let shape = recon.dims().to_vec();
let reconstructed = recon
.into_data()
.convert::<f32>().to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("recon→vec: {e:?}"))?;
let chan_pos = ep.chan_pos.clone();
Ok(EpochOutput { reconstructed, shape, chan_pos, n_channels: ep.n_channels })
}
}