1use super::{SpeakerEmbedding, VoiceError, VoiceResult};
39
40#[derive(Debug, Clone)]
46pub struct StyleConfig {
47 pub prosody_dim: usize,
49 pub timbre_dim: usize,
51 pub rhythm_dim: usize,
53 pub sample_rate: u32,
55 pub frame_shift_ms: u32,
57 pub style_strength: f32,
59 pub preserve_pitch_contour: bool,
61}
62
63impl Default for StyleConfig {
64 fn default() -> Self {
65 Self {
66 prosody_dim: 64,
67 timbre_dim: 128,
68 rhythm_dim: 32,
69 sample_rate: 16000,
70 frame_shift_ms: 10,
71 style_strength: 1.0,
72 preserve_pitch_contour: false,
73 }
74 }
75}
76
77impl StyleConfig {
78 #[must_use]
80 pub fn prosody_only() -> Self {
81 Self {
82 style_strength: 0.5,
83 preserve_pitch_contour: true,
84 ..Self::default()
85 }
86 }
87
88 #[must_use]
90 pub fn full_conversion() -> Self {
91 Self {
92 style_strength: 1.0,
93 preserve_pitch_contour: false,
94 ..Self::default()
95 }
96 }
97
98 pub fn validate(&self) -> VoiceResult<()> {
100 if self.prosody_dim == 0 {
101 return Err(VoiceError::InvalidConfig(
102 "prosody_dim must be > 0".to_string(),
103 ));
104 }
105 if self.timbre_dim == 0 {
106 return Err(VoiceError::InvalidConfig(
107 "timbre_dim must be > 0".to_string(),
108 ));
109 }
110 if self.rhythm_dim == 0 {
111 return Err(VoiceError::InvalidConfig(
112 "rhythm_dim must be > 0".to_string(),
113 ));
114 }
115 if self.sample_rate == 0 {
116 return Err(VoiceError::InvalidConfig(
117 "sample_rate must be > 0".to_string(),
118 ));
119 }
120 if !(0.0..=1.0).contains(&self.style_strength) {
121 return Err(VoiceError::InvalidConfig(
122 "style_strength must be in [0.0, 1.0]".to_string(),
123 ));
124 }
125 Ok(())
126 }
127
128 #[must_use]
130 pub fn total_dim(&self) -> usize {
131 self.prosody_dim + self.timbre_dim + self.rhythm_dim
132 }
133}
134
135#[derive(Debug, Clone)]
144pub struct StyleVector {
145 prosody: Vec<f32>,
147 timbre: Vec<f32>,
149 rhythm: Vec<f32>,
151}
152
153impl StyleVector {
154 #[must_use]
156 pub fn new(prosody: Vec<f32>, timbre: Vec<f32>, rhythm: Vec<f32>) -> Self {
157 Self {
158 prosody,
159 timbre,
160 rhythm,
161 }
162 }
163
164 #[must_use]
166 pub fn zeros(config: &StyleConfig) -> Self {
167 Self {
168 prosody: vec![0.0; config.prosody_dim],
169 timbre: vec![0.0; config.timbre_dim],
170 rhythm: vec![0.0; config.rhythm_dim],
171 }
172 }
173
174 pub fn from_flat(vector: &[f32], config: &StyleConfig) -> VoiceResult<Self> {
179 let expected_len = config.total_dim();
180 if vector.len() != expected_len {
181 return Err(VoiceError::DimensionMismatch {
182 expected: expected_len,
183 got: vector.len(),
184 });
185 }
186
187 let prosody_end = config.prosody_dim;
188 let timbre_end = prosody_end + config.timbre_dim;
189
190 Ok(Self {
191 prosody: vector[..prosody_end].to_vec(),
192 timbre: vector[prosody_end..timbre_end].to_vec(),
193 rhythm: vector[timbre_end..].to_vec(),
194 })
195 }
196
197 #[must_use]
199 pub fn prosody(&self) -> &[f32] {
200 &self.prosody
201 }
202
203 #[must_use]
205 pub fn timbre(&self) -> &[f32] {
206 &self.timbre
207 }
208
209 #[must_use]
211 pub fn rhythm(&self) -> &[f32] {
212 &self.rhythm
213 }
214
215 #[must_use]
217 pub fn dim(&self) -> usize {
218 self.prosody.len() + self.timbre.len() + self.rhythm.len()
219 }
220
221 #[must_use]
223 pub fn to_flat(&self) -> Vec<f32> {
224 let mut flat = Vec::with_capacity(self.dim());
225 flat.extend_from_slice(&self.prosody);
226 flat.extend_from_slice(&self.timbre);
227 flat.extend_from_slice(&self.rhythm);
228 flat
229 }
230
231 pub fn interpolate(&self, other: &Self, t: f32) -> VoiceResult<Self> {
240 if self.prosody.len() != other.prosody.len() {
241 return Err(VoiceError::DimensionMismatch {
242 expected: self.prosody.len(),
243 got: other.prosody.len(),
244 });
245 }
246 if self.timbre.len() != other.timbre.len() {
247 return Err(VoiceError::DimensionMismatch {
248 expected: self.timbre.len(),
249 got: other.timbre.len(),
250 });
251 }
252 if self.rhythm.len() != other.rhythm.len() {
253 return Err(VoiceError::DimensionMismatch {
254 expected: self.rhythm.len(),
255 got: other.rhythm.len(),
256 });
257 }
258
259 let t = t.clamp(0.0, 1.0);
260 let one_minus_t = 1.0 - t;
261
262 let prosody = self
263 .prosody
264 .iter()
265 .zip(other.prosody.iter())
266 .map(|(a, b)| a * one_minus_t + b * t)
267 .collect();
268
269 let timbre = self
270 .timbre
271 .iter()
272 .zip(other.timbre.iter())
273 .map(|(a, b)| a * one_minus_t + b * t)
274 .collect();
275
276 let rhythm = self
277 .rhythm
278 .iter()
279 .zip(other.rhythm.iter())
280 .map(|(a, b)| a * one_minus_t + b * t)
281 .collect();
282
283 Ok(Self {
284 prosody,
285 timbre,
286 rhythm,
287 })
288 }
289
290 #[must_use]
292 pub fn l2_norm(&self) -> f32 {
293 let sum_sq: f32 = self
294 .prosody
295 .iter()
296 .chain(self.timbre.iter())
297 .chain(self.rhythm.iter())
298 .map(|x| x * x)
299 .sum();
300 sum_sq.sqrt()
301 }
302
303 pub fn normalize(&mut self) {
305 let norm = self.l2_norm();
306 if norm > f32::EPSILON {
307 for x in &mut self.prosody {
308 *x /= norm;
309 }
310 for x in &mut self.timbre {
311 *x /= norm;
312 }
313 for x in &mut self.rhythm {
314 *x /= norm;
315 }
316 }
317 }
318}
319
320pub trait StyleEncoder {
326 fn encode(&self, audio: &[f32]) -> VoiceResult<StyleVector>;
334
335 fn config(&self) -> &StyleConfig;
337}
338
339pub trait StyleTransfer {
345 fn transfer(&self, source_audio: &[f32], target_style: &StyleVector) -> VoiceResult<Vec<f32>>;
357
358 fn transfer_from_reference(
365 &self,
366 source_audio: &[f32],
367 reference_audio: &[f32],
368 ) -> VoiceResult<Vec<f32>>;
369
370 fn config(&self) -> &StyleConfig;
372}
373
374#[derive(Debug)]
385pub struct GstEncoder {
386 config: StyleConfig,
387}
388
389impl GstEncoder {
390 #[must_use]
392 pub fn new(config: StyleConfig) -> Self {
393 Self { config }
394 }
395
396 #[must_use]
398 pub fn default_config() -> Self {
399 Self::new(StyleConfig::default())
400 }
401}
402
403impl StyleEncoder for GstEncoder {
404 fn encode(&self, audio: &[f32]) -> VoiceResult<StyleVector> {
405 if audio.is_empty() {
406 return Err(VoiceError::InvalidAudio("empty audio".to_string()));
407 }
408 Err(VoiceError::NotImplemented(
409 "GST encoder requires model weights".to_string(),
410 ))
411 }
412
413 fn config(&self) -> &StyleConfig {
414 &self.config
415 }
416}
417
418#[derive(Debug)]
425pub struct AutoVcTransfer {
426 config: StyleConfig,
427}
428
429impl AutoVcTransfer {
430 #[must_use]
432 pub fn new(config: StyleConfig) -> Self {
433 Self { config }
434 }
435
436 #[must_use]
438 pub fn default_config() -> Self {
439 Self::new(StyleConfig::default())
440 }
441}
442
443impl StyleTransfer for AutoVcTransfer {
444 fn transfer(&self, source_audio: &[f32], _target_style: &StyleVector) -> VoiceResult<Vec<f32>> {
445 if source_audio.is_empty() {
446 return Err(VoiceError::InvalidAudio("empty source audio".to_string()));
447 }
448 Err(VoiceError::NotImplemented(
449 "AutoVC requires model weights".to_string(),
450 ))
451 }
452
453 fn transfer_from_reference(
454 &self,
455 source_audio: &[f32],
456 reference_audio: &[f32],
457 ) -> VoiceResult<Vec<f32>> {
458 if source_audio.is_empty() {
459 return Err(VoiceError::InvalidAudio("empty source audio".to_string()));
460 }
461 if reference_audio.is_empty() {
462 return Err(VoiceError::InvalidAudio(
463 "empty reference audio".to_string(),
464 ));
465 }
466 Err(VoiceError::NotImplemented(
467 "AutoVC requires model weights".to_string(),
468 ))
469 }
470
471 fn config(&self) -> &StyleConfig {
472 &self.config
473 }
474}
475
476#[must_use]
484pub fn prosody_distance(a: &StyleVector, b: &StyleVector) -> f32 {
485 if a.prosody.len() != b.prosody.len() {
486 return f32::MAX;
487 }
488 if a.rhythm.len() != b.rhythm.len() {
489 return f32::MAX;
490 }
491
492 let prosody_dist: f32 = a
493 .prosody
494 .iter()
495 .zip(b.prosody.iter())
496 .map(|(x, y)| (x - y).powi(2))
497 .sum();
498
499 let rhythm_dist: f32 = a
500 .rhythm
501 .iter()
502 .zip(b.rhythm.iter())
503 .map(|(x, y)| (x - y).powi(2))
504 .sum();
505
506 (prosody_dist + rhythm_dist).sqrt()
507}
508
509#[must_use]
513pub fn timbre_distance(a: &StyleVector, b: &StyleVector) -> f32 {
514 if a.timbre.len() != b.timbre.len() {
515 return f32::MAX;
516 }
517
518 let dist: f32 = a
519 .timbre
520 .iter()
521 .zip(b.timbre.iter())
522 .map(|(x, y)| (x - y).powi(2))
523 .sum();
524
525 dist.sqrt()
526}
527
528#[must_use]
530pub fn style_distance(a: &StyleVector, b: &StyleVector) -> f32 {
531 if a.dim() != b.dim() {
532 return f32::MAX;
533 }
534
535 let flat_a = a.to_flat();
536 let flat_b = b.to_flat();
537
538 let dist: f32 = flat_a
539 .iter()
540 .zip(flat_b.iter())
541 .map(|(x, y)| (x - y).powi(2))
542 .sum();
543
544 dist.sqrt()
545}
546
547#[must_use]
552pub fn style_from_embedding(embedding: &SpeakerEmbedding, config: &StyleConfig) -> StyleVector {
553 let emb_slice = embedding.as_slice();
554 let emb_len = emb_slice.len();
555
556 let prosody_len = config.prosody_dim.min(emb_len);
558 let timbre_len = config.timbre_dim.min(emb_len.saturating_sub(prosody_len));
559 let rhythm_len = config
560 .rhythm_dim
561 .min(emb_len.saturating_sub(prosody_len + timbre_len));
562
563 let mut prosody = vec![0.0_f32; config.prosody_dim];
564 let mut timbre = vec![0.0_f32; config.timbre_dim];
565 let mut rhythm = vec![0.0_f32; config.rhythm_dim];
566
567 prosody[..prosody_len].copy_from_slice(&emb_slice[..prosody_len]);
569
570 if timbre_len > 0 {
571 timbre[..timbre_len].copy_from_slice(&emb_slice[prosody_len..prosody_len + timbre_len]);
572 }
573
574 if rhythm_len > 0 {
575 rhythm[..rhythm_len].copy_from_slice(
576 &emb_slice[prosody_len + timbre_len..prosody_len + timbre_len + rhythm_len],
577 );
578 }
579
580 StyleVector::new(prosody, timbre, rhythm)
581}
582
583pub fn average_styles(styles: &[StyleVector]) -> VoiceResult<StyleVector> {
588 if styles.is_empty() {
589 return Err(VoiceError::InvalidConfig(
590 "cannot average empty style list".to_string(),
591 ));
592 }
593
594 let first = &styles[0];
595 let prosody_len = first.prosody.len();
596 let timbre_len = first.timbre.len();
597 let rhythm_len = first.rhythm.len();
598
599 for style in styles.iter().skip(1) {
601 if style.prosody.len() != prosody_len {
602 return Err(VoiceError::DimensionMismatch {
603 expected: prosody_len,
604 got: style.prosody.len(),
605 });
606 }
607 if style.timbre.len() != timbre_len {
608 return Err(VoiceError::DimensionMismatch {
609 expected: timbre_len,
610 got: style.timbre.len(),
611 });
612 }
613 if style.rhythm.len() != rhythm_len {
614 return Err(VoiceError::DimensionMismatch {
615 expected: rhythm_len,
616 got: style.rhythm.len(),
617 });
618 }
619 }
620
621 let count = styles.len() as f32;
622
623 let mut prosody = vec![0.0_f32; prosody_len];
624 let mut timbre = vec![0.0_f32; timbre_len];
625 let mut rhythm = vec![0.0_f32; rhythm_len];
626
627 for style in styles {
628 for (i, &v) in style.prosody.iter().enumerate() {
629 prosody[i] += v / count;
630 }
631 for (i, &v) in style.timbre.iter().enumerate() {
632 timbre[i] += v / count;
633 }
634 for (i, &v) in style.rhythm.iter().enumerate() {
635 rhythm[i] += v / count;
636 }
637 }
638
639 Ok(StyleVector::new(prosody, timbre, rhythm))
640}
641
642#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_style_config_default() {
652 let config = StyleConfig::default();
653 assert_eq!(config.prosody_dim, 64);
654 assert_eq!(config.timbre_dim, 128);
655 assert_eq!(config.rhythm_dim, 32);
656 assert_eq!(config.total_dim(), 224);
657 assert!(config.validate().is_ok());
658 }
659
660 #[test]
661 fn test_style_config_prosody_only() {
662 let config = StyleConfig::prosody_only();
663 assert!(config.preserve_pitch_contour);
664 assert!((config.style_strength - 0.5).abs() < f32::EPSILON);
665 }
666
667 #[test]
668 fn test_style_config_full_conversion() {
669 let config = StyleConfig::full_conversion();
670 assert!(!config.preserve_pitch_contour);
671 assert!((config.style_strength - 1.0).abs() < f32::EPSILON);
672 }
673
674 #[test]
675 fn test_style_config_validation_prosody() {
676 let mut config = StyleConfig::default();
677 config.prosody_dim = 0;
678 assert!(config.validate().is_err());
679 }
680
681 #[test]
682 fn test_style_config_validation_strength() {
683 let mut config = StyleConfig::default();
684 config.style_strength = 1.5;
685 assert!(config.validate().is_err());
686
687 config.style_strength = -0.1;
688 assert!(config.validate().is_err());
689 }
690
691 #[test]
692 fn test_style_vector_new() {
693 let style = StyleVector::new(vec![1.0, 2.0], vec![3.0, 4.0, 5.0], vec![6.0]);
694 assert_eq!(style.prosody().len(), 2);
695 assert_eq!(style.timbre().len(), 3);
696 assert_eq!(style.rhythm().len(), 1);
697 assert_eq!(style.dim(), 6);
698 }
699
700 #[test]
701 fn test_style_vector_zeros() {
702 let config = StyleConfig::default();
703 let style = StyleVector::zeros(&config);
704 assert_eq!(style.prosody().len(), config.prosody_dim);
705 assert_eq!(style.timbre().len(), config.timbre_dim);
706 assert_eq!(style.rhythm().len(), config.rhythm_dim);
707 assert!((style.l2_norm()).abs() < f32::EPSILON);
708 }
709
710 #[test]
711 fn test_style_vector_from_flat() {
712 let config = StyleConfig {
713 prosody_dim: 2,
714 timbre_dim: 3,
715 rhythm_dim: 1,
716 ..StyleConfig::default()
717 };
718 let flat = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
719 let style = StyleVector::from_flat(&flat, &config).expect("from_flat failed");
720
721 assert_eq!(style.prosody(), &[1.0, 2.0]);
722 assert_eq!(style.timbre(), &[3.0, 4.0, 5.0]);
723 assert_eq!(style.rhythm(), &[6.0]);
724 assert_eq!(style.to_flat(), flat.to_vec());
725 }
726
727 #[test]
728 fn test_style_vector_from_flat_wrong_size() {
729 let config = StyleConfig::default();
730 let flat = [1.0, 2.0, 3.0]; assert!(StyleVector::from_flat(&flat, &config).is_err());
732 }
733
734 #[test]
735 fn test_style_vector_interpolate() {
736 let style_a = StyleVector::new(vec![0.0, 0.0], vec![0.0, 0.0, 0.0], vec![0.0]);
737 let style_b = StyleVector::new(vec![1.0, 1.0], vec![1.0, 1.0, 1.0], vec![1.0]);
738
739 let mid = style_a
740 .interpolate(&style_b, 0.5)
741 .expect("interpolate failed");
742 assert!((mid.prosody()[0] - 0.5).abs() < 1e-6);
743 assert!((mid.timbre()[0] - 0.5).abs() < 1e-6);
744 assert!((mid.rhythm()[0] - 0.5).abs() < 1e-6);
745
746 let start = style_a
748 .interpolate(&style_b, 0.0)
749 .expect("interpolate 0 failed");
750 assert!((start.prosody()[0] - 0.0).abs() < 1e-6);
751
752 let end = style_a
753 .interpolate(&style_b, 1.0)
754 .expect("interpolate 1 failed");
755 assert!((end.prosody()[0] - 1.0).abs() < 1e-6);
756 }
757
758 #[test]
759 fn test_style_vector_interpolate_dimension_mismatch() {
760 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
761 let style_b = StyleVector::new(vec![1.0, 1.0], vec![1.0], vec![1.0]);
762 assert!(style_a.interpolate(&style_b, 0.5).is_err());
763 }
764
765 #[test]
766 fn test_style_vector_l2_norm() {
767 let style = StyleVector::new(vec![3.0], vec![4.0], vec![0.0]);
768 assert!((style.l2_norm() - 5.0).abs() < 1e-6);
769 }
770
771 #[test]
772 fn test_style_vector_normalize() {
773 let mut style = StyleVector::new(vec![3.0], vec![4.0], vec![0.0]);
774 style.normalize();
775 assert!((style.l2_norm() - 1.0).abs() < 1e-6);
776 assert!((style.prosody()[0] - 0.6).abs() < 1e-6);
777 assert!((style.timbre()[0] - 0.8).abs() < 1e-6);
778 }
779
780 #[test]
781 fn test_prosody_distance() {
782 let style_a = StyleVector::new(vec![0.0, 0.0], vec![0.0], vec![0.0]);
783 let style_b = StyleVector::new(vec![3.0, 4.0], vec![0.0], vec![0.0]);
784 let dist = prosody_distance(&style_a, &style_b);
785 assert!((dist - 5.0).abs() < 1e-6);
786 }
787
788 #[test]
789 fn test_prosody_distance_mismatch() {
790 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
791 let style_b = StyleVector::new(vec![0.0, 0.0], vec![0.0], vec![0.0]);
792 assert_eq!(prosody_distance(&style_a, &style_b), f32::MAX);
793 }
794
795 #[test]
796 fn test_timbre_distance() {
797 let style_a = StyleVector::new(vec![0.0], vec![0.0, 0.0], vec![0.0]);
798 let style_b = StyleVector::new(vec![0.0], vec![3.0, 4.0], vec![0.0]);
799 let dist = timbre_distance(&style_a, &style_b);
800 assert!((dist - 5.0).abs() < 1e-6);
801 }
802
803 #[test]
804 fn test_style_distance() {
805 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
806 let style_b = StyleVector::new(vec![3.0], vec![4.0], vec![0.0]);
807 let dist = style_distance(&style_a, &style_b);
808 assert!((dist - 5.0).abs() < 1e-6);
809 }
810
811 #[test]
812 fn test_style_from_embedding() {
813 let config = StyleConfig {
814 prosody_dim: 64,
815 timbre_dim: 64,
816 rhythm_dim: 64,
817 ..StyleConfig::default()
818 };
819 let embedding = SpeakerEmbedding::from_vec(vec![1.0; 192]);
820 let style = style_from_embedding(&embedding, &config);
821
822 assert_eq!(style.prosody().len(), 64);
823 assert_eq!(style.timbre().len(), 64);
824 assert_eq!(style.rhythm().len(), 64);
825 assert!((style.prosody()[0] - 1.0).abs() < f32::EPSILON);
826 }
827
828 #[test]
829 fn test_style_from_embedding_small() {
830 let config = StyleConfig::default(); let embedding = SpeakerEmbedding::from_vec(vec![1.0; 100]); let style = style_from_embedding(&embedding, &config);
833
834 assert_eq!(style.dim(), config.total_dim());
836 }
837
838 #[test]
839 fn test_average_styles() {
840 let style_a = StyleVector::new(vec![0.0, 0.0], vec![0.0], vec![0.0]);
841 let style_b = StyleVector::new(vec![1.0, 1.0], vec![1.0], vec![1.0]);
842 let styles = vec![style_a, style_b];
843
844 let avg = average_styles(&styles).expect("average_styles failed");
845 assert!((avg.prosody()[0] - 0.5).abs() < 1e-6);
846 assert!((avg.timbre()[0] - 0.5).abs() < 1e-6);
847 assert!((avg.rhythm()[0] - 0.5).abs() < 1e-6);
848 }
849
850 #[test]
851 fn test_average_styles_empty() {
852 let styles: Vec<StyleVector> = vec![];
853 assert!(average_styles(&styles).is_err());
854 }
855
856 #[test]
857 fn test_average_styles_dimension_mismatch() {
858 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
859 let style_b = StyleVector::new(vec![1.0, 1.0], vec![1.0], vec![1.0]);
860 let styles = vec![style_a, style_b];
861 assert!(average_styles(&styles).is_err());
862 }
863
864 #[test]
865 fn test_gst_encoder_stub() {
866 let encoder = GstEncoder::default_config();
867 let audio = vec![0.0_f32; 16000];
868 assert!(encoder.encode(&audio).is_err());
869 }
870
871 #[test]
872 fn test_gst_encoder_empty_audio() {
873 let encoder = GstEncoder::default_config();
874 let result = encoder.encode(&[]);
875 assert!(matches!(result, Err(VoiceError::InvalidAudio(_))));
876 }
877
878 #[test]
879 fn test_autovc_transfer_stub() {
880 let transfer = AutoVcTransfer::default_config();
881 let source = vec![0.0_f32; 16000];
882 let style = StyleVector::zeros(&StyleConfig::default());
883 assert!(transfer.transfer(&source, &style).is_err());
884 }
885
886 #[test]
887 fn test_autovc_transfer_empty_source() {
888 let transfer = AutoVcTransfer::default_config();
889 let style = StyleVector::zeros(&StyleConfig::default());
890 let result = transfer.transfer(&[], &style);
891 assert!(matches!(result, Err(VoiceError::InvalidAudio(_))));
892 }
893
894 #[test]
895 fn test_autovc_transfer_from_reference() {
896 let transfer = AutoVcTransfer::default_config();
897 let source = vec![0.0_f32; 16000];
898 let reference = vec![0.0_f32; 16000];
899 assert!(transfer
900 .transfer_from_reference(&source, &reference)
901 .is_err());
902 }
903
904 #[test]
905 fn test_autovc_transfer_from_reference_empty() {
906 let transfer = AutoVcTransfer::default_config();
907 let source = vec![0.0_f32; 16000];
908
909 let result = transfer.transfer_from_reference(&[], &source);
910 assert!(matches!(result, Err(VoiceError::InvalidAudio(_))));
911
912 let result = transfer.transfer_from_reference(&source, &[]);
913 assert!(matches!(result, Err(VoiceError::InvalidAudio(_))));
914 }
915
916 #[test]
919 fn test_style_config_validate_timbre_zero() {
920 let config = StyleConfig {
921 timbre_dim: 0,
922 ..StyleConfig::default()
923 };
924 let err = config.validate().unwrap_err();
925 match err {
926 VoiceError::InvalidConfig(msg) => assert!(msg.contains("timbre_dim")),
927 other => panic!("Expected InvalidConfig, got {other:?}"),
928 }
929 }
930
931 #[test]
932 fn test_style_config_validate_rhythm_zero() {
933 let config = StyleConfig {
934 rhythm_dim: 0,
935 ..StyleConfig::default()
936 };
937 let err = config.validate().unwrap_err();
938 match err {
939 VoiceError::InvalidConfig(msg) => assert!(msg.contains("rhythm_dim")),
940 other => panic!("Expected InvalidConfig, got {other:?}"),
941 }
942 }
943
944 #[test]
945 fn test_style_config_validate_sample_rate_zero() {
946 let config = StyleConfig {
947 sample_rate: 0,
948 ..StyleConfig::default()
949 };
950 let err = config.validate().unwrap_err();
951 match err {
952 VoiceError::InvalidConfig(msg) => assert!(msg.contains("sample_rate")),
953 other => panic!("Expected InvalidConfig, got {other:?}"),
954 }
955 }
956
957 #[test]
958 fn test_style_config_debug_clone() {
959 let config = StyleConfig::default();
960 let cloned = config.clone();
961 let debug_str = format!("{config:?}");
962 assert!(!debug_str.is_empty());
963 assert_eq!(cloned.prosody_dim, config.prosody_dim);
964 }
965
966 #[test]
967 fn test_style_vector_interpolate_timbre_mismatch() {
968 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
969 let style_b = StyleVector::new(vec![0.0], vec![0.0, 0.0], vec![0.0]);
970 let result = style_a.interpolate(&style_b, 0.5);
971 assert!(result.is_err());
972 match result.unwrap_err() {
973 VoiceError::DimensionMismatch { expected, got } => {
974 assert_eq!(expected, 1);
975 assert_eq!(got, 2);
976 }
977 other => panic!("Expected DimensionMismatch, got {other:?}"),
978 }
979 }
980
981 #[test]
982 fn test_style_vector_interpolate_rhythm_mismatch() {
983 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
984 let style_b = StyleVector::new(vec![0.0], vec![0.0], vec![0.0, 0.0]);
985 let result = style_a.interpolate(&style_b, 0.5);
986 assert!(result.is_err());
987 match result.unwrap_err() {
988 VoiceError::DimensionMismatch { expected, got } => {
989 assert_eq!(expected, 1);
990 assert_eq!(got, 2);
991 }
992 other => panic!("Expected DimensionMismatch, got {other:?}"),
993 }
994 }
995
996 #[test]
997 fn test_style_vector_interpolate_clamp_beyond_range() {
998 let style_a = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
999 let style_b = StyleVector::new(vec![1.0], vec![1.0], vec![1.0]);
1000
1001 let result = style_a.interpolate(&style_b, -0.5).expect("clamp low");
1003 assert!((result.prosody()[0] - 0.0).abs() < 1e-6);
1004
1005 let result = style_a.interpolate(&style_b, 1.5).expect("clamp high");
1007 assert!((result.prosody()[0] - 1.0).abs() < 1e-6);
1008 }
1009
1010 #[test]
1011 fn test_style_vector_normalize_zero_vector() {
1012 let mut style = StyleVector::new(vec![0.0], vec![0.0], vec![0.0]);
1013 style.normalize();
1014 assert_eq!(style.l2_norm(), 0.0);
1016 assert!((style.prosody()[0] - 0.0).abs() < f32::EPSILON);
1017 }
1018
1019 #[test]
1020 fn test_style_vector_debug_clone() {
1021 let style = StyleVector::new(vec![1.0, 2.0], vec![3.0], vec![4.0]);
1022 let cloned = style.clone();
1023 let debug_str = format!("{style:?}");
1024 assert!(!debug_str.is_empty());
1025 assert_eq!(cloned.dim(), style.dim());
1026 }
1027
1028 #[test]
1029 fn test_style_vector_to_flat_directly() {
1030 let style = StyleVector::new(vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0]);
1031 let flat = style.to_flat();
1032 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1033 }
1034
1035 #[test]
1036 fn test_style_distance_dimension_mismatch() {
1037 let a = StyleVector::new(vec![1.0], vec![2.0], vec![3.0]);
1038 let b = StyleVector::new(vec![1.0, 2.0], vec![3.0], vec![4.0]);
1039 let dist = style_distance(&a, &b);
1040 assert_eq!(dist, f32::MAX);
1041 }
1042
1043 #[test]
1044 fn test_style_distance_same() {
1045 let a = StyleVector::new(vec![1.0, 2.0], vec![3.0], vec![4.0]);
1046 let dist = style_distance(&a, &a);
1047 assert!((dist - 0.0).abs() < 1e-6);
1048 }
1049
1050 #[test]
1051 fn test_timbre_distance_mismatch() {
1052 let a = StyleVector::new(vec![0.0], vec![1.0], vec![0.0]);
1053 let b = StyleVector::new(vec![0.0], vec![1.0, 2.0], vec![0.0]);
1054 let dist = timbre_distance(&a, &b);
1055 assert_eq!(dist, f32::MAX);
1056 }
1057
1058 #[test]
1059 fn test_timbre_distance_same() {
1060 let a = StyleVector::new(vec![0.0], vec![3.0, 4.0], vec![0.0]);
1061 let dist = timbre_distance(&a, &a);
1062 assert!((dist - 0.0).abs() < 1e-6);
1063 }
1064
1065 #[test]
1066 fn test_prosody_distance_rhythm_mismatch() {
1067 let a = StyleVector::new(vec![0.0], vec![0.0], vec![1.0]);
1068 let b = StyleVector::new(vec![0.0], vec![0.0], vec![1.0, 2.0]);
1069 let dist = prosody_distance(&a, &b);
1070 assert_eq!(dist, f32::MAX);
1071 }
1072
1073 #[test]
1074 fn test_average_styles_timbre_mismatch() {
1075 let styles = vec![
1076 StyleVector::new(vec![0.0], vec![0.0], vec![0.0]),
1077 StyleVector::new(vec![0.0], vec![0.0, 0.0], vec![0.0]),
1078 ];
1079 let result = average_styles(&styles);
1080 assert!(result.is_err());
1081 match result.unwrap_err() {
1082 VoiceError::DimensionMismatch { expected, got } => {
1083 assert_eq!(expected, 1);
1084 assert_eq!(got, 2);
1085 }
1086 other => panic!("Expected DimensionMismatch, got {other:?}"),
1087 }
1088 }
1089
1090 #[test]
1091 fn test_average_styles_rhythm_mismatch() {
1092 let styles = vec![
1093 StyleVector::new(vec![0.0], vec![0.0], vec![0.0]),
1094 StyleVector::new(vec![0.0], vec![0.0], vec![0.0, 0.0]),
1095 ];
1096 let result = average_styles(&styles);
1097 assert!(result.is_err());
1098 match result.unwrap_err() {
1099 VoiceError::DimensionMismatch { expected, got } => {
1100 assert_eq!(expected, 1);
1101 assert_eq!(got, 2);
1102 }
1103 other => panic!("Expected DimensionMismatch, got {other:?}"),
1104 }
1105 }
1106
1107 #[test]
1108 fn test_average_styles_single() {
1109 let styles = vec![StyleVector::new(vec![3.0, 4.0], vec![5.0], vec![6.0])];
1110 let avg = average_styles(&styles).expect("single avg");
1111 assert!((avg.prosody()[0] - 3.0).abs() < 1e-6);
1112 assert!((avg.timbre()[0] - 5.0).abs() < 1e-6);
1113 assert!((avg.rhythm()[0] - 6.0).abs() < 1e-6);
1114 }
1115
1116 #[test]
1117 fn test_gst_encoder_new_direct() {
1118 let config = StyleConfig {
1119 prosody_dim: 32,
1120 ..StyleConfig::default()
1121 };
1122 let encoder = GstEncoder::new(config);
1123 assert_eq!(encoder.config().prosody_dim, 32);
1124 }
1125
1126 #[test]
1127 fn test_gst_encoder_debug() {
1128 let encoder = GstEncoder::default_config();
1129 let debug_str = format!("{encoder:?}");
1130 assert!(!debug_str.is_empty());
1131 }
1132
1133 #[test]
1134 fn test_gst_encoder_config_accessor() {
1135 let encoder = GstEncoder::default_config();
1136 let config = encoder.config();
1137 assert_eq!(config.prosody_dim, 64);
1138 assert_eq!(config.timbre_dim, 128);
1139 }
1140
1141 #[test]
1142 fn test_autovc_new_direct() {
1143 let config = StyleConfig {
1144 timbre_dim: 64,
1145 ..StyleConfig::default()
1146 };
1147 let transfer = AutoVcTransfer::new(config);
1148 assert_eq!(transfer.config().timbre_dim, 64);
1149 }
1150
1151 #[test]
1152 fn test_autovc_debug() {
1153 let transfer = AutoVcTransfer::default_config();
1154 let debug_str = format!("{transfer:?}");
1155 assert!(!debug_str.is_empty());
1156 }
1157
1158 #[test]
1159 fn test_autovc_config_accessor() {
1160 let transfer = AutoVcTransfer::default_config();
1161 let config = transfer.config();
1162 assert_eq!(config.prosody_dim, 64);
1163 assert_eq!(config.sample_rate, 16000);
1164 }
1165
1166 #[test]
1167 fn test_style_from_embedding_zero_length() {
1168 let config = StyleConfig::default();
1169 let embedding = SpeakerEmbedding::from_vec(vec![]);
1170 let style = style_from_embedding(&embedding, &config);
1171 assert_eq!(style.dim(), config.total_dim());
1173 for &v in style.prosody() {
1174 assert!((v - 0.0).abs() < f32::EPSILON);
1175 }
1176 }
1177
1178 #[test]
1179 fn test_style_from_embedding_exact_match() {
1180 let config = StyleConfig {
1181 prosody_dim: 2,
1182 timbre_dim: 2,
1183 rhythm_dim: 2,
1184 ..StyleConfig::default()
1185 };
1186 let embedding = SpeakerEmbedding::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1187 let style = style_from_embedding(&embedding, &config);
1188 assert_eq!(style.prosody(), &[1.0, 2.0]);
1189 assert_eq!(style.timbre(), &[3.0, 4.0]);
1190 assert_eq!(style.rhythm(), &[5.0, 6.0]);
1191 }
1192
1193 #[test]
1194 fn test_style_from_embedding_larger_than_needed() {
1195 let config = StyleConfig {
1196 prosody_dim: 2,
1197 timbre_dim: 2,
1198 rhythm_dim: 2,
1199 ..StyleConfig::default()
1200 };
1201 let embedding = SpeakerEmbedding::from_vec(vec![1.0; 100]);
1202 let style = style_from_embedding(&embedding, &config);
1203 assert_eq!(style.prosody().len(), 2);
1204 assert_eq!(style.timbre().len(), 2);
1205 assert_eq!(style.rhythm().len(), 2);
1206 }
1207
1208 #[test]
1209 fn test_style_vector_from_flat_round_trip() {
1210 let config = StyleConfig {
1211 prosody_dim: 3,
1212 timbre_dim: 4,
1213 rhythm_dim: 2,
1214 ..StyleConfig::default()
1215 };
1216 let original = StyleVector::new(
1217 vec![1.0, 2.0, 3.0],
1218 vec![4.0, 5.0, 6.0, 7.0],
1219 vec![8.0, 9.0],
1220 );
1221 let flat = original.to_flat();
1222 let reconstructed = StyleVector::from_flat(&flat, &config).expect("from_flat");
1223 assert_eq!(reconstructed.prosody(), original.prosody());
1224 assert_eq!(reconstructed.timbre(), original.timbre());
1225 assert_eq!(reconstructed.rhythm(), original.rhythm());
1226 }
1227
1228 #[test]
1229 fn test_prosody_distance_zero() {
1230 let a = StyleVector::new(vec![1.0, 2.0], vec![0.0], vec![3.0]);
1231 let dist = prosody_distance(&a, &a);
1232 assert!((dist - 0.0).abs() < 1e-6);
1233 }
1234
1235 #[test]
1236 fn test_style_config_total_dim() {
1237 let config = StyleConfig {
1238 prosody_dim: 10,
1239 timbre_dim: 20,
1240 rhythm_dim: 30,
1241 ..StyleConfig::default()
1242 };
1243 assert_eq!(config.total_dim(), 60);
1244 }
1245
1246 #[test]
1247 fn test_style_vector_l2_norm_all_components() {
1248 let style = StyleVector::new(vec![1.0, 0.0], vec![0.0, 2.0], vec![0.0]);
1250 let expected_norm = (1.0_f32 + 4.0).sqrt(); assert!((style.l2_norm() - expected_norm).abs() < 1e-6);
1252 }
1253
1254 #[test]
1255 fn test_average_styles_three() {
1256 let styles = vec![
1257 StyleVector::new(vec![3.0], vec![6.0], vec![9.0]),
1258 StyleVector::new(vec![0.0], vec![0.0], vec![0.0]),
1259 StyleVector::new(vec![0.0], vec![0.0], vec![0.0]),
1260 ];
1261 let avg = average_styles(&styles).expect("three avg");
1262 assert!((avg.prosody()[0] - 1.0).abs() < 1e-6);
1263 assert!((avg.timbre()[0] - 2.0).abs() < 1e-6);
1264 assert!((avg.rhythm()[0] - 3.0).abs() < 1e-6);
1265 }
1266}