use std::{collections::HashMap, path::Path, sync::Mutex};
use anyhow::{Context, Result};
use ort::{session::Session, value::Tensor};
use crate::{
npz::{load_npz, NpyArray},
tokenize::ipa_to_ids,
};
#[cfg(feature = "espeak")]
use crate::{phonemize::phonemize, preprocess::TextPreprocessor};
const TAIL_TRIM: usize = 2_000;
pub const SAMPLE_RATE: u32 = 24_000;
#[cfg(feature = "espeak")]
const CHUNK_MAX_CHARS: usize = 400;
#[cfg(feature = "espeak")]
fn ensure_punctuation(text: &str) -> String {
let text = text.trim();
if text.is_empty() {
return text.to_string();
}
match text.chars().last() {
Some(c) if ".!?,;:".contains(c) => text.to_string(),
_ => format!("{},", text),
}
}
#[cfg(feature = "espeak")]
fn chunk_text(text: &str, max_len: usize) -> Vec<String> {
let mut chunks = Vec::new();
for sentence in text.split_terminator(['.', '!', '?']) {
let sentence = sentence.trim();
if sentence.is_empty() {
continue;
}
if sentence.len() <= max_len {
chunks.push(ensure_punctuation(sentence));
} else {
let mut current = String::new();
for word in sentence.split_whitespace() {
if !current.is_empty() && current.len() + 1 + word.len() > max_len {
chunks.push(ensure_punctuation(current.trim()));
current = word.to_string();
} else {
if !current.is_empty() {
current.push(' ');
}
current.push_str(word);
}
}
if !current.trim().is_empty() {
chunks.push(ensure_punctuation(current.trim()));
}
}
}
chunks
}
struct Voice {
nrows: usize,
ncols: usize,
data: Vec<f32>, }
impl Voice {
fn from_npy(arr: NpyArray) -> Self {
Self { nrows: arr.nrows(), ncols: arr.ncols(), data: arr.data }
}
fn style_row(&self, text_len: usize) -> &[f32] {
let i = text_len.min(self.nrows.saturating_sub(1));
&self.data[i * self.ncols..(i + 1) * self.ncols]
}
}
pub struct KittenTtsOnnx {
session: Mutex<Session>,
voices: HashMap<String, Voice>,
speed_priors: HashMap<String, f32>,
voice_aliases: HashMap<String, String>,
#[cfg(feature = "espeak")]
preprocessor: TextPreprocessor,
pub available_voices: Vec<String>,
}
impl KittenTtsOnnx {
pub fn load(
model_path: &Path,
voices_path: &Path,
speed_priors: HashMap<String, f32>,
voice_aliases: HashMap<String, String>,
) -> Result<Self> {
let session = Session::builder()
.context("Failed to create ORT session builder")?
.commit_from_file(model_path)
.with_context(|| format!("Cannot load ONNX model: {}", model_path.display()))?;
let raw = load_npz(voices_path)
.with_context(|| format!("Cannot load voices: {}", voices_path.display()))?;
let available_voices: Vec<String> = raw.keys().cloned().collect();
let voices: HashMap<String, Voice> =
raw.into_iter().map(|(k, v)| (k, Voice::from_npy(v))).collect();
Ok(Self {
session: Mutex::new(session),
voices,
speed_priors,
voice_aliases,
#[cfg(feature = "espeak")]
preprocessor: TextPreprocessor::new(),
available_voices,
})
}
fn resolve_voice<'a>(&'a self, voice: &'a str) -> &'a str {
self.voice_aliases.get(voice).map(String::as_str).unwrap_or(voice)
}
fn infer_ipa(
&self,
ipa: &str,
style_idx: usize,
voice_key: &str,
effective_speed: f32,
) -> Result<Vec<f32>> {
let voice_data = self.voices.get(voice_key).with_context(|| {
format!("Voice '{}' not found. Available: {:?}", voice_key, self.available_voices)
})?;
let ids = ipa_to_ids(ipa);
let seq_len = ids.len();
let style_slice = voice_data.style_row(style_idx);
let style_dim = style_slice.len();
let t_input_ids = Tensor::<i64>::from_array(([1usize, seq_len], ids))
.context("Failed to build input_ids tensor")?;
let t_style = Tensor::<f32>::from_array(([1usize, style_dim], style_slice.to_vec()))
.context("Failed to build style tensor")?;
let t_speed = Tensor::<f32>::from_array(([1usize], vec![effective_speed]))
.context("Failed to build speed tensor")?;
let mut session = self.session.lock().expect("ORT session mutex poisoned");
let outputs = session
.run(ort::inputs![t_input_ids, t_style, t_speed])
.context("ONNX inference failed")?;
let (_shape, audio_data) = outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract audio tensor")?;
let audio_flat: Vec<f32> = audio_data.to_vec();
let trimmed_len = audio_flat.len().saturating_sub(TAIL_TRIM);
Ok(audio_flat[..trimmed_len].to_vec())
}
#[cfg(feature = "espeak")]
pub fn generate_chunk(&self, text: &str, voice: &str, speed: f32) -> Result<Vec<f32>> {
let voice_key = self.resolve_voice(voice);
let effective_speed = speed * self.speed_priors.get(voice_key).copied().unwrap_or(1.0);
let ipa = phonemize(text)
.with_context(|| format!("Phonemisation failed for {:?}", text))?;
self.infer_ipa(&ipa, text.len(), voice_key, effective_speed)
}
pub fn generate_from_ipa(
&self,
ipa: &str,
voice: &str,
speed: f32,
style_idx: usize,
) -> Result<Vec<f32>> {
let voice_key = self.resolve_voice(voice);
let effective_speed = speed * self.speed_priors.get(voice_key).copied().unwrap_or(1.0);
self.infer_ipa(ipa, style_idx, voice_key, effective_speed)
}
pub fn generate_from_ipa_chunks(
&self,
chunks: &[&str],
voice: &str,
speed: f32,
) -> Result<Vec<f32>> {
let voice_key = self.resolve_voice(voice);
if !self.voices.contains_key(voice_key) {
anyhow::bail!(
"Unknown voice '{}'. Available: {:?}",
voice,
self.available_voices
);
}
let mut audio = Vec::new();
for &ipa in chunks {
audio.extend(self.generate_from_ipa(ipa, voice, speed, ipa.len())?);
}
Ok(audio)
}
pub fn generate_to_file_from_ipa(
&self,
ipa: &str,
output_path: &Path,
voice: &str,
speed: f32,
style_idx: usize,
) -> Result<()> {
let audio = self.generate_from_ipa(ipa, voice, speed, style_idx)?;
self.write_wav(&audio, output_path)
}
pub fn write_wav(&self, audio: &[f32], output_path: &Path) -> Result<()> {
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(output_path, spec)
.with_context(|| format!("Cannot create WAV: {}", output_path.display()))?;
for &s in audio {
let s16 = (s * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
writer.write_sample(s16).context("WAV write error")?;
}
writer.finalize().context("WAV finalise error")?;
println!("Saved {} samples ({} s) to {}", audio.len(),
audio.len() as f32 / SAMPLE_RATE as f32, output_path.display());
Ok(())
}
#[cfg(feature = "espeak")]
pub fn generate(
&self,
text: &str,
voice: &str,
speed: f32,
clean_text: bool,
) -> Result<Vec<f32>> {
let voice_key = self.resolve_voice(voice);
if !self.voices.contains_key(voice_key) {
anyhow::bail!(
"Unknown voice '{}'. Available: {:?}",
voice,
self.available_voices
);
}
let processed = if clean_text {
self.preprocessor.process(text)
} else {
text.to_string()
};
let chunks = chunk_text(&processed, CHUNK_MAX_CHARS);
if chunks.is_empty() {
return Ok(Vec::new());
}
let mut audio = Vec::new();
for chunk in &chunks {
audio.extend(self.generate_chunk(chunk, voice, speed)?);
}
Ok(audio)
}
#[cfg(feature = "espeak")]
pub fn generate_to_file(
&self,
text: &str,
output_path: &Path,
voice: &str,
speed: f32,
clean_text: bool,
) -> Result<()> {
let audio = self.generate(text, voice, speed, clean_text)?;
self.write_wav(&audio, output_path)
}
}
#[cfg(all(test, feature = "espeak"))]
mod tests {
use super::*;
#[test]
fn test_chunk_short() {
let c = chunk_text("Hello world.", 400);
assert_eq!(c, vec!["Hello world,"]);
}
#[test]
fn test_chunk_multiple_sentences() {
let c = chunk_text("Hello. World. Foo.", 400);
assert_eq!(c.len(), 3);
}
#[test]
fn test_chunk_long_sentence() {
let long = "word ".repeat(200);
let c = chunk_text(long.trim(), 400);
assert!(c.len() > 1);
for chunk in &c {
assert!(chunk.len() <= 405);
}
}
#[test]
fn test_ensure_punctuation() {
assert_eq!(ensure_punctuation("hello"), "hello,");
assert_eq!(ensure_punctuation("hello."), "hello.");
assert_eq!(ensure_punctuation(""), "");
}
}