1use super::{SyntheticConfig, SyntheticGenerator};
12use crate::error::Result;
13
14pub trait Embeddable: Clone {
47 fn embedding(&self) -> &[f32];
49
50 fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self;
52
53 fn embedding_dim(&self) -> usize {
55 self.embedding().len()
56 }
57}
58
59#[derive(Debug, Clone)]
65pub struct MixUpConfig {
66 pub alpha: f32,
68 pub cross_class_only: bool,
70 pub lambda_min: f32,
72 pub lambda_max: f32,
74}
75
76impl Default for MixUpConfig {
77 fn default() -> Self {
78 Self {
79 alpha: 0.4,
80 cross_class_only: false,
81 lambda_min: 0.2,
82 lambda_max: 0.8,
83 }
84 }
85}
86
87impl MixUpConfig {
88 #[must_use]
90 pub fn new() -> Self {
91 Self::default()
92 }
93
94 #[must_use]
96 pub fn with_alpha(mut self, alpha: f32) -> Self {
97 self.alpha = alpha.max(0.01);
98 self
99 }
100
101 #[must_use]
103 pub fn with_cross_class_only(mut self, enabled: bool) -> Self {
104 self.cross_class_only = enabled;
105 self
106 }
107
108 #[must_use]
110 pub fn with_lambda_range(mut self, min: f32, max: f32) -> Self {
111 self.lambda_min = min.clamp(0.0, 1.0);
112 self.lambda_max = max.clamp(0.0, 1.0);
113 if self.lambda_min > self.lambda_max {
114 std::mem::swap(&mut self.lambda_min, &mut self.lambda_max);
115 }
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
126struct SimpleRng {
127 state: u64,
128}
129
130impl SimpleRng {
131 fn new(seed: u64) -> Self {
132 Self {
133 state: seed.wrapping_add(1),
134 }
135 }
136
137 fn next_u64(&mut self) -> u64 {
138 self.state = self.state.wrapping_mul(6_364_136_223_846_793_005);
140 self.state = self.state.wrapping_add(1_442_695_040_888_963_407);
141 self.state
142 }
143
144 fn next_f32(&mut self) -> f32 {
145 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
146 }
147
148 fn next_usize(&mut self, max: usize) -> usize {
149 if max == 0 {
150 return 0;
151 }
152 (self.next_u64() as usize) % max
153 }
154
155 fn beta(&mut self, alpha: f32) -> f32 {
158 let u = self.next_f32().max(0.001);
162 let a = alpha;
163 let b = alpha;
164
165 let x = (1.0 - (1.0 - u).powf(1.0 / b)).powf(1.0 / a);
167 x.clamp(0.0, 1.0)
168 }
169}
170
171#[derive(Debug, Clone)]
207pub struct MixUpGenerator<T: Embeddable> {
208 config: MixUpConfig,
209 _phantom: std::marker::PhantomData<T>,
210}
211
212impl<T: Embeddable> MixUpGenerator<T> {
213 #[must_use]
215 pub fn new() -> Self {
216 Self {
217 config: MixUpConfig::default(),
218 _phantom: std::marker::PhantomData,
219 }
220 }
221
222 #[must_use]
224 pub fn with_config(mut self, config: MixUpConfig) -> Self {
225 self.config = config;
226 self
227 }
228
229 fn interpolate(e1: &[f32], e2: &[f32], lambda: f32) -> Vec<f32> {
231 e1.iter()
232 .zip(e2.iter())
233 .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
234 .collect()
235 }
236
237 fn cosine_similarity(e1: &[f32], e2: &[f32]) -> f32 {
239 if e1.len() != e2.len() || e1.is_empty() {
240 return 0.0;
241 }
242
243 let dot: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
244 let norm1: f32 = e1.iter().map(|x| x * x).sum::<f32>().sqrt();
245 let norm2: f32 = e2.iter().map(|x| x * x).sum::<f32>().sqrt();
246
247 if norm1 < f32::EPSILON || norm2 < f32::EPSILON {
248 return 0.0;
249 }
250
251 (dot / (norm1 * norm2)).clamp(-1.0, 1.0)
252 }
253
254 fn embedding_variance(embeddings: &[Vec<f32>]) -> f32 {
256 if embeddings.is_empty() || embeddings[0].is_empty() {
257 return 0.0;
258 }
259
260 let dim = embeddings[0].len();
261 let n = embeddings.len() as f32;
262
263 let mut mean = vec![0.0; dim];
265 for emb in embeddings {
266 for (i, &v) in emb.iter().enumerate() {
267 mean[i] += v / n;
268 }
269 }
270
271 let mut var = 0.0;
273 for emb in embeddings {
274 for (i, &v) in emb.iter().enumerate() {
275 var += (v - mean[i]).powi(2);
276 }
277 }
278
279 var / (n * dim as f32)
280 }
281}
282
283impl<T: Embeddable> Default for MixUpGenerator<T> {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl<T: Embeddable + std::fmt::Debug> SyntheticGenerator for MixUpGenerator<T> {
290 type Input = T;
291 type Output = T;
292
293 fn generate(&self, seeds: &[T], config: &SyntheticConfig) -> Result<Vec<T>> {
294 if seeds.len() < 2 {
295 return Ok(Vec::new());
296 }
297
298 let target = config.target_count(seeds.len());
299 let mut results = Vec::with_capacity(target);
300 let mut rng = SimpleRng::new(config.seed);
301
302 let mut attempts = 0;
303 let max_attempts = target * config.max_attempts;
304
305 while results.len() < target && attempts < max_attempts {
306 attempts += 1;
307
308 let i = rng.next_usize(seeds.len());
310 let mut j = rng.next_usize(seeds.len());
311 if j == i {
312 j = (j + 1) % seeds.len();
313 }
314
315 let raw_lambda = rng.beta(self.config.alpha);
317 let lambda = self.config.lambda_min
318 + raw_lambda * (self.config.lambda_max - self.config.lambda_min);
319
320 let e1 = seeds[i].embedding();
322 let e2 = seeds[j].embedding();
323
324 if e1.len() != e2.len() || e1.is_empty() {
325 continue;
326 }
327
328 let mixed_embedding = Self::interpolate(e1, e2, lambda);
329
330 let mixed = T::from_embedding(mixed_embedding, &seeds[i]);
332
333 let quality = self.quality_score(&mixed, &seeds[i]);
335 if config.meets_quality(quality) {
336 results.push(mixed);
337 }
338 }
339
340 Ok(results)
341 }
342
343 fn quality_score(&self, generated: &T, seed: &T) -> f32 {
344 let sim = Self::cosine_similarity(generated.embedding(), seed.embedding());
346
347 let quality = if sim < 0.3 {
350 sim / 0.3 * 0.5 } else if sim > 0.9 {
352 1.0 - (sim - 0.9) / 0.1 * 0.5 } else {
354 0.5 + (sim - 0.3) / 0.6 * 0.5 };
356
357 quality.clamp(0.0, 1.0)
358 }
359
360 fn diversity_score(&self, batch: &[T]) -> f32 {
361 if batch.is_empty() {
362 return 0.0;
363 }
364
365 let embeddings: Vec<Vec<f32>> = batch.iter().map(|s| s.embedding().to_vec()).collect();
366
367 let variance = Self::embedding_variance(&embeddings);
369
370 (variance * 10.0).min(1.0)
372 }
373}
374
375#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[derive(Clone, Debug, PartialEq)]
385 struct TestSample {
386 embedding: Vec<f32>,
387 label: i32,
388 }
389
390 impl TestSample {
391 fn new(embedding: Vec<f32>, label: i32) -> Self {
392 Self { embedding, label }
393 }
394 }
395
396 impl Embeddable for TestSample {
397 fn embedding(&self) -> &[f32] {
398 &self.embedding
399 }
400
401 fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self {
402 Self {
403 embedding,
404 label: reference.label, }
406 }
407 }
408
409 #[test]
414 fn test_config_default() {
415 let config = MixUpConfig::default();
416 assert!((config.alpha - 0.4).abs() < f32::EPSILON);
417 assert!(!config.cross_class_only);
418 assert!((config.lambda_min - 0.2).abs() < f32::EPSILON);
419 assert!((config.lambda_max - 0.8).abs() < f32::EPSILON);
420 }
421
422 #[test]
423 fn test_config_with_alpha() {
424 let config = MixUpConfig::new().with_alpha(1.0);
425 assert!((config.alpha - 1.0).abs() < f32::EPSILON);
426
427 let config = MixUpConfig::new().with_alpha(-1.0);
429 assert!((config.alpha - 0.01).abs() < f32::EPSILON);
430 }
431
432 #[test]
433 fn test_config_with_cross_class() {
434 let config = MixUpConfig::new().with_cross_class_only(true);
435 assert!(config.cross_class_only);
436 }
437
438 #[test]
439 fn test_config_with_lambda_range() {
440 let config = MixUpConfig::new().with_lambda_range(0.3, 0.7);
441 assert!((config.lambda_min - 0.3).abs() < f32::EPSILON);
442 assert!((config.lambda_max - 0.7).abs() < f32::EPSILON);
443
444 let config = MixUpConfig::new().with_lambda_range(0.8, 0.2);
446 assert!((config.lambda_min - 0.2).abs() < f32::EPSILON);
447 assert!((config.lambda_max - 0.8).abs() < f32::EPSILON);
448
449 let config = MixUpConfig::new().with_lambda_range(-0.5, 1.5);
451 assert!((config.lambda_min - 0.0).abs() < f32::EPSILON);
452 assert!((config.lambda_max - 1.0).abs() < f32::EPSILON);
453 }
454
455 #[test]
460 fn test_rng_deterministic() {
461 let mut rng1 = SimpleRng::new(42);
462 let mut rng2 = SimpleRng::new(42);
463
464 for _ in 0..10 {
465 assert_eq!(rng1.next_u64(), rng2.next_u64());
466 }
467 }
468
469 #[test]
470 fn test_rng_different_seeds() {
471 let mut rng1 = SimpleRng::new(42);
472 let mut rng2 = SimpleRng::new(43);
473
474 assert_ne!(rng1.next_u64(), rng2.next_u64());
475 }
476
477 #[test]
478 fn test_rng_f32_range() {
479 let mut rng = SimpleRng::new(12345);
480
481 for _ in 0..100 {
482 let v = rng.next_f32();
483 assert!((0.0..=1.0).contains(&v));
484 }
485 }
486
487 #[test]
488 fn test_rng_usize_range() {
489 let mut rng = SimpleRng::new(12345);
490
491 for _ in 0..100 {
492 let v = rng.next_usize(10);
493 assert!(v < 10);
494 }
495
496 assert_eq!(rng.next_usize(0), 0);
498 }
499
500 #[test]
501 fn test_rng_beta_range() {
502 let mut rng = SimpleRng::new(12345);
503
504 for alpha in [0.1, 0.5, 1.0, 2.0] {
505 for _ in 0..50 {
506 let v = rng.beta(alpha);
507 assert!((0.0..=1.0).contains(&v));
508 }
509 }
510 }
511
512 #[test]
517 fn test_generator_new() {
518 let gen = MixUpGenerator::<TestSample>::new();
519 assert!((gen.config.alpha - 0.4).abs() < f32::EPSILON);
520 }
521
522 #[test]
523 fn test_generator_with_config() {
524 let config = MixUpConfig::new().with_alpha(1.0);
525 let gen = MixUpGenerator::<TestSample>::new().with_config(config);
526 assert!((gen.config.alpha - 1.0).abs() < f32::EPSILON);
527 }
528
529 #[test]
530 fn test_generator_default() {
531 let gen = MixUpGenerator::<TestSample>::default();
532 assert!((gen.config.alpha - 0.4).abs() < f32::EPSILON);
533 }
534
535 #[test]
536 fn test_interpolate() {
537 let e1 = vec![1.0, 0.0, 0.0];
538 let e2 = vec![0.0, 1.0, 0.0];
539
540 let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 0.5);
542 assert!((result[0] - 0.5).abs() < f32::EPSILON);
543 assert!((result[1] - 0.5).abs() < f32::EPSILON);
544 assert!((result[2] - 0.0).abs() < f32::EPSILON);
545
546 let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 1.0);
548 assert!((result[0] - 1.0).abs() < f32::EPSILON);
549 assert!((result[1] - 0.0).abs() < f32::EPSILON);
550
551 let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 0.0);
553 assert!((result[0] - 0.0).abs() < f32::EPSILON);
554 assert!((result[1] - 1.0).abs() < f32::EPSILON);
555 }
556
557 #[test]
558 fn test_cosine_similarity() {
559 let e1 = vec![1.0, 0.0, 0.0];
561 let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e1);
562 assert!((sim - 1.0).abs() < 0.001);
563
564 let e2 = vec![0.0, 1.0, 0.0];
566 let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e2);
567 assert!(sim.abs() < 0.001);
568
569 let e3 = vec![-1.0, 0.0, 0.0];
571 let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e3);
572 assert!((sim - (-1.0)).abs() < 0.001);
573
574 let empty: Vec<f32> = vec![];
576 let sim = MixUpGenerator::<TestSample>::cosine_similarity(&empty, &empty);
577 assert!((sim - 0.0).abs() < f32::EPSILON);
578
579 let e4 = vec![1.0, 0.0];
581 let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e4);
582 assert!((sim - 0.0).abs() < f32::EPSILON);
583 }
584
585 #[test]
586 fn test_embedding_variance() {
587 let embeddings = vec![vec![1.0, 0.0]];
589 let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
590 assert!((var - 0.0).abs() < f32::EPSILON);
591
592 let embeddings = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
594 let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
595 assert!((var - 0.0).abs() < f32::EPSILON);
596
597 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
599 let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
600 assert!(var > 0.0);
601
602 let var = MixUpGenerator::<TestSample>::embedding_variance(&[]);
604 assert!((var - 0.0).abs() < f32::EPSILON);
605 }
606
607 #[test]
608 fn test_generate_basic() {
609 let gen = MixUpGenerator::<TestSample>::new();
610 let seeds = vec![
611 TestSample::new(vec![1.0, 0.0, 0.0], 0),
612 TestSample::new(vec![0.0, 1.0, 0.0], 1),
613 TestSample::new(vec![0.0, 0.0, 1.0], 2),
614 ];
615
616 let config = SyntheticConfig::default()
617 .with_augmentation_ratio(1.0)
618 .with_quality_threshold(0.1);
619
620 let result = gen.generate(&seeds, &config).expect("generation failed");
621
622 assert!(!result.is_empty());
624
625 for sample in &result {
627 assert_eq!(sample.embedding().len(), 3);
628 }
629 }
630
631 #[test]
632 fn test_generate_insufficient_seeds() {
633 let gen = MixUpGenerator::<TestSample>::new();
634
635 let result = gen
637 .generate(&[], &SyntheticConfig::default())
638 .expect("should succeed");
639 assert!(result.is_empty());
640
641 let seeds = vec![TestSample::new(vec![1.0, 0.0], 0)];
643 let result = gen
644 .generate(&seeds, &SyntheticConfig::default())
645 .expect("should succeed");
646 assert!(result.is_empty());
647 }
648
649 #[test]
650 fn test_generate_respects_target() {
651 let gen = MixUpGenerator::<TestSample>::new();
652 let seeds = vec![
653 TestSample::new(vec![1.0, 0.0], 0),
654 TestSample::new(vec![0.0, 1.0], 1),
655 ];
656
657 let config = SyntheticConfig::default()
658 .with_augmentation_ratio(2.0) .with_quality_threshold(0.0); let result = gen.generate(&seeds, &config).expect("generation failed");
662
663 assert!(result.len() <= 4);
665 }
666
667 #[test]
668 fn test_generate_deterministic() {
669 let gen = MixUpGenerator::<TestSample>::new();
670 let seeds = vec![
671 TestSample::new(vec![1.0, 0.0, 0.0], 0),
672 TestSample::new(vec![0.0, 1.0, 0.0], 1),
673 ];
674
675 let config = SyntheticConfig::default()
676 .with_augmentation_ratio(1.0)
677 .with_quality_threshold(0.1)
678 .with_seed(12345);
679
680 let result1 = gen.generate(&seeds, &config).expect("generation failed");
681 let result2 = gen.generate(&seeds, &config).expect("generation failed");
682
683 assert_eq!(result1.len(), result2.len());
684 for (r1, r2) in result1.iter().zip(result2.iter()) {
685 assert_eq!(r1.embedding(), r2.embedding());
686 }
687 }
688
689 #[test]
690 fn test_quality_score() {
691 let gen = MixUpGenerator::<TestSample>::new();
692
693 let seed = TestSample::new(vec![1.0, 0.0, 0.0], 0);
694
695 let identical = TestSample::new(vec![1.0, 0.0, 0.0], 0);
697 let score = gen.quality_score(&identical, &seed);
698 assert!(score < 1.0); let similar = TestSample::new(vec![0.8, 0.2, 0.0], 0);
702 let score = gen.quality_score(&similar, &seed);
703 assert!(score > 0.3);
704
705 let different = TestSample::new(vec![0.0, 0.0, 1.0], 0);
707 let score = gen.quality_score(&different, &seed);
708 assert!((0.0..=1.0).contains(&score));
709 }
710
711 #[test]
712 fn test_diversity_score() {
713 let gen = MixUpGenerator::<TestSample>::new();
714
715 let score = gen.diversity_score(&[]);
717 assert!((score - 0.0).abs() < f32::EPSILON);
718
719 let single = vec![TestSample::new(vec![1.0, 0.0], 0)];
721 let score = gen.diversity_score(&single);
722 assert!((score - 0.0).abs() < f32::EPSILON);
723
724 let diverse = vec![
726 TestSample::new(vec![1.0, 0.0], 0),
727 TestSample::new(vec![0.0, 1.0], 1),
728 TestSample::new(vec![-1.0, 0.0], 2),
729 ];
730 let score = gen.diversity_score(&diverse);
731 assert!(score > 0.0);
732
733 let homogeneous = vec![
735 TestSample::new(vec![1.0, 0.0], 0),
736 TestSample::new(vec![1.0, 0.0], 0),
737 TestSample::new(vec![1.0, 0.0], 0),
738 ];
739 let homo_score = gen.diversity_score(&homogeneous);
740 assert!(homo_score < score);
741 }
742
743 #[test]
744 fn test_generate_mixed_embeddings() {
745 let gen = MixUpGenerator::<TestSample>::new();
746 let seeds = vec![
747 TestSample::new(vec![1.0, 0.0], 0),
748 TestSample::new(vec![0.0, 1.0], 1),
749 ];
750
751 let config = SyntheticConfig::default()
752 .with_augmentation_ratio(2.0)
753 .with_quality_threshold(0.0)
754 .with_seed(42);
755
756 let result = gen.generate(&seeds, &config).expect("generation failed");
757
758 for sample in &result {
760 let e = sample.embedding();
761 assert!((0.0..=1.0).contains(&e[0]));
763 assert!((0.0..=1.0).contains(&e[1]));
764 }
765 }
766
767 #[test]
768 fn test_embeddable_trait() {
769 let sample = TestSample::new(vec![1.0, 2.0, 3.0], 5);
770
771 assert_eq!(sample.embedding(), &[1.0, 2.0, 3.0]);
772 assert_eq!(sample.embedding_dim(), 3);
773
774 let new_emb = vec![4.0, 5.0, 6.0];
775 let new_sample = TestSample::from_embedding(new_emb.clone(), &sample);
776 assert_eq!(new_sample.embedding(), &[4.0, 5.0, 6.0]);
777 assert_eq!(new_sample.label, 5); }
779
780 #[test]
785 fn test_full_mixup_pipeline() {
786 let gen = MixUpGenerator::new().with_config(
787 MixUpConfig::new()
788 .with_alpha(0.5)
789 .with_lambda_range(0.3, 0.7),
790 );
791
792 let seeds = vec![
794 TestSample::new(vec![1.0, 0.0, 0.0, 0.0], 0),
795 TestSample::new(vec![0.0, 1.0, 0.0, 0.0], 1),
796 TestSample::new(vec![0.0, 0.0, 1.0, 0.0], 2),
797 TestSample::new(vec![0.0, 0.0, 0.0, 1.0], 3),
798 ];
799
800 let config = SyntheticConfig::default()
801 .with_augmentation_ratio(2.0)
802 .with_quality_threshold(0.2)
803 .with_seed(9999);
804
805 let synthetic = gen.generate(&seeds, &config).expect("generation failed");
806
807 for sample in &synthetic {
809 assert_eq!(sample.embedding_dim(), 4);
811
812 let quality = gen.quality_score(sample, &seeds[0]);
814 assert!(quality >= 0.0);
815 }
816
817 if !synthetic.is_empty() {
819 let diversity = gen.diversity_score(&synthetic);
820 assert!((0.0..=1.0).contains(&diversity));
821 }
822 }
823}