use crate::VoirsError;
use std::sync::Arc;
use tokio::sync::RwLock;
pub use voirs_conversion::prelude::*;
pub use voirs_conversion::types::{AgeGroup, Gender};
#[derive(Debug, Clone)]
pub struct VoiceConverter {
converter: Arc<voirs_conversion::VoiceConverter>,
config: Arc<RwLock<VoiceConverterConfig>>,
target_cache: Arc<RwLock<std::collections::HashMap<String, ConversionTarget>>>,
}
#[derive(Debug, Clone)]
pub struct VoiceConverterConfig {
pub enabled: bool,
pub default_conversion_type: ConversionType,
pub realtime_enabled: bool,
pub cache_targets: bool,
pub max_cache_size: usize,
pub quality_level: f32,
}
impl VoiceConverter {
pub async fn new() -> crate::Result<Self> {
let converter = voirs_conversion::VoiceConverter::new()
.map_err(|e| VoirsError::model_error(format!("Voice converter: {}", e)))?;
Ok(Self {
converter: Arc::new(converter),
config: Arc::new(RwLock::new(VoiceConverterConfig::default())),
target_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub async fn with_config(conversion_config: ConversionConfig) -> crate::Result<Self> {
let converter = voirs_conversion::VoiceConverter::with_config(conversion_config)
.map_err(|e| VoirsError::model_error(format!("Voice converter: {}", e)))?;
Ok(Self {
converter: Arc::new(converter),
config: Arc::new(RwLock::new(VoiceConverterConfig::default())),
target_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub async fn convert_voice(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
target: ConversionTarget,
conversion_type: Option<ConversionType>,
) -> crate::Result<ConversionResult> {
let config = self.config.read().await;
if !config.enabled {
return Err(VoirsError::ConfigError {
field: "feature".to_string(),
message: "Voice conversion is disabled".to_string(),
});
}
let conversion_type = conversion_type.unwrap_or(config.default_conversion_type.clone());
let request = ConversionRequest::new(
format!("convert_{}", fastrand::u64(..)),
source_audio,
source_sample_rate,
conversion_type,
target,
)
.with_realtime(config.realtime_enabled)
.with_quality_level(config.quality_level);
self.converter
.convert(request)
.await
.map_err(|e| VoirsError::audio_error(format!("Voice conversion: {}", e)))
}
pub async fn convert_age(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
target_age: AgeGroup,
) -> crate::Result<ConversionResult> {
let characteristics = VoiceCharacteristics::for_age(target_age);
let target = ConversionTarget::new(characteristics);
self.convert_voice(
source_audio,
source_sample_rate,
target,
Some(ConversionType::AgeTransformation),
)
.await
}
pub async fn convert_gender(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
target_gender: Gender,
) -> crate::Result<ConversionResult> {
let characteristics = VoiceCharacteristics::for_gender(target_gender);
let target = ConversionTarget::new(characteristics);
self.convert_voice(
source_audio,
source_sample_rate,
target,
Some(ConversionType::GenderTransformation),
)
.await
}
pub async fn pitch_shift(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
pitch_factor: f32,
) -> crate::Result<ConversionResult> {
let mut characteristics = VoiceCharacteristics::new();
characteristics.pitch.mean_f0 *= pitch_factor;
let target = ConversionTarget::new(characteristics);
self.convert_voice(
source_audio,
source_sample_rate,
target,
Some(ConversionType::PitchShift),
)
.await
}
pub async fn change_speed(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
speed_factor: f32,
) -> crate::Result<ConversionResult> {
let mut characteristics = VoiceCharacteristics::new();
characteristics.timing.speaking_rate = speed_factor;
let target = ConversionTarget::new(characteristics);
self.convert_voice(
source_audio,
source_sample_rate,
target,
Some(ConversionType::SpeedTransformation),
)
.await
}
pub async fn convert_with_cached_target(
&self,
source_audio: Vec<f32>,
source_sample_rate: u32,
target_id: &str,
) -> crate::Result<ConversionResult> {
let cache = self.target_cache.read().await;
let target = cache
.get(target_id)
.ok_or_else(|| VoirsError::ConfigError {
field: "cache".to_string(),
message: format!("Cached target not found: {}", target_id),
})?
.clone();
self.convert_voice(source_audio, source_sample_rate, target, None)
.await
}
pub async fn cache_target(
&self,
target_id: String,
target: ConversionTarget,
) -> crate::Result<()> {
let config = self.config.read().await;
let mut cache = self.target_cache.write().await;
if cache.len() >= config.max_cache_size {
if let Some(oldest_key) = cache.keys().next().cloned() {
cache.remove(&oldest_key);
}
}
cache.insert(target_id, target);
Ok(())
}
pub async fn list_cached_targets(&self) -> Vec<String> {
let cache = self.target_cache.read().await;
cache.keys().cloned().collect()
}
pub async fn remove_cached_target(&self, target_id: &str) -> crate::Result<()> {
let mut cache = self.target_cache.write().await;
cache.remove(target_id);
Ok(())
}
pub async fn clear_cache(&self) -> crate::Result<()> {
let mut cache = self.target_cache.write().await;
cache.clear();
Ok(())
}
pub async fn set_enabled(&self, enabled: bool) -> crate::Result<()> {
let mut config = self.config.write().await;
config.enabled = enabled;
Ok(())
}
pub async fn is_enabled(&self) -> bool {
let config = self.config.read().await;
config.enabled
}
pub async fn set_quality_level(&self, level: f32) -> crate::Result<()> {
let mut config = self.config.write().await;
config.quality_level = level.clamp(0.0, 1.0);
Ok(())
}
pub async fn get_quality_level(&self) -> f32 {
let config = self.config.read().await;
config.quality_level
}
pub async fn set_realtime_enabled(&self, enabled: bool) -> crate::Result<()> {
let mut config = self.config.write().await;
config.realtime_enabled = enabled;
Ok(())
}
pub async fn is_realtime_enabled(&self) -> bool {
let config = self.config.read().await;
config.realtime_enabled
}
pub async fn get_statistics(&self) -> crate::Result<ConversionStatistics> {
let cache_size = self.target_cache.read().await.len();
Ok(ConversionStatistics {
cached_targets: cache_size,
realtime_enabled: self.is_realtime_enabled().await,
quality_level: self.get_quality_level().await,
processing_enabled: self.is_enabled().await,
})
}
pub async fn validate_audio(
&self,
audio: &[f32],
sample_rate: u32,
) -> crate::Result<AudioValidationResult> {
let mut issues = Vec::new();
let duration = audio.len() as f32 / sample_rate as f32;
if audio.is_empty() {
issues.push("Audio is empty".to_string());
} else if duration < 0.1 {
issues.push("Audio too short (minimum 0.1 seconds)".to_string());
} else if duration > 300.0 {
issues.push("Audio too long (maximum 5 minutes)".to_string());
}
if sample_rate < 8000 {
issues.push("Sample rate too low (minimum 8kHz)".to_string());
} else if sample_rate > 96000 {
issues.push("Sample rate too high (maximum 96kHz)".to_string());
}
let max_amplitude = audio.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
if max_amplitude < 0.001 {
issues.push("Audio appears to be silent".to_string());
}
Ok(AudioValidationResult {
valid: issues.is_empty(),
issues,
duration,
sample_rate,
max_amplitude,
})
}
}
impl Default for VoiceConverterConfig {
fn default() -> Self {
Self {
enabled: true,
default_conversion_type: ConversionType::SpeakerConversion,
realtime_enabled: false,
cache_targets: true,
max_cache_size: 50,
quality_level: 0.8,
}
}
}
#[derive(Debug, Clone)]
pub struct ConversionStatistics {
pub cached_targets: usize,
pub realtime_enabled: bool,
pub quality_level: f32,
pub processing_enabled: bool,
}
#[derive(Debug, Clone)]
pub struct AudioValidationResult {
pub valid: bool,
pub issues: Vec<String>,
pub duration: f32,
pub sample_rate: u32,
pub max_amplitude: f32,
}
#[derive(Debug, Clone)]
pub struct VoiceConverterBuilder {
config: VoiceConverterConfig,
conversion_config: Option<ConversionConfig>,
}
impl VoiceConverterBuilder {
pub fn new() -> Self {
Self {
config: VoiceConverterConfig::default(),
conversion_config: None,
}
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.config.enabled = enabled;
self
}
pub fn default_conversion_type(mut self, conversion_type: ConversionType) -> Self {
self.config.default_conversion_type = conversion_type;
self
}
pub fn realtime_enabled(mut self, enabled: bool) -> Self {
self.config.realtime_enabled = enabled;
self
}
pub fn quality_level(mut self, level: f32) -> Self {
self.config.quality_level = level.clamp(0.0, 1.0);
self
}
pub fn cache_size(mut self, size: usize) -> Self {
self.config.max_cache_size = size;
self
}
pub fn conversion_config(mut self, config: ConversionConfig) -> Self {
self.conversion_config = Some(config);
self
}
pub async fn build(self) -> crate::Result<VoiceConverter> {
let converter = if let Some(conversion_config) = self.conversion_config {
VoiceConverter::with_config(conversion_config).await?
} else {
VoiceConverter::new().await?
};
{
let mut config = converter.config.write().await;
*config = self.config;
}
Ok(converter)
}
}
impl Default for VoiceConverterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_voice_converter_creation() {
let converter = VoiceConverter::new().await.unwrap();
assert!(converter.is_enabled().await);
}
#[tokio::test]
async fn test_age_conversion() {
let converter = VoiceConverter::new().await.unwrap();
let audio = vec![0.1; 22050];
let result = converter
.convert_age(audio, 22050, AgeGroup::Child)
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_gender_conversion() {
let converter = VoiceConverter::new().await.unwrap();
let audio = vec![0.1; 22050];
let result = converter
.convert_gender(audio, 22050, Gender::Female)
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_pitch_shift() {
let converter = VoiceConverter::new().await.unwrap();
let audio = vec![0.1; 22050];
let result = converter.pitch_shift(audio, 22050, 1.2).await.unwrap(); assert!(result.success);
}
#[tokio::test]
async fn test_target_caching() {
let converter = VoiceConverter::new().await.unwrap();
let characteristics = VoiceCharacteristics::for_age(AgeGroup::Teen);
let target = ConversionTarget::new(characteristics);
converter
.cache_target("teen_voice".to_string(), target)
.await
.unwrap();
let targets = converter.list_cached_targets().await;
assert!(targets.contains(&"teen_voice".to_string()));
}
#[tokio::test]
async fn test_audio_validation() {
let converter = VoiceConverter::new().await.unwrap();
let valid_audio = vec![0.1; 22050]; let result = converter.validate_audio(&valid_audio, 22050).await.unwrap();
assert!(result.valid);
let invalid_audio = vec![];
let result = converter
.validate_audio(&invalid_audio, 22050)
.await
.unwrap();
assert!(!result.valid);
assert!(!result.issues.is_empty());
}
#[tokio::test]
async fn test_converter_builder() {
let converter = VoiceConverterBuilder::new()
.enabled(true)
.default_conversion_type(ConversionType::PitchShift)
.realtime_enabled(true)
.quality_level(0.9)
.build()
.await
.unwrap();
assert!(converter.is_enabled().await);
assert!(converter.is_realtime_enabled().await);
assert_eq!(converter.get_quality_level().await, 0.9);
}
}