use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline};
use serde::{Deserialize, Serialize};
use trustformers_core::traits::{Model, Tokenizer};
use trustformers_core::Tensor;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextToSpeechConfig {
pub voice: String,
pub speaking_rate: f32,
pub pitch: f32,
pub volume: f32,
pub sample_rate: u32,
pub output_format: AudioFormat,
pub max_duration: Option<f64>,
pub prosody_control: bool,
pub emotion_control: bool,
pub target_emotion: Option<String>,
}
impl Default for TextToSpeechConfig {
fn default() -> Self {
Self {
voice: "default".to_string(),
speaking_rate: 1.0,
pitch: 1.0,
volume: 1.0,
sample_rate: 22050,
output_format: AudioFormat::Wav,
max_duration: Some(60.0),
prosody_control: false,
emotion_control: false,
target_emotion: None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum AudioFormat {
Wav,
Mp3,
Flac,
Ogg,
Raw,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextToSpeechInput {
pub text: String,
pub voice: Option<String>,
pub speaking_rate: Option<f32>,
pub pitch: Option<f32>,
pub volume: Option<f32>,
pub emotion: Option<String>,
pub prosody_markers: Option<Vec<ProsodyMarker>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProsodyMarker {
pub start: usize,
pub end: usize,
pub prosody_type: ProsodyType,
pub intensity: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ProsodyType {
Emphasis,
Pause,
Speed,
Pitch,
Volume,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextToSpeechOutput {
pub audio_data: Vec<f32>,
pub sample_rate: u32,
pub duration: f64,
pub format: AudioFormat,
pub voice: String,
pub text: String,
pub phonemes: Option<Vec<String>>,
pub phoneme_timings: Option<Vec<PhonemeTimings>>,
pub prosody_info: Option<ProsodyInfo>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PhonemeTimings {
pub phoneme: String,
pub start_time: f64,
pub end_time: f64,
pub confidence: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProsodyInfo {
pub avg_pitch: f32,
pub pitch_range: f32,
pub speaking_rate: f32,
pub pauses: Vec<PauseInfo>,
pub emphasis: Vec<EmphasisInfo>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PauseInfo {
pub start_time: f64,
pub duration: f64,
pub pause_type: PauseType,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PauseType {
Sentence,
Phrase,
Comma,
Breath,
Emphasis,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EmphasisInfo {
pub start_time: f64,
pub end_time: f64,
pub intensity: f32,
pub emphasis_type: EmphasisType,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum EmphasisType {
Stress,
Pitch,
Volume,
Duration,
}
pub struct TextToSpeechPipeline<M, T>
where
M: Model + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
base: BasePipeline<M, T>,
config: TextToSpeechConfig,
available_voices: Vec<String>,
phoneme_converter: Option<PhonemeConverter>,
prosody_analyzer: Option<ProsodyAnalyzer>,
}
impl<M, T> TextToSpeechPipeline<M, T>
where
M: Model<Input = Tensor, Output = Tensor> + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
pub fn new(model: M, tokenizer: T) -> Result<Self> {
let base = BasePipeline::new(model, tokenizer);
let config = TextToSpeechConfig::default();
let available_voices = Self::get_available_voices();
Ok(Self {
base,
config,
available_voices,
phoneme_converter: None,
prosody_analyzer: None,
})
}
pub fn with_config(mut self, config: TextToSpeechConfig) -> Self {
self.config = config;
self
}
pub fn with_voice(mut self, voice: String) -> Self {
self.config.voice = voice;
self
}
pub fn with_speaking_rate(mut self, rate: f32) -> Self {
self.config.speaking_rate = rate.clamp(0.5, 2.0);
self
}
pub fn with_pitch(mut self, pitch: f32) -> Self {
self.config.pitch = pitch.clamp(0.5, 2.0);
self
}
pub fn with_volume(mut self, volume: f32) -> Self {
self.config.volume = volume.clamp(0.0, 1.0);
self
}
pub fn with_output_format(mut self, format: AudioFormat) -> Self {
self.config.output_format = format;
self
}
pub fn with_prosody_control(mut self, enable: bool) -> Self {
self.config.prosody_control = enable;
if enable && self.prosody_analyzer.is_none() {
self.prosody_analyzer = Some(ProsodyAnalyzer::new());
}
self
}
pub fn with_emotion_control(mut self, enable: bool) -> Self {
self.config.emotion_control = enable;
self
}
pub fn with_target_emotion(mut self, emotion: String) -> Self {
self.config.target_emotion = Some(emotion);
self.config.emotion_control = true;
self
}
pub fn get_available_voices() -> Vec<String> {
vec![
"default".to_string(),
"male-neutral".to_string(),
"female-neutral".to_string(),
"male-young".to_string(),
"female-young".to_string(),
"male-elderly".to_string(),
"female-elderly".to_string(),
"child".to_string(),
"narrator".to_string(),
"robot".to_string(),
]
}
pub fn get_supported_emotions() -> Vec<String> {
vec![
"neutral".to_string(),
"happy".to_string(),
"sad".to_string(),
"angry".to_string(),
"excited".to_string(),
"calm".to_string(),
"surprised".to_string(),
"fearful".to_string(),
"disgusted".to_string(),
"confident".to_string(),
"whispering".to_string(),
"shouting".to_string(),
]
}
pub fn synthesize(&self, input: TextToSpeechInput) -> Result<TextToSpeechOutput> {
if input.text.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Text input cannot be empty. Expected: non-empty text string, received: empty string".to_string()
));
}
let voice = input.voice.unwrap_or_else(|| self.config.voice.clone());
let speaking_rate = input.speaking_rate.unwrap_or(self.config.speaking_rate);
let pitch = input.pitch.unwrap_or(self.config.pitch);
let volume = input.volume.unwrap_or(self.config.volume);
if !self.available_voices.contains(&voice) {
return Err(TrustformersError::invalid_input_simple(
format!("Voice '{}' is not available. Parameter: voice, Expected: one of {:?}, Received: {}", voice, self.available_voices, voice)
));
}
let processed_text = self.preprocess_text(&input.text)?;
let tokenized = self.base.tokenizer.encode(&processed_text)?;
let input_ids_f32: Vec<f32> = tokenized.input_ids.iter().map(|&x| x as f32).collect();
let input_tensor = Tensor::from_vec(input_ids_f32, &[1, tokenized.input_ids.len()])?;
let output = self.base.model.forward(input_tensor)?;
let audio_data = self.tensor_to_audio(&output, &voice, speaking_rate, pitch, volume)?;
let duration = audio_data.len() as f64 / self.config.sample_rate as f64;
let phonemes = if let Some(converter) = &self.phoneme_converter {
Some(converter.text_to_phonemes(&processed_text)?)
} else {
None
};
let phoneme_timings = if let Some(phonemes) = &phonemes {
Some(self.generate_phoneme_timings(phonemes, duration)?)
} else {
None
};
let prosody_info = if self.config.prosody_control {
if let Some(analyzer) = &self.prosody_analyzer {
Some(analyzer.analyze(&processed_text, &audio_data, self.config.sample_rate)?)
} else {
None
}
} else {
None
};
Ok(TextToSpeechOutput {
audio_data,
sample_rate: self.config.sample_rate,
duration,
format: self.config.output_format.clone(),
voice,
text: processed_text,
phonemes,
phoneme_timings,
prosody_info,
})
}
fn preprocess_text(&self, text: &str) -> Result<String> {
let mut processed = text.to_string();
processed = processed.split_whitespace().collect::<Vec<_>>().join(" ");
processed = self.expand_abbreviations(&processed);
processed = self.normalize_numbers(&processed);
processed = self.normalize_punctuation(&processed);
Ok(processed)
}
fn expand_abbreviations(&self, text: &str) -> String {
let abbreviations = [
("Dr.", "Doctor"),
("Mr.", "Mister"),
("Mrs.", "Missus"),
("Ms.", "Miss"),
("Prof.", "Professor"),
("St.", "Street"),
("Ave.", "Avenue"),
("Blvd.", "Boulevard"),
("etc.", "et cetera"),
("vs.", "versus"),
("Inc.", "Incorporated"),
("Corp.", "Corporation"),
("Ltd.", "Limited"),
("Co.", "Company"),
("USA", "United States of America"),
("UK", "United Kingdom"),
("CEO", "Chief Executive Officer"),
("CTO", "Chief Technology Officer"),
("CFO", "Chief Financial Officer"),
("AI", "Artificial Intelligence"),
("ML", "Machine Learning"),
("TTS", "Text To Speech"),
("API", "Application Programming Interface"),
];
let mut result = text.to_string();
for (abbr, expansion) in &abbreviations {
result = result.replace(abbr, expansion);
}
result
}
fn normalize_numbers(&self, text: &str) -> String {
let number_words = [
("0", "zero"),
("1", "one"),
("2", "two"),
("3", "three"),
("4", "four"),
("5", "five"),
("6", "six"),
("7", "seven"),
("8", "eight"),
("9", "nine"),
("10", "ten"),
("11", "eleven"),
("12", "twelve"),
("13", "thirteen"),
("14", "fourteen"),
("15", "fifteen"),
("16", "sixteen"),
("17", "seventeen"),
("18", "eighteen"),
("19", "nineteen"),
("20", "twenty"),
];
let mut result = text.to_string();
for (num, word) in &number_words {
result = result.replace(&format!(" {} ", num), &format!(" {} ", word));
}
result
}
fn normalize_punctuation(&self, text: &str) -> String {
text.replace("...", " pause ")
.replace(".", " period ")
.replace("!", " exclamation ")
.replace("?", " question ")
.replace(",", " comma ")
.replace(";", " semicolon ")
.replace(":", " colon ")
.replace("-", " dash ")
.replace("(", " open parenthesis ")
.replace(")", " close parenthesis ")
.replace("\"", " quote ")
.replace("'", " apostrophe ")
}
fn tensor_to_audio(
&self,
tensor: &Tensor,
voice: &str,
speaking_rate: f32,
pitch: f32,
volume: f32,
) -> Result<Vec<f32>> {
let tensor_data = tensor.data()?;
let audio_length = (tensor_data.len() as f32 * speaking_rate) as usize;
let mut audio_data = Vec::with_capacity(audio_length);
for i in 0..audio_length {
let t = i as f32 / self.config.sample_rate as f32;
let tensor_index = (i * tensor_data.len() / audio_length).min(tensor_data.len() - 1);
let base_freq = 220.0 * pitch; let amplitude = tensor_data[tensor_index] * volume;
let fundamental = (2.0 * std::f32::consts::PI * base_freq * t).sin();
let harmonic2 = 0.5 * (2.0 * std::f32::consts::PI * base_freq * 2.0 * t).sin();
let harmonic3 = 0.25 * (2.0 * std::f32::consts::PI * base_freq * 3.0 * t).sin();
let sample = amplitude * (fundamental + harmonic2 + harmonic3);
audio_data.push(sample);
}
self.apply_voice_characteristics(&mut audio_data, voice);
Ok(audio_data)
}
fn apply_voice_characteristics(&self, audio_data: &mut [f32], voice: &str) {
match voice {
"male-neutral" => {
for sample in audio_data.iter_mut() {
*sample *= 0.9;
}
},
"female-neutral" => {
for sample in audio_data.iter_mut() {
*sample *= 1.1;
}
},
"child" => {
for sample in audio_data.iter_mut() {
*sample *= 1.3;
}
},
"elderly" => {
for (i, sample) in audio_data.iter_mut().enumerate() {
let tremolo = 1.0 + 0.1 * (i as f32 * 0.01).sin();
*sample *= 0.8 * tremolo;
}
},
"robot" => {
for sample in audio_data.iter_mut() {
*sample = (*sample * 10.0).round() / 10.0; }
},
_ => {
},
}
}
fn generate_phoneme_timings(
&self,
phonemes: &[String],
total_duration: f64,
) -> Result<Vec<PhonemeTimings>> {
let mut timings = Vec::new();
let avg_duration = total_duration / phonemes.len() as f64;
for (i, phoneme) in phonemes.iter().enumerate() {
let start_time = i as f64 * avg_duration;
let end_time = start_time + avg_duration;
timings.push(PhonemeTimings {
phoneme: phoneme.clone(),
start_time,
end_time,
confidence: 0.8, });
}
Ok(timings)
}
}
impl<M, T> Pipeline for TextToSpeechPipeline<M, T>
where
M: Model<Input = Tensor, Output = Tensor> + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
type Input = TextToSpeechInput;
type Output = TextToSpeechOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
self.synthesize(input)
}
}
pub struct PhonemeConverter {
phoneme_dict: std::collections::HashMap<String, Vec<String>>,
}
impl Default for PhonemeConverter {
fn default() -> Self {
Self::new()
}
}
impl PhonemeConverter {
pub fn new() -> Self {
let mut phoneme_dict = std::collections::HashMap::new();
phoneme_dict.insert(
"hello".to_string(),
vec![
"h".to_string(),
"É™".to_string(),
"ˈl".to_string(),
"oÊŠ".to_string(),
],
);
phoneme_dict.insert(
"world".to_string(),
vec![
"w".to_string(),
"Éœr".to_string(),
"l".to_string(),
"d".to_string(),
],
);
phoneme_dict.insert("the".to_string(), vec!["ð".to_string(), "ə".to_string()]);
phoneme_dict.insert("a".to_string(), vec!["É™".to_string()]);
phoneme_dict.insert("an".to_string(), vec!["æ".to_string(), "n".to_string()]);
Self { phoneme_dict }
}
pub fn text_to_phonemes(&self, text: &str) -> Result<Vec<String>> {
let mut phonemes = Vec::new();
for word in text.split_whitespace() {
let clean_word =
word.to_lowercase().trim_matches(|c: char| !c.is_alphabetic()).to_string();
if let Some(word_phonemes) = self.phoneme_dict.get(&clean_word) {
phonemes.extend(word_phonemes.clone());
} else {
for ch in clean_word.chars() {
phonemes.push(ch.to_string());
}
}
}
Ok(phonemes)
}
}
pub struct ProsodyAnalyzer;
impl ProsodyAnalyzer {
pub fn new() -> Self {
Self
}
pub fn analyze(
&self,
_text: &str,
audio_data: &[f32],
sample_rate: u32,
) -> Result<ProsodyInfo> {
let avg_pitch = self.calculate_average_pitch(audio_data, sample_rate);
let pitch_range = self.calculate_pitch_range(audio_data, sample_rate);
let speaking_rate = 150.0;
let pauses = self.detect_pauses(audio_data, sample_rate)?;
let emphasis = self.detect_emphasis(audio_data, sample_rate)?;
Ok(ProsodyInfo {
avg_pitch,
pitch_range,
speaking_rate,
pauses,
emphasis,
})
}
fn calculate_average_pitch(&self, audio_data: &[f32], _sample_rate: u32) -> f32 {
let mut sum = 0.0;
for sample in audio_data {
sum += sample.abs();
}
sum / audio_data.len() as f32 * 440.0 }
fn calculate_pitch_range(&self, audio_data: &[f32], _sample_rate: u32) -> f32 {
let min_val = audio_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = audio_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(max_val - min_val) * 100.0 }
fn detect_pauses(&self, audio_data: &[f32], sample_rate: u32) -> Result<Vec<PauseInfo>> {
let mut pauses = Vec::new();
let silence_threshold = 0.01;
let min_pause_duration = 0.1;
let mut in_pause = false;
let mut pause_start = 0.0;
for (i, &sample) in audio_data.iter().enumerate() {
let time = i as f64 / sample_rate as f64;
if sample.abs() < silence_threshold {
if !in_pause {
pause_start = time;
in_pause = true;
}
} else if in_pause {
let duration = time - pause_start;
if duration >= min_pause_duration {
pauses.push(PauseInfo {
start_time: pause_start,
duration,
pause_type: PauseType::Phrase, });
}
in_pause = false;
}
}
Ok(pauses)
}
fn detect_emphasis(&self, audio_data: &[f32], sample_rate: u32) -> Result<Vec<EmphasisInfo>> {
let mut emphasis = Vec::new();
let emphasis_threshold = 0.5;
let window_size = sample_rate as usize / 10;
for (i, window) in audio_data.windows(window_size).enumerate() {
let avg_amplitude = window.iter().map(|&x| x.abs()).sum::<f32>() / window.len() as f32;
if avg_amplitude > emphasis_threshold {
let start_time = i as f64 / sample_rate as f64 * window_size as f64;
let end_time = start_time + window_size as f64 / sample_rate as f64;
emphasis.push(EmphasisInfo {
start_time,
end_time,
intensity: avg_amplitude,
emphasis_type: EmphasisType::Volume,
});
}
}
Ok(emphasis)
}
}
impl Default for ProsodyAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::{Model, TokenizedInput, Tokenizer};
use crate::AutoConfig;
use std::collections::HashMap;
use trustformers_core::Tensor;
#[derive(Clone)]
struct MockModel {
config: AutoConfig,
}
impl MockModel {
fn new() -> Self {
MockModel {
config: {
#[cfg(feature = "bert")]
{
AutoConfig::Bert(Default::default())
}
#[cfg(all(not(feature = "bert"), feature = "roberta"))]
{
AutoConfig::Roberta(Default::default())
}
#[cfg(all(not(feature = "bert"), not(feature = "roberta"), feature = "gpt2"))]
{
AutoConfig::Gpt2(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
feature = "gpt_neo"
))]
{
AutoConfig::GptNeo(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
feature = "gpt_j"
))]
{
AutoConfig::GptJ(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
not(feature = "gpt_j"),
feature = "t5"
))]
{
AutoConfig::T5(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
not(feature = "gpt_j"),
not(feature = "t5"),
feature = "albert"
))]
{
AutoConfig::Albert(Default::default())
}
#[cfg(not(any(
feature = "bert",
feature = "roberta",
feature = "gpt2",
feature = "gpt_neo",
feature = "gpt_j",
feature = "t5",
feature = "albert"
)))]
{
compile_error!("At least one model feature must be enabled for tests (bert, roberta, gpt2, gpt_neo, gpt_j, t5, or albert)")
}
},
}
}
}
impl Model for MockModel {
type Input = Tensor;
type Output = Tensor;
type Config = AutoConfig;
fn forward(&self, _input: Self::Input) -> trustformers_core::errors::Result<Self::Output> {
Tensor::zeros(&[1, 10])
}
fn num_parameters(&self) -> usize {
1000 }
fn load_pretrained(
&mut self,
_reader: &mut dyn std::io::Read,
) -> trustformers_core::errors::Result<()> {
Ok(()) }
fn get_config(&self) -> &Self::Config {
&self.config
}
}
#[derive(Clone)]
struct MockTokenizer;
impl MockTokenizer {
fn new() -> Self {
MockTokenizer
}
}
impl Tokenizer for MockTokenizer {
fn encode(&self, _text: &str) -> trustformers_core::errors::Result<TokenizedInput> {
Ok(TokenizedInput {
input_ids: vec![1, 2, 3], attention_mask: vec![1, 1, 1],
token_type_ids: Some(vec![0, 0, 0]),
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
})
}
fn encode_pair(
&self,
_text_a: &str,
_text_b: &str,
) -> trustformers_core::errors::Result<TokenizedInput> {
Ok(TokenizedInput {
input_ids: vec![1, 2, 3, 4, 5], attention_mask: vec![1, 1, 1, 1, 1],
token_type_ids: Some(vec![0, 0, 0, 1, 1]),
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
})
}
fn decode(&self, _token_ids: &[u32]) -> trustformers_core::errors::Result<String> {
Ok("mock decoded text".to_string())
}
fn vocab_size(&self) -> usize {
1000
}
fn get_vocab(&self) -> HashMap<String, u32> {
let mut vocab = HashMap::new();
vocab.insert("test".to_string(), 1);
vocab.insert("mock".to_string(), 2);
vocab.insert("token".to_string(), 3);
vocab
}
fn token_to_id(&self, token: &str) -> Option<u32> {
match token {
"test" => Some(1),
"mock" => Some(2),
"token" => Some(3),
_ => None,
}
}
fn id_to_token(&self, id: u32) -> Option<String> {
match id {
1 => Some("test".to_string()),
2 => Some("mock".to_string()),
3 => Some("token".to_string()),
_ => None,
}
}
}
#[test]
fn test_text_to_speech_pipeline_creation() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer);
assert!(pipeline.is_ok());
}
#[test]
fn test_text_to_speech_config() {
let config = TextToSpeechConfig::default();
assert_eq!(config.voice, "default");
assert_eq!(config.speaking_rate, 1.0);
assert_eq!(config.pitch, 1.0);
assert_eq!(config.volume, 1.0);
}
#[test]
fn test_available_voices() {
let voices = TextToSpeechPipeline::<MockModel, MockTokenizer>::get_available_voices();
assert!(voices.contains(&"default".to_string()));
assert!(voices.contains(&"male-neutral".to_string()));
assert!(voices.contains(&"female-neutral".to_string()));
}
#[test]
fn test_supported_emotions() {
let emotions = TextToSpeechPipeline::<MockModel, MockTokenizer>::get_supported_emotions();
assert!(emotions.contains(&"neutral".to_string()));
assert!(emotions.contains(&"happy".to_string()));
assert!(emotions.contains(&"sad".to_string()));
}
#[test]
fn test_phoneme_converter() {
let converter = PhonemeConverter::new();
let phonemes = converter.text_to_phonemes("hello world").expect("operation failed in test");
assert!(!phonemes.is_empty());
}
#[test]
fn test_prosody_analyzer() {
let analyzer = ProsodyAnalyzer::new();
let audio_data = vec![0.1, 0.2, 0.0, 0.0, 0.3, 0.4];
let prosody =
analyzer.analyze("test", &audio_data, 22050).expect("operation failed in test");
assert!(prosody.avg_pitch > 0.0);
}
#[test]
fn test_text_preprocessing() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline =
TextToSpeechPipeline::new(model, tokenizer).expect("operation failed in test");
let processed = pipeline
.preprocess_text("Dr. Smith said 5 words.")
.expect("operation failed in test");
assert!(processed.contains("Doctor"));
assert!(processed.contains("five"));
}
#[test]
fn test_pipeline_configuration() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer)
.expect("operation failed in test")
.with_voice("female-neutral".to_string())
.with_speaking_rate(1.5)
.with_pitch(1.2)
.with_volume(0.8);
assert_eq!(pipeline.config.voice, "female-neutral");
assert_eq!(pipeline.config.speaking_rate, 1.5);
assert_eq!(pipeline.config.pitch, 1.2);
assert_eq!(pipeline.config.volume, 0.8);
}
#[test]
fn test_default_config_sample_rate() {
let cfg = TextToSpeechConfig::default();
assert_eq!(cfg.sample_rate, 22050);
}
#[test]
fn test_default_config_prosody_disabled() {
let cfg = TextToSpeechConfig::default();
assert!(!cfg.prosody_control);
assert!(!cfg.emotion_control);
assert!(cfg.target_emotion.is_none());
}
#[test]
fn test_default_config_max_duration_positive() {
let cfg = TextToSpeechConfig::default();
if let Some(d) = cfg.max_duration {
assert!(d > 0.0, "max_duration must be positive when set");
}
}
#[test]
fn test_audio_format_variants_constructable() {
let formats = [
AudioFormat::Wav,
AudioFormat::Mp3,
AudioFormat::Flac,
AudioFormat::Ogg,
AudioFormat::Raw,
];
assert_eq!(formats.len(), 5);
}
#[test]
fn test_prosody_type_variants_constructable() {
let types = [
ProsodyType::Emphasis,
ProsodyType::Pause,
ProsodyType::Speed,
ProsodyType::Pitch,
ProsodyType::Volume,
];
assert_eq!(types.len(), 5);
}
#[test]
fn test_prosody_marker_fields() {
let marker = ProsodyMarker {
start: 0,
end: 5,
prosody_type: ProsodyType::Emphasis,
intensity: 0.8,
};
assert_eq!(marker.start, 0);
assert_eq!(marker.end, 5);
assert!(marker.intensity >= 0.0 && marker.intensity <= 1.0);
}
#[test]
fn test_tts_input_with_all_fields() {
let input = TextToSpeechInput {
text: "Hello world".to_string(),
voice: Some("male-neutral".to_string()),
speaking_rate: Some(1.2),
pitch: Some(0.9),
volume: Some(0.7),
emotion: Some("happy".to_string()),
prosody_markers: Some(vec![ProsodyMarker {
start: 0,
end: 5,
prosody_type: ProsodyType::Emphasis,
intensity: 0.5,
}]),
};
assert_eq!(input.text, "Hello world");
assert!(input.voice.is_some());
assert!(input.speaking_rate.is_some());
assert!(input.prosody_markers.is_some());
}
#[test]
fn test_phoneme_converter_known_words() {
let converter = PhonemeConverter::new();
let hello_phonemes =
converter.text_to_phonemes("hello").expect("phoneme conversion should succeed");
assert_eq!(hello_phonemes.len(), 4, "hello → h ə ˈl oʊ (4 phonemes)");
}
#[test]
fn test_phoneme_converter_unknown_word_fallback() {
let converter = PhonemeConverter::new();
let phonemes = converter
.text_to_phonemes("zzz")
.expect("phoneme conversion should succeed for unknown word");
assert_eq!(phonemes.len(), 3, "unknown word → 1 phoneme per character");
}
#[test]
fn test_expand_abbreviations_dr() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline creation ok");
let result = pipeline.expand_abbreviations("Dr. Smith");
assert!(result.contains("Doctor"), "Dr. should expand to Doctor");
}
#[test]
fn test_expand_abbreviations_ai() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline creation ok");
let result = pipeline.expand_abbreviations("AI systems");
assert!(
result.contains("Artificial Intelligence"),
"AI should expand to Artificial Intelligence"
);
}
#[test]
fn test_normalize_numbers_single_digit() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline creation ok");
let result = pipeline.normalize_numbers(" 5 steps");
assert!(result.contains("five"), "5 should become 'five'");
}
#[test]
fn test_normalize_punctuation_period() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline creation ok");
let result = pipeline.normalize_punctuation("End.");
assert!(
result.contains("period"),
"period should be expanded to 'period'"
);
}
#[test]
fn test_normalize_punctuation_question_mark() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline creation ok");
let result = pipeline.normalize_punctuation("Really?");
assert!(result.contains("question"), "? should expand to 'question'");
}
#[test]
fn test_prosody_analyzer_avg_pitch_non_negative() {
let analyzer = ProsodyAnalyzer::new();
let mut s = 42u64;
let samples: Vec<f32> = (0..100)
.map(|_| {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(s % 1000) as f32 / 1000.0
})
.collect();
let prosody = analyzer.analyze("test text", &samples, 22050).expect("prosody ok");
assert!(prosody.avg_pitch >= 0.0, "avg_pitch must be non-negative");
}
#[test]
fn test_prosody_analyzer_speaking_rate_positive() {
let analyzer = ProsodyAnalyzer::new();
let samples = vec![0.3_f32; 200];
let prosody = analyzer.analyze("test", &samples, 22050).expect("prosody should succeed");
assert!(
prosody.speaking_rate > 0.0,
"speaking_rate must be positive"
);
}
#[test]
fn test_phoneme_timings_ordering() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline ok");
let phonemes = vec!["h".to_string(), "e".to_string(), "l".to_string()];
let timings =
pipeline.generate_phoneme_timings(&phonemes, 1.5).expect("phoneme timings ok");
assert_eq!(timings.len(), 3, "one timing per phoneme");
for w in timings.windows(2) {
assert!(
w[1].start_time >= w[0].start_time,
"phoneme start times should be non-decreasing"
);
}
}
#[test]
fn test_phoneme_timings_confidence_range() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = TextToSpeechPipeline::new(model, tokenizer).expect("pipeline ok");
let phonemes = vec!["a".to_string(), "b".to_string()];
let timings =
pipeline.generate_phoneme_timings(&phonemes, 1.0).expect("phoneme timings ok");
for t in &timings {
assert!(
t.confidence >= 0.0 && t.confidence <= 1.0,
"confidence must be in [0, 1]"
);
}
}
}