use std::sync::Arc;
use crate::backend::{Pipelines, WgpuCtx};
use crate::error::Result;
use crate::gguf::GgufReader;
use crate::reference::kokoro::KokoroModel;
use crate::reference::kokoro::g2p::{Lexicon, g2p};
use crate::reference::kokoro::gpu_fast::WeightCache;
use crate::reference::kokoro::voice_train::voice_signature;
fn sig_l2(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum::<f32>() / a.len().max(1) as f32
}
struct VoiceTrainState {
target_sig: Vec<f32>,
ids: Vec<i64>,
style: Vec<f32>,
loss: f32,
step: f32,
rng: u64,
}
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
pub const SAMPLE_RATE: u32 = 24000;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct KokoroTts {
model: KokoroModel,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
lex: Option<Lexicon>,
wc: WeightCache,
train: Option<VoiceTrainState>,
}
impl KokoroTts {
pub async fn load_native(bytes: Vec<u8>) -> Result<Self> {
let reader = Arc::new(GgufReader::new(bytes)?);
let model = KokoroModel::new(reader)?;
let ctx = WgpuCtx::new().await?;
let pipes = Arc::new(Pipelines::new(&ctx.device));
Ok(Self {
model,
ctx,
pipes,
lex: None,
wc: WeightCache::new(),
train: None,
})
}
pub async fn train_begin_native(
&mut self,
target_pcm: &[f32],
ref_text: &str,
init_voice: &str,
step0: f32,
seed: u64,
) {
let target_sig = voice_signature(target_pcm);
let phonemes = {
let lex = self.lex.as_ref().expect("lexicon not set");
g2p(ref_text, lex).0
};
let ids = self.model.phonemes_to_ids(&phonemes);
let style = self.model.load_voice(init_voice, ids.len());
let audio = self
.model
.synthesize_gpu_fast(&self.ctx, &self.pipes, &mut self.wc, &ids, &style)
.await;
let loss = sig_l2(&voice_signature(&audio), &target_sig);
self.train = Some(VoiceTrainState {
target_sig,
ids,
style,
loss,
step: step0,
rng: seed | 1,
});
}
pub async fn train_step_native(&mut self) -> f32 {
let st = self.train.as_mut().expect("train_begin first");
self.model
.voice_train_step(
&self.ctx,
&self.pipes,
&mut self.wc,
&st.ids,
&st.target_sig,
&mut st.style,
&mut st.loss,
&mut st.step,
&mut st.rng,
)
.await
}
pub fn trained_voice_native(&self) -> Vec<f32> {
self.train
.as_ref()
.map(|s| s.style.clone())
.unwrap_or_default()
}
pub async fn synthesize_text_style_native(
&mut self,
text: &str,
style: &[f32],
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let phonemes = {
let lex = self.lex.as_ref().expect("lexicon not set");
g2p(text, lex).0
};
let ids = self.model.phonemes_to_ids(&phonemes);
self.model
.synthesize_gpu_fast_p(&self.ctx, &self.pipes, &mut self.wc, &ids, style, progress)
.await
}
pub fn set_lexicon_native(&mut self, gold: &[u8], silver: &[u8]) {
self.lex = Some(Lexicon::load(gold, silver));
}
pub async fn synthesize_native(
&mut self,
text: &str,
voice: &str,
progress: Option<&dyn Fn(f32, &str)>,
) -> (Vec<f32>, Vec<String>) {
let (phonemes, oov) = {
let lex = self.lex.as_ref().expect("lexicon not set");
g2p(text, lex)
};
let audio = self
.synthesize_phonemes_native(&phonemes, voice, progress)
.await;
(audio, oov)
}
pub async fn synthesize_phonemes_native(
&mut self,
phonemes: &str,
voice: &str,
progress: Option<&dyn Fn(f32, &str)>,
) -> Vec<f32> {
let ids = self.model.phonemes_to_ids(phonemes);
let ref_s = self.model.load_voice(voice, ids.len());
self.model
.synthesize_gpu_fast_p(&self.ctx, &self.pipes, &mut self.wc, &ids, &ref_s, progress)
.await
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl KokoroTts {
#[wasm_bindgen(js_name = load)]
pub async fn load_js(bytes: Vec<u8>) -> std::result::Result<KokoroTts, JsError> {
Self::load_native(bytes)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = setLexicon)]
pub fn set_lexicon_js(&mut self, gold: Vec<u8>, silver: Vec<u8>) {
self.set_lexicon_native(&gold, &silver);
}
#[wasm_bindgen(js_name = synthesize)]
pub async fn synthesize_js(
&mut self,
text: String,
voice: String,
on_progress: js_sys::Function,
) -> Vec<f32> {
let cb = |frac: f32, stage: &str| {
let _ = on_progress.call2(
&JsValue::NULL,
&JsValue::from_f64(frac as f64),
&JsValue::from_str(stage),
);
};
let pcm = self.synthesize_native(&text, &voice, Some(&cb)).await.0;
let _ = on_progress.call2(
&JsValue::NULL,
&JsValue::from_f64(1.0),
&JsValue::from_str("done"),
);
pcm
}
#[wasm_bindgen(js_name = synthesizePhonemes)]
pub async fn synthesize_phonemes_js(&mut self, phonemes: String, voice: String) -> Vec<f32> {
self.synthesize_phonemes_native(&phonemes, &voice, None)
.await
}
#[wasm_bindgen(js_name = sampleRate, getter)]
pub fn sample_rate_js(&self) -> u32 {
SAMPLE_RATE
}
#[wasm_bindgen(js_name = trainBegin)]
pub async fn train_begin_js(
&mut self,
target_pcm: Vec<f32>,
ref_text: String,
init_voice: String,
) {
self.train_begin_native(&target_pcm, &ref_text, &init_voice, 0.06, 1)
.await;
}
#[wasm_bindgen(js_name = trainStep)]
pub async fn train_step_js(&mut self) -> f32 {
self.train_step_native().await
}
#[wasm_bindgen(js_name = trainedVoice)]
pub fn trained_voice_js(&self) -> Vec<f32> {
self.trained_voice_native()
}
#[wasm_bindgen(js_name = synthesizeWithVoice)]
pub async fn synthesize_with_voice_js(&mut self, text: String, voice: Vec<f32>) -> Vec<f32> {
self.synthesize_text_style_native(&text, &voice, None).await
}
}