use std::collections::HashMap;
use crate::backend::{Pipelines, WgpuCtx};
use crate::error::Result;
use crate::gguf::GgufReader;
use crate::reference::kokoro::g2p::{Lexicon, g2p};
use crate::reference::styletts2::StyleTtsModel;
use crate::reference::styletts2::acoustic::DiffusionConfig;
use crate::reference::styletts2::gpu::GpuWeightCache;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
pub const SAMPLE_RATE: u32 = 24000;
const VOCAB: &str = include_str!("reference/styletts2/vocab.txt");
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct StyleTtsClone {
model: StyleTtsModel,
lex: Option<Lexicon>,
vocab: HashMap<char, i64>,
ctx: WgpuCtx,
pipes: Pipelines,
wc: GpuWeightCache,
}
impl StyleTtsClone {
async fn from_model(model: StyleTtsModel) -> Result<Self> {
let vocab = VOCAB
.chars()
.enumerate()
.map(|(i, c)| (c, i as i64))
.collect();
let ctx = WgpuCtx::new().await?;
let pipes = Pipelines::new(&ctx.device);
Ok(Self {
model,
lex: None,
vocab,
ctx,
pipes,
wc: GpuWeightCache::new(),
})
}
pub async fn load_native(bytes: Vec<u8>) -> Result<Self> {
let reader = GgufReader::new(bytes)?;
let model = StyleTtsModel::load(&reader)?; Self::from_model(model).await
}
pub async fn load_streaming(reader: &GgufReader) -> Result<Self> {
let model = StyleTtsModel::load_streaming(reader).await?;
Self::from_model(model).await
}
pub async fn load_streaming_f16(reader: &GgufReader) -> Result<Self> {
let model = StyleTtsModel::load_streaming_f16(reader).await?;
Self::from_model(model).await
}
pub fn set_lexicon_native(&mut self, gold: &[u8], silver: &[u8]) {
self.lex = Some(Lexicon::load(gold, silver));
}
fn phonemes_to_ids(&self, ps: &str) -> Vec<i64> {
let mut ids = vec![0i64];
for ch in ps.chars() {
if let Some(&id) = self.vocab.get(&ch) {
ids.push(id);
}
}
ids
}
pub async fn encode_voice_native(
&mut self,
pcm24k: &[f32],
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
self.model
.encode_voice_gpu(&self.ctx, &self.pipes, &mut self.wc, pcm24k, progress)
.await
}
pub async fn synthesize_native(
&mut self,
text: &str,
voice: &[f32],
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let ids = {
let lex = self.lex.as_ref().expect("lexicon not set");
let (ps, _oov) = g2p(text, lex);
self.phonemes_to_ids(&ps)
};
self.model
.synthesize_gpu(
&self.ctx,
&self.pipes,
&mut self.wc,
&ids,
voice,
Some(DiffusionConfig::default()),
progress,
)
.await
}
pub async fn synthesize_phonemes_native(
&mut self,
phonemes: &str,
voice: &[f32],
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let ids = self.phonemes_to_ids(phonemes);
self.model
.synthesize_gpu(
&self.ctx,
&self.pipes,
&mut self.wc,
&ids,
voice,
Some(DiffusionConfig::default()),
progress,
)
.await
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl StyleTtsClone {
#[wasm_bindgen(js_name = load)]
pub async fn load_js(bytes: Vec<u8>) -> std::result::Result<StyleTtsClone, JsError> {
Self::load_native(bytes)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = loadStreaming)]
pub async fn load_streaming_js(
read_fn: js_sys::Function,
total_bytes: f64,
) -> std::result::Result<StyleTtsClone, JsError> {
use crate::gguf::{OpfsFetcher, TensorFetcher};
use std::sync::Arc;
if !(total_bytes.is_finite() && total_bytes >= 0.0) {
return Err(JsError::new(
"loadStreaming: total_bytes must be a non-negative finite number",
));
}
let fetcher: Arc<dyn TensorFetcher> =
Arc::new(OpfsFetcher::new(read_fn, total_bytes as u64));
let reader = GgufReader::new_streaming(fetcher)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))?;
Self::load_streaming(&reader)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = loadStreamingF16)]
pub async fn load_streaming_f16_js(
read_fn: js_sys::Function,
total_bytes: f64,
) -> std::result::Result<StyleTtsClone, JsError> {
use crate::gguf::{OpfsFetcher, TensorFetcher};
use std::sync::Arc;
if !(total_bytes.is_finite() && total_bytes >= 0.0) {
return Err(JsError::new(
"loadStreamingF16: total_bytes must be a non-negative finite number",
));
}
let fetcher: Arc<dyn TensorFetcher> =
Arc::new(OpfsFetcher::new(read_fn, total_bytes as u64));
let reader = GgufReader::new_streaming(fetcher)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))?;
Self::load_streaming_f16(&reader)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = setLexicon)]
pub fn set_lexicon_js(&mut self, gold: Vec<u8>, silver: Vec<u8>) {
self.set_lexicon_native(&gold, &silver);
}
#[wasm_bindgen(js_name = encodeVoice)]
pub async fn encode_voice_js(
&mut self,
pcm24k: Vec<f32>,
on_progress: js_sys::Function,
) -> Vec<f32> {
let cb = |frac: f32, stage: &str| {
let _ = on_progress.call2(
&JsValue::NULL,
&JsValue::from_f64(frac as f64),
&JsValue::from_str(stage),
);
};
self.encode_voice_native(&pcm24k, Some(&cb)).await
}
#[wasm_bindgen(js_name = synthesize)]
pub async fn synthesize_js(
&mut self,
text: String,
voice: Vec<f32>,
on_progress: js_sys::Function,
) -> Vec<f32> {
let cb = |frac: f32, stage: &str| {
let _ = on_progress.call2(
&JsValue::NULL,
&JsValue::from_f64(frac as f64),
&JsValue::from_str(stage),
);
};
self.synthesize_native(&text, &voice, Some(&cb)).await
}
#[wasm_bindgen(js_name = synthesizePhonemes)]
pub async fn synthesize_phonemes_js(&mut self, phonemes: String, voice: Vec<f32>) -> Vec<f32> {
self.synthesize_phonemes_native(&phonemes, &voice, None)
.await
}
#[wasm_bindgen(js_name = sampleRate, getter)]
pub fn sample_rate_js(&self) -> u32 {
SAMPLE_RATE
}
}