use rlx_runtime::{AotCache, CompileOptions, DType};
use rlx_tiny_tts::model::import_graph;
use std::path::PathBuf;
fn rd(name: &str) -> Vec<f32> {
std::fs::read(format!("/tmp/ttsdump/{name}.f32"))
.expect("read")
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn f32_bytes(v: &[f32]) -> Vec<u8> {
v.iter().flat_map(|x| x.to_le_bytes()).collect()
}
fn main() -> anyhow::Result<()> {
let device = rlx_runtime::parse_device(&std::env::args().nth(1).unwrap_or("mlx".into()))
.map_err(|e| anyhow::anyhow!("{e}"))?;
let z = rd("z");
let g = rd("g");
let length = z.len() / 80;
eprintln!("[dec] length={length}");
let path = PathBuf::from("/Users/Shared/rlx-models/weights/tiny-tts-rlx/onnx/decoder.onnx");
let (mut hir, params, _r) = import_graph(&path, "decoder", length, true)?;
if let Ok(tap) = std::env::var("RLX_TAP") {
let mut found = None;
for (i, n) in hir.nodes().iter().enumerate() {
if n.name.as_deref() == Some(tap.as_str()) {
found = Some(rlx_ir::hir::HirNodeId(i as u32));
}
}
match found {
Some(id) => {
eprintln!("[tap] {tap} shape {:?}", hir.node(id).shape.dims());
hir.outputs.push(id);
}
None => eprintln!("[tap] NOT FOUND {tap}"),
}
}
let cache = AotCache::new(PathBuf::from("/tmp/dec_probe_aot"));
let _ = std::fs::remove_dir_all("/tmp/dec_probe_aot");
let mut c = cache
.compile_hir_cached(
&format!("dec_{device:?}_{length}"),
device,
hir,
&CompileOptions::default(),
)
.map_err(|e| anyhow::anyhow!("{e}"))?;
for (n, d) in ¶ms {
c.set_param(n, d);
}
c.finalize_params();
let out = c.run_typed(&[
("z", &f32_bytes(&z), DType::F32),
("g", &f32_bytes(&g), DType::F32),
]);
std::fs::create_dir_all("/tmp/rlx_dec")?;
for (i, (b, dt)) in out.iter().enumerate() {
if *dt != DType::F32 {
continue;
}
let v: Vec<f32> = b
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
std::fs::write(format!("/tmp/rlx_dec/out{i}.f32"), f32_bytes(&v))?;
eprintln!(
" out{i}: len={} absmean={:.4}",
v.len(),
v.iter().map(|x| x.abs()).sum::<f32>() / v.len().max(1) as f32
);
}
Ok(())
}