use std::path::PathBuf;
use std::time::Instant;
use burn::prelude::{Backend, ElementConversion};
use clap::{Parser, ValueEnum};
use zuna_rs::{EpochEmbedding, EncodingResult, ZunaEncoder};
#[derive(Debug, Clone, ValueEnum)]
enum Device { Cpu, Gpu, GpuF16 }
#[derive(Parser, Debug)]
#[command(name = "embedding_api", about = "ZUNA EEG — minimal embedding API example")]
struct Args {
#[arg(long, default_value = "cpu")]
device: Device,
#[arg(long, default_value = "Zyphra/ZUNA")]
repo: String,
#[arg(long)]
weights: Option<PathBuf>,
#[arg(long)]
config: Option<PathBuf>,
#[arg(long, default_value = concat!(env!("CARGO_MANIFEST_DIR"), "/data/sample1_raw.fif"))]
fif: PathBuf,
#[arg(long, default_value = concat!(env!("CARGO_MANIFEST_DIR"), "/data/api_embeddings.safetensors"))]
output: PathBuf,
#[arg(long, env = "RAYON_NUM_THREADS")]
threads: Option<usize>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let _n = zuna_rs::init_threads(args.threads);
match args.device {
Device::Cpu => run_cpu(args),
Device::Gpu => run_gpu(args),
Device::GpuF16 => run_gpu_f16(args),
}
}
#[cfg(feature = "ndarray")]
fn run_cpu(args: Args) -> anyhow::Result<()> {
use burn::backend::{ndarray::NdArrayDevice, NdArray};
run::<NdArray>(NdArrayDevice::Cpu, args)
}
#[cfg(not(feature = "ndarray"))]
fn run_cpu(_: Args) -> anyhow::Result<()> {
anyhow::bail!("CPU backend not compiled — rebuild with `--features ndarray`")
}
#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
fn run_gpu(args: Args) -> anyhow::Result<()> {
use burn::backend::{wgpu::WgpuDevice, Wgpu};
run::<Wgpu>(WgpuDevice::DefaultDevice, args)
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
fn run_gpu(_: Args) -> anyhow::Result<()> {
anyhow::bail!("GPU backend not compiled — rebuild with `--no-default-features --features wgpu`")
}
#[cfg(any(feature = "wgpu-f16", feature = "wgpu"))]
fn run_gpu_f16(args: Args) -> anyhow::Result<()> {
type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
run::<B>(burn::backend::wgpu::WgpuDevice::DefaultDevice, args)
}
#[cfg(not(any(feature = "wgpu-f16", feature = "wgpu")))]
fn run_gpu_f16(_: Args) -> anyhow::Result<()> {
anyhow::bail!("GPU f16 backend not compiled — rebuild with `--no-default-features --features wgpu-f16`")
}
fn run<B: Backend>(device: B::Device, args: Args) -> anyhow::Result<()> {
let (weights, config) = resolve_weights(&args.repo, args.weights, args.config)?;
println!("Weights : {}", weights.display());
println!("Config : {}", config.display());
let t = Instant::now();
let (encoder, _) = ZunaEncoder::<B>::load(&config, &weights, device)?;
println!("Loaded : {} ({:.0} ms)\n", encoder.describe(), t.elapsed().as_secs_f64() * 1000.0);
let data_norm: f32 = 10.0;
println!("── A: one-shot encode ──────────────────────────────────────────");
let result: EncodingResult = {
let t = Instant::now();
let r = encoder.encode_fif(&args.fif, data_norm)?;
println!("encode_fif : {:.1} ms", t.elapsed().as_secs_f64() * 1000.0);
print_result(&r);
if let Some(p) = args.output.parent() { std::fs::create_dir_all(p)?; }
r.save_safetensors(args.output.to_str().unwrap_or("data/api_embeddings.safetensors"))?;
println!("Saved : {}\n", args.output.display());
r
};
let _ = result;
println!("── B: two-step (preprocess → encode) ───────────────────────────");
{
let t = Instant::now();
let (batches, fif_info) = encoder.preprocess_fif(&args.fif, data_norm)?;
println!(
"preprocess : {:.1} ms │ {} epochs │ {} ch │ {:.0}→{:.0} Hz",
t.elapsed().as_secs_f64() * 1000.0,
batches.len(),
fif_info.ch_names.len(),
fif_info.sfreq,
fif_info.target_sfreq,
);
let t = Instant::now();
let epochs: Vec<EpochEmbedding> = encoder.encode_batches(batches)?;
println!("encode : {:.1} ms │ {} epochs", t.elapsed().as_secs_f64() * 1000.0, epochs.len());
if let Some(ep) = epochs.first() { print_epoch(0, ep); }
println!();
}
println!("── C: per-tensor (raw Burn tensor) ─────────────────────────────");
{
let (batches, _) = encoder.preprocess_fif(&args.fif, data_norm)?;
if let Some(batch) = batches.into_iter().next() {
let t = Instant::now();
let tensor = encoder.encode_tensor(&batch);
println!("encode_tensor : {:.1} ms │ shape {:?}", t.elapsed().as_secs_f64() * 1000.0, tensor.dims());
let flat = tensor.flatten::<1>(0, 2); let mean_t = flat.clone().mean(); let mean = mean_t.clone().into_scalar().elem::<f32>();
let diff = flat - mean_t;
let std = (diff.clone() * diff).mean().into_scalar().elem::<f32>().sqrt();
println!(" mean = {mean:+.4} std = {std:.4} (ideal ≈ 0.0 and ≈ 1.0 via MMD)");
}
}
println!("\nDone.");
Ok(())
}
fn resolve_weights(
repo: &str,
weights: Option<PathBuf>,
config: Option<PathBuf>,
) -> anyhow::Result<(PathBuf, PathBuf)> {
match (weights, config) {
(Some(w), Some(c)) => return Ok((w, c)),
(Some(_), None) | (None, Some(_)) =>
anyhow::bail!("supply both --weights and --config together, or neither"),
(None, None) => {}
}
hf_download(repo)
}
#[cfg(feature = "hf-download")]
fn hf_download(repo: &str) -> anyhow::Result<(PathBuf, PathBuf)> {
use hf_hub::api::sync::ApiBuilder;
let model = ApiBuilder::new().with_progress(true).build()?.model(repo.to_string());
let weights = model.get("model-00001-of-00001.safetensors")?;
let config = model.get("config.json")?;
Ok((weights, config))
}
#[cfg(not(feature = "hf-download"))]
fn hf_download(_repo: &str) -> anyhow::Result<(PathBuf, PathBuf)> {
anyhow::bail!(
"Add `--features hf-download` to fetch weights automatically, \
or pass --weights and --config explicitly."
)
}
fn print_result(r: &EncodingResult) {
println!(" preproc : {:.1} ms │ {} epochs", r.ms_preproc, r.epochs.len());
println!(" encode : {:.1} ms", r.ms_encode);
if let Some(ep) = r.epochs.first() { print_epoch(0, ep); }
}
fn print_epoch(idx: usize, ep: &EpochEmbedding) {
let n = ep.embeddings.len();
let mean: f64 = ep.embeddings.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let std: f64 = (ep.embeddings.iter()
.map(|&v| { let d = v as f64 - mean; d * d })
.sum::<f64>() / n as f64).sqrt();
let min = ep.embeddings.iter().copied().fold(f32::INFINITY, f32::min);
let max = ep.embeddings.iter().copied().fold(f32::NEG_INFINITY, f32::max);
println!(
" epoch[{idx}]: {} tokens × {} dims mean={mean:+.4} std={std:.4} [{min:+.3}, {max:+.3}]",
ep.n_tokens(), ep.output_dim(),
);
}