use std::collections::HashMap;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use espeak_rs::text_to_phonemes;
use ndarray::{ArrayBase, IxDyn, OwnedRepr};
use ndarray_npy::NpzReader;
use ort::{
session::{builder::GraphOptimizationLevel, Session, SessionInputs, SessionInputValue},
value::{Tensor, Value},
};
#[cfg(feature = "playback")]
use rodio::{Decoder, OutputStream, Sink, Source};
#[cfg(feature = "playback")]
use std::io::Cursor;
const MODEL_URL: &str = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/kokoro-v1.0.onnx";
const VOICES_URL: &str = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin";
const SAMPLE_RATE: u32 = 24000;
const DEFAULT_VOICE: &str = "af_sky";
const DEFAULT_SPEED: f32 = 1.0;
fn get_cache_dir() -> PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
Path::new(&home).join(".cache").join("kokoros")
}
pub struct TtsEngine {
session: Arc<Mutex<Session>>,
voices: HashMap<String, Vec<f32>>,
vocab: HashMap<char, i64>,
}
impl TtsEngine {
pub async fn new() -> Result<Self, String> {
let cache_dir = get_cache_dir();
let model_path = cache_dir.join("kokoro-v1.0.onnx");
let voices_path = cache_dir.join("voices-v1.0.bin");
Self::with_paths(
model_path.to_str().unwrap_or("kokoro-v1.0.onnx"),
voices_path.to_str().unwrap_or("voices-v1.0.bin")
).await
}
pub async fn with_paths(model_path: &str, voices_path: &str) -> Result<Self, String> {
if let Some(parent) = Path::new(model_path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create cache directory: {}", e))?;
}
if !Path::new(model_path).exists() {
download_file(MODEL_URL, model_path).await
.map_err(|e| format!("Failed to download model: {}", e))?;
}
if !Path::new(voices_path).exists() {
download_file(VOICES_URL, voices_path).await
.map_err(|e| format!("Failed to download voices: {}", e))?;
}
let model_bytes = std::fs::read(model_path)
.map_err(|e| format!("Failed to read model file: {}", e))?;
let session = Session::builder()
.map_err(|e| format!("Failed to create session builder: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set optimization level: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set intra threads: {}", e))?
.commit_from_memory(&model_bytes)
.map_err(|e| format!("Failed to load model: {}", e))?;
let voices = load_voices(voices_path)
.map_err(|e| format!("Failed to load voices: {}", e))?;
let vocab = build_vocab();
Ok(Self {
session: Arc::new(Mutex::new(session)),
voices,
vocab,
})
}
pub fn synthesize(&mut self, text: &str, voice: Option<&str>) -> Result<Vec<f32>, String> {
let voice = voice.unwrap_or(DEFAULT_VOICE);
let style = self.voices.get(voice)
.ok_or_else(|| format!("Voice '{}' not found", voice))?
.clone();
let phonemes = text_to_phonemes(text, "en", None, true, false)
.map_err(|e| format!("Failed to convert text to phonemes: {:?}", e))?;
let phonemes_str = phonemes.join("");
let tokens = self.tokenize(&phonemes_str);
let audio = self.infer(tokens, style, DEFAULT_SPEED)?;
Ok(audio)
}
#[cfg(feature = "playback")]
pub fn play(&self, audio: &[f32], volume: f32) -> Result<(), String> {
let wav_data = self.to_wav_bytes(audio)?;
let (_stream, stream_handle) = OutputStream::try_default()
.map_err(|e| format!("Failed to get audio output: {}", e))?;
let sink = Sink::try_new(&stream_handle)
.map_err(|e| format!("Failed to create audio sink: {}", e))?;
let cursor = Cursor::new(wav_data);
let source = Decoder::new(cursor)
.map_err(|e| format!("Failed to decode audio: {}", e))?;
sink.set_volume(volume.clamp(0.0, 1.0));
sink.append(source);
sink.sleep_until_end();
Ok(())
}
#[cfg(feature = "playback")]
pub fn list_devices() -> Vec<String> {
use cpal::traits::{HostTrait, DeviceTrait};
if let Ok(devices) = cpal::default_host().output_devices() {
devices.filter_map(|device| device.name().ok()).collect()
} else {
vec!["default".to_string()]
}
}
fn to_wav_bytes(&self, audio: &[f32]) -> Result<Vec<u8>, String> {
let mut buffer = Vec::new();
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
{
let mut writer = hound::WavWriter::new(Cursor::new(&mut buffer), spec)
.map_err(|e| format!("Failed to create WAV writer: {}", e))?;
for &sample in audio {
let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
writer.write_sample(sample_i16)
.map_err(|e| format!("Failed to write sample: {}", e))?;
}
writer.finalize()
.map_err(|e| format!("Failed to finalize WAV: {}", e))?;
}
Ok(buffer)
}
pub fn voices(&self) -> Vec<String> {
self.voices.keys().cloned().collect()
}
pub fn save_wav(&self, path: &str, audio: &[f32]) -> Result<(), String> {
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(path, spec)
.map_err(|e| format!("Failed to create WAV file: {}", e))?;
for &sample in audio {
let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
writer.write_sample(sample_i16)
.map_err(|e| format!("Failed to write sample: {}", e))?;
}
writer.finalize()
.map_err(|e| format!("Failed to finalize WAV: {}", e))?;
Ok(())
}
#[cfg(feature = "mp3")]
pub fn save_mp3(&self, path: &str, audio: &[f32]) -> Result<(), String> {
use mp3lame_encoder::{Builder, Encoder, FlushNoGap};
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
let samples: Vec<i16> = audio.iter()
.map(|&s| (s * 32767.0).clamp(-32768.0, 32767.0) as i16)
.collect();
let mut encoder = Builder::new()
.map_err(|e| format!("Failed to create MP3 encoder: {:?}", e))?
.set_num_channels(1)
.map_err(|e| format!("Failed to set channels: {:?}", e))?
.set_sample_rate(SAMPLE_RATE)
.map_err(|e| format!("Failed to set sample rate: {:?}", e))?
.set_brate(mp3lame_encoder::Bitrate::Kbps128)
.map_err(|e| format!("Failed to set bitrate: {:?}", e))?
.set_quality(mp3lame_encoder::Quality::Best)
.map_err(|e| format!("Failed to set quality: {:?}", e))?
.build()
.map_err(|e| format!("Failed to build encoder: {:?}", e))?;
let mut mp3_data = Vec::new();
let encoded = encoder.encode(&samples)
.map_err(|e| format!("Failed to encode: {:?}", e))?;
mp3_data.extend_from_slice(&encoded);
let encoded = encoder.flush::<FlushNoGap>()
.map_err(|e| format!("Failed to flush: {:?}", e))?;
mp3_data.extend_from_slice(&encoded);
let mut file = File::create(path)
.map_err(|e| format!("Failed to create file: {}", e))?;
file.write_all(&mp3_data)
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(())
}
#[cfg(feature = "opus-format")]
pub fn save_opus(&self, path: &str, audio: &[f32], bitrate: i32) -> Result<(), String> {
use audiopus::{coder::Encoder, Channels, SampleRate, Application};
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
let samples: Vec<i16> = audio.iter()
.map(|&s| (s * 32767.0).clamp(-32768.0, 32767.0) as i16)
.collect();
let mut encoder = Encoder::new(
SampleRate::Hz24000,
Channels::Mono,
Application::Audio
).map_err(|e| format!("Failed to create OPUS encoder: {:?}", e))?;
encoder.set_bitrate(bitrate)
.map_err(|e| format!("Failed to set bitrate: {:?}", e))?;
let frame_size = 480; let mut opus_data = Vec::new();
for chunk in samples.chunks(frame_size) {
let mut encoded = vec![0u8; 4000];
let len = encoder.encode(chunk, &mut encoded)
.map_err(|e| format!("Failed to encode OPUS: {:?}", e))?;
opus_data.extend_from_slice(&encoded[..len]);
}
let mut file = File::create(path)
.map_err(|e| format!("Failed to create file: {}", e))?;
file.write_all(&opus_data)
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(())
}
#[cfg(feature = "flac-format")]
pub fn save_flac(&self, path: &str, audio: &[f32]) -> Result<(), String> {
use flac::StreamWriter;
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
let samples: Vec<i32> = audio.iter()
.map(|&s| (s * 8388607.0).clamp(-8388608.0, 8388607.0) as i32)
.collect();
let file = File::create(path)
.map_err(|e| format!("Failed to create file: {}", e))?;
let mut writer = StreamWriter::new(file, 24)
.map_err(|e| format!("Failed to create FLAC writer: {:?}", e))?;
for sample in samples {
writer.write_sample(sample)
.map_err(|e| format!("Failed to write FLAC sample: {:?}", e))?;
}
writer.finalize()
.map_err(|e| format!("Failed to finalize FLAC: {:?}", e))?;
Ok(())
}
pub fn save_audio(&self, path: &str, audio: &[f32]) -> Result<(), String> {
let extension = Path::new(path)
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("")
.to_lowercase();
match extension.as_str() {
"wav" => self.save_wav(path, audio),
#[cfg(feature = "mp3")]
"mp3" => self.save_mp3(path, audio),
#[cfg(feature = "opus-format")]
"opus" | "ogg" => self.save_opus(path, audio, 24000),
#[cfg(feature = "flac-format")]
"flac" => self.save_flac(path, audio),
_ => Err(format!("Unsupported audio format: {}", extension))
}
}
fn tokenize(&self, phonemes: &str) -> Vec<Vec<i64>> {
let tokens: Vec<i64> = phonemes
.chars()
.filter_map(|c| self.vocab.get(&c).copied())
.collect();
vec![tokens]
}
fn infer(&mut self, tokens: Vec<Vec<i64>>, style: Vec<f32>, speed: f32) -> Result<Vec<f32>, String> {
let mut session = self.session.lock().unwrap();
let tokens_shape = [tokens.len(), tokens[0].len()];
let tokens_flat: Vec<i64> = tokens.into_iter().flatten().collect();
let tokens_tensor = Tensor::from_array((tokens_shape, tokens_flat))
.map_err(|e| format!("Failed to create tokens tensor: {}", e))?;
let style_shape = [1, style.len()];
let style_tensor = Tensor::from_array((style_shape, style))
.map_err(|e| format!("Failed to create style tensor: {}", e))?;
let speed_tensor = Tensor::from_array(([1], vec![speed]))
.map_err(|e| format!("Failed to create speed tensor: {}", e))?;
use std::borrow::Cow;
let inputs = SessionInputs::from(vec![
(Cow::Borrowed("tokens"), SessionInputValue::Owned(Value::from(tokens_tensor))),
(Cow::Borrowed("style"), SessionInputValue::Owned(Value::from(style_tensor))),
(Cow::Borrowed("speed"), SessionInputValue::Owned(Value::from(speed_tensor))),
]);
let outputs = session.run(inputs)
.map_err(|e| format!("Failed to run inference: {}", e))?;
let (_shape, data) = outputs["audio"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract audio tensor: {}", e))?;
Ok(data.to_vec())
}
}
fn build_vocab() -> HashMap<char, i64> {
let pad = "$";
let punctuation = r#";:,.!?¡¿—…"«»"" "#;
let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ";
let symbols: String = [pad, punctuation, letters, letters_ipa].concat();
symbols
.chars()
.enumerate()
.map(|(idx, c)| (c, idx as i64))
.collect()
}
fn load_voices(path: &str) -> Result<HashMap<String, Vec<f32>>, String> {
let mut npz = NpzReader::new(File::open(path).map_err(|e| format!("Failed to open voices file: {}", e))?)
.map_err(|e| format!("Failed to read NPZ: {:?}", e))?;
let mut voices = HashMap::new();
for name in npz.names().map_err(|e| format!("Failed to get NPZ names: {:?}", e))? {
let arr: ArrayBase<OwnedRepr<f32>, IxDyn> = npz.by_name(&name)
.map_err(|e| format!("Failed to read voice {}: {:?}", name, e))?;
let shape = arr.shape();
if shape.len() == 3 && shape[1] == 1 && shape[2] == 256 {
let start = 0;
let end = 256;
let data = arr.as_slice()
.ok_or_else(|| format!("Failed to get slice for voice {}", name))?[start..end]
.to_vec();
voices.insert(name.trim_end_matches(".npy").to_string(), data);
}
}
Ok(voices)
}
async fn download_file(url: &str, path: &str) -> Result<(), Box<dyn std::error::Error>> {
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent)?;
}
println!("Downloading {} to {}...", url, path);
let response = reqwest::get(url).await?;
let bytes = response.bytes().await?;
let mut file = File::create(path)?;
file.write_all(&bytes)?;
println!("Downloaded successfully!");
Ok(())
}
pub struct TtsBuilder {
model_path: String,
voices_path: String,
}
impl Default for TtsBuilder {
fn default() -> Self {
let cache_dir = get_cache_dir();
Self {
model_path: cache_dir.join("kokoro-v1.0.onnx").to_str().unwrap_or("kokoro-v1.0.onnx").to_string(),
voices_path: cache_dir.join("voices-v1.0.bin").to_str().unwrap_or("voices-v1.0.bin").to_string(),
}
}
}
impl TtsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn model_path(mut self, path: &str) -> Self {
self.model_path = path.to_string();
self
}
pub fn voices_path(mut self, path: &str) -> Self {
self.voices_path = path.to_string();
self
}
pub async fn build(self) -> Result<TtsEngine, String> {
TtsEngine::with_paths(&self.model_path, &self.voices_path).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_engine_creation() {
}
#[test]
fn test_builder_pattern() {
let _builder = TtsBuilder::new()
.model_path("custom_model.onnx")
.voices_path("custom_voices.bin");
}
}