use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use rullama::reference::kokoro::ops::max_abs_diff;
use rullama::reference::styletts2::decoder::{StyleTtsDecoder, source_signal};
fn corr(a: &[f32], b: &[f32]) -> f32 {
let (ma, mb) = (
a.iter().sum::<f32>() / a.len() as f32,
b.iter().sum::<f32>() / b.len() as f32,
);
let mut num = 0.0;
let (mut da, mut db) = (0.0f32, 0.0f32);
for (x, y) in a.iter().zip(b) {
num += (x - ma) * (y - mb);
da += (x - ma) * (x - ma);
db += (y - mb) * (y - mb);
}
num / (da.sqrt() * db.sqrt() + 1e-12)
}
fn main() {
let dir =
PathBuf::from(std::env::var("HOME").unwrap()).join(".cache/styletts2/fixtures/decoder/bin");
assert!(
dir.is_dir(),
"run scripts/styletts2_dump_decoder_fixtures.py first ({dir:?})"
);
let mut w: HashMap<String, Vec<f32>> = HashMap::new();
for e in fs::read_dir(&dir).unwrap() {
let p = e.unwrap().path();
if p.extension().and_then(|x| x.to_str()) == Some("bin") {
let b = fs::read(&p).unwrap();
w.insert(
p.file_stem().unwrap().to_str().unwrap().into(),
b.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect(),
);
}
}
let f0 = w.get("in_F0_curve").expect("in_F0_curve");
let lw = w
.get("generator.m_source.l_linear.weight")
.expect("l_linear.weight"); let lb = w
.get("generator.m_source.l_linear.bias")
.expect("l_linear.bias")[0];
let har = source_signal(f0, 300, 9, lw, lb);
let har_ref = w.get("har_source").expect("har_source fixture"); let d = max_abs_diff(&har, har_ref);
println!("har_source[{}] max_abs_diff = {d:.3e}", har.len());
assert!(
har.len() == har_ref.len(),
"source length {} != ref {}",
har.len(),
har_ref.len()
);
assert!(d < 2e-3, "HnNSF source parity FAILED ({d:.3e})");
println!("✅ hifigan HnNSF source matches PyTorch");
let asr = w.get("in_asr").expect("in_asr").clone(); let f0c = w.get("in_F0_curve").unwrap().clone();
let nc = w.get("in_N").unwrap().clone();
let style = w.get("in_style").unwrap().clone();
let audio_ref = w.get("audio").expect("audio").clone();
let dec = StyleTtsDecoder::new(&w);
let audio = dec.forward(&asr, 512, 40, &f0c, &nc, &style, None);
let da = max_abs_diff(&audio, &audio_ref);
let c = corr(&audio, &audio_ref);
println!(
"\naudio[{}] max_abs_diff = {da:.3e} corr = {c:.6}",
audio.len()
);
assert!(
audio.len() == audio_ref.len(),
"audio len {} != {}",
audio.len(),
audio_ref.len()
);
assert!(
c > 0.999 && da < 5e-3,
"hifigan decoder parity FAILED (corr {c:.6}, max_abs {da:.3e})"
);
println!("✅ StyleTTS2 hifigan decoder (CPU) matches PyTorch end-to-end");
use rullama::backend::{Pipelines, WgpuCtx};
use rullama::reference::styletts2::gpu::StyleTtsGpu;
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = Pipelines::new(&ctx.device);
let mut wc = HashMap::new();
let w16: std::collections::HashMap<String, Vec<u16>> = std::collections::HashMap::new();
let gpu_audio = pollster::block_on(
StyleTtsGpu::new(&w, &w16, &ctx, &pipes, &mut wc).decode(&asr, 40, &f0c, &nc, &style),
);
let dg = max_abs_diff(&gpu_audio, &audio_ref);
let cg = corr(&gpu_audio, &audio_ref);
println!(
"\nGPU decoder vs ref max_abs_diff = {dg:.3e} corr = {cg:.6} (len {} vs {})",
gpu_audio.len(),
audio_ref.len()
);
assert!(
gpu_audio.len() == audio_ref.len(),
"GPU audio len {} != {}",
gpu_audio.len(),
audio_ref.len()
);
assert!(
cg > 0.999,
"GPU hifigan decoder parity FAILED (corr {cg:.6})"
);
println!("✅ StyleTTS2 hifigan decoder GPU matches CPU/PyTorch");
let bigf = 300usize;
let asr_big: Vec<f32> = (0..512 * bigf)
.map(|i| ((i % 13) as f32 - 6.0) * 0.2)
.collect();
let f0_big: Vec<f32> = (0..2 * bigf).map(|i| 110.0 + (i % 40) as f32).collect();
let n_big: Vec<f32> = (0..2 * bigf)
.map(|i| ((i % 5) as f32 - 2.0) * 0.5)
.collect();
let mut wc2 = HashMap::new();
let big = pollster::block_on(
StyleTtsGpu::new(&w, &w16, &ctx, &pipes, &mut wc2)
.decode(&asr_big, bigf, &f0_big, &n_big, &style),
);
let finite = big.iter().all(|x| x.is_finite());
println!(
"2D-dispatch (f={bigf}): {} samples, all-finite = {finite}",
big.len()
);
assert!(
big.len() == bigf * 2 * 300 && finite,
"2D-dispatch path FAILED (len {}, finite {finite})",
big.len()
);
println!("✅ 2D workgroup-grid path OK for long utterances");
}