use std::io::{Cursor};
use crate::onn::ort_base::OrtBase;
use crate::onn::ort_koko::OrtKoko;
use crate::tts::phonemizer::Phonemizer;
use crate::tts::tokenize::tokenize;
use ndarray_npy::NpzReader;
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
static VOICES_LOGGED: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Clone)]
pub struct WordAlignment {
pub word: String,
pub start_sec: f32,
pub end_sec: f32,
}
#[derive(Debug, Clone)]
pub enum ModelStrategy {
Koko,
}
impl ModelStrategy {
pub fn audio_key(&self) -> &str {
match self {
ModelStrategy::Koko => "audio",
}
}
}
pub struct TTSKoko {
onn: Arc<Mutex<OrtKoko>>,
styles: HashMap<String, Vec<f32>>,
strategy: ModelStrategy,
phonemizer: Phonemizer,
}
impl TTSKoko {
pub fn new(model_path: &str, voices_path: &str) -> Result<Self, Box<dyn Error>> {
let mut ort = OrtKoko::new();
ort.load_model(model_path.to_string())?;
let styles = Self::load_voices(voices_path)?;
let phonemizer = Phonemizer::new("a");
if !VOICES_LOGGED.load(Ordering::SeqCst) {
let keys: Vec<_> = styles.keys().collect();
tracing::info!("Loaded styles: {:?}", keys);
VOICES_LOGGED.store(true, Ordering::SeqCst);
}
Ok(Self {
onn: Arc::new(Mutex::new(ort)),
styles,
strategy: ModelStrategy::Koko,
phonemizer,
})
}
pub fn new_from_memory(model_bytes: &[u8], voices_bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
let mut ort = OrtKoko::new();
ort.load_model_from_memory(model_bytes)?;
let styles = Self::load_voices_from_memory(voices_bytes)?;
let phonemizer = Phonemizer::new("a");
if !VOICES_LOGGED.load(Ordering::SeqCst) {
let keys: Vec<_> = styles.keys().collect();
tracing::info!("Loaded (baked) styles: {:?}", keys);
VOICES_LOGGED.store(true, Ordering::SeqCst);
}
Ok(Self {
onn: Arc::new(Mutex::new(ort)),
styles,
strategy: ModelStrategy::Koko,
phonemizer,
})
}
pub fn generate_full(
&self,
text: &str,
voice: &str,
speed: f32,
) -> Result<(Vec<f32>, Vec<WordAlignment>), Box<dyn Error>> {
let input_text = text.trim();
if input_text.is_empty() {
return Ok((vec![], vec![]));
}
let mut full_audio = Vec::new();
let mut full_alignments = Vec::new();
let mut text_chunks = Vec::new();
let sentences: Vec<_> = input_text.split_inclusive(|c| c == '.' || c == '!' || c == '?' || c == '\n' || c == ';' || c == ',').collect();
let mut current_chunk = String::new();
for s in sentences {
let next_text = format!("{}{}", current_chunk, s);
if tokenize(&self.phonemizer.phonemize(&next_text, true)).len() > 80 {
if !current_chunk.is_empty() { text_chunks.push(current_chunk.clone()); current_chunk.clear(); }
let words: Vec<_> = s.split_whitespace().collect();
let mut temp_text = String::new();
for w in words {
let next_temp = format!("{} {}", temp_text, w);
if tokenize(&self.phonemizer.phonemize(&next_temp, true)).len() > 80 {
text_chunks.push(temp_text.trim().to_string()); temp_text = w.to_string();
} else { temp_text = next_temp; }
}
current_chunk = temp_text;
} else { current_chunk = next_text; }
}
if !current_chunk.is_empty() { text_chunks.push(current_chunk); }
let mut current_offset = 0.0;
for chunk_text in text_chunks {
let mut phonemes = self.phonemizer.phonemize(&chunk_text, true);
if !phonemes.ends_with('.') && !phonemes.ends_with('!') && !phonemes.ends_with('?') {
phonemes.push('.');
}
let tokens = tokenize(&phonemes);
if tokens.is_empty() { continue; }
let style = self.styles.get(voice).ok_or("Voice style not found")?;
let tokens_batch = vec![tokens.iter().map(|&t| t as i64).collect::<Vec<i64>>()];
let mut onn = self.onn.lock().unwrap();
let res = onn.infer(tokens_batch, style, speed, &self.strategy);
drop(onn);
let audio_data = match res {
Ok(audio) => audio,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("Expand node") || err_str.contains("invalid expand shape") {
tracing::warn!("Kokoro Engine: Suppressing ONNX synthesis glitch (Expand Error). Returning silence.");
vec![] } else {
return Err(e);
}
}
};
let start = audio_data.iter().position(|&s| s.abs() > 0.005).unwrap_or(0);
let end = audio_data.iter().rposition(|&s| s.abs() > 0.005).unwrap_or(audio_data.len());
let trimmed = &audio_data[start..end];
let sq_sum: f32 = trimmed.iter().map(|&s| s * s).sum();
let rms = (sq_sum / trimmed.len().max(1) as f32).sqrt();
let gain = if rms > 0.001 { (0.15f32 / rms).min(100.0f32) } else { 1.0f32 };
let mut chunk_audio: Vec<f32> = trimmed.iter().map(|&s| s * gain).collect();
let fade_len = 240;
if chunk_audio.len() > fade_len * 2 {
for i in 0..fade_len {
let alpha = i as f32 / fade_len as f32;
chunk_audio[i] *= alpha;
let end_idx = chunk_audio.len() - 1 - i;
chunk_audio[end_idx] *= alpha;
}
}
let chunk_words: Vec<_> = chunk_text.split_whitespace().collect();
let chunk_dur = chunk_audio.len() as f32 / 24000.0;
let word_dur = chunk_dur / chunk_words.len().max(1) as f32;
for (i, word) in chunk_words.iter().enumerate() {
full_alignments.push(WordAlignment {
word: word.to_string(),
start_sec: current_offset + (i as f32 * word_dur),
end_sec: current_offset + ((i + 1) as f32 * word_dur),
});
}
full_audio.extend(chunk_audio);
current_offset += chunk_dur;
}
tracing::info!("Sovereign Engine: Multi-Pass Vocalization Reset Success.");
Ok((full_audio, full_alignments))
}
#[allow(clippy::too_many_arguments)]
pub fn tts_raw_audio_streaming<F>(
&self,
text: &str,
_lang: &str,
voice: &str,
speed: f32,
_param1: Option<()>,
_param2: Option<()>,
_param3: Option<()>,
_param4: Option<()>,
mut callback: F,
) -> Result<(), Box<dyn Error>>
where
F: FnMut(Vec<f32>) -> Result<(), Box<dyn Error>>,
{
let input_text = text.trim();
if input_text.is_empty() { return Ok(()); }
let mut text_chunks = Vec::new();
let sentences: Vec<_> = input_text.split_inclusive(|c| c == '.' || c == '!' || c == '?' || c == '\n' || c == ';' || c == ',').collect();
let mut current_chunk = String::new();
for s in sentences {
let next_text = format!("{}{}", current_chunk, s);
if tokenize(&self.phonemizer.phonemize(&next_text, true)).len() > 80 {
if !current_chunk.is_empty() { text_chunks.push(current_chunk.clone()); current_chunk.clear(); }
let words: Vec<_> = s.split_whitespace().collect();
let mut temp_text = String::new();
for w in words {
let next_temp = format!("{} {}", temp_text, w);
if tokenize(&self.phonemizer.phonemize(&next_temp, true)).len() > 80 {
text_chunks.push(temp_text.trim().to_string()); temp_text = w.to_string();
} else { temp_text = next_temp; }
}
current_chunk = temp_text;
} else { current_chunk = next_text; }
}
if !current_chunk.is_empty() { text_chunks.push(current_chunk); }
for chunk_text in text_chunks {
let mut ph = self.phonemizer.phonemize(&chunk_text, true);
if !ph.ends_with('.') && !ph.ends_with('!') && !ph.ends_with('?') {
ph.push('.');
}
let tok = tokenize(&ph);
if tok.is_empty() { continue; }
let style = self.styles.get(voice).ok_or("Voice style not found")?;
let mut onn = self.onn.lock().unwrap();
let res = onn.infer(vec![tok.iter().map(|&t| t as i64).collect()], style, speed, &self.strategy);
drop(onn);
let raw_audio = match res {
Ok(audio) => audio,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("Expand node") || err_str.contains("invalid expand shape") {
tracing::warn!("Kokoro Engine: Suppressing ONNX synthesis glitch (Streaming). Returning silence.");
vec![]
} else {
return Err(e);
}
}
};
let start = raw_audio.iter().position(|&s| s.abs() > 0.005).unwrap_or(0);
let end = raw_audio.iter().rposition(|&s| s.abs() > 0.005).unwrap_or(raw_audio.len());
let trimmed = &raw_audio[start..end];
let sq_sum: f32 = trimmed.iter().map(|&s| s * s).sum();
let rms = (sq_sum / trimmed.len().max(1) as f32).sqrt();
let gain = if rms > 0.001 { (0.15f32 / rms).min(100.0f32) } else { 1.0f32 };
let mut chunk_audio: Vec<f32> = trimmed.iter().map(|&s| s * gain).collect();
let fade_len = 240;
if chunk_audio.len() > fade_len * 2 {
for i in 0..fade_len {
let alpha = i as f32 / fade_len as f32;
chunk_audio[i] *= alpha;
let end_idx = chunk_audio.len() - 1 - i;
chunk_audio[end_idx] *= alpha;
}
}
callback(chunk_audio)?;
}
Ok(())
}
fn load_voices(path: &str) -> Result<HashMap<String, Vec<f32>>, Box<dyn Error>> {
let file = File::open(path)?;
let mut npz = NpzReader::new(file)?;
let mut styles = HashMap::new();
let names = npz.names()?;
for name in names {
if let Ok(array) = npz.by_name::<ndarray::OwnedRepr<f32>, ndarray::Ix3>(&name) {
let style_vec: Vec<f32> = array
.index_axis(ndarray::Axis(0), 0)
.iter()
.cloned()
.collect();
styles.insert(name.replace(".npy", ""), style_vec);
}
}
Ok(styles)
}
fn load_voices_from_memory(bytes: &[u8]) -> Result<HashMap<String, Vec<f32>>, Box<dyn Error>> {
let cursor = Cursor::new(bytes);
let mut npz = NpzReader::new(cursor)?;
let mut styles = HashMap::new();
let names = npz.names()?;
for name in names {
if let Ok(array) = npz.by_name::<ndarray::OwnedRepr<f32>, ndarray::Ix3>(&name) {
let style_vec: Vec<f32> = array
.index_axis(ndarray::Axis(0), 0)
.iter()
.cloned()
.collect();
styles.insert(name.replace(".npy", ""), style_vec);
}
}
Ok(styles)
}
}