#![allow(dead_code)]
use std::collections::HashMap;
use super::acoustic::DiffusionConfig;
use super::gpu::{GpuWeightCache, StyleTtsGpu};
use super::{MelFrontend, StyleEncoder, StyleTtsAcoustic};
use crate::backend::{Pipelines, WgpuCtx};
use crate::error::Result;
use crate::gguf::GgmlDtype;
use crate::gguf::GgufReader;
use crate::gguf::tensor::{
dequant_tensor_to_f16_async, dequant_tensor_to_f32, dequant_tensor_to_f32_async,
};
pub struct StyleTtsModel {
w: HashMap<String, Vec<f32>>,
w16: HashMap<String, Vec<u16>>,
}
fn remap(name: &str) -> String {
if let Some(rest) = name.strip_prefix("style_encoder.") {
return format!("acoustic.{}", remap_style(rest));
}
if let Some(rest) = name.strip_prefix("predictor_encoder.") {
return format!("prosodic.{}", remap_style(rest));
}
if let Some(rest) = name.strip_prefix("decoder.") {
return rest.to_string(); }
name.to_string() }
fn remap_style(rest: &str) -> String {
if let Some(r) = rest.strip_prefix("unshared.") {
return format!("linear.{r}");
}
if let Some(r) = rest.strip_prefix("shared.") {
let (idx, tail) = r.split_once('.').unwrap_or((r, ""));
match idx.parse::<usize>().unwrap_or(99) {
0 => format!("conv0.{tail}"),
6 => format!("conv_out.{tail}"),
i @ 1..=4 => {
let t2 = if let Some(x) = tail.strip_prefix("downsample_res.conv.") {
format!("down.{x}")
} else if let Some(x) = tail.strip_prefix("conv1x1.") {
format!("sc.{x}")
} else {
tail.to_string() };
format!("blk{}.{t2}", i - 1)
}
_ => rest.to_string(),
}
} else {
rest.to_string()
}
}
impl StyleTtsModel {
pub fn load(reader: &GgufReader) -> Result<Self> {
let mut w = HashMap::new();
for td in reader.tensors() {
let data = dequant_tensor_to_f32(reader, &td.name)?;
w.insert(remap(&td.name), data);
}
Ok(Self {
w,
w16: HashMap::new(),
})
}
pub async fn load_streaming(reader: &GgufReader) -> Result<Self> {
let mut w = HashMap::new();
let names: Vec<String> = reader.tensors().iter().map(|td| td.name.clone()).collect();
for name in names {
let data = dequant_tensor_to_f32_async(reader, &name).await?;
w.insert(remap(&name), data);
}
Ok(Self {
w,
w16: HashMap::new(),
})
}
pub async fn load_streaming_f16(reader: &GgufReader) -> Result<Self> {
let mut w = HashMap::new();
let mut w16 = HashMap::new();
let descs: Vec<(String, GgmlDtype, usize)> = reader
.tensors()
.iter()
.map(|td| (td.name.clone(), td.dtype, td.dims.len()))
.collect();
for (name, dtype, rank) in descs {
let key = remap(&name);
let cpu_conv = key.starts_with("text_encoder.") || key.starts_with("predictor.");
if dtype == GgmlDtype::F16 && (rank == 3 || rank == 4) && !cpu_conv {
w16.insert(key, dequant_tensor_to_f16_async(reader, &name).await?);
} else {
w.insert(key, dequant_tensor_to_f32_async(reader, &name).await?);
}
}
Ok(Self { w, w16 })
}
pub fn encode_voice(&self, pcm24k: &[f32], progress: Option<&dyn Fn(f32, &str)>) -> Vec<f32> {
if let Some(p) = progress {
p(0.10, "computing spectrogram");
}
let front = MelFrontend::new(
self.w.get("mel.window").expect("mel.window"),
self.w.get("mel.filterbank").expect("mel.filterbank"),
);
let (mel, t) = front.compute(pcm24k);
if let Some(p) = progress {
p(0.35, "analyzing timbre");
}
let a = StyleEncoder::from_weights(&self.w, "acoustic").forward(&mel, 80, t);
if let Some(p) = progress {
p(0.70, "analyzing prosody");
}
let pros = StyleEncoder::from_weights(&self.w, "prosodic").forward(&mel, 80, t);
if let Some(p) = progress {
p(1.0, "voice ready");
}
a.into_iter().chain(pros).collect()
}
pub fn synthesize(
&self,
ids: &[i64],
voice: &[f32],
diffuse: Option<DiffusionConfig>,
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let out = StyleTtsAcoustic::new(&self.w).synthesize(ids, voice, diffuse, progress);
if let Some(p) = progress {
p(1.0, "done");
}
out
}
pub async fn encode_voice_gpu(
&self,
ctx: &WgpuCtx,
p: &Pipelines,
wc: &mut GpuWeightCache,
pcm24k: &[f32],
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
if let Some(pp) = progress {
pp(0.10, "computing spectrogram");
}
let front = MelFrontend::new(
self.w.get("mel.window").expect("mel.window"),
self.w.get("mel.filterbank").expect("mel.filterbank"),
);
let (mel, t) = front.compute(pcm24k);
if let Some(pp) = progress {
pp(0.30, "analyzing voice (GPU)");
}
let out = StyleTtsGpu::new(&self.w, &self.w16, ctx, p, wc)
.encode(&mel, 80, t)
.await;
if let Some(pp) = progress {
pp(1.0, "voice ready");
}
out
}
pub async fn synthesize_gpu(
&self,
ctx: &WgpuCtx,
p: &Pipelines,
wc: &mut GpuWeightCache,
ids: &[i64],
voice: &[f32],
diffuse: Option<DiffusionConfig>,
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let ac = StyleTtsAcoustic::new(&self.w);
let (t_en, bert_out, t) = ac.acoustic_prep(ids, progress);
let eff_s = match diffuse {
Some(cfg) => {
if let Some(pp) = progress {
pp(0.16, "imagining delivery (GPU)");
}
let (noise_init, noises) =
crate::reference::styletts2::acoustic::diffusion_noise(&cfg);
let s_pred = StyleTtsGpu::new(&self.w, &self.w16, ctx, p, wc)
.diffusion_sample(
&bert_out,
t,
voice,
&noise_init,
&noises,
0.2,
1e-4,
3.0,
9.0,
crate::reference::styletts2::acoustic::DIFFUSION_STEPS,
)
.await;
crate::reference::styletts2::acoustic::blend_style(&s_pred, voice, &cfg)
}
None => voice.to_vec(),
};
let (asr, f0, n, r, f) = ac.acoustic_rest(&t_en, &bert_out, t, &eff_s, progress);
if let Some(pp) = progress {
pp(0.36, "generating audio (GPU)");
}
let out = StyleTtsGpu::new(&self.w, &self.w16, ctx, p, wc)
.decode(&asr, f, &f0, &n, &r)
.await;
if let Some(pp) = progress {
pp(1.0, "done");
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn st2_streaming_load_is_bit_identical() {
let Ok(path) = std::env::var("ST2_GGUF") else {
eprintln!("skip: set ST2_GGUF to the styletts2 f32 gguf to run");
return;
};
let reader = GgufReader::new(std::fs::read(&path).unwrap()).unwrap();
let bulk = StyleTtsModel::load(&reader).unwrap();
let streamed = pollster::block_on(StyleTtsModel::load_streaming(&reader)).unwrap();
assert_eq!(bulk.w.len(), streamed.w.len(), "tensor count differs");
for (k, v) in &bulk.w {
let s = streamed
.w
.get(k)
.unwrap_or_else(|| panic!("streamed missing {k}"));
assert_eq!(v.as_slice(), s.as_slice(), "weights differ for {k}");
}
}
}