use super::discovery::{VoiceRegistry, VoiceSearchCriteria};
use crate::{
error::Result,
traits::VoiceManager,
types::{LanguageCode, VoiceConfig},
VoirsError,
};
use async_trait::async_trait;
use futures::StreamExt;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
pub struct DefaultVoiceManager {
registry: VoiceRegistry,
models_dir: PathBuf,
current_voice: Option<String>,
download_enabled: bool,
switching_history: Vec<VoiceSwitch>,
test_mode: bool,
}
impl DefaultVoiceManager {
pub fn new(models_dir: impl Into<PathBuf>) -> Self {
Self {
registry: VoiceRegistry::new(),
models_dir: models_dir.into(),
current_voice: None,
download_enabled: !cfg!(test), switching_history: Vec::new(),
test_mode: cfg!(test), }
}
pub fn with_registry(models_dir: impl Into<PathBuf>, registry: VoiceRegistry) -> Self {
Self {
registry,
models_dir: models_dir.into(),
current_voice: None,
download_enabled: !cfg!(test), switching_history: Vec::new(),
test_mode: cfg!(test), }
}
pub fn add_voice(&mut self, voice: VoiceConfig) {
self.registry.register_voice(voice);
}
pub fn search(&self, criteria: &VoiceSearchCriteria) -> Vec<&VoiceConfig> {
self.registry.find_voices(criteria)
}
pub fn current_voice(&self) -> Option<&str> {
self.current_voice.as_deref()
}
pub fn set_current_voice(&mut self, voice_id: Option<String>) -> Result<()> {
if let Some(ref id) = voice_id {
if self.registry.get_voice(id).is_none() {
return Err(VoirsError::voice_not_found(
id.clone(),
self.registry
.list_voices()
.iter()
.map(|v| v.id.clone())
.collect(),
));
}
}
self.current_voice = voice_id;
Ok(())
}
pub async fn switch_to_voice(&mut self, voice_id: &str) -> Result<()> {
if self.registry.get_voice(voice_id).is_none() {
return Err(VoirsError::voice_not_found(
voice_id.to_string(),
self.registry
.list_voices()
.iter()
.map(|v| v.id.clone())
.collect(),
));
}
if !self.is_voice_available(voice_id) {
if self.download_enabled {
tracing::info!(
"Voice {} not available locally, attempting download",
voice_id
);
self.download_voice(voice_id).await?;
} else {
return Err(VoirsError::voice_not_found(
voice_id.to_string(),
self.get_available_voices(),
));
}
}
let voice_switch = VoiceSwitch {
timestamp: std::time::SystemTime::now(),
from_voice: self.current_voice.clone(),
to_voice: voice_id.to_string(),
successful: true,
error: None,
};
self.switching_history.push(voice_switch);
self.current_voice = Some(voice_id.to_string());
tracing::info!("Switched to voice: {}", voice_id);
Ok(())
}
pub async fn switch_to_language(&mut self, language: LanguageCode) -> Result<()> {
if let Some(default_voice_id) = self.default_voice_for_language(language) {
self.switch_to_voice(&default_voice_id).await
} else {
Err(VoirsError::voice_not_found(
format!("default voice for {language:?}"),
self.registry
.list_voices()
.iter()
.map(|v| v.id.clone())
.collect(),
))
}
}
pub fn get_available_voices(&self) -> Vec<String> {
self.registry
.list_voices()
.iter()
.filter(|voice| self.is_voice_available(&voice.id))
.map(|voice| voice.id.clone())
.collect()
}
pub fn get_unavailable_voices(&self) -> Vec<String> {
self.registry
.list_voices()
.iter()
.filter(|voice| !self.is_voice_available(&voice.id))
.map(|voice| voice.id.clone())
.collect()
}
pub fn has_available_voice_for_language(&self, language: LanguageCode) -> bool {
self.registry
.voices_for_language(language)
.iter()
.any(|voice| self.is_voice_available(&voice.id))
}
pub fn get_available_voice_for_language(&self, language: LanguageCode) -> Option<String> {
self.registry
.voices_for_language(language)
.iter()
.find(|voice| self.is_voice_available(&voice.id))
.map(|voice| voice.id.clone())
}
pub fn set_download_enabled(&mut self, enabled: bool) {
self.download_enabled = enabled;
}
pub fn is_download_enabled(&self) -> bool {
self.download_enabled
}
pub fn set_test_mode(&mut self, enabled: bool) {
self.test_mode = enabled;
if enabled {
self.download_enabled = false;
}
}
pub fn validate_voice_models(&self, voice_id: &str) -> Result<VoiceValidationResult> {
let voice = self.registry.get_voice(voice_id).ok_or_else(|| {
VoirsError::voice_not_found(
voice_id.to_string(),
self.registry
.list_voices()
.iter()
.map(|v| v.id.clone())
.collect(),
)
})?;
let mut result = VoiceValidationResult {
voice_id: voice_id.to_string(),
valid: true,
missing_files: Vec::new(),
invalid_files: Vec::new(),
warnings: Vec::new(),
};
let acoustic_path = self.models_dir.join(&voice.model_config.acoustic_model);
if !acoustic_path.exists() {
result
.missing_files
.push(voice.model_config.acoustic_model.clone());
result.valid = false;
}
let vocoder_path = self.models_dir.join(&voice.model_config.vocoder_model);
if !vocoder_path.exists() {
result
.missing_files
.push(voice.model_config.vocoder_model.clone());
result.valid = false;
}
if let Some(ref g2p_model) = voice.model_config.g2p_model {
let g2p_path = self.models_dir.join(g2p_model);
if !g2p_path.exists() {
result.missing_files.push(g2p_model.clone());
result.valid = false;
}
}
if voice.model_config.device_requirements.min_memory_mb > 2048 {
result
.warnings
.push("Voice requires more than 2GB of memory".to_string());
}
if !voice
.model_config
.device_requirements
.compute_capabilities
.contains(&"cpu".to_string())
{
result
.warnings
.push("Voice may not support CPU-only inference".to_string());
}
Ok(result)
}
async fn download_voice_models(&self, voice: &VoiceConfig) -> Result<()> {
let models = [
&voice.model_config.acoustic_model,
&voice.model_config.vocoder_model,
];
let mut all_models = models.to_vec();
if let Some(ref g2p_model) = voice.model_config.g2p_model {
all_models.push(g2p_model);
}
for model_path in all_models {
let full_path = self.models_dir.join(model_path);
if let Some(parent) = full_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| VoirsError::IoError {
path: parent.to_path_buf(),
operation: crate::error::types::IoOperation::Create,
source: e,
})?;
}
if cfg!(test) {
tracing::debug!("Test mode: creating dummy file for {}", model_path);
tokio::fs::write(&full_path, b"dummy model file")
.await
.map_err(|e| VoirsError::IoError {
path: full_path.clone(),
operation: crate::error::types::IoOperation::Write,
source: e,
})?;
continue;
}
tracing::info!("Downloading model: {}", model_path);
let download_url = self.get_model_download_url(&voice.id, model_path)?;
self.download_file(&download_url, &full_path).await?;
tracing::info!("Successfully downloaded: {}", model_path);
}
Ok(())
}
fn get_model_download_url(&self, voice_id: &str, model_path: &str) -> Result<String> {
if model_path.starts_with("http://") || model_path.starts_with("https://") {
return Ok(model_path.to_string());
}
if model_path.contains('/') && !model_path.starts_with("models/") {
return Ok(format!(
"https://huggingface.co/{voice_id}/resolve/main/{model_path}"
));
}
let base_url = std::env::var("VOIRS_MODEL_REPOSITORY_URL")
.unwrap_or_else(|_| "https://github.com/voirs-ai/models/releases/download".to_string());
let model_filename = std::path::Path::new(model_path)
.file_name()
.and_then(|f| f.to_str())
.unwrap_or(model_path);
Ok(format!("{base_url}/{voice_id}/{model_filename}"))
}
async fn download_file(&self, url: &str, local_path: &std::path::Path) -> Result<()> {
tracing::debug!("Downloading {} to {}", url, local_path.display());
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300)) .build()
.map_err(|e| VoirsError::DownloadFailed {
url: url.to_string(),
reason: format!("Failed to create HTTP client: {e}"),
bytes_downloaded: 0,
total_bytes: None,
})?;
let response = client
.get(url)
.send()
.await
.map_err(|e| VoirsError::DownloadFailed {
url: url.to_string(),
reason: format!("Request failed: {e}"),
bytes_downloaded: 0,
total_bytes: None,
})?;
if !response.status().is_success() {
return Err(VoirsError::DownloadFailed {
url: url.to_string(),
reason: format!("HTTP error: {}", response.status()),
bytes_downloaded: 0,
total_bytes: response.content_length(),
});
}
let total_size = response.content_length();
let mut bytes_downloaded = 0u64;
let mut file =
tokio::fs::File::create(local_path)
.await
.map_err(|e| VoirsError::IoError {
path: local_path.to_path_buf(),
operation: crate::error::types::IoOperation::Create,
source: e,
})?;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| VoirsError::DownloadFailed {
url: url.to_string(),
reason: format!("Stream error: {e}"),
bytes_downloaded,
total_bytes: total_size,
})?;
use tokio::io::AsyncWriteExt;
file.write_all(&chunk)
.await
.map_err(|e| VoirsError::IoError {
path: local_path.to_path_buf(),
operation: crate::error::types::IoOperation::Write,
source: e,
})?;
bytes_downloaded += chunk.len() as u64;
if let Some(total) = total_size {
if total > 10 * 1024 * 1024 {
let percent = (bytes_downloaded as f64 / total as f64) * 100.0;
if bytes_downloaded.is_multiple_of(1024 * 1024) {
tracing::debug!(
"Download progress: {:.1}% ({}/{})",
percent,
bytes_downloaded,
total
);
}
}
}
}
file.flush().await.map_err(|e| VoirsError::IoError {
path: local_path.to_path_buf(),
operation: crate::error::types::IoOperation::Write,
source: e,
})?;
tracing::debug!(
"Downloaded {} bytes to {}",
bytes_downloaded,
local_path.display()
);
Ok(())
}
pub fn get_switching_history(&self) -> Vec<VoiceSwitch> {
self.switching_history.clone()
}
pub fn clear_switching_history(&mut self) {
self.switching_history.clear();
}
}
#[async_trait]
impl VoiceManager for DefaultVoiceManager {
async fn list_voices(&self) -> Result<Vec<VoiceConfig>> {
Ok(self.registry.list_voices().into_iter().cloned().collect())
}
async fn get_voice(&self, voice_id: &str) -> Result<Option<VoiceConfig>> {
Ok(self.registry.get_voice(voice_id).cloned())
}
async fn download_voice(&self, voice_id: &str) -> Result<()> {
let voice = self.registry.get_voice(voice_id).ok_or_else(|| {
VoirsError::voice_not_found(
voice_id.to_string(),
self.registry
.list_voices()
.iter()
.map(|v| v.id.clone())
.collect(),
)
})?;
self.download_voice_models(voice).await?;
tracing::info!("Voice '{}' downloaded successfully", voice_id);
Ok(())
}
fn is_voice_available(&self, voice_id: &str) -> bool {
if self.test_mode {
return self.registry.get_voice(voice_id).is_some();
}
if let Some(voice) = self.registry.get_voice(voice_id) {
let models_exist = [
&voice.model_config.acoustic_model,
&voice.model_config.vocoder_model,
]
.iter()
.chain(voice.model_config.g2p_model.as_ref().iter())
.all(|model_path| {
let full_path = self.models_dir.join(model_path);
full_path.exists()
});
models_exist
} else {
false
}
}
fn default_voice_for_language(&self, lang: LanguageCode) -> Option<String> {
self.registry
.default_voice_for_language(lang)
.map(|voice| voice.id.clone())
}
}
pub struct ConcurrentVoiceManager {
inner: Arc<RwLock<DefaultVoiceManager>>,
}
impl ConcurrentVoiceManager {
pub fn new(models_dir: impl Into<PathBuf>) -> Self {
Self {
inner: Arc::new(RwLock::new(DefaultVoiceManager::new(models_dir))),
}
}
pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, DefaultVoiceManager> {
self.inner.read().await
}
pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, DefaultVoiceManager> {
self.inner.write().await
}
pub async fn switch_to_voice(&self, voice_id: &str) -> Result<()> {
self.inner.write().await.switch_to_voice(voice_id).await
}
pub async fn current_voice(&self) -> Option<String> {
self.inner
.read()
.await
.current_voice()
.map(|s| s.to_string())
}
}
#[derive(Debug, Clone)]
pub struct VoiceValidationResult {
pub voice_id: String,
pub valid: bool,
pub missing_files: Vec<String>,
pub invalid_files: Vec<String>,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct VoiceSwitch {
pub timestamp: std::time::SystemTime,
pub from_voice: Option<String>,
pub to_voice: String,
pub successful: bool,
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Gender;
use tempfile::tempdir;
#[tokio::test]
async fn test_voice_manager_creation() {
let temp_dir = tempdir().unwrap();
let manager = DefaultVoiceManager::new(temp_dir.path());
let voices = manager.list_voices().await.unwrap();
assert!(!voices.is_empty());
}
#[tokio::test]
async fn test_voice_availability() {
let temp_dir = tempdir().unwrap();
let mut manager = DefaultVoiceManager::new(temp_dir.path());
manager.set_test_mode(false);
assert!(!manager.is_voice_available("en-US-female-calm"));
let available = manager.get_available_voices();
assert!(available.is_empty());
let unavailable = manager.get_unavailable_voices();
assert!(!unavailable.is_empty());
}
#[tokio::test]
async fn test_voice_switching() {
let temp_dir = tempdir().unwrap();
let mut manager = DefaultVoiceManager::new(temp_dir.path());
assert!(manager.current_voice().is_none());
manager.set_download_enabled(true);
let result = manager.switch_to_voice("en-US-female-calm").await;
assert!(result.is_ok());
assert_eq!(manager.current_voice(), Some("en-US-female-calm"));
assert!(manager.is_voice_available("en-US-female-calm"));
}
#[tokio::test]
async fn test_language_switching() {
let temp_dir = tempdir().unwrap();
let mut manager = DefaultVoiceManager::new(temp_dir.path());
manager.set_download_enabled(true);
let result = manager.switch_to_language(LanguageCode::EnUs).await;
assert!(result.is_ok());
let current = manager.current_voice().unwrap();
let voice = manager.get_voice(current).await.unwrap().unwrap();
assert_eq!(voice.language, LanguageCode::EnUs);
}
#[tokio::test]
async fn test_voice_validation() {
let temp_dir = tempdir().unwrap();
let manager = DefaultVoiceManager::new(temp_dir.path());
let result = manager.validate_voice_models("non-existent");
assert!(result.is_err());
let result = manager.validate_voice_models("en-US-female-calm").unwrap();
assert!(!result.valid);
assert!(!result.missing_files.is_empty());
}
#[tokio::test]
async fn test_concurrent_voice_manager() {
let temp_dir = tempdir().unwrap();
let manager = ConcurrentVoiceManager::new(temp_dir.path());
let current = manager.current_voice().await;
assert!(current.is_none());
{
let mut write_guard = manager.write().await;
write_guard.set_download_enabled(true);
}
{
let read_guard = manager.read().await;
assert!(read_guard.is_download_enabled());
}
}
#[tokio::test]
async fn test_voice_search() {
let temp_dir = tempdir().unwrap();
let manager = DefaultVoiceManager::new(temp_dir.path());
let criteria = VoiceSearchCriteria::new().gender(Gender::Female);
let results = manager.search(&criteria);
assert!(!results.is_empty());
for voice in results {
assert_eq!(voice.characteristics.gender, Some(Gender::Female));
}
}
#[test]
fn test_voice_validation_result() {
let result = VoiceValidationResult {
voice_id: "test-voice".to_string(),
valid: false,
missing_files: vec!["model1.bin".to_string(), "model2.bin".to_string()],
invalid_files: vec![],
warnings: vec!["High memory usage".to_string()],
};
assert!(!result.valid);
assert_eq!(result.missing_files.len(), 2);
assert_eq!(result.warnings.len(), 1);
}
}