pub mod config;
pub mod custom_voice;
pub mod generate;
pub mod loader;
pub mod options;
pub mod types;
pub mod voice_clone;
pub mod voice_design;
use crate::{
audio::{
mel::{MelSpectrogramConfig, mel_spectrogram},
tokenizer::v2::TokenizerV2,
},
model::{
config::GenerateConfig,
types::{SUPPORTED_LANGUAGES, TTSModelType},
voice_clone::VoiceClonePromptItem,
},
nn::generation::ConditionalGeneration,
text::processing::{PaddingSide, TextProcessor, TokenizerOutput},
};
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use std::collections::HashMap;
pub fn validate_language_value(language: &str) -> Result<()> {
if SUPPORTED_LANGUAGES.contains(&language) {
Ok(())
} else {
Err(candle_core::Error::Msg(format!(
"Unsupported language '{}'. Supported: {:?}",
language, SUPPORTED_LANGUAGES
)))
}
}
pub fn validate_speaker_value(
speaker: &str,
spk_id: &Option<HashMap<String, usize>>,
) -> Result<()> {
match spk_id {
Some(spk_map) if spk_map.contains_key(speaker) => Ok(()),
Some(spk_map) => {
let available: Vec<_> = spk_map.keys().collect();
Err(candle_core::Error::Msg(format!(
"Unknown speaker '{}'. Available: {:?}",
speaker, available
)))
}
None => Err(candle_core::Error::Msg(
"No speakers defined in model config (spk_id is None)".to_string(),
)),
}
}
pub struct Model {
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
text_processor: Option<TextProcessor>,
mel_config: MelSpectrogramConfig,
device: Device,
dtype: DType,
generate_defaults: GenerateConfig,
model_type: TTSModelType,
}
const SPEAKER_ENCODER_SAMPLE_RATE: u32 = 24000;
impl Model {
pub fn new(
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
device: Device,
dtype: DType,
) -> Self {
let model_type = model
.get_config()
.tts_model_type
.as_ref()
.map(|s| TTSModelType::parse(s))
.unwrap_or(TTSModelType::Unknown);
Self {
model,
audio_tokenizer,
text_processor: None,
mel_config: MelSpectrogramConfig::default(),
device,
dtype,
generate_defaults: GenerateConfig::default(),
model_type,
}
}
pub fn with_mel_config(
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
mel_config: MelSpectrogramConfig,
device: Device,
dtype: DType,
) -> Self {
let model_type = model
.get_config()
.tts_model_type
.as_ref()
.map(|s| TTSModelType::parse(s))
.unwrap_or(TTSModelType::Unknown);
Self {
model,
audio_tokenizer,
text_processor: None,
mel_config,
device,
dtype,
generate_defaults: GenerateConfig::default(),
model_type,
}
}
pub fn with_generate_config(
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
generate_defaults: GenerateConfig,
device: Device,
dtype: DType,
) -> Self {
let model_type = model
.get_config()
.tts_model_type
.as_ref()
.map(|s| TTSModelType::parse(s))
.unwrap_or(TTSModelType::Unknown);
Self {
model,
audio_tokenizer,
text_processor: None,
mel_config: MelSpectrogramConfig::default(),
device,
dtype,
generate_defaults,
model_type,
}
}
pub fn with_all(
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
text_processor: Option<TextProcessor>,
generate_defaults: GenerateConfig,
device: Device,
dtype: DType,
) -> Self {
let model_type = model
.get_config()
.tts_model_type
.as_ref()
.map(|s| TTSModelType::parse(s))
.unwrap_or(TTSModelType::Unknown);
Self {
model,
audio_tokenizer,
text_processor,
mel_config: MelSpectrogramConfig::default(),
device,
dtype,
generate_defaults,
model_type,
}
}
pub fn from_model_dir(
model: ConditionalGeneration,
audio_tokenizer: Option<TokenizerV2>,
model_dir: impl AsRef<std::path::Path>,
device: Device,
dtype: DType,
) -> Self {
let generate_defaults = GenerateConfig::from_model_dir(&model_dir);
let text_processor = Self::try_load_text_processor(&model_dir);
Self::with_all(
model,
audio_tokenizer,
text_processor,
generate_defaults,
device,
dtype,
)
}
fn try_load_text_processor(model_dir: impl AsRef<std::path::Path>) -> Option<TextProcessor> {
let model_dir = model_dir.as_ref();
match TextProcessor::from_pretrained(model_dir) {
Ok(processor) => {
tracing::info!("Loaded text tokenizer from {}", model_dir.display());
Some(processor)
}
Err(e) => {
tracing::debug!("No text tokenizer found in {}: {}", model_dir.display(), e);
None
}
}
}
pub fn model(&self) -> &ConditionalGeneration {
&self.model
}
pub fn audio_tokenizer(&self) -> Option<&TokenizerV2> {
self.audio_tokenizer.as_ref()
}
pub fn text_processor(&self) -> Option<&TextProcessor> {
self.text_processor.as_ref()
}
pub fn has_text_processor(&self) -> bool {
self.text_processor.is_some()
}
pub fn set_text_processor(&mut self, processor: TextProcessor) {
self.text_processor = Some(processor);
}
pub fn tokenize_text(&self, text: &str) -> std::result::Result<Vec<u32>, String> {
match &self.text_processor {
Some(processor) => Ok(processor.tokenize_for_tts(text)),
None => Err("No text processor loaded. Load a tokenizer.json file first.".to_string()),
}
}
pub fn tokenize_texts(&self, texts: &[&str]) -> std::result::Result<TokenizerOutput, String> {
match &self.text_processor {
Some(processor) => processor
.batch_tokenize_for_tts(texts, PaddingSide::Left)
.map_err(|e| e.to_string()),
None => Err("No text processor loaded. Load a tokenizer.json file first.".to_string()),
}
}
pub fn tokenize_texts_with_padding(
&self,
texts: &[&str],
padding_side: PaddingSide,
) -> std::result::Result<TokenizerOutput, String> {
match &self.text_processor {
Some(processor) => processor
.batch_tokenize_for_tts(texts, padding_side)
.map_err(|e| e.to_string()),
None => Err("No text processor loaded. Load a tokenizer.json file first.".to_string()),
}
}
pub fn tokens_to_tensor(&self, output: &TokenizerOutput) -> Result<Tensor> {
let batch_size = output.input_ids.len();
if batch_size == 0 {
return Tensor::zeros((0, 0), DType::U32, &self.device);
}
let seq_len = output.input_ids[0].len();
let flat: Vec<u32> = output.input_ids.iter().flatten().copied().collect();
Tensor::from_vec(flat, (batch_size, seq_len), &self.device)
}
pub fn attention_mask_to_tensor(&self, output: &TokenizerOutput) -> Result<Tensor> {
let batch_size = output.attention_mask.len();
if batch_size == 0 {
return Tensor::zeros((0, 0), DType::U32, &self.device);
}
let seq_len = output.attention_mask[0].len();
let flat: Vec<u32> = output.attention_mask.iter().flatten().copied().collect();
Tensor::from_vec(flat, (batch_size, seq_len), &self.device)
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn mel_config(&self) -> &MelSpectrogramConfig {
&self.mel_config
}
pub fn speaker_encoder_sample_rate(&self) -> u32 {
SPEAKER_ENCODER_SAMPLE_RATE
}
pub fn model_type(&self) -> TTSModelType {
self.model_type
}
pub fn is_base_model(&self) -> bool {
matches!(self.model_type, TTSModelType::Base | TTSModelType::Unknown)
}
pub fn is_custom_voice_model(&self) -> bool {
matches!(
self.model_type,
TTSModelType::CustomVoice | TTSModelType::Unknown
)
}
pub fn is_voice_design_model(&self) -> bool {
matches!(
self.model_type,
TTSModelType::VoiceDesign | TTSModelType::Unknown
)
}
fn require_base_model(&self) -> Result<()> {
if !self.is_base_model() {
return Err(candle_core::Error::Msg(format!(
"Model type {:?} does not support generate_voice_clone(). \
Only Base models support voice cloning.",
self.model_type
)));
}
Ok(())
}
fn require_custom_voice_model(&self) -> Result<()> {
if !self.is_custom_voice_model() {
return Err(candle_core::Error::Msg(format!(
"Model type {:?} does not support generate_custom_voice(). \
Only CustomVoice models support predefined speakers.",
self.model_type
)));
}
Ok(())
}
fn require_voice_design_model(&self) -> Result<()> {
if !self.is_voice_design_model() {
return Err(candle_core::Error::Msg(format!(
"Model type {:?} does not support generate_voice_design(). \
Only VoiceDesign models support text-based voice descriptions.",
self.model_type
)));
}
Ok(())
}
pub fn compute_mel_spectrogram(&self, audio: &Tensor) -> Result<Tensor> {
let mel = mel_spectrogram(audio, &self.mel_config, &self.device)?;
let mel = mel.permute((0, 2, 1))?;
mel.to_dtype(self.dtype)
}
pub fn extract_speaker_embedding(&self, audio: &Tensor) -> Result<Tensor> {
let was_1d = audio.dims().len() == 1;
let mel = self.compute_mel_spectrogram(audio)?;
let embedding = self.model.encode_speaker(&mel)?;
if was_1d {
embedding.squeeze(0)
} else {
Ok(embedding)
}
}
pub fn create_voice_clone_prompt_from_mel(
&self,
mel: &Tensor,
ref_text: Option<String>,
x_vector_only_mode: bool,
) -> Result<VoiceClonePromptItem> {
let mel = if mel.dims().len() == 2 {
mel.unsqueeze(0)?
} else {
mel.clone()
};
let speaker_embed = self.model.encode_speaker(&mel)?;
let speaker_embed = speaker_embed.squeeze(0)?;
if x_vector_only_mode {
Ok(VoiceClonePromptItem::x_vector_only(speaker_embed))
} else {
Ok(VoiceClonePromptItem::new(
None, speaker_embed,
x_vector_only_mode,
ref_text,
))
}
}
pub fn create_voice_clone_prompt_from_audio(
&mut self,
audio: &Tensor,
ref_text: Option<String>,
x_vector_only_mode: bool,
) -> Result<VoiceClonePromptItem> {
let mel = self.compute_mel_spectrogram(audio)?;
if tracing::enabled!(tracing::Level::DEBUG) {
tracing::debug!(shape = ?mel.dims(), dtype = ?mel.dtype(), "Mel spectrogram");
}
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(mel_f32) = mel.to_dtype(DType::F32)
&& let Ok(mel_flat) = mel_f32.flatten_all()
&& let Ok(values) = mel_flat.to_vec1::<f32>()
{
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum: f32 = values.iter().sum();
let mean = sum / values.len() as f32;
tracing::debug!(
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.4}", mean),
"Mel stats"
);
}
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(mel_f32) = mel.to_dtype(DType::F32)
&& let Ok(first_frame) = mel_f32.i((0, 0, ..5))
&& let Ok(first5) = first_frame.to_vec1::<f32>()
{
tracing::debug!(?first5, "Mel first 5 values (time=0)");
}
let speaker_embed = self.model.encode_speaker(&mel)?;
let speaker_embed = speaker_embed.squeeze(0)?;
if x_vector_only_mode {
return Ok(VoiceClonePromptItem::x_vector_only(speaker_embed));
}
let ref_code = if let Some(ref mut tokenizer) = self.audio_tokenizer {
if tokenizer.has_encoder() {
let audio_batched = if audio.dims().len() == 1 {
audio.unsqueeze(0)?.to_dtype(self.dtype)?
} else {
audio.to_dtype(self.dtype)?
};
match tokenizer.encode(&audio_batched) {
Ok(codes) => {
match codes.squeeze(0).and_then(|c| c.transpose(0, 1)) {
Ok(transposed) => {
tracing::debug!(
shape = ?transposed.dims(),
"Encoded reference audio to codes"
);
Some(transposed)
}
Err(e) => {
tracing::warn!(
error = %e,
"Failed to reshape encoded codes. Falling back to x-vector only mode."
);
None
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
"Failed to encode audio for ICL mode. Falling back to x-vector only mode."
);
None
}
}
} else {
tracing::warn!(
"Audio tokenizer encoder not available for ICL mode. Falling back to x-vector only mode."
);
None
}
} else {
tracing::warn!(
"No audio tokenizer available for ICL mode. Falling back to x-vector only mode."
);
None
};
let icl_mode = ref_code.is_some() && ref_text.is_some();
if ref_code.is_some() && ref_text.is_none() {
tracing::warn!(
"Reference audio was encoded but no --ref-text provided. \
Falling back to x-vector only mode. For higher quality voice cloning, \
provide --ref-text with the transcript of the reference audio."
);
}
Ok(VoiceClonePromptItem {
ref_code: if icl_mode { ref_code } else { None },
ref_spk_embedding: speaker_embed,
x_vector_only_mode: !icl_mode,
icl_mode,
ref_text,
})
}
pub fn create_voice_clone_prompt_with_sample_rate(
&mut self,
audio: &Tensor,
_sample_rate: u32, ref_text: Option<String>,
x_vector_only_mode: bool,
) -> Result<VoiceClonePromptItem> {
self.create_voice_clone_prompt_from_audio(audio, ref_text, x_vector_only_mode)
}
pub fn validate_language(&self, language: &str) -> Result<()> {
validate_language_value(language)
}
pub fn validate_speaker(&self, speaker: &str) -> Result<()> {
let config = self.model.get_config();
validate_speaker_value(speaker, &config.talker_config.spk_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_language_all_supported() {
for lang in SUPPORTED_LANGUAGES {
let result = validate_language_value(lang);
assert!(
result.is_ok(),
"Language '{}' should be valid but got: {:?}",
lang,
result
);
}
}
#[test]
fn test_validate_language_invalid() {
let result = validate_language_value("esperanto");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Unsupported language"),
"Error should mention unsupported language"
);
assert!(
err_msg.contains("esperanto"),
"Error should contain the invalid language"
);
assert!(
err_msg.contains("english"),
"Error should list supported languages"
);
}
#[test]
fn test_validate_language_case_sensitive() {
let result = validate_language_value("English");
assert!(
result.is_err(),
"Language validation should be case-sensitive"
);
let result = validate_language_value("CHINESE");
assert!(
result.is_err(),
"Language validation should be case-sensitive"
);
}
#[test]
fn test_validate_language_empty_string() {
let result = validate_language_value("");
assert!(
result.is_err(),
"Empty string should not be a valid language"
);
}
#[test]
fn test_validate_speaker_valid() {
let mut spk_map = HashMap::new();
spk_map.insert("alice".to_string(), 0);
spk_map.insert("bob".to_string(), 1);
let spk_id = Some(spk_map);
let result = validate_speaker_value("alice", &spk_id);
assert!(result.is_ok(), "Valid speaker should pass validation");
let result = validate_speaker_value("bob", &spk_id);
assert!(result.is_ok(), "Valid speaker should pass validation");
}
#[test]
fn test_validate_speaker_invalid() {
let mut spk_map = HashMap::new();
spk_map.insert("alice".to_string(), 0);
spk_map.insert("bob".to_string(), 1);
let spk_id = Some(spk_map);
let result = validate_speaker_value("charlie", &spk_id);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Unknown speaker"),
"Error should mention unknown speaker"
);
assert!(
err_msg.contains("charlie"),
"Error should contain the invalid speaker"
);
assert!(
err_msg.contains("alice") || err_msg.contains("bob"),
"Error should list available speakers"
);
}
#[test]
fn test_validate_speaker_none_spk_id() {
let spk_id: Option<HashMap<String, usize>> = None;
let result = validate_speaker_value("anyone", &spk_id);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("No speakers defined"),
"Error should indicate no speakers are defined"
);
}
#[test]
fn test_validate_speaker_empty_map() {
let spk_map: HashMap<String, usize> = HashMap::new();
let spk_id = Some(spk_map);
let result = validate_speaker_value("anyone", &spk_id);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Unknown speaker"),
"Error should mention unknown speaker"
);
}
#[test]
fn test_validate_speaker_case_sensitive() {
let mut spk_map = HashMap::new();
spk_map.insert("alice".to_string(), 0);
let spk_id = Some(spk_map);
let result = validate_speaker_value("Alice", &spk_id);
assert!(
result.is_err(),
"Speaker validation should be case-sensitive"
);
let result = validate_speaker_value("ALICE", &spk_id);
assert!(
result.is_err(),
"Speaker validation should be case-sensitive"
);
}
}