Skip to main content

aprender/synthetic/
mixup.rs

1//! `MixUp` Data Augmentation.
2//!
3//! Implements `MixUp` (Zhang et al., 2018) for creating synthetic samples
4//! via convex combinations in embedding space.
5//!
6//! # References
7//!
8//! Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2018).
9//! mixup: Beyond Empirical Risk Minimization. ICLR.
10
11use super::{SyntheticConfig, SyntheticGenerator};
12use crate::error::Result;
13
14// ============================================================================
15// Embedding Trait
16// ============================================================================
17
18/// Trait for types that can be embedded into a vector space.
19///
20/// Implement this trait to enable `MixUp` interpolation for your data type.
21///
22/// # Example
23///
24/// ```
25/// use aprender::synthetic::mixup::Embeddable;
26///
27/// #[derive(Clone)]
28/// struct TextSample {
29///     text: String,
30///     embedding: Vec<f32>,
31/// }
32///
33/// impl Embeddable for TextSample {
34///     fn embedding(&self) -> &[f32] {
35///         &self.embedding
36///     }
37///
38///     fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self {
39///         TextSample {
40///             text: format!("[mixup from: {}]", reference.text),
41///             embedding,
42///         }
43///     }
44/// }
45/// ```
46pub trait Embeddable: Clone {
47    /// Get the embedding vector for this sample.
48    fn embedding(&self) -> &[f32];
49
50    /// Create a new sample from an embedding, using reference for metadata.
51    fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self;
52
53    /// Get the embedding dimension.
54    fn embedding_dim(&self) -> usize {
55        self.embedding().len()
56    }
57}
58
59// ============================================================================
60// MixUp Configuration
61// ============================================================================
62
63/// Configuration for `MixUp` augmentation.
64#[derive(Debug, Clone)]
65pub struct MixUpConfig {
66    /// Alpha parameter for Beta distribution (higher = more uniform mixing).
67    pub alpha: f32,
68    /// Whether to mix samples from different classes only.
69    pub cross_class_only: bool,
70    /// Minimum lambda value (avoids near-copies).
71    pub lambda_min: f32,
72    /// Maximum lambda value.
73    pub lambda_max: f32,
74}
75
76impl Default for MixUpConfig {
77    fn default() -> Self {
78        Self {
79            alpha: 0.4,
80            cross_class_only: false,
81            lambda_min: 0.2,
82            lambda_max: 0.8,
83        }
84    }
85}
86
87impl MixUpConfig {
88    /// Create a new `MixUp` configuration.
89    #[must_use]
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Set the alpha parameter.
95    #[must_use]
96    pub fn with_alpha(mut self, alpha: f32) -> Self {
97        self.alpha = alpha.max(0.01);
98        self
99    }
100
101    /// Enable cross-class mixing only.
102    #[must_use]
103    pub fn with_cross_class_only(mut self, enabled: bool) -> Self {
104        self.cross_class_only = enabled;
105        self
106    }
107
108    /// Set the lambda range.
109    #[must_use]
110    pub fn with_lambda_range(mut self, min: f32, max: f32) -> Self {
111        self.lambda_min = min.clamp(0.0, 1.0);
112        self.lambda_max = max.clamp(0.0, 1.0);
113        if self.lambda_min > self.lambda_max {
114            std::mem::swap(&mut self.lambda_min, &mut self.lambda_max);
115        }
116        self
117    }
118}
119
120// ============================================================================
121// Simple RNG for deterministic mixing
122// ============================================================================
123
124/// Linear Congruential Generator for deterministic randomness.
125#[derive(Debug, Clone)]
126struct SimpleRng {
127    state: u64,
128}
129
130impl SimpleRng {
131    fn new(seed: u64) -> Self {
132        Self {
133            state: seed.wrapping_add(1),
134        }
135    }
136
137    fn next_u64(&mut self) -> u64 {
138        // LCG parameters from Numerical Recipes
139        self.state = self.state.wrapping_mul(6_364_136_223_846_793_005);
140        self.state = self.state.wrapping_add(1_442_695_040_888_963_407);
141        self.state
142    }
143
144    fn next_f32(&mut self) -> f32 {
145        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
146    }
147
148    fn next_usize(&mut self, max: usize) -> usize {
149        if max == 0 {
150            return 0;
151        }
152        (self.next_u64() as usize) % max
153    }
154
155    /// Sample from Beta(alpha, alpha) using inverse transform.
156    /// Simplified approximation for alpha in [0.1, 2.0].
157    fn beta(&mut self, alpha: f32) -> f32 {
158        // Use Kumaraswamy distribution as approximation
159        // For alpha near 1, this gives uniform-like behavior
160        // For alpha < 1, gives U-shaped; for alpha > 1, gives bell-shaped
161        let u = self.next_f32().max(0.001);
162        let a = alpha;
163        let b = alpha;
164
165        // Kumaraswamy CDF inverse approximation
166        let x = (1.0 - (1.0 - u).powf(1.0 / b)).powf(1.0 / a);
167        x.clamp(0.0, 1.0)
168    }
169}
170
171// ============================================================================
172// MixUp Generator
173// ============================================================================
174
175/// `MixUp` synthetic data generator.
176///
177/// Creates synthetic samples by interpolating between pairs of samples
178/// in embedding space: x' = λ*x1 + (1-λ)*x2
179///
180/// # Example
181///
182/// ```
183/// use aprender::synthetic::mixup::{MixUpGenerator, MixUpConfig, Embeddable};
184/// use aprender::synthetic::{SyntheticGenerator, SyntheticConfig};
185///
186/// #[derive(Clone, Debug)]
187/// struct Sample {
188///     data: Vec<f32>,
189/// }
190///
191/// impl Embeddable for Sample {
192///     fn embedding(&self) -> &[f32] { &self.data }
193///     fn from_embedding(embedding: Vec<f32>, _: &Self) -> Self {
194///         Sample { data: embedding }
195///     }
196/// }
197///
198/// let gen = MixUpGenerator::<Sample>::new();
199/// let seeds = vec![
200///     Sample { data: vec![1.0, 0.0] },
201///     Sample { data: vec![0.0, 1.0] },
202/// ];
203/// let config = SyntheticConfig::default().with_augmentation_ratio(1.0);
204/// let mixed = gen.generate(&seeds, &config).expect("generation failed");
205/// ```
206#[derive(Debug, Clone)]
207pub struct MixUpGenerator<T: Embeddable> {
208    config: MixUpConfig,
209    _phantom: std::marker::PhantomData<T>,
210}
211
212impl<T: Embeddable> MixUpGenerator<T> {
213    /// Create a new `MixUp` generator with default configuration.
214    #[must_use]
215    pub fn new() -> Self {
216        Self {
217            config: MixUpConfig::default(),
218            _phantom: std::marker::PhantomData,
219        }
220    }
221
222    /// Create with custom configuration.
223    #[must_use]
224    pub fn with_config(mut self, config: MixUpConfig) -> Self {
225        self.config = config;
226        self
227    }
228
229    /// Interpolate two embeddings.
230    fn interpolate(e1: &[f32], e2: &[f32], lambda: f32) -> Vec<f32> {
231        e1.iter()
232            .zip(e2.iter())
233            .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
234            .collect()
235    }
236
237    /// Compute cosine similarity between two embeddings.
238    fn cosine_similarity(e1: &[f32], e2: &[f32]) -> f32 {
239        if e1.len() != e2.len() || e1.is_empty() {
240            return 0.0;
241        }
242
243        let dot: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
244        let norm1: f32 = e1.iter().map(|x| x * x).sum::<f32>().sqrt();
245        let norm2: f32 = e2.iter().map(|x| x * x).sum::<f32>().sqrt();
246
247        if norm1 < f32::EPSILON || norm2 < f32::EPSILON {
248            return 0.0;
249        }
250
251        (dot / (norm1 * norm2)).clamp(-1.0, 1.0)
252    }
253
254    /// Compute embedding variance as diversity measure.
255    fn embedding_variance(embeddings: &[Vec<f32>]) -> f32 {
256        if embeddings.is_empty() || embeddings[0].is_empty() {
257            return 0.0;
258        }
259
260        let dim = embeddings[0].len();
261        let n = embeddings.len() as f32;
262
263        // Compute mean embedding
264        let mut mean = vec![0.0; dim];
265        for emb in embeddings {
266            for (i, &v) in emb.iter().enumerate() {
267                mean[i] += v / n;
268            }
269        }
270
271        // Compute variance
272        let mut var = 0.0;
273        for emb in embeddings {
274            for (i, &v) in emb.iter().enumerate() {
275                var += (v - mean[i]).powi(2);
276            }
277        }
278
279        var / (n * dim as f32)
280    }
281}
282
283impl<T: Embeddable> Default for MixUpGenerator<T> {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl<T: Embeddable + std::fmt::Debug> SyntheticGenerator for MixUpGenerator<T> {
290    type Input = T;
291    type Output = T;
292
293    fn generate(&self, seeds: &[T], config: &SyntheticConfig) -> Result<Vec<T>> {
294        if seeds.len() < 2 {
295            return Ok(Vec::new());
296        }
297
298        let target = config.target_count(seeds.len());
299        let mut results = Vec::with_capacity(target);
300        let mut rng = SimpleRng::new(config.seed);
301
302        let mut attempts = 0;
303        let max_attempts = target * config.max_attempts;
304
305        while results.len() < target && attempts < max_attempts {
306            attempts += 1;
307
308            // Select two different samples
309            let i = rng.next_usize(seeds.len());
310            let mut j = rng.next_usize(seeds.len());
311            if j == i {
312                j = (j + 1) % seeds.len();
313            }
314
315            // Sample lambda from Beta distribution
316            let raw_lambda = rng.beta(self.config.alpha);
317            let lambda = self.config.lambda_min
318                + raw_lambda * (self.config.lambda_max - self.config.lambda_min);
319
320            // Interpolate embeddings
321            let e1 = seeds[i].embedding();
322            let e2 = seeds[j].embedding();
323
324            if e1.len() != e2.len() || e1.is_empty() {
325                continue;
326            }
327
328            let mixed_embedding = Self::interpolate(e1, e2, lambda);
329
330            // Create mixed sample using first seed as reference
331            let mixed = T::from_embedding(mixed_embedding, &seeds[i]);
332
333            // Quality check
334            let quality = self.quality_score(&mixed, &seeds[i]);
335            if config.meets_quality(quality) {
336                results.push(mixed);
337            }
338        }
339
340        Ok(results)
341    }
342
343    fn quality_score(&self, generated: &T, seed: &T) -> f32 {
344        // Quality is based on embedding similarity to seed
345        let sim = Self::cosine_similarity(generated.embedding(), seed.embedding());
346
347        // Transform to quality score: similarity should be moderate (not too high = copy)
348        // Ideal range: 0.3 to 0.9
349        let quality = if sim < 0.3 {
350            sim / 0.3 * 0.5 // Low similarity = lower quality
351        } else if sim > 0.9 {
352            1.0 - (sim - 0.9) / 0.1 * 0.5 // Too similar = lower quality
353        } else {
354            0.5 + (sim - 0.3) / 0.6 * 0.5 // Sweet spot
355        };
356
357        quality.clamp(0.0, 1.0)
358    }
359
360    fn diversity_score(&self, batch: &[T]) -> f32 {
361        if batch.is_empty() {
362            return 0.0;
363        }
364
365        let embeddings: Vec<Vec<f32>> = batch.iter().map(|s| s.embedding().to_vec()).collect();
366
367        // Use embedding variance as diversity measure
368        let variance = Self::embedding_variance(&embeddings);
369
370        // Normalize to [0, 1] assuming typical variance range
371        (variance * 10.0).min(1.0)
372    }
373}
374
375// ============================================================================
376// Tests
377// ============================================================================
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    // Test sample type
384    #[derive(Clone, Debug, PartialEq)]
385    struct TestSample {
386        embedding: Vec<f32>,
387        label: i32,
388    }
389
390    impl TestSample {
391        fn new(embedding: Vec<f32>, label: i32) -> Self {
392            Self { embedding, label }
393        }
394    }
395
396    impl Embeddable for TestSample {
397        fn embedding(&self) -> &[f32] {
398            &self.embedding
399        }
400
401        fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self {
402            Self {
403                embedding,
404                label: reference.label, // Keep reference label
405            }
406        }
407    }
408
409    // ========================================================================
410    // MixUpConfig Tests
411    // ========================================================================
412
413    #[test]
414    fn test_config_default() {
415        let config = MixUpConfig::default();
416        assert!((config.alpha - 0.4).abs() < f32::EPSILON);
417        assert!(!config.cross_class_only);
418        assert!((config.lambda_min - 0.2).abs() < f32::EPSILON);
419        assert!((config.lambda_max - 0.8).abs() < f32::EPSILON);
420    }
421
422    #[test]
423    fn test_config_with_alpha() {
424        let config = MixUpConfig::new().with_alpha(1.0);
425        assert!((config.alpha - 1.0).abs() < f32::EPSILON);
426
427        // Alpha should be at least 0.01
428        let config = MixUpConfig::new().with_alpha(-1.0);
429        assert!((config.alpha - 0.01).abs() < f32::EPSILON);
430    }
431
432    #[test]
433    fn test_config_with_cross_class() {
434        let config = MixUpConfig::new().with_cross_class_only(true);
435        assert!(config.cross_class_only);
436    }
437
438    #[test]
439    fn test_config_with_lambda_range() {
440        let config = MixUpConfig::new().with_lambda_range(0.3, 0.7);
441        assert!((config.lambda_min - 0.3).abs() < f32::EPSILON);
442        assert!((config.lambda_max - 0.7).abs() < f32::EPSILON);
443
444        // Should swap if min > max
445        let config = MixUpConfig::new().with_lambda_range(0.8, 0.2);
446        assert!((config.lambda_min - 0.2).abs() < f32::EPSILON);
447        assert!((config.lambda_max - 0.8).abs() < f32::EPSILON);
448
449        // Should clamp to [0, 1]
450        let config = MixUpConfig::new().with_lambda_range(-0.5, 1.5);
451        assert!((config.lambda_min - 0.0).abs() < f32::EPSILON);
452        assert!((config.lambda_max - 1.0).abs() < f32::EPSILON);
453    }
454
455    // ========================================================================
456    // SimpleRng Tests
457    // ========================================================================
458
459    #[test]
460    fn test_rng_deterministic() {
461        let mut rng1 = SimpleRng::new(42);
462        let mut rng2 = SimpleRng::new(42);
463
464        for _ in 0..10 {
465            assert_eq!(rng1.next_u64(), rng2.next_u64());
466        }
467    }
468
469    #[test]
470    fn test_rng_different_seeds() {
471        let mut rng1 = SimpleRng::new(42);
472        let mut rng2 = SimpleRng::new(43);
473
474        assert_ne!(rng1.next_u64(), rng2.next_u64());
475    }
476
477    #[test]
478    fn test_rng_f32_range() {
479        let mut rng = SimpleRng::new(12345);
480
481        for _ in 0..100 {
482            let v = rng.next_f32();
483            assert!((0.0..=1.0).contains(&v));
484        }
485    }
486
487    #[test]
488    fn test_rng_usize_range() {
489        let mut rng = SimpleRng::new(12345);
490
491        for _ in 0..100 {
492            let v = rng.next_usize(10);
493            assert!(v < 10);
494        }
495
496        // Edge case: max = 0
497        assert_eq!(rng.next_usize(0), 0);
498    }
499
500    #[test]
501    fn test_rng_beta_range() {
502        let mut rng = SimpleRng::new(12345);
503
504        for alpha in [0.1, 0.5, 1.0, 2.0] {
505            for _ in 0..50 {
506                let v = rng.beta(alpha);
507                assert!((0.0..=1.0).contains(&v));
508            }
509        }
510    }
511
512    // ========================================================================
513    // MixUpGenerator Tests
514    // ========================================================================
515
516    #[test]
517    fn test_generator_new() {
518        let gen = MixUpGenerator::<TestSample>::new();
519        assert!((gen.config.alpha - 0.4).abs() < f32::EPSILON);
520    }
521
522    #[test]
523    fn test_generator_with_config() {
524        let config = MixUpConfig::new().with_alpha(1.0);
525        let gen = MixUpGenerator::<TestSample>::new().with_config(config);
526        assert!((gen.config.alpha - 1.0).abs() < f32::EPSILON);
527    }
528
529    #[test]
530    fn test_generator_default() {
531        let gen = MixUpGenerator::<TestSample>::default();
532        assert!((gen.config.alpha - 0.4).abs() < f32::EPSILON);
533    }
534
535    #[test]
536    fn test_interpolate() {
537        let e1 = vec![1.0, 0.0, 0.0];
538        let e2 = vec![0.0, 1.0, 0.0];
539
540        // lambda = 0.5 should give midpoint
541        let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 0.5);
542        assert!((result[0] - 0.5).abs() < f32::EPSILON);
543        assert!((result[1] - 0.5).abs() < f32::EPSILON);
544        assert!((result[2] - 0.0).abs() < f32::EPSILON);
545
546        // lambda = 1.0 should give e1
547        let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 1.0);
548        assert!((result[0] - 1.0).abs() < f32::EPSILON);
549        assert!((result[1] - 0.0).abs() < f32::EPSILON);
550
551        // lambda = 0.0 should give e2
552        let result = MixUpGenerator::<TestSample>::interpolate(&e1, &e2, 0.0);
553        assert!((result[0] - 0.0).abs() < f32::EPSILON);
554        assert!((result[1] - 1.0).abs() < f32::EPSILON);
555    }
556
557    #[test]
558    fn test_cosine_similarity() {
559        // Identical vectors
560        let e1 = vec![1.0, 0.0, 0.0];
561        let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e1);
562        assert!((sim - 1.0).abs() < 0.001);
563
564        // Orthogonal vectors
565        let e2 = vec![0.0, 1.0, 0.0];
566        let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e2);
567        assert!(sim.abs() < 0.001);
568
569        // Opposite vectors
570        let e3 = vec![-1.0, 0.0, 0.0];
571        let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e3);
572        assert!((sim - (-1.0)).abs() < 0.001);
573
574        // Empty vectors
575        let empty: Vec<f32> = vec![];
576        let sim = MixUpGenerator::<TestSample>::cosine_similarity(&empty, &empty);
577        assert!((sim - 0.0).abs() < f32::EPSILON);
578
579        // Different lengths
580        let e4 = vec![1.0, 0.0];
581        let sim = MixUpGenerator::<TestSample>::cosine_similarity(&e1, &e4);
582        assert!((sim - 0.0).abs() < f32::EPSILON);
583    }
584
585    #[test]
586    fn test_embedding_variance() {
587        // Single embedding
588        let embeddings = vec![vec![1.0, 0.0]];
589        let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
590        assert!((var - 0.0).abs() < f32::EPSILON);
591
592        // Identical embeddings
593        let embeddings = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
594        let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
595        assert!((var - 0.0).abs() < f32::EPSILON);
596
597        // Different embeddings
598        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
599        let var = MixUpGenerator::<TestSample>::embedding_variance(&embeddings);
600        assert!(var > 0.0);
601
602        // Empty
603        let var = MixUpGenerator::<TestSample>::embedding_variance(&[]);
604        assert!((var - 0.0).abs() < f32::EPSILON);
605    }
606
607    #[test]
608    fn test_generate_basic() {
609        let gen = MixUpGenerator::<TestSample>::new();
610        let seeds = vec![
611            TestSample::new(vec![1.0, 0.0, 0.0], 0),
612            TestSample::new(vec![0.0, 1.0, 0.0], 1),
613            TestSample::new(vec![0.0, 0.0, 1.0], 2),
614        ];
615
616        let config = SyntheticConfig::default()
617            .with_augmentation_ratio(1.0)
618            .with_quality_threshold(0.1);
619
620        let result = gen.generate(&seeds, &config).expect("generation failed");
621
622        // Should generate some samples
623        assert!(!result.is_empty());
624
625        // All samples should have valid embeddings
626        for sample in &result {
627            assert_eq!(sample.embedding().len(), 3);
628        }
629    }
630
631    #[test]
632    fn test_generate_insufficient_seeds() {
633        let gen = MixUpGenerator::<TestSample>::new();
634
635        // 0 seeds
636        let result = gen
637            .generate(&[], &SyntheticConfig::default())
638            .expect("should succeed");
639        assert!(result.is_empty());
640
641        // 1 seed - need at least 2 for mixing
642        let seeds = vec![TestSample::new(vec![1.0, 0.0], 0)];
643        let result = gen
644            .generate(&seeds, &SyntheticConfig::default())
645            .expect("should succeed");
646        assert!(result.is_empty());
647    }
648
649    #[test]
650    fn test_generate_respects_target() {
651        let gen = MixUpGenerator::<TestSample>::new();
652        let seeds = vec![
653            TestSample::new(vec![1.0, 0.0], 0),
654            TestSample::new(vec![0.0, 1.0], 1),
655        ];
656
657        let config = SyntheticConfig::default()
658            .with_augmentation_ratio(2.0) // Target: 4 samples
659            .with_quality_threshold(0.0); // Accept all
660
661        let result = gen.generate(&seeds, &config).expect("generation failed");
662
663        // Should generate up to target (may be fewer due to quality)
664        assert!(result.len() <= 4);
665    }
666
667    #[test]
668    fn test_generate_deterministic() {
669        let gen = MixUpGenerator::<TestSample>::new();
670        let seeds = vec![
671            TestSample::new(vec![1.0, 0.0, 0.0], 0),
672            TestSample::new(vec![0.0, 1.0, 0.0], 1),
673        ];
674
675        let config = SyntheticConfig::default()
676            .with_augmentation_ratio(1.0)
677            .with_quality_threshold(0.1)
678            .with_seed(12345);
679
680        let result1 = gen.generate(&seeds, &config).expect("generation failed");
681        let result2 = gen.generate(&seeds, &config).expect("generation failed");
682
683        assert_eq!(result1.len(), result2.len());
684        for (r1, r2) in result1.iter().zip(result2.iter()) {
685            assert_eq!(r1.embedding(), r2.embedding());
686        }
687    }
688
689    #[test]
690    fn test_quality_score() {
691        let gen = MixUpGenerator::<TestSample>::new();
692
693        let seed = TestSample::new(vec![1.0, 0.0, 0.0], 0);
694
695        // Identical sample - should have moderate quality (not too high)
696        let identical = TestSample::new(vec![1.0, 0.0, 0.0], 0);
697        let score = gen.quality_score(&identical, &seed);
698        assert!(score < 1.0); // Not perfect since it's too similar
699
700        // Somewhat similar sample - should have good quality
701        let similar = TestSample::new(vec![0.8, 0.2, 0.0], 0);
702        let score = gen.quality_score(&similar, &seed);
703        assert!(score > 0.3);
704
705        // Very different sample - lower quality
706        let different = TestSample::new(vec![0.0, 0.0, 1.0], 0);
707        let score = gen.quality_score(&different, &seed);
708        assert!((0.0..=1.0).contains(&score));
709    }
710
711    #[test]
712    fn test_diversity_score() {
713        let gen = MixUpGenerator::<TestSample>::new();
714
715        // Empty batch
716        let score = gen.diversity_score(&[]);
717        assert!((score - 0.0).abs() < f32::EPSILON);
718
719        // Single sample
720        let single = vec![TestSample::new(vec![1.0, 0.0], 0)];
721        let score = gen.diversity_score(&single);
722        assert!((score - 0.0).abs() < f32::EPSILON);
723
724        // Diverse batch
725        let diverse = vec![
726            TestSample::new(vec![1.0, 0.0], 0),
727            TestSample::new(vec![0.0, 1.0], 1),
728            TestSample::new(vec![-1.0, 0.0], 2),
729        ];
730        let score = gen.diversity_score(&diverse);
731        assert!(score > 0.0);
732
733        // Homogeneous batch
734        let homogeneous = vec![
735            TestSample::new(vec![1.0, 0.0], 0),
736            TestSample::new(vec![1.0, 0.0], 0),
737            TestSample::new(vec![1.0, 0.0], 0),
738        ];
739        let homo_score = gen.diversity_score(&homogeneous);
740        assert!(homo_score < score);
741    }
742
743    #[test]
744    fn test_generate_mixed_embeddings() {
745        let gen = MixUpGenerator::<TestSample>::new();
746        let seeds = vec![
747            TestSample::new(vec![1.0, 0.0], 0),
748            TestSample::new(vec![0.0, 1.0], 1),
749        ];
750
751        let config = SyntheticConfig::default()
752            .with_augmentation_ratio(2.0)
753            .with_quality_threshold(0.0)
754            .with_seed(42);
755
756        let result = gen.generate(&seeds, &config).expect("generation failed");
757
758        // Mixed samples should have embeddings between the two seeds
759        for sample in &result {
760            let e = sample.embedding();
761            // Each component should be in [0, 1] (interpolation of unit vectors)
762            assert!((0.0..=1.0).contains(&e[0]));
763            assert!((0.0..=1.0).contains(&e[1]));
764        }
765    }
766
767    #[test]
768    fn test_embeddable_trait() {
769        let sample = TestSample::new(vec![1.0, 2.0, 3.0], 5);
770
771        assert_eq!(sample.embedding(), &[1.0, 2.0, 3.0]);
772        assert_eq!(sample.embedding_dim(), 3);
773
774        let new_emb = vec![4.0, 5.0, 6.0];
775        let new_sample = TestSample::from_embedding(new_emb.clone(), &sample);
776        assert_eq!(new_sample.embedding(), &[4.0, 5.0, 6.0]);
777        assert_eq!(new_sample.label, 5); // Kept from reference
778    }
779
780    // ========================================================================
781    // Integration Tests
782    // ========================================================================
783
784    #[test]
785    fn test_full_mixup_pipeline() {
786        let gen = MixUpGenerator::new().with_config(
787            MixUpConfig::new()
788                .with_alpha(0.5)
789                .with_lambda_range(0.3, 0.7),
790        );
791
792        // Create samples with distinct embeddings
793        let seeds = vec![
794            TestSample::new(vec![1.0, 0.0, 0.0, 0.0], 0),
795            TestSample::new(vec![0.0, 1.0, 0.0, 0.0], 1),
796            TestSample::new(vec![0.0, 0.0, 1.0, 0.0], 2),
797            TestSample::new(vec![0.0, 0.0, 0.0, 1.0], 3),
798        ];
799
800        let config = SyntheticConfig::default()
801            .with_augmentation_ratio(2.0)
802            .with_quality_threshold(0.2)
803            .with_seed(9999);
804
805        let synthetic = gen.generate(&seeds, &config).expect("generation failed");
806
807        // Verify generated samples
808        for sample in &synthetic {
809            // Should have correct dimension
810            assert_eq!(sample.embedding_dim(), 4);
811
812            // Quality should meet threshold
813            let quality = gen.quality_score(sample, &seeds[0]);
814            assert!(quality >= 0.0);
815        }
816
817        // Diversity should be reasonable
818        if !synthetic.is_empty() {
819            let diversity = gen.diversity_score(&synthetic);
820            assert!((0.0..=1.0).contains(&diversity));
821        }
822    }
823}