use super::{SpeakerEmbedding, VoiceError, VoiceResult};
#[derive(Debug, Clone)]
pub struct StyleConfig {
pub prosody_dim: usize,
pub timbre_dim: usize,
pub rhythm_dim: usize,
pub sample_rate: u32,
pub frame_shift_ms: u32,
pub style_strength: f32,
pub preserve_pitch_contour: bool,
}
impl Default for StyleConfig {
fn default() -> Self {
Self {
prosody_dim: 64,
timbre_dim: 128,
rhythm_dim: 32,
sample_rate: 16000,
frame_shift_ms: 10,
style_strength: 1.0,
preserve_pitch_contour: false,
}
}
}
impl StyleConfig {
#[must_use]
pub fn prosody_only() -> Self {
Self {
style_strength: 0.5,
preserve_pitch_contour: true,
..Self::default()
}
}
#[must_use]
pub fn full_conversion() -> Self {
Self {
style_strength: 1.0,
preserve_pitch_contour: false,
..Self::default()
}
}
pub fn validate(&self) -> VoiceResult<()> {
if self.prosody_dim == 0 {
return Err(VoiceError::InvalidConfig(
"prosody_dim must be > 0".to_string(),
));
}
if self.timbre_dim == 0 {
return Err(VoiceError::InvalidConfig(
"timbre_dim must be > 0".to_string(),
));
}
if self.rhythm_dim == 0 {
return Err(VoiceError::InvalidConfig(
"rhythm_dim must be > 0".to_string(),
));
}
if self.sample_rate == 0 {
return Err(VoiceError::InvalidConfig(
"sample_rate must be > 0".to_string(),
));
}
if !(0.0..=1.0).contains(&self.style_strength) {
return Err(VoiceError::InvalidConfig(
"style_strength must be in [0.0, 1.0]".to_string(),
));
}
Ok(())
}
#[must_use]
pub fn total_dim(&self) -> usize {
self.prosody_dim + self.timbre_dim + self.rhythm_dim
}
}
#[derive(Debug, Clone)]
pub struct StyleVector {
prosody: Vec<f32>,
timbre: Vec<f32>,
rhythm: Vec<f32>,
}
impl StyleVector {
#[must_use]
pub fn new(prosody: Vec<f32>, timbre: Vec<f32>, rhythm: Vec<f32>) -> Self {
Self {
prosody,
timbre,
rhythm,
}
}
#[must_use]
pub fn zeros(config: &StyleConfig) -> Self {
Self {
prosody: vec![0.0; config.prosody_dim],
timbre: vec![0.0; config.timbre_dim],
rhythm: vec![0.0; config.rhythm_dim],
}
}
pub fn from_flat(vector: &[f32], config: &StyleConfig) -> VoiceResult<Self> {
let expected_len = config.total_dim();
if vector.len() != expected_len {
return Err(VoiceError::DimensionMismatch {
expected: expected_len,
got: vector.len(),
});
}
let prosody_end = config.prosody_dim;
let timbre_end = prosody_end + config.timbre_dim;
Ok(Self {
prosody: vector[..prosody_end].to_vec(),
timbre: vector[prosody_end..timbre_end].to_vec(),
rhythm: vector[timbre_end..].to_vec(),
})
}
#[must_use]
pub fn prosody(&self) -> &[f32] {
&self.prosody
}
#[must_use]
pub fn timbre(&self) -> &[f32] {
&self.timbre
}
#[must_use]
pub fn rhythm(&self) -> &[f32] {
&self.rhythm
}
#[must_use]
pub fn dim(&self) -> usize {
self.prosody.len() + self.timbre.len() + self.rhythm.len()
}
#[must_use]
pub fn to_flat(&self) -> Vec<f32> {
let mut flat = Vec::with_capacity(self.dim());
flat.extend_from_slice(&self.prosody);
flat.extend_from_slice(&self.timbre);
flat.extend_from_slice(&self.rhythm);
flat
}
pub fn interpolate(&self, other: &Self, t: f32) -> VoiceResult<Self> {
if self.prosody.len() != other.prosody.len() {
return Err(VoiceError::DimensionMismatch {
expected: self.prosody.len(),
got: other.prosody.len(),
});
}
if self.timbre.len() != other.timbre.len() {
return Err(VoiceError::DimensionMismatch {
expected: self.timbre.len(),
got: other.timbre.len(),
});
}
if self.rhythm.len() != other.rhythm.len() {
return Err(VoiceError::DimensionMismatch {
expected: self.rhythm.len(),
got: other.rhythm.len(),
});
}
let t = t.clamp(0.0, 1.0);
let one_minus_t = 1.0 - t;
let prosody = self
.prosody
.iter()
.zip(other.prosody.iter())
.map(|(a, b)| a * one_minus_t + b * t)
.collect();
let timbre = self
.timbre
.iter()
.zip(other.timbre.iter())
.map(|(a, b)| a * one_minus_t + b * t)
.collect();
let rhythm = self
.rhythm
.iter()
.zip(other.rhythm.iter())
.map(|(a, b)| a * one_minus_t + b * t)
.collect();
Ok(Self {
prosody,
timbre,
rhythm,
})
}
#[must_use]
pub fn l2_norm(&self) -> f32 {
let sum_sq: f32 = self
.prosody
.iter()
.chain(self.timbre.iter())
.chain(self.rhythm.iter())
.map(|x| x * x)
.sum();
sum_sq.sqrt()
}
pub fn normalize(&mut self) {
let norm = self.l2_norm();
if norm > f32::EPSILON {
for x in &mut self.prosody {
*x /= norm;
}
for x in &mut self.timbre {
*x /= norm;
}
for x in &mut self.rhythm {
*x /= norm;
}
}
}
}
pub trait StyleEncoder {
fn encode(&self, audio: &[f32]) -> VoiceResult<StyleVector>;
fn config(&self) -> &StyleConfig;
}
pub trait StyleTransfer {
fn transfer(&self, source_audio: &[f32], target_style: &StyleVector) -> VoiceResult<Vec<f32>>;
fn transfer_from_reference(
&self,
source_audio: &[f32],
reference_audio: &[f32],
) -> VoiceResult<Vec<f32>>;
fn config(&self) -> &StyleConfig;
}
#[derive(Debug)]
pub struct GstEncoder {
config: StyleConfig,
}
impl GstEncoder {
#[must_use]
pub fn new(config: StyleConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_config() -> Self {
Self::new(StyleConfig::default())
}
}
impl StyleEncoder for GstEncoder {
fn encode(&self, audio: &[f32]) -> VoiceResult<StyleVector> {
if audio.is_empty() {
return Err(VoiceError::InvalidAudio("empty audio".to_string()));
}
Err(VoiceError::NotImplemented(
"GST encoder requires model weights".to_string(),
))
}
fn config(&self) -> &StyleConfig {
&self.config
}
}
#[derive(Debug)]
pub struct AutoVcTransfer {
config: StyleConfig,
}
impl AutoVcTransfer {
#[must_use]
pub fn new(config: StyleConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_config() -> Self {
Self::new(StyleConfig::default())
}
}
include!("compute.rs");
include!("style_tests.rs");