use std::{path::Path, time::Instant};
use anyhow::Context;
use burn::prelude::*;
use crate::{
config::{DataConfig, ModelConfig},
data::{load_batch, load_from_fif, invert_reshape, FifInfo, InputBatch},
encoder::{EncodingResult, EpochEmbedding},
model::{encoder_decoder::EncoderDecoder, rope::RotaryEmbedding},
weights::load_model,
};
pub struct EpochOutput {
pub reconstructed: Vec<f32>,
pub shape: Vec<usize>,
pub chan_pos: Vec<f32>,
pub n_channels: usize,
}
pub struct InferenceResult {
pub epochs: Vec<EpochOutput>,
pub fif_info: Option<FifInfo>,
pub ms_preproc: f64,
pub ms_infer: f64,
}
impl InferenceResult {
pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
use safetensors::{Dtype, View};
use std::borrow::Cow;
struct F32Tensor { data: Vec<u8>, shape: Vec<usize> }
impl View for F32Tensor {
fn dtype(&self) -> Dtype { Dtype::F32 }
fn shape(&self) -> &[usize] { &self.shape }
fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
fn data_len(&self) -> usize { self.data.len() }
}
fn to_bytes(v: &[f32]) -> Vec<u8> {
v.iter().flat_map(|f| f.to_le_bytes()).collect()
}
let mut keys: Vec<String> = Vec::new();
let mut tensors: Vec<F32Tensor> = Vec::new();
for (i, ep) in self.epochs.iter().enumerate() {
keys.push(format!("reconstructed_{i}"));
tensors.push(F32Tensor { data: to_bytes(&ep.reconstructed), shape: ep.shape.clone() });
keys.push(format!("chan_pos_{i}"));
tensors.push(F32Tensor { data: to_bytes(&ep.chan_pos), shape: vec![ep.n_channels, 3] });
}
let n = self.epochs.len() as f32;
keys.push("n_samples".into());
tensors.push(F32Tensor { data: to_bytes(&[n]), shape: vec![1] });
let pairs: Vec<(&str, F32Tensor)> =
keys.iter().map(|s| s.as_str()).zip(tensors).collect();
let bytes = safetensors::serialize(pairs, None)?;
std::fs::write(path, bytes)?;
Ok(())
}
}
pub struct ZunaInference<B: Backend> {
model: EncoderDecoder<B>,
rope: RotaryEmbedding<B>,
pub model_cfg: ModelConfig,
pub data_cfg: DataConfig,
device: B::Device,
}
impl<B: Backend> ZunaInference<B> {
pub fn load(
config_path: &Path,
weights_path: &Path,
device: B::Device,
) -> anyhow::Result<(Self, f64)> {
let cfg_str = std::fs::read_to_string(config_path)
.with_context(|| format!("config: {}", config_path.display()))?;
let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
.context("parsing model config")?;
let rope = RotaryEmbedding::<B>::new(
model_cfg.head_dim,
model_cfg.rope_dim,
model_cfg.max_seqlen,
model_cfg.rope_theta,
&device,
);
let t = Instant::now();
let model = load_model::<B>(
&model_cfg,
weights_path.to_str().context("weights path not valid UTF-8")?,
&device,
)?;
let ms_weights = t.elapsed().as_secs_f64() * 1000.0;
Ok((Self { model, rope, model_cfg, data_cfg: DataConfig::default(), device }, ms_weights))
}
pub fn describe(&self) -> String {
let c = &self.model_cfg;
format!(
"ZUNA dim={} layers={} head_dim={} ffn_hidden={} \
rope_dim={} max_seqlen={}",
c.dim, c.n_layers, c.head_dim, c.ffn_hidden_dim(), c.rope_dim, c.max_seqlen,
)
}
pub fn run_fif(
&self,
fif_path: &Path,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<InferenceResult> {
let t_pp = Instant::now();
let (batches, fif_info) = load_from_fif::<B>(
fif_path, &self.data_cfg, data_norm, &self.device,
).with_context(|| format!("exg on {}", fif_path.display()))?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_inf = Instant::now();
let epochs = self.run_batches(batches, steps, cfg, data_norm)?;
let ms_infer = t_inf.elapsed().as_secs_f64() * 1000.0;
Ok(InferenceResult { epochs, fif_info: Some(fif_info), ms_preproc, ms_infer })
}
pub fn run_safetensors_batch(
&self,
batch_path: &Path,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<InferenceResult> {
let t_pp = Instant::now();
let batches = load_batch::<B>(
batch_path.to_str().context("batch path not valid UTF-8")?,
&self.data_cfg,
&self.device,
)?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_inf = Instant::now();
let epochs = self.run_batches(batches, steps, cfg, data_norm)?;
let ms_infer = t_inf.elapsed().as_secs_f64() * 1000.0;
Ok(InferenceResult { epochs, fif_info: None, ms_preproc, ms_infer })
}
pub fn encode_fif(
&self,
fif_path: &Path,
data_norm: f32,
) -> anyhow::Result<EncodingResult> {
let t_pp = Instant::now();
let (batches, fif_info) = load_from_fif::<B>(
fif_path, &self.data_cfg, data_norm, &self.device,
).with_context(|| format!("exg on {}", fif_path.display()))?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_enc = Instant::now();
let epochs = self.encode_inputs(batches)?;
let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
Ok(EncodingResult { epochs, fif_info: Some(fif_info), ms_preproc, ms_encode })
}
pub fn encode_batch(
&self,
batch_path: &Path,
) -> anyhow::Result<EncodingResult> {
let t_pp = Instant::now();
let batches = load_batch::<B>(
batch_path.to_str().context("batch path not valid UTF-8")?,
&self.data_cfg,
&self.device,
)?;
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let t_enc = Instant::now();
let epochs = self.encode_inputs(batches)?;
let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
Ok(EncodingResult { epochs, fif_info: None, ms_preproc, ms_encode })
}
pub fn encode_fif_parallel(
&self,
fif_paths: &[impl AsRef<Path> + Sync],
data_norm: f32,
) -> anyhow::Result<Vec<EncodingResult>> {
use rayon::prelude::*;
use crate::data::{preprocess_fif_cpu, preprocessed_to_batch, PreprocessedFif};
let data_cfg = self.data_cfg.clone();
let t_pp = Instant::now();
let preprocessed: Vec<anyhow::Result<PreprocessedFif>> = fif_paths
.par_iter()
.map(|p| preprocess_fif_cpu(p.as_ref(), &data_cfg, data_norm))
.collect();
let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
let mut all_batches: Vec<InputBatch<B>> = Vec::new();
let mut file_epoch_counts: Vec<usize> = Vec::new();
let mut fif_infos: Vec<FifInfo> = Vec::new();
for (i, result) in preprocessed.into_iter().enumerate() {
let pfif = result.with_context(|| {
format!("preprocessing file {}", fif_paths[i].as_ref().display())
})?;
file_epoch_counts.push(pfif.epochs.len());
fif_infos.push(pfif.info);
for ep in pfif.epochs {
all_batches.push(preprocessed_to_batch(ep, &self.device));
}
}
let t_enc = Instant::now();
let all_embeddings = self.encode_inputs(all_batches)?;
let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
let mut emb_iter = all_embeddings.into_iter();
let results = file_epoch_counts.into_iter().zip(fif_infos)
.map(|(count, info)| {
let epochs: Vec<EpochEmbedding> = (&mut emb_iter).take(count).collect();
EncodingResult { epochs, fif_info: Some(info), ms_preproc, ms_encode }
})
.collect();
Ok(results)
}
fn encode_inputs(
&self,
batches: Vec<InputBatch<B>>,
) -> anyhow::Result<Vec<EpochEmbedding>> {
if batches.len() <= 1 {
return batches.into_iter().map(|batch| self.encode_one(batch)).collect();
}
let first_s = batches[0].encoder_input.dims()[1];
let all_same = batches.iter().all(|b| b.encoder_input.dims()[1] == first_s);
if all_same {
self.encode_batched(batches)
} else {
batches.into_iter().map(|b| self.encode_one(b)).collect()
}
}
fn encode_one(&self, batch: InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
let n_channels = batch.n_channels;
let tc = batch.tc;
let tok_idx_saved = batch.tok_idx.clone();
let chan_pos_saved = batch.chan_pos.clone();
let enc_out = self.model.encoder.forward(
batch.encoder_input, batch.tok_idx, &self.rope,
);
let [_, s, output_dim] = enc_out.dims();
let embeddings = tensor_data_to_f32(enc_out.squeeze::<2>().into_data())
.map_err(|e| anyhow::anyhow!("embedding→vec: {e}"))?;
let tok_idx_data = tok_idx_saved.into_data();
let tok_idx: Vec<i64> = tok_idx_data.to_vec::<i64>()
.or_else(|_| tok_idx_data.to_vec::<i32>()
.map(|v| v.into_iter().map(|x| x as i64).collect()))
.map_err(|e| anyhow::anyhow!("tok_idx→vec: {e:?}"))?;
let chan_pos = tensor_data_to_f32(chan_pos_saved.into_data())
.map_err(|e| anyhow::anyhow!("chan_pos→vec: {e}"))?;
Ok(EpochEmbedding { embeddings, shape: vec![s, output_dim], tok_idx, chan_pos, n_channels, tc })
}
fn encode_batched(&self, batches: Vec<InputBatch<B>>) -> anyhow::Result<Vec<EpochEmbedding>> {
let n = batches.len();
let metadata: Vec<_> = batches.iter().map(|b| {
(b.n_channels, b.tc, b.tok_idx.clone(), b.chan_pos.clone())
}).collect();
let inputs: Vec<Tensor<B, 3>> = batches.into_iter().map(|b| b.encoder_input).collect();
let stacked = Tensor::cat(inputs, 0); let tok_idx = metadata[0].2.clone();
let enc_out = self.model.encoder.forward(stacked, tok_idx, &self.rope);
(0..n).map(|i| {
let enc = enc_out.clone().narrow(0, i, 1);
let [_, s, od] = enc.dims();
let (n_channels, tc, ref tok_idx_saved, ref chan_pos_saved) = metadata[i];
let embeddings = tensor_data_to_f32(enc.squeeze::<2>().into_data())
.map_err(|e| anyhow::anyhow!("embedding→vec: {e}"))?;
let tok_idx_data = tok_idx_saved.clone().into_data();
let tok_idx: Vec<i64> = tok_idx_data.to_vec::<i64>()
.or_else(|_| tok_idx_data.to_vec::<i32>()
.map(|v| v.into_iter().map(|x| x as i64).collect()))
.map_err(|e| anyhow::anyhow!("tok_idx→vec: {e:?}"))?;
let chan_pos = tensor_data_to_f32(chan_pos_saved.clone().into_data())
.map_err(|e| anyhow::anyhow!("chan_pos→vec: {e}"))?;
Ok(EpochEmbedding { embeddings, shape: vec![s, od], tok_idx, chan_pos, n_channels, tc })
}).collect()
}
}
fn tensor_data_to_f32(data: burn::tensor::TensorData) -> Result<Vec<f32>, String> {
if let Ok(v) = data.to_vec::<f32>() {
return Ok(v);
}
let converted = data.clone().convert::<f32>();
if let Ok(v) = converted.to_vec::<f32>() {
return Ok(v);
}
let bytes = &data.bytes;
if bytes.len() % 2 == 0 {
let values: Vec<f32> = bytes
.chunks_exact(2)
.map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
.collect();
return Ok(values);
}
Err(format!("cannot convert tensor data ({} bytes) to f32", bytes.len()))
}
impl<B: Backend> ZunaInference<B> {
pub fn decode_embeddings(
&self,
embeddings: &EncodingResult,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<InferenceResult> {
let t = Instant::now();
let epochs = embeddings.epochs
.iter()
.map(|ep| self.decode_epoch(ep, steps, cfg, data_norm))
.collect::<anyhow::Result<Vec<_>>>()?;
let ms_infer = t.elapsed().as_secs_f64() * 1000.0;
Ok(InferenceResult { epochs, fif_info: None, ms_preproc: 0.0, ms_infer })
}
pub fn decode_epoch(
&self,
ep: &EpochEmbedding,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<EpochOutput> {
use burn::tensor::Distribution;
let n_tokens = ep.n_tokens();
let dc = &self.data_cfg;
let enc_out = Tensor::<B, 2>::from_data(
TensorData::new(ep.embeddings.clone(), ep.shape.clone()),
&self.device,
)
.unsqueeze_dim::<3>(0);
let tok_idx = Tensor::<B, 2, Int>::from_data(
TensorData::new(ep.tok_idx.clone(), vec![n_tokens, 4]),
&self.device,
);
let [b, s, d] = enc_out.dims();
let dt = 1.0_f32 / steps as f32;
let mut z = Tensor::<B, 3>::random(
[b, s, d],
Distribution::Normal(0.0, self.model.global_sigma as f64),
&self.device,
);
for i in (1..=steps).rev() {
let t_val = dt * i as f32;
let time_t = Tensor::<B, 3>::full([b, 1, 1], t_val, &self.device);
let vc = self.model.decoder.forward(
z.clone(), enc_out.clone(), time_t.clone(), tok_idx.clone(), &self.rope,
);
let vc = if (cfg - 1.0).abs() > 1e-4 {
let enc_zeros = Tensor::zeros([b, s, d], &self.device);
let vc_u = self.model.decoder.forward(
z.clone(), enc_zeros, time_t, tok_idx.clone(), &self.rope,
);
vc_u.clone() + (vc - vc_u).mul_scalar(cfg)
} else {
vc
};
z = z - vc.mul_scalar(dt);
}
let [_, s, tf] = z.dims();
let recon = invert_reshape(
z.reshape([s, tf]), ep.n_channels, ep.tc, dc.num_fine_time_pts,
);
let recon = recon.mul_scalar(data_norm);
let shape = recon.dims().to_vec();
let reconstructed = recon.into_data().convert::<f32>().to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("recon→vec: {e:?}"))?;
Ok(EpochOutput { reconstructed, shape, chan_pos: ep.chan_pos.clone(), n_channels: ep.n_channels })
}
fn run_batches(
&self,
batches: Vec<InputBatch<B>>,
steps: usize,
cfg: f32,
data_norm: f32,
) -> anyhow::Result<Vec<EpochOutput>> {
let dc = &self.data_cfg;
batches.into_iter().map(|batch| {
let z = self.model.sample(
batch.encoder_input,
batch.tok_idx,
&self.rope,
steps,
cfg,
);
let [_, s, tf] = z.dims();
let z = z.reshape([s, tf]);
let recon = invert_reshape(z, batch.n_channels, batch.tc, dc.num_fine_time_pts);
let recon = recon.mul_scalar(data_norm);
let shape = recon.dims().to_vec();
let reconstructed = recon.into_data().convert::<f32>().to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("tensor→vec: {e:?}"))?;
let chan_pos = batch.chan_pos.into_data().convert::<f32>().to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("chan_pos→vec: {e:?}"))?;
Ok(EpochOutput { reconstructed, shape, chan_pos, n_channels: batch.n_channels })
}).collect()
}
}