use std::path::Path;
use crate::error::{AudioError, AudioResult};
#[derive(Debug, Clone)]
pub struct PreprocessorOutput {
pub token_ids: Vec<i64>,
pub style_embedding: Option<Vec<f32>>,
pub speed: Option<f32>,
}
pub trait Preprocessor: Send + Sync {
fn preprocess(
&self,
text: &str,
voice_id: &str,
speed: f32,
model_dir: &Path,
) -> AudioResult<PreprocessorOutput>;
fn name(&self) -> &str;
}
pub struct TokenizerPreprocessor {
tokenizer: tokenizers::Tokenizer,
}
impl TokenizerPreprocessor {
pub fn from_model_dir(model_dir: &Path) -> AudioResult<Self> {
let path = model_dir.join("tokenizer.json");
let tokenizer = tokenizers::Tokenizer::from_file(&path).map_err(|e| AudioError::Tts {
provider: "ONNX".into(),
message: format!("failed to load tokenizer from {}: {e}", path.display()),
})?;
Ok(Self { tokenizer })
}
}
impl Preprocessor for TokenizerPreprocessor {
fn preprocess(
&self,
text: &str,
_voice_id: &str,
speed: f32,
_model_dir: &Path,
) -> AudioResult<PreprocessorOutput> {
let encoding = self.tokenizer.encode(text, true).map_err(|e| AudioError::Tts {
provider: "ONNX".into(),
message: format!("tokenization failed: {e}"),
})?;
let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
if token_ids.is_empty() {
return Err(AudioError::Tts {
provider: "ONNX".into(),
message: "tokenization produced no tokens".into(),
});
}
Ok(PreprocessorOutput {
token_ids,
style_embedding: None,
speed: if (speed - 1.0).abs() > f32::EPSILON { Some(speed) } else { None },
})
}
fn name(&self) -> &str {
"TokenizerPreprocessor"
}
}
#[cfg(feature = "kokoro")]
fn build_kokoro_vocab() -> std::collections::HashMap<char, usize> {
let pad = "$";
let punctuation = ";:,.!?¡¿—…\"«»\u{201c}\u{201d} ";
let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘\u{2018}̩\u{2019}ᵻ";
let symbols: String = [pad, punctuation, letters, letters_ipa].concat();
symbols.chars().enumerate().map(|(idx, c)| (c, idx)).collect()
}
#[cfg(feature = "kokoro")]
fn kokoro_tokenize(phonemes: &str, vocab: &std::collections::HashMap<char, usize>) -> Vec<i64> {
phonemes.chars().filter_map(|c| vocab.get(&c)).map(|&idx| idx as i64).collect()
}
#[cfg(feature = "kokoro")]
pub struct KokoroPreprocessor {
vocab: std::collections::HashMap<char, usize>,
voices: KokoroVoices,
language: String,
}
#[cfg(feature = "kokoro")]
pub struct KokoroVoices {
styles: std::collections::HashMap<String, Vec<[f32; 256]>>,
}
#[cfg(feature = "kokoro")]
impl KokoroVoices {
pub fn load(voices_path: &Path) -> AudioResult<Self> {
use ndarray::Array3;
use ndarray_npy::NpzReader;
use std::fs::File;
let file = File::open(voices_path).map_err(|e| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!("failed to open voices file {}: {e}", voices_path.display()),
})?;
let mut npz = NpzReader::new(file).map_err(|e| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!("failed to read npz voices file: {e}"),
})?;
let names = npz.names().map_err(|e| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!("failed to list voices in npz: {e}"),
})?;
let mut styles = std::collections::HashMap::new();
for name in names {
let arr: Array3<f32> = npz.by_name(&name).map_err(|e| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!("failed to read voice '{name}': {e}"),
})?;
let n = arr.shape()[0];
let mut embeddings = Vec::with_capacity(n);
for i in 0..n {
let mut emb = [0.0f32; 256];
for (k, val) in arr.slice(ndarray::s![i, 0, ..]).iter().enumerate() {
emb[k] = *val;
}
embeddings.push(emb);
}
styles.insert(name, embeddings);
}
tracing::info!("loaded {} Kokoro voice styles", styles.len());
Ok(Self { styles })
}
pub fn get_style(&self, voice_id: &str, token_len: usize) -> AudioResult<Vec<f32>> {
let embeddings = self.styles.get(voice_id).ok_or_else(|| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!(
"voice '{voice_id}' not found. Available: {:?}",
self.available_voices()
),
})?;
let idx = token_len.min(embeddings.len().saturating_sub(1));
Ok(embeddings[idx].to_vec())
}
pub fn available_voices(&self) -> Vec<String> {
let mut v: Vec<String> = self.styles.keys().cloned().collect();
v.sort();
v
}
}
#[cfg(feature = "kokoro")]
impl KokoroPreprocessor {
pub fn new(voices_path: &Path, language: &str) -> AudioResult<Self> {
let vocab = build_kokoro_vocab();
let voices = KokoroVoices::load(voices_path)?;
Ok(Self { vocab, voices, language: language.to_string() })
}
pub fn voices(&self) -> &KokoroVoices {
&self.voices
}
}
#[cfg(feature = "kokoro")]
impl Preprocessor for KokoroPreprocessor {
fn preprocess(
&self,
text: &str,
voice_id: &str,
speed: f32,
_model_dir: &Path,
) -> AudioResult<PreprocessorOutput> {
let phonemes = espeak_rs::text_to_phonemes(text, &self.language, None, true, false)
.map_err(|e| AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: format!("espeak-ng phonemization failed: {e}"),
})?
.join("");
if phonemes.is_empty() {
return Err(AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: "phonemization produced empty output".into(),
});
}
let mut token_ids = kokoro_tokenize(&phonemes, &self.vocab);
if token_ids.is_empty() {
return Err(AudioError::Tts {
provider: "ONNX/Kokoro".into(),
message: "tokenization produced no tokens from phonemes".into(),
});
}
token_ids.insert(0, 0);
token_ids.push(0);
let style_len = token_ids.len() - 2; let style = self.voices.get_style(voice_id, style_len)?;
Ok(PreprocessorOutput { token_ids, style_embedding: Some(style), speed: Some(speed) })
}
fn name(&self) -> &str {
"KokoroPreprocessor"
}
}