use log::{debug, info};
use ndarray::ArrayD;
use std::collections::HashMap;
use std::path::Path;
use super::{ExecutionContext, ExecutionStrategy};
use crate::execution::executor::extract_tts_speed;
use crate::execution::modes::execute_tts_inference;
use crate::execution::session_factory::OnnxSessionFactory;
use crate::execution::template::{
ExecutionTemplate, ModelMetadata, PostprocessingStep, PreprocessingStep,
};
use crate::execution::types::{ExecutorResult, PreprocessedData, RawOutputs};
use crate::execution::voice_loader::TtsVoiceLoader;
use crate::execution::{postprocessing, preprocessing};
use crate::ir::{Envelope, EnvelopeKind};
use crate::runtime_adapter::onnx::{ExecutionProviderKind, SessionOptions};
use crate::runtime_adapter::AdapterError;
use crate::tracing as xybrid_trace;
const MAX_TTS_CHARS: usize = 350;
const INTER_CHUNK_SILENCE_MS: u32 = 200;
pub struct TtsStrategy {
max_chars: usize,
}
impl TtsStrategy {
pub fn new() -> Self {
Self {
max_chars: MAX_TTS_CHARS,
}
}
pub fn with_max_chars(max_chars: usize) -> Self {
Self { max_chars }
}
fn is_tts_model(metadata: &ModelMetadata) -> bool {
metadata
.preprocessing
.iter()
.any(|step| matches!(step, PreprocessingStep::Phonemize { .. }))
}
fn get_model_file(metadata: &ModelMetadata) -> ExecutorResult<&str> {
match &metadata.execution_template {
ExecutionTemplate::Onnx { model_file } => Ok(model_file),
ExecutionTemplate::SafeTensors { model_file, .. } => Ok(model_file),
_ => Err(AdapterError::InvalidInput(
"TTS strategy requires ONNX or SafeTensors model".to_string(),
)),
}
}
fn run_preprocessing(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<PreprocessedData> {
if metadata.preprocessing.is_empty() {
debug!(target: "xybrid_core", "No preprocessing steps configured");
return PreprocessedData::from_envelope(input);
}
let _preprocess_span = xybrid_trace::SpanGuard::new("preprocessing");
let mut data = PreprocessedData::from_envelope(input)?;
for step in &metadata.preprocessing {
data = preprocessing::apply_preprocessing_step(step, data, input, ctx.base_path)?;
}
Ok(data)
}
fn run_postprocessing(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
outputs: RawOutputs,
) -> ExecutorResult<Envelope> {
if metadata.postprocessing.is_empty() {
debug!(target: "xybrid_core", "No postprocessing steps configured");
return outputs.to_envelope();
}
let _postprocess_span = xybrid_trace::SpanGuard::new("postprocessing");
let mut data = outputs;
for step in &metadata.postprocessing {
data = postprocessing::apply_postprocessing_step(step, data, ctx.base_path)?;
}
data.to_envelope()
}
fn execute_single(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
model_path: &Path,
) -> ExecutorResult<Envelope> {
let preprocessed = self.run_preprocessing(ctx, metadata, input)?;
let phoneme_ids = preprocessed
.as_phoneme_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected phoneme IDs".to_string()))?;
debug!(
target: "xybrid_core",
"TTS Single: Phoneme IDs count: {}, first 20: {:?}",
phoneme_ids.len(),
&phoneme_ids[..phoneme_ids.len().min(20)]
);
let voice_loader = TtsVoiceLoader::new(ctx.base_path);
let voice_embedding =
voice_loader.load_for_token_count(metadata, input, Some(phoneme_ids.len()))?;
let session = OnnxSessionFactory::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let speed = extract_tts_speed(input);
let mut raw_outputs = execute_tts_inference(&session, phoneme_ids, voice_embedding, speed)?;
let trim_count = metadata.trim_trailing_samples.unwrap_or(0);
if trim_count > 0 {
for audio in raw_outputs.values_mut() {
let len = audio.len();
if len > trim_count {
audio.slice_collapse(ndarray::s![..len - trim_count]);
}
}
}
self.run_postprocessing(ctx, metadata, RawOutputs::TensorMap(raw_outputs))
}
fn execute_chunked(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
model_path: &Path,
) -> ExecutorResult<Envelope> {
let text = match &input.kind {
EnvelopeKind::Text(t) => t.clone(),
_ => {
return Err(AdapterError::InvalidInput(
"TTS requires text input".to_string(),
))
}
};
debug!(
target: "xybrid_core",
"TTS Chunked: Input text length: {} chars (max={})",
text.len(),
self.max_chars
);
if text.len() <= self.max_chars {
return self.execute_single(ctx, metadata, input, model_path);
}
info!(
target: "xybrid_core",
"Text too long ({} chars), splitting into chunks",
text.len()
);
let chunks = Self::chunk_text(&text, self.max_chars);
debug!(target: "xybrid_core", "TTS: Split into {} chunks", chunks.len());
let mut all_audio: Vec<f32> = Vec::new();
let session = OnnxSessionFactory::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let speed = extract_tts_speed(input);
let sample_rate = Self::get_sample_rate(metadata);
let silence_samples = (sample_rate as usize * INTER_CHUNK_SILENCE_MS as usize) / 1000;
for (i, chunk) in chunks.iter().enumerate() {
debug!(
target: "xybrid_core",
"TTS: Processing chunk {}/{}: {} chars",
i + 1,
chunks.len(),
chunk.len()
);
if i > 0 && !all_audio.is_empty() {
all_audio.extend(std::iter::repeat_n(0.0f32, silence_samples));
}
let chunk_input = Envelope {
kind: EnvelopeKind::Text(chunk.clone()),
metadata: input.metadata.clone(),
};
let preprocessed = self.run_preprocessing(ctx, metadata, &chunk_input)?;
let phoneme_ids = preprocessed
.as_phoneme_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected phoneme IDs".to_string()))?;
debug!(
target: "xybrid_core",
"TTS: Chunk {} has {} phoneme IDs",
i + 1,
phoneme_ids.len()
);
let voice_loader = TtsVoiceLoader::new(ctx.base_path);
let voice_embedding = voice_loader.load_for_token_count(
metadata,
&chunk_input,
Some(phoneme_ids.len()),
)?;
let raw_outputs = execute_tts_inference(&session, phoneme_ids, voice_embedding, speed)?;
if let Some(audio_tensor) = raw_outputs.values().next() {
let mut chunk_audio: Vec<f32> = audio_tensor.iter().cloned().collect();
let trim_count = metadata.trim_trailing_samples.unwrap_or(0);
if trim_count > 0 && chunk_audio.len() > trim_count {
chunk_audio.truncate(chunk_audio.len() - trim_count);
}
all_audio.extend(chunk_audio);
}
}
debug!(target: "xybrid_core", "TTS: Total audio samples: {}", all_audio.len());
let output_names = session.output_names();
let output_name = output_names.first().map(|s| s.as_str()).unwrap_or("audio");
let mut combined_outputs: HashMap<String, ArrayD<f32>> = HashMap::new();
let audio_array = ndarray::Array1::from_vec(all_audio).into_dyn();
combined_outputs.insert(output_name.to_string(), audio_array);
self.run_postprocessing(ctx, metadata, RawOutputs::TensorMap(combined_outputs))
}
fn get_sample_rate(metadata: &ModelMetadata) -> u32 {
for step in &metadata.postprocessing {
if let PostprocessingStep::TTSAudioEncode { sample_rate, .. } = step {
return *sample_rate;
}
}
24000
}
fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
if text.len() <= max_chars {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let sentences: Vec<&str> = text.split_inclusive(['.', '!', '?']).collect();
for sentence in sentences {
let sentence = sentence.trim();
if sentence.is_empty() {
continue;
}
if sentence.len() > max_chars {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
current_chunk = String::new();
}
let mut remaining = sentence;
while remaining.len() > max_chars {
let split_at = remaining[..max_chars]
.rfind(|c: char| c == ',' || c.is_whitespace())
.unwrap_or(max_chars);
chunks.push(remaining[..split_at].trim().to_string());
remaining = remaining[split_at..].trim_start_matches(',').trim();
}
if !remaining.is_empty() {
current_chunk = remaining.to_string();
}
} else if current_chunk.len() + sentence.len() + 1 > max_chars {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
current_chunk = sentence.to_string();
} else {
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
current_chunk.push_str(sentence);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
chunks
}
}
impl Default for TtsStrategy {
fn default() -> Self {
Self::new()
}
}
impl ExecutionStrategy for TtsStrategy {
fn can_handle(&self, metadata: &ModelMetadata) -> bool {
Self::is_tts_model(metadata)
}
fn execute(
&self,
ctx: &mut ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<Envelope> {
let _span = xybrid_trace::SpanGuard::new("tts_execution");
let model_file = Self::get_model_file(metadata)?;
let model_path = ctx.resolve_path(model_file);
let max_chars = metadata.max_chunk_chars.unwrap_or(self.max_chars);
let strategy = TtsStrategy::with_max_chars(max_chars);
strategy.execute_chunked(ctx, metadata, input, &model_path)
}
fn name(&self) -> &'static str {
"tts"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_tts_model_with_phonemize() {
let metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
assert!(TtsStrategy::is_tts_model(&metadata));
}
#[test]
fn test_is_tts_model_without_phonemize() {
let metadata = ModelMetadata::onnx("test-asr", "1.0", "model.onnx");
assert!(!TtsStrategy::is_tts_model(&metadata));
}
#[test]
fn test_can_handle_tts() {
let strategy = TtsStrategy::new();
let tts_metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
assert!(strategy.can_handle(&tts_metadata));
}
#[test]
fn test_cannot_handle_non_tts() {
let strategy = TtsStrategy::new();
let other_metadata = ModelMetadata::onnx("test-other", "1.0", "model.onnx");
assert!(!strategy.can_handle(&other_metadata));
}
#[test]
fn test_chunk_text_under_limit_returns_single() {
let chunks = TtsStrategy::chunk_text("Hello world.", 350);
assert_eq!(chunks, vec!["Hello world."]);
}
#[test]
fn test_chunk_text_empty_string() {
let chunks = TtsStrategy::chunk_text("", 350);
assert!(chunks.is_empty() || (chunks.len() == 1 && chunks[0].is_empty()));
}
#[test]
fn test_chunk_text_splits_at_sentence_boundaries() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = TtsStrategy::chunk_text(text, 20);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], "First sentence.");
assert_eq!(chunks[1], "Second sentence.");
assert_eq!(chunks[2], "Third sentence.");
}
#[test]
fn test_chunk_text_combines_short_sentences() {
let text = "Hi. Hello. Hey there.";
let chunks = TtsStrategy::chunk_text(text, 50);
assert_eq!(chunks.len(), 1);
}
#[test]
fn test_chunk_text_respects_max_chars() {
let text = "This is a test sentence. Here is another one. And a third.";
let max_chars = 30;
let chunks = TtsStrategy::chunk_text(text, max_chars);
for (i, chunk) in chunks.iter().enumerate() {
assert!(
chunk.len() <= max_chars + 15, "Chunk {} too long: {} chars (max {}): '{}'",
i,
chunk.len(),
max_chars,
chunk
);
}
}
#[test]
fn test_chunk_text_long_sentence_splits_at_comma() {
let text = "This is a very long sentence with many words, and it has a comma here, which should be a split point.";
let chunks = TtsStrategy::chunk_text(text, 50);
assert!(chunks.len() >= 2, "Long sentence should be split");
}
#[test]
fn test_chunk_text_long_sentence_splits_at_space() {
let text = "This is a sentence without any commas that should still be split somewhere at a word boundary.";
let chunks = TtsStrategy::chunk_text(text, 40);
assert!(chunks.len() >= 2, "Should split at spaces");
for chunk in &chunks {
if let Some(c) = chunk.chars().next() {
assert!(
c.is_alphabetic() || c == '"',
"Chunk starts unexpectedly: '{}'",
chunk
);
}
}
}
#[test]
fn test_chunk_text_preserves_content() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = TtsStrategy::chunk_text(text, 20);
let rejoined = chunks.join(" ");
assert!(rejoined.contains("First"));
assert!(rejoined.contains("Second"));
assert!(rejoined.contains("Third"));
}
#[test]
fn test_chunk_text_handles_question_marks() {
let text = "Is this a question? Yes it is! And here is a statement.";
let chunks = TtsStrategy::chunk_text(text, 25);
assert!(chunks.len() >= 2, "Should split at ? and !");
}
#[test]
fn test_chunk_text_handles_exclamation_marks() {
let text = "Wow! Amazing! Incredible!";
let chunks = TtsStrategy::chunk_text(text, 10);
assert!(chunks.len() >= 2);
}
#[test]
fn test_chunk_text_real_llm_output() {
let text = "Paris is the capital of France. France is a country in Western Europe. \
It is known for its art, culture, and cuisine. The Eiffel Tower is a famous landmark. \
Paris has a population of over 2 million people.";
let chunks = TtsStrategy::chunk_text(text, 100);
assert!(
chunks.len() >= 2,
"LLM output of {} chars should split with max=100",
text.len()
);
let total_chars: usize = chunks.iter().map(|c| c.len()).sum();
assert!(
total_chars >= text.len() - 30, "Content lost: {} vs {}",
total_chars,
text.len()
);
}
#[test]
fn test_chunk_text_with_llm_special_tokens() {
let text = "Paris.<|im_end|><|im_start|>user\nWhat else?<|im_end|><|im_start|>assistant\n\
France has many cities.";
let chunks = TtsStrategy::chunk_text(text, 80);
assert!(!chunks.is_empty());
}
#[test]
fn test_chunk_text_very_long_input() {
let text = "Paris is the capital. ".repeat(50); let chunks = TtsStrategy::chunk_text(&text, 350);
assert!(chunks.len() >= 3, "Should split into multiple chunks");
for chunk in &chunks {
assert!(
chunk.len() <= 400, "Chunk too long: {} chars",
chunk.len()
);
}
}
#[test]
fn test_custom_max_chars() {
let strategy = TtsStrategy::with_max_chars(100);
assert_eq!(strategy.max_chars, 100);
}
#[test]
fn test_default_max_chars_is_350() {
let strategy = TtsStrategy::new();
assert_eq!(strategy.max_chars, MAX_TTS_CHARS);
assert_eq!(strategy.max_chars, 350);
}
#[test]
fn test_strategy_name() {
let strategy = TtsStrategy::new();
assert_eq!(strategy.name(), "tts");
}
#[test]
fn test_get_model_file_onnx() {
let metadata = ModelMetadata::onnx("test", "1.0", "custom_model.onnx");
let result = TtsStrategy::get_model_file(&metadata);
assert_eq!(result.unwrap(), "custom_model.onnx");
}
#[test]
fn test_get_model_file_safetensors() {
let metadata = ModelMetadata::safetensors("test", "1.0", "model.safetensors", "whisper");
let result = TtsStrategy::get_model_file(&metadata);
assert_eq!(result.unwrap(), "model.safetensors");
}
#[test]
fn test_get_model_file_model_graph_unsupported() {
let metadata = ModelMetadata::model_graph("test", "1.0", vec![], vec![]);
let result = TtsStrategy::get_model_file(&metadata);
assert!(
result.is_err(),
"ModelGraph should not be supported for TTS"
);
}
}