use candle_core::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TTSModelType {
Base,
CustomVoice,
VoiceDesign,
Unknown,
}
impl TTSModelType {
pub fn parse(s: &str) -> Self {
match s.to_lowercase().as_str() {
"base" => Self::Base,
"custom_voice" | "customvoice" => Self::CustomVoice,
"voice_design" | "voicedesign" => Self::VoiceDesign,
_ => Self::Unknown,
}
}
pub fn allows_voice_clone(&self) -> bool {
matches!(self, Self::Base | Self::Unknown)
}
pub fn allows_custom_voice(&self) -> bool {
matches!(self, Self::CustomVoice | Self::Unknown)
}
pub fn allows_voice_design(&self) -> bool {
matches!(self, Self::VoiceDesign | Self::Unknown)
}
}
#[derive(Debug)]
pub struct GenerationResult {
pub audio: Tensor,
pub sample_rate: usize,
pub codes: Option<Tensor>,
pub effective_lengths: Option<Vec<usize>>,
}
impl GenerationResult {
pub fn new(audio: Tensor, sample_rate: usize) -> Self {
Self {
audio,
sample_rate,
codes: None,
effective_lengths: None,
}
}
pub fn with_codes(audio: Tensor, sample_rate: usize, codes: Tensor) -> Self {
Self {
audio,
sample_rate,
codes: Some(codes),
effective_lengths: None,
}
}
pub fn with_effective_lengths(
audio: Tensor,
sample_rate: usize,
effective_lengths: Vec<usize>,
) -> Self {
Self {
audio,
sample_rate,
codes: None,
effective_lengths: Some(effective_lengths),
}
}
}
pub const SUPPORTED_LANGUAGES: &[&str] = &[
"auto", "chinese", "english", "japanese", "korean", "french", "german", "spanish",
];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tts_model_type_parse_base() {
assert_eq!(TTSModelType::parse("base"), TTSModelType::Base);
assert_eq!(TTSModelType::parse("Base"), TTSModelType::Base);
assert_eq!(TTSModelType::parse("BASE"), TTSModelType::Base);
}
#[test]
fn test_tts_model_type_parse_custom_voice() {
assert_eq!(
TTSModelType::parse("custom_voice"),
TTSModelType::CustomVoice
);
assert_eq!(
TTSModelType::parse("customvoice"),
TTSModelType::CustomVoice
);
assert_eq!(
TTSModelType::parse("CustomVoice"),
TTSModelType::CustomVoice
);
assert_eq!(
TTSModelType::parse("CUSTOM_VOICE"),
TTSModelType::CustomVoice
);
}
#[test]
fn test_tts_model_type_parse_voice_design() {
assert_eq!(
TTSModelType::parse("voice_design"),
TTSModelType::VoiceDesign
);
assert_eq!(
TTSModelType::parse("voicedesign"),
TTSModelType::VoiceDesign
);
assert_eq!(
TTSModelType::parse("VoiceDesign"),
TTSModelType::VoiceDesign
);
assert_eq!(
TTSModelType::parse("VOICE_DESIGN"),
TTSModelType::VoiceDesign
);
}
#[test]
fn test_tts_model_type_parse_unknown() {
assert_eq!(TTSModelType::parse("unknown"), TTSModelType::Unknown);
assert_eq!(TTSModelType::parse(""), TTSModelType::Unknown);
assert_eq!(
TTSModelType::parse("some_other_type"),
TTSModelType::Unknown
);
}
#[test]
fn test_supported_languages() {
assert!(SUPPORTED_LANGUAGES.contains(&"auto"));
assert!(SUPPORTED_LANGUAGES.contains(&"english"));
assert!(SUPPORTED_LANGUAGES.contains(&"chinese"));
assert!(SUPPORTED_LANGUAGES.contains(&"japanese"));
assert!(SUPPORTED_LANGUAGES.contains(&"korean"));
assert!(SUPPORTED_LANGUAGES.contains(&"french"));
assert!(SUPPORTED_LANGUAGES.contains(&"german"));
assert!(SUPPORTED_LANGUAGES.contains(&"spanish"));
assert_eq!(SUPPORTED_LANGUAGES.len(), 8);
}
#[test]
fn test_generation_result_new() {
let device = candle_core::Device::Cpu;
let audio = candle_core::Tensor::zeros((1000,), candle_core::DType::F32, &device).unwrap();
let result = GenerationResult::new(audio, 24000);
assert_eq!(result.sample_rate, 24000);
assert!(result.codes.is_none());
assert!(result.effective_lengths.is_none());
}
#[test]
fn test_generation_result_with_codes() {
let device = candle_core::Device::Cpu;
let audio = candle_core::Tensor::zeros((1000,), candle_core::DType::F32, &device).unwrap();
let codes =
candle_core::Tensor::zeros((32, 100), candle_core::DType::I64, &device).unwrap();
let result = GenerationResult::with_codes(audio, 24000, codes);
assert_eq!(result.sample_rate, 24000);
assert!(result.codes.is_some());
assert!(result.effective_lengths.is_none());
}
#[test]
fn test_generation_result_with_effective_lengths() {
let device = candle_core::Device::Cpu;
let audio =
candle_core::Tensor::zeros((2, 1000), candle_core::DType::F32, &device).unwrap();
let lengths = vec![800, 950];
let result = GenerationResult::with_effective_lengths(audio, 24000, lengths);
assert_eq!(result.sample_rate, 24000);
assert!(result.codes.is_none());
assert_eq!(result.effective_lengths, Some(vec![800, 950]));
}
#[test]
fn test_base_model_type_allows_voice_clone() {
let model_type = TTSModelType::Base;
assert!(
model_type.allows_voice_clone(),
"Base model should allow voice cloning"
);
}
#[test]
fn test_custom_voice_type_allows_custom() {
let model_type = TTSModelType::CustomVoice;
assert!(
model_type.allows_custom_voice(),
"CustomVoice model should allow custom voice API"
);
}
#[test]
fn test_voice_design_type_allows_design() {
let model_type = TTSModelType::VoiceDesign;
assert!(
model_type.allows_voice_design(),
"VoiceDesign model should allow voice design API"
);
}
#[test]
fn test_unknown_type_allows_all() {
let model_type = TTSModelType::Unknown;
assert!(
model_type.allows_voice_clone(),
"Unknown model should allow voice cloning"
);
assert!(
model_type.allows_custom_voice(),
"Unknown model should allow custom voice"
);
assert!(
model_type.allows_voice_design(),
"Unknown model should allow voice design"
);
}
#[test]
fn test_base_type_rejects_custom_voice() {
let model_type = TTSModelType::Base;
assert!(
!model_type.allows_custom_voice(),
"Base model should not allow custom voice API"
);
assert!(
!model_type.allows_voice_design(),
"Base model should not allow voice design API"
);
}
#[test]
fn test_custom_voice_type_rejects_voice_clone() {
let model_type = TTSModelType::CustomVoice;
assert!(
!model_type.allows_voice_clone(),
"CustomVoice model should not allow voice cloning API"
);
assert!(
!model_type.allows_voice_design(),
"CustomVoice model should not allow voice design API"
);
}
#[test]
fn test_voice_design_type_rejects_others() {
let model_type = TTSModelType::VoiceDesign;
assert!(
!model_type.allows_voice_clone(),
"VoiceDesign model should not allow voice cloning API"
);
assert!(
!model_type.allows_custom_voice(),
"VoiceDesign model should not allow custom voice API"
);
}
}