use super::voice::VoiceStyle;
use crate::espeak::EspeakIpaTokenizer;
use ndarray::{Array1, Array2, CowArray, IxDyn};
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, Value};
use std::error::Error;
use std::io::Cursor;
use std::path::Path;
use std::sync::Arc;
pub struct TTSConfig {
pub model_path: String,
pub tokenizer_path: String,
pub max_length: usize,
pub sample_rate: u32,
pub graph_level: GraphOptimizationLevel,
pub execution_provider: Vec<ExecutionProvider>,
}
impl TTSConfig {
pub fn new(model_path: &str, tokenizer_path: &str) -> Self {
TTSConfig {
model_path: model_path.to_string(),
tokenizer_path: tokenizer_path.to_string(),
max_length: 512,
sample_rate: 24000,
graph_level: GraphOptimizationLevel::Level3,
execution_provider: vec![],
}
}
pub fn with_max_tokens_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
pub fn with_sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
pub fn with_graph_optimization_level(mut self, level: GraphOptimizationLevel) -> Self {
self.graph_level = level;
self
}
pub fn with_execution_providers(mut self, providers: Vec<ExecutionProvider>) -> Self {
self.execution_provider = providers;
self
}
}
pub struct GeneratedAudio {
pub samples: Vec<f32>,
pub sample_rate: u32,
pub duration_seconds: f32,
}
impl GeneratedAudio {
pub fn save_to_wav<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn Error>> {
let bytes = self.to_wav_bytes()?;
std::fs::write(path, bytes)?;
Ok(())
}
pub fn to_wav_bytes(&self) -> Result<Vec<u8>, Box<dyn Error>> {
let spec = hound::WavSpec {
channels: 1,
sample_rate: self.sample_rate,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut cursor = Cursor::new(Vec::new());
{
let mut writer = hound::WavWriter::new(&mut cursor, spec)?;
let silence_samples = (self.sample_rate as f32 * 0.1) as usize;
for _ in 0..silence_samples {
writer.write_sample(0i16)?;
}
for &sample in &self.samples {
let clamped = sample.max(-1.0).min(1.0);
let amplitude = (clamped * i16::MAX as f32) as i16;
writer.write_sample(amplitude)?;
}
for _ in 0..silence_samples {
writer.write_sample(0i16)?;
}
writer.finalize()?;
}
Ok(cursor.into_inner())
}
}
pub struct KokoroTTS {
session: Session,
tokenizer: EspeakIpaTokenizer,
sample_rate: u32,
}
impl KokoroTTS {
pub fn with_config(config: TTSConfig) -> Result<Self, Box<dyn Error>> {
let TTSConfig {
model_path,
tokenizer_path,
max_length,
sample_rate,
graph_level,
execution_provider,
} = config;
let env = Arc::new(Environment::builder().with_name("kokoro_tts").build()?);
let optimization = match graph_level {
GraphOptimizationLevel::Disable => GraphOptimizationLevel::Disable,
GraphOptimizationLevel::Level1 => GraphOptimizationLevel::Level1,
GraphOptimizationLevel::Level2 => GraphOptimizationLevel::Level2,
GraphOptimizationLevel::Level3 => GraphOptimizationLevel::Level3,
};
let mut builder = SessionBuilder::new(&env)?
.with_optimization_level(optimization)?
.with_parallel_execution(true)?;
if !execution_provider.is_empty() {
builder = builder.with_execution_providers(&execution_provider)?;
}
let session = builder.with_model_from_file(&model_path)?;
let tokenizer_content = std::fs::read_to_string(&tokenizer_path)?;
let tokenizer_json: serde_json::Value = serde_json::from_str(&tokenizer_content)?;
let vocab_obj = tokenizer_json["model"]["vocab"]
.as_object()
.ok_or("No vocab found in tokenizer.json")?;
let mut vocab = std::collections::HashMap::new();
for (token, id) in vocab_obj {
vocab.insert(token.clone(), id.as_i64().unwrap_or(0));
}
let tokenizer = EspeakIpaTokenizer::new(vocab)?.with_model_max_length(max_length);
Ok(KokoroTTS {
session,
tokenizer,
sample_rate,
})
}
pub fn generate_speech_from_phonemes(
&self,
phonemes: &str,
voice_style: &VoiceStyle,
speed: f32,
) -> Result<GeneratedAudio, Box<dyn Error>> {
let tokens = self.tokenizer.encode_phonemes(phonemes, None)?;
self.generate_from_tokens(&tokens, voice_style, speed)
}
pub fn generate_speech(
&self,
text: &str,
voice_style: &VoiceStyle,
speed: f32,
) -> Result<GeneratedAudio, Box<dyn Error>> {
let tokens = self.tokenizer.encode(text, None)?;
self.generate_from_tokens(&tokens, voice_style, speed)
}
pub fn generate_from_tokens(
&self,
tokens: &[i64],
voice_style: &VoiceStyle,
speed: f32,
) -> Result<GeneratedAudio, Box<dyn Error>> {
let input_ids = Array2::<i64>::from_shape_vec((1, tokens.len()), tokens.to_vec())?;
let style_vector = voice_style.get_style_vector_for_token_length(tokens.len(), 256);
let style = Array2::<f32>::from_shape_vec((1, 256), style_vector)?;
let speed_array = Array1::<f32>::from_vec(vec![speed]);
let input_ids_cow: CowArray<i64, IxDyn> = CowArray::from(input_ids.into_dyn());
let style_cow: CowArray<f32, IxDyn> = CowArray::from(style.into_dyn());
let speed_cow: CowArray<f32, IxDyn> = CowArray::from(speed_array.into_dyn());
let input_ids_tensor = Value::from_array(self.session.allocator(), &input_ids_cow)?;
let style_tensor = Value::from_array(self.session.allocator(), &style_cow)?;
let speed_tensor = Value::from_array(self.session.allocator(), &speed_cow)?;
let outputs = self
.session
.run(vec![input_ids_tensor, style_tensor, speed_tensor])?;
if let Ok(output) = outputs[0].try_extract::<f32>() {
let view = output.view();
let samples = view.as_slice().unwrap().to_vec();
let duration_seconds = samples.len() as f32 / self.sample_rate as f32;
let audio = GeneratedAudio {
samples,
sample_rate: self.sample_rate,
duration_seconds,
};
Ok(audio)
} else {
Err("Failed to extract audio output".into())
}
}
#[allow(dead_code)]
pub fn speak(
&self,
text: &str,
voice_style: &VoiceStyle,
) -> Result<GeneratedAudio, Box<dyn Error>> {
self.generate_speech(text, voice_style, 1.0)
}
}