use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{Cursor, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use espeak_rs::text_to_phonemes;
use ndarray::{ArrayBase, IxDyn, OwnedRepr};
use ndarray_npy::NpzReader;
use ort::{
session::{builder::GraphOptimizationLevel, Session, SessionInputValue, SessionInputs},
value::{Tensor, Value},
};
macro_rules! debug_log {
($($arg:tt)*) => {
if std::env::var("KOKORO_DEBUG").map(|v| v == "1").unwrap_or(false) {
eprintln!($($arg)*);
}
};
}
const MODEL_URL: &str = "https://github.com/8b-is/kokoro-tiny/raw/main/models/0.onnx";
const VOICES_URL: &str = "https://github.com/8b-is/kokoro-tiny/raw/main/models/0.bin";
const SAMPLE_RATE: u32 = 24000; const DEFAULT_VOICE: &str = "af_sky";
const DEFAULT_SPEED: f32 = 1.0; const DEFAULT_LANG: &str = "en";
const SPEED_SCALE: f32 = 0.65; const LONG_TEXT_THRESHOLD: usize = 120;
const MAX_CHARS_PER_CHUNK: usize = 180;
const CHUNK_CROSSFADE_MS: usize = 45;
const MIN_ENGINE_SPEED: f32 = 0.35;
const MAX_ENGINE_SPEED: f32 = 2.2;
const PAD_TOKEN: char = '$';
const FALLBACK_MESSAGE: &[u8] = include_bytes!("../assets/fallback.wav");
fn get_cache_dir() -> PathBuf {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.unwrap_or_else(|_| {
debug_log!("⚠️ Could not determine HOME directory, using current directory");
".".to_string()
});
Path::new(&home).join(".cache").join("k")
}
pub struct TtsEngine {
session: Option<Arc<Mutex<Session>>>,
voices: HashMap<String, Vec<f32>>,
vocab: HashMap<char, i64>,
fallback_mode: bool,
}
impl TtsEngine {
pub async fn new() -> Result<Self, String> {
let cache_dir = get_cache_dir();
let model_path = cache_dir.join("0.onnx");
let voices_path = cache_dir.join("0.bin");
Self::with_paths(
model_path.to_str().unwrap_or("0.onnx"),
voices_path.to_str().unwrap_or("0.bin"),
)
.await
}
pub async fn with_paths(model_path: &str, voices_path: &str) -> Result<Self, String> {
if let Some(parent) = Path::new(model_path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create cache directory: {}", e))?;
}
let need_download = !Path::new(model_path).exists() || !Path::new(voices_path).exists();
if need_download {
debug_log!("🎤 First time setup - downloading voice model...");
debug_log!(" (This only happens once, files will be cached in ~/.cache/k)");
let download_success = {
let mut success = true;
if !Path::new(model_path).exists() {
debug_log!(" 📥 Downloading model (310MB)...");
if let Err(e) = download_file(MODEL_URL, model_path).await {
debug_log!(" ❌ Failed to download model: {}", e);
success = false;
}
}
if success && !Path::new(voices_path).exists() {
debug_log!(" 📥 Downloading voices (27MB)...");
if let Err(e) = download_file(VOICES_URL, voices_path).await {
debug_log!(" ❌ Failed to download voices: {}", e);
success = false;
}
}
if success {
debug_log!(" ✅ Voice model downloaded successfully!");
}
success
};
if !download_success {
debug_log!("\n⚠️ Using fallback mode. The model files are not available at:");
debug_log!(" - {}", MODEL_URL);
debug_log!(" - {}", VOICES_URL);
debug_log!("\n💡 Please manually download the model files to ~/.cache/k/");
return Ok(Self {
session: None,
voices: HashMap::new(),
vocab: build_vocab(),
fallback_mode: true,
});
}
}
let model_bytes =
std::fs::read(model_path).map_err(|e| format!("Failed to read model file: {}", e))?;
let session = Session::builder()
.map_err(|e| format!("Failed to create session builder: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set optimization level: {}", e))?
.commit_from_memory(&model_bytes)
.map_err(|e| format!("Failed to load model: {}", e))?;
let voices = load_voices(voices_path)?;
Ok(Self {
session: Some(Arc::new(Mutex::new(session))),
voices,
vocab: build_vocab(),
fallback_mode: false,
})
}
pub fn voices(&self) -> Vec<String> {
if self.fallback_mode {
vec!["fallback".to_string()]
} else {
self.voices.keys().cloned().collect()
}
}
pub fn synthesize_with_options(
&mut self,
text: &str,
voice: Option<&str>,
speed: f32,
gain: f32,
lang: Option<&str>,
) -> Result<Vec<f32>, String> {
if self.fallback_mode {
debug_log!("🎤 Playing fallback message while downloading voice model...");
return wav_to_f32(FALLBACK_MESSAGE);
}
let session = self
.session
.as_ref()
.ok_or_else(|| "TTS engine not initialized".to_string())?;
let model_speed = speed * SPEED_SCALE;
let clamped_speed = model_speed.clamp(MIN_ENGINE_SPEED, MAX_ENGINE_SPEED);
let voice = voice.unwrap_or(DEFAULT_VOICE);
if !needs_chunking(text) {
let mut audio = self.synthesize_segment(session, voice, text, clamped_speed, lang)?;
if gain != 1.0 {
audio = amplify_audio(&audio, gain);
}
return Ok(audio);
}
let prepared_chunks: Vec<String> = split_text_for_tts(text, MAX_CHARS_PER_CHUNK)
.into_iter()
.filter(|chunk| !chunk.trim().is_empty())
.collect();
if prepared_chunks.is_empty() {
return Err("No text provided for synthesis".to_string());
}
let chunk_count = prepared_chunks.len();
debug_log!(
"📚 Long-form synthesis enabled: {} chars -> {} chunk(s) (≤ {} chars each)",
text.chars().count(),
chunk_count,
MAX_CHARS_PER_CHUNK
);
let overlap = chunk_crossfade_samples();
let mut combined_audio = Vec::new();
for (idx, chunk) in prepared_chunks.iter().enumerate() {
debug_log!(
" → Chunk {}/{} ({} chars)",
idx + 1,
chunk_count,
chunk.chars().count()
);
let chunk_audio =
self.synthesize_segment(session, voice, chunk, clamped_speed, lang)?;
append_with_crossfade(&mut combined_audio, &chunk_audio, overlap);
}
if combined_audio.is_empty() {
return Err("Failed to synthesize combined audio".to_string());
}
let mut final_audio = combined_audio;
if gain != 1.0 {
final_audio = amplify_audio(&final_audio, gain);
}
Ok(final_audio)
}
fn synthesize_segment(
&self,
session: &Arc<Mutex<Session>>,
voice: &str,
text: &str,
speed: f32,
lang: Option<&str>,
) -> Result<Vec<f32>, String> {
let phonemes = text_to_phonemes(text, lang.unwrap_or(DEFAULT_LANG), None, true, false)
.map_err(|e| format!("Failed to convert text to phonemes: {}", e))?;
let mut phonemes_text = phonemes.join(" ");
phonemes_text.insert_str(0, "$");
phonemes_text.push_str("$");
if text.len() > 50 {
debug_log!(" Text length: {} chars", text.len());
debug_log!(" Phonemes array: {} entries", phonemes.len());
debug_log!(" Phoneme text length: {} chars", phonemes_text.len());
}
let tokens = self.tokenize(phonemes_text);
let style = self.parse_voice_style(voice, tokens.len())?;
self.run_inference(session, tokens, style, speed)
}
pub fn save_wav(&self, path: &str, audio: &[f32]) -> Result<(), String> {
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(path, spec)
.map_err(|e| format!("Failed to create WAV file: {}", e))?;
for &sample in audio {
let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
writer
.write_sample(sample_i16)
.map_err(|e| format!("Failed to write sample: {}", e))?;
}
writer
.finalize()
.map_err(|e| format!("Failed to finalize WAV: {}", e))?;
Ok(())
}
fn parse_voice_style(&self, voice_str: &str, tokens_len: usize) -> Result<Vec<f32>, String> {
if self.fallback_mode {
return Ok(vec![0.0; 256]);
}
let mut result = vec![0.0; 256];
let parts: Vec<&str> = voice_str.split('+').collect();
for part in parts {
let (voice_name, weight) = if part.contains('.') {
let pieces: Vec<&str> = part.split('.').collect();
if pieces.len() != 2 {
return Err(format!("Invalid voice format: {}", part));
}
let weight = pieces[1]
.parse::<f32>()
.map_err(|_| format!("Invalid weight: {}", pieces[1]))?;
(pieces[0], weight / 10.0)
} else {
(part, 1.0)
};
let voice_style = self
.voices
.get(voice_name)
.ok_or_else(|| format!("Voice not found: {}", voice_name))?;
let style_dim: usize = 256;
let max_idx = voice_style.len().saturating_sub(style_dim) / style_dim;
let idx = tokens_len.min(max_idx);
let offset = idx * style_dim;
let slice_end = (offset + style_dim).min(voice_style.len());
for (i, val) in voice_style[offset..slice_end].iter().enumerate() {
if i < result.len() {
result[i] += val * weight;
}
}
}
Ok(result)
}
fn tokenize(&self, text: String) -> Vec<i64> {
text.chars()
.map(|c| *self.vocab.get(&c).unwrap_or(&0))
.collect()
}
fn run_inference(
&self,
session: &Arc<Mutex<Session>>,
tokens: Vec<i64>,
style: Vec<f32>,
speed: f32,
) -> Result<Vec<f32>, String> {
let mut session = session
.lock()
.map_err(|e| format!("Failed to lock session: {}", e))?;
let token_count = tokens.len();
let tokens_array = ndarray::Array2::from_shape_vec((1, tokens.len()), tokens)
.map_err(|e| format!("Failed to create tokens array: {}", e))?;
let tokens_tensor = Tensor::from_array(tokens_array)
.map_err(|e| format!("Failed to create tokens tensor: {}", e))?;
let style_array = ndarray::Array2::from_shape_vec((1, style.len()), style)
.map_err(|e| format!("Failed to create style array: {}", e))?;
let style_tensor = Tensor::from_array(style_array)
.map_err(|e| format!("Failed to create style tensor: {}", e))?;
let speed_array = ndarray::Array1::from_vec(vec![speed]);
let speed_tensor = Tensor::from_array(speed_array)
.map_err(|e| format!("Failed to create speed tensor: {}", e))?;
use std::borrow::Cow;
let inputs = SessionInputs::from(vec![
(
Cow::Borrowed("tokens"),
SessionInputValue::Owned(Value::from(tokens_tensor)),
),
(
Cow::Borrowed("style"),
SessionInputValue::Owned(Value::from(style_tensor)),
),
(
Cow::Borrowed("speed"),
SessionInputValue::Owned(Value::from(speed_tensor)),
),
]);
let outputs = session
.run(inputs)
.map_err(|e| format!("Failed to run inference: {}", e))?;
let (shape, data) = outputs["audio"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract audio tensor: {}", e))?;
let data_vec = data.to_vec();
if token_count > 100 {
debug_log!(
" Output audio shape: {:?}, samples: {}",
shape,
data_vec.len()
);
}
Ok(data_vec)
}
}
fn build_vocab() -> HashMap<char, i64> {
let pad = "$";
let punctuation = r#";:,.!?¡¿—…"«»"" "#;
let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ";
let symbols: String = [pad, punctuation, letters, letters_ipa].concat();
symbols
.chars()
.enumerate()
.map(|(idx, c)| (c, idx as i64))
.collect()
}
fn load_voices(path: &str) -> Result<HashMap<String, Vec<f32>>, String> {
let mut file = File::open(path).map_err(|e| format!("Failed to open voices file: {}", e))?;
let mut reader =
NpzReader::new(&mut file).map_err(|e| format!("Failed to create NPZ reader: {}", e))?;
let mut voices = HashMap::new();
for name in reader
.names()
.map_err(|e| format!("Failed to read NPZ names: {:?}", e))?
{
let array: ArrayBase<OwnedRepr<f32>, IxDyn> = reader
.by_name(&name)
.map_err(|e| format!("Failed to read NPZ array {}: {:?}", name, e))?;
let data: Vec<f32> = array.iter().cloned().collect();
let clean_name = name.trim_end_matches(".npy");
voices.insert(clean_name.to_string(), data);
}
Ok(voices)
}
async fn download_file(url: &str, path: &str) -> Result<(), Box<dyn std::error::Error>> {
let response = reqwest::get(url).await?;
let bytes = response.bytes().await?;
let mut file = File::create(path)?;
file.write_all(&bytes)?;
Ok(())
}
fn wav_to_f32(wav_bytes: &[u8]) -> Result<Vec<f32>, String> {
let cursor = Cursor::new(wav_bytes);
let mut reader =
hound::WavReader::new(cursor).map_err(|e| format!("Failed to read WAV: {}", e))?;
let samples: Result<Vec<f32>, _> = reader
.samples::<i16>()
.map(|s| s.map(|sample| sample as f32 / 32768.0))
.collect();
samples.map_err(|e| format!("Failed to read samples: {}", e))
}
fn needs_chunking(text: &str) -> bool {
text.chars().count() > LONG_TEXT_THRESHOLD || text.lines().count() > 3
}
fn chunk_crossfade_samples() -> usize {
((SAMPLE_RATE as usize) * CHUNK_CROSSFADE_MS) / 1000
}
fn append_with_crossfade(buffer: &mut Vec<f32>, next: &[f32], overlap_samples: usize) {
if next.is_empty() {
return;
}
if buffer.is_empty() || overlap_samples == 0 {
buffer.extend_from_slice(next);
return;
}
let overlap = overlap_samples.min(buffer.len()).min(next.len());
if overlap == 0 {
buffer.extend_from_slice(next);
return;
}
let start = buffer.len() - overlap;
for i in 0..overlap {
let fade_in = i as f32 / overlap as f32;
let fade_out = 1.0 - fade_in;
buffer[start + i] = buffer[start + i] * fade_out + next[i] * fade_in;
}
buffer.extend_from_slice(&next[overlap..]);
}
fn split_text_for_tts(text: &str, max_chars: usize) -> Vec<String> {
let sentences: Vec<&str> = text
.split_terminator(&['.', '!', '?'][..])
.filter(|s| !s.trim().is_empty())
.collect();
let mut chunks = Vec::new();
let mut current_chunk = String::new();
for sentence in sentences {
let full_sentence = if text.contains(&format!("{}.", sentence.trim())) {
format!("{}.", sentence.trim())
} else if text.contains(&format!("{}!", sentence.trim())) {
format!("{}!", sentence.trim())
} else if text.contains(&format!("{}?", sentence.trim())) {
format!("{}?", sentence.trim())
} else {
sentence.trim().to_string()
};
if full_sentence.len() > max_chars {
let parts: Vec<&str> = full_sentence.split(',').collect();
if parts.len() > 1 {
for part in parts {
if part.trim().len() > max_chars {
chunks.extend(split_by_words(part, max_chars));
} else if !part.trim().is_empty() {
chunks.push(part.trim().to_string());
}
}
} else {
chunks.extend(split_by_words(&full_sentence, max_chars));
}
}
else if !current_chunk.is_empty()
&& current_chunk.len() + full_sentence.len() + 1 > max_chars
{
chunks.push(current_chunk.trim().to_string());
current_chunk = full_sentence;
}
else {
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
current_chunk.push_str(&full_sentence);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
if chunks.is_empty() && !text.trim().is_empty() {
chunks = split_by_words(text, max_chars);
}
chunks
}
fn split_by_words(text: &str, max_chars: usize) -> Vec<String> {
let words: Vec<&str> = text.split_whitespace().collect();
let mut chunks = Vec::new();
let mut current = String::new();
for word in words {
if current.len() + word.len() + 1 > max_chars && !current.is_empty() {
chunks.push(current.trim().to_string());
current = word.to_string();
} else {
if !current.is_empty() {
current.push(' ');
}
current.push_str(word);
}
}
if !current.is_empty() {
chunks.push(current.trim().to_string());
}
chunks
}
fn amplify_audio(audio: &[f32], gain: f32) -> Vec<f32> {
audio
.iter()
.map(|&sample| {
let amplified = sample * gain;
amplified.clamp(-1.0, 1.0)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crossfade_extends_buffer() {
let mut buffer = vec![1.0, 1.0, 1.0];
let next = vec![0.0, 0.0, 0.0];
append_with_crossfade(&mut buffer, &next, 2);
assert_eq!(buffer.len(), 4);
assert!((buffer.last().copied().unwrap() - 0.0).abs() < f32::EPSILON);
}
#[test]
fn detects_need_for_chunking() {
let short = "hello world";
assert!(!needs_chunking(short));
let long = "This sentence is intentionally quite a bit longer than the \
short sample so that it exceeds the chunking threshold we set.";
assert!(needs_chunking(long));
}
}