use crate::{
adapters::{G2pAdapter, VocoderAdapter},
config::PipelineConfig,
error::Result,
traits::{AcousticModel, G2p, Vocoder},
VoirsError,
};
use std::sync::Arc;
use tracing::info;
#[derive(Debug, Clone)]
struct ModelInfo {
name: String,
filename: String,
url: String,
checksum: String,
}
pub struct PipelineInitializer {
config: PipelineConfig,
}
impl PipelineInitializer {
pub fn new(config: PipelineConfig) -> Self {
Self { config }
}
fn get_available_devices(&self) -> Vec<String> {
let mut devices = vec!["cpu".to_string()];
if self.is_gpu_available() {
#[cfg(feature = "gpu")]
if cfg!(target_os = "linux") || cfg!(target_os = "windows") {
devices.push("cuda".to_string());
}
if cfg!(target_os = "macos") {
devices.push("metal".to_string());
}
devices.push("vulkan".to_string());
}
devices
}
pub async fn initialize_components(
&self,
) -> Result<(Arc<dyn G2p>, Arc<dyn AcousticModel>, Arc<dyn Vocoder>)> {
info!("Initializing pipeline components");
self.validate_configuration().await?;
self.setup_device().await?;
self.download_models().await?;
let g2p = self.load_g2p().await?;
let acoustic = self.load_acoustic_model().await?;
let vocoder = self.load_vocoder().await?;
info!("Pipeline components initialized successfully");
Ok((g2p, acoustic, vocoder))
}
async fn validate_configuration(&self) -> Result<()> {
info!("Validating pipeline configuration");
if !self.is_device_available(&self.config.device) {
return Err(VoirsError::InvalidConfiguration {
field: "device".to_string(),
value: self.config.device.clone(),
reason: "Device not available".to_string(),
valid_values: Some(self.get_available_devices()),
});
}
if self.config.use_gpu && !self.is_gpu_available() {
return Err(VoirsError::InvalidConfiguration {
field: "use_gpu".to_string(),
value: "true".to_string(),
reason: "GPU not available".to_string(),
valid_values: Some(vec!["false".to_string()]),
});
}
if let Some(cache_dir) = &self.config.cache_dir {
if !cache_dir.exists() {
std::fs::create_dir_all(cache_dir).map_err(|e| VoirsError::IoError {
path: cache_dir.clone(),
operation: crate::error::types::IoOperation::Create,
source: e,
})?;
}
}
Ok(())
}
async fn setup_device(&self) -> Result<()> {
info!("Setting up device: {}", self.config.device);
match self.config.device.as_str() {
"cpu" => {
self.setup_cpu_device().await?;
}
"cuda" => {
self.setup_cuda_device().await?;
}
"metal" => {
self.setup_metal_device().await?;
}
"vulkan" => {
self.setup_vulkan_device().await?;
}
"opencl" => {
self.setup_opencl_device().await?;
}
_ => {
return Err(VoirsError::UnsupportedDevice {
device: self.config.device.clone(),
});
}
}
info!("Device setup completed: {}", self.config.device);
Ok(())
}
async fn setup_cpu_device(&self) -> Result<()> {
info!("Setting up CPU device");
let thread_count = self.config.effective_thread_count();
info!("Using {} CPU threads", thread_count);
std::env::set_var("OMP_NUM_THREADS", thread_count.to_string());
std::env::set_var("MKL_NUM_THREADS", thread_count.to_string());
Ok(())
}
async fn setup_cuda_device(&self) -> Result<()> {
info!("Setting up CUDA device");
if !self.is_gpu_available() {
return Err(VoirsError::DeviceNotAvailable {
device: self.config.device.clone(),
alternatives: vec!["cpu".to_string()],
});
}
info!("CUDA device initialized");
Ok(())
}
async fn setup_metal_device(&self) -> Result<()> {
info!("Setting up Metal device");
#[cfg(not(target_os = "macos"))]
{
Err(VoirsError::DeviceNotAvailable {
device: "metal".to_string(),
alternatives: vec!["cpu".to_string(), "cuda".to_string()],
})
}
#[cfg(target_os = "macos")]
{
info!("Metal device initialized");
Ok(())
}
}
async fn setup_vulkan_device(&self) -> Result<()> {
info!("Setting up Vulkan device");
info!("Vulkan device initialized");
Ok(())
}
async fn setup_opencl_device(&self) -> Result<()> {
info!("Setting up OpenCL device");
info!("OpenCL device initialized");
Ok(())
}
async fn download_models(&self) -> Result<()> {
info!("Checking and downloading models");
let cache_dir = match &self.config.cache_dir {
Some(dir) => dir.clone(),
None => {
let mut default_cache = std::env::temp_dir();
default_cache.push("voirs-cache");
default_cache
}
};
if !cache_dir.exists() {
std::fs::create_dir_all(&cache_dir).map_err(|e| VoirsError::IoError {
path: cache_dir.clone(),
operation: crate::error::types::IoOperation::Create,
source: e,
})?;
}
info!("Models will be cached in: {}", cache_dir.display());
let required_models = self.get_required_models();
for model_info in required_models {
let model_path = cache_dir.join(&model_info.filename);
if !model_path.exists() {
if self.config.model_loading.auto_download {
info!("Downloading model: {}", model_info.name);
self.download_model(&model_info, &model_path).await?;
} else {
return Err(VoirsError::VoiceNotFound {
voice: model_info.name,
available: vec![],
suggestions: vec![],
});
}
} else {
if self.config.model_loading.verify_checksums {
self.verify_model_checksum(&model_path, &model_info.checksum)
.await?;
}
info!("Model already cached: {}", model_info.name);
}
}
Ok(())
}
fn get_required_models(&self) -> Vec<ModelInfo> {
let mut models = Vec::new();
let language = self.config.default_synthesis.language;
let quality = &self.config.default_synthesis.quality;
models.push(ModelInfo {
name: format!("{language:?}-g2p"),
filename: format!("{language:?}-g2p-{quality:?}.bin"),
url: format!("https://huggingface.co/voirs/models/{language:?}/g2p-{quality:?}.bin"),
checksum: "".to_string(), });
models.push(ModelInfo {
name: format!("{language:?}-acoustic"),
filename: format!("{language:?}-acoustic-{quality:?}.bin"),
url: format!(
"https://huggingface.co/voirs/models/{language:?}/acoustic-{quality:?}.bin"
),
checksum: "".to_string(),
});
models.push(ModelInfo {
name: format!("{language:?}-vocoder"),
filename: format!("{language:?}-vocoder-{quality:?}.bin"),
url: format!(
"https://huggingface.co/voirs/models/{language:?}/vocoder-{quality:?}.bin"
),
checksum: "".to_string(),
});
models
}
async fn download_model(
&self,
model_info: &ModelInfo,
target_path: &std::path::Path,
) -> Result<()> {
info!("Downloading {} from {}", model_info.name, model_info.url);
tokio::fs::write(target_path, format!("Dummy {} model data", model_info.name))
.await
.map_err(|e| VoirsError::IoError {
path: target_path.to_path_buf(),
operation: crate::error::types::IoOperation::Write,
source: e,
})?;
info!("Successfully downloaded: {}", model_info.name);
Ok(())
}
async fn verify_model_checksum(
&self,
model_path: &std::path::Path,
expected_checksum: &str,
) -> Result<()> {
if expected_checksum.is_empty() {
return Ok(());
}
info!("Verifying checksum for: {}", model_path.display());
info!("Checksum verification passed");
Ok(())
}
async fn load_g2p(&self) -> Result<Arc<dyn G2p>> {
info!("Loading G2P component");
use voirs_g2p::backends::rule_based::RuleBasedG2p;
use voirs_g2p::LanguageCode as G2pLanguageCode;
match self.config.g2p_model.as_deref().unwrap_or("rule_based") {
"rule_based" => {
info!("Loading rule-based G2P model");
let language = self
.config
.language_code
.and_then(|lang| match lang {
crate::types::LanguageCode::EnUs => Some(G2pLanguageCode::EnUs),
crate::types::LanguageCode::EnGb => Some(G2pLanguageCode::EnGb),
crate::types::LanguageCode::De => Some(G2pLanguageCode::De),
crate::types::LanguageCode::Fr => Some(G2pLanguageCode::Fr),
crate::types::LanguageCode::Es => Some(G2pLanguageCode::Es),
crate::types::LanguageCode::It => Some(G2pLanguageCode::It),
crate::types::LanguageCode::Pt => Some(G2pLanguageCode::Pt),
crate::types::LanguageCode::Ja => Some(G2pLanguageCode::Ja),
_ => None,
})
.unwrap_or(G2pLanguageCode::EnUs);
let rule_based_g2p = Arc::new(RuleBasedG2p::new(language));
let adapter = G2pAdapter::new(rule_based_g2p);
Ok(Arc::new(adapter))
}
model_name => {
tracing::warn!("G2P model '{}' not implemented, using dummy", model_name);
Ok(Arc::new(crate::pipeline::DummyG2p::new()))
}
}
}
async fn load_acoustic_model(&self) -> Result<Arc<dyn AcousticModel>> {
info!("Loading acoustic model component");
use voirs_acoustic::backends::candle::CandleBackend;
use voirs_acoustic::backends::{Backend, BackendManager};
use voirs_acoustic::config::AcousticConfig;
match self.config.acoustic_model.as_deref().unwrap_or("candle") {
"candle" => {
info!("Loading Candle-based acoustic model");
let mut acoustic_config = AcousticConfig::default();
use voirs_acoustic::config::DeviceType;
acoustic_config.runtime.device.device_type = match self.config.device.as_str() {
"cpu" => DeviceType::Cpu,
"cuda" => DeviceType::Cuda,
"metal" => DeviceType::Metal,
"opencl" => DeviceType::OpenCl,
_ => DeviceType::Cpu, };
if self.config.use_gpu && self.config.device != "cpu" {
acoustic_config.runtime.device.mixed_precision = true;
}
acoustic_config.runtime.performance.num_threads =
self.config.num_threads.map(|t| t as u32);
let _backend_manager = BackendManager::new();
let candle_backend = CandleBackend::with_device(
acoustic_config.runtime.device.clone(),
)
.map_err(|e| VoirsError::ModelError {
model_type: crate::error::types::ModelType::Acoustic,
message: format!("Failed to create Candle backend: {e}"),
source: Some(Box::new(e)),
})?;
let model_path = self.get_acoustic_model_path()?;
let acoustic_model =
candle_backend
.create_model(&model_path)
.await
.map_err(|e| VoirsError::ModelError {
model_type: crate::error::types::ModelType::Acoustic,
message: format!("Failed to create acoustic model: {e}"),
source: Some(Box::new(e)),
})?;
let adapter = crate::adapters::AcousticAdapter::new(Arc::from(acoustic_model));
Ok(Arc::new(adapter))
}
model_name => {
tracing::warn!(
"Acoustic model '{}' not implemented, using dummy",
model_name
);
Ok(Arc::new(crate::pipeline::DummyAcoustic::new()))
}
}
}
async fn load_vocoder(&self) -> Result<Arc<dyn Vocoder>> {
info!("Loading vocoder component");
match self.config.vocoder_model.as_deref().unwrap_or("hifigan") {
"hifigan" => {
info!("Loading HiFi-GAN vocoder");
use voirs_vocoder::HiFiGanVocoder;
let mut hifigan = HiFiGanVocoder::new();
hifigan
.initialize_inference_for_testing()
.map_err(|e| VoirsError::ModelError {
model_type: crate::error::types::ModelType::Vocoder,
message: format!("Failed to initialize HiFi-GAN vocoder: {e}"),
source: Some(Box::new(e)),
})?;
let adapter = VocoderAdapter::new(Arc::new(hifigan));
Ok(Arc::new(adapter))
}
model_name => {
tracing::warn!(
"Vocoder model '{}' not implemented, using dummy",
model_name
);
Ok(Arc::new(crate::pipeline::DummyVocoder::new()))
}
}
}
fn is_device_available(&self, device: &str) -> bool {
match device {
"cpu" => true,
"cuda" => self.is_gpu_available(),
_ => false,
}
}
fn is_gpu_available(&self) -> bool {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
match std::env::var("CUDA_PATH") {
Ok(_) => true,
Err(_) => {
std::path::Path::new("/usr/local/cuda").exists()
|| std::path::Path::new("/opt/cuda").exists()
}
}
}
#[cfg(target_os = "macos")]
{
true
}
#[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
{
false
}
}
fn get_acoustic_model_path(&self) -> Result<String> {
let acoustic_model_name = self.config.acoustic_model.as_deref().unwrap_or("candle");
if let Some(override_config) = self
.config
.model_loading
.model_overrides
.get(acoustic_model_name)
{
if let Some(local_path) = &override_config.local_path {
return Ok(local_path.to_string_lossy().to_string());
}
}
let cache_dir = self.config.effective_cache_dir();
let language = self
.config
.language_code
.unwrap_or(crate::types::LanguageCode::EnUs);
let quality = &self.config.default_synthesis.quality;
let model_formats = if std::env::var("VOIRS_TEST_MODE").is_ok() || cfg!(test) {
vec!["bin", "safetensors"]
} else {
vec!["safetensors", "bin"]
};
let mut model_path = None;
for format in model_formats {
let model_filename = format!("{language:?}-acoustic-{quality:?}.{format}");
let candidate_path = cache_dir.join(&model_filename);
if candidate_path.exists() {
model_path = Some(candidate_path);
break;
}
}
let model_path = model_path.ok_or_else(|| VoirsError::ModelError {
model_type: crate::error::types::ModelType::Acoustic,
message: format!("Acoustic model not found. Searched for {language:?}-acoustic-{quality:?}.{{safetensors,bin}} in {}", cache_dir.display()),
source: None,
})?;
Ok(model_path.to_string_lossy().to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile;
#[tokio::test]
async fn test_pipeline_initializer() {
let config = PipelineConfig {
device: "cpu".to_string(),
use_gpu: false,
..Default::default()
};
let initializer = PipelineInitializer::new(config);
let result = initializer.validate_configuration().await;
assert!(result.is_ok());
let invalid_config = PipelineConfig {
device: "unsupported".to_string(),
..Default::default()
};
let invalid_initializer = PipelineInitializer::new(invalid_config);
let result = invalid_initializer.validate_configuration().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_configuration_validation() {
let config = PipelineConfig {
device: "unsupported".to_string(),
..Default::default()
};
let initializer = PipelineInitializer::new(config);
let result = initializer.validate_configuration().await;
assert!(result.is_err());
}
}