use super::{SyntheticConfig, SyntheticGenerator};
use crate::error::Result;
pub trait Embeddable: Clone {
fn embedding(&self) -> &[f32];
fn from_embedding(embedding: Vec<f32>, reference: &Self) -> Self;
fn embedding_dim(&self) -> usize {
self.embedding().len()
}
}
#[derive(Debug, Clone)]
pub struct MixUpConfig {
pub alpha: f32,
pub cross_class_only: bool,
pub lambda_min: f32,
pub lambda_max: f32,
}
impl Default for MixUpConfig {
fn default() -> Self {
Self {
alpha: 0.4,
cross_class_only: false,
lambda_min: 0.2,
lambda_max: 0.8,
}
}
}
impl MixUpConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha.max(0.01);
self
}
#[must_use]
pub fn with_cross_class_only(mut self, enabled: bool) -> Self {
self.cross_class_only = enabled;
self
}
#[must_use]
pub fn with_lambda_range(mut self, min: f32, max: f32) -> Self {
self.lambda_min = min.clamp(0.0, 1.0);
self.lambda_max = max.clamp(0.0, 1.0);
if self.lambda_min > self.lambda_max {
std::mem::swap(&mut self.lambda_min, &mut self.lambda_max);
}
self
}
}
#[derive(Debug, Clone)]
struct SimpleRng {
state: u64,
}
impl SimpleRng {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_mul(6_364_136_223_846_793_005);
self.state = self.state.wrapping_add(1_442_695_040_888_963_407);
self.state
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
}
fn next_usize(&mut self, max: usize) -> usize {
if max == 0 {
return 0;
}
(self.next_u64() as usize) % max
}
fn beta(&mut self, alpha: f32) -> f32 {
let u = self.next_f32().max(0.001);
let a = alpha;
let b = alpha;
let x = (1.0 - (1.0 - u).powf(1.0 / b)).powf(1.0 / a);
x.clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone)]
pub struct MixUpGenerator<T: Embeddable> {
config: MixUpConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Embeddable> MixUpGenerator<T> {
#[must_use]
pub fn new() -> Self {
Self {
config: MixUpConfig::default(),
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_config(mut self, config: MixUpConfig) -> Self {
self.config = config;
self
}
fn interpolate(e1: &[f32], e2: &[f32], lambda: f32) -> Vec<f32> {
e1.iter()
.zip(e2.iter())
.map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
.collect()
}
fn cosine_similarity(e1: &[f32], e2: &[f32]) -> f32 {
crate::nn::functional::cosine_similarity_slice(e1, e2)
}
fn embedding_variance(embeddings: &[Vec<f32>]) -> f32 {
if embeddings.is_empty() || embeddings[0].is_empty() {
return 0.0;
}
let dim = embeddings[0].len();
let n = embeddings.len() as f32;
let mut mean = vec![0.0; dim];
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
mean[i] += v / n;
}
}
let mut var = 0.0;
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
var += (v - mean[i]).powi(2);
}
}
var / (n * dim as f32)
}
}
impl<T: Embeddable> Default for MixUpGenerator<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Embeddable + std::fmt::Debug> SyntheticGenerator for MixUpGenerator<T> {
type Input = T;
type Output = T;
fn generate(&self, seeds: &[T], config: &SyntheticConfig) -> Result<Vec<T>> {
if seeds.len() < 2 {
return Ok(Vec::new());
}
let target = config.target_count(seeds.len());
let mut results = Vec::with_capacity(target);
let mut rng = SimpleRng::new(config.seed);
let mut attempts = 0;
let max_attempts = target * config.max_attempts;
while results.len() < target && attempts < max_attempts {
attempts += 1;
let i = rng.next_usize(seeds.len());
let mut j = rng.next_usize(seeds.len());
if j == i {
j = (j + 1) % seeds.len();
}
let raw_lambda = rng.beta(self.config.alpha);
let lambda = self.config.lambda_min
+ raw_lambda * (self.config.lambda_max - self.config.lambda_min);
let e1 = seeds[i].embedding();
let e2 = seeds[j].embedding();
if e1.len() != e2.len() || e1.is_empty() {
continue;
}
let mixed_embedding = Self::interpolate(e1, e2, lambda);
let mixed = T::from_embedding(mixed_embedding, &seeds[i]);
let quality = self.quality_score(&mixed, &seeds[i]);
if config.meets_quality(quality) {
results.push(mixed);
}
}
Ok(results)
}
fn quality_score(&self, generated: &T, seed: &T) -> f32 {
let sim = Self::cosine_similarity(generated.embedding(), seed.embedding());
let quality = if sim < 0.3 {
sim / 0.3 * 0.5 } else if sim > 0.9 {
1.0 - (sim - 0.9) / 0.1 * 0.5 } else {
0.5 + (sim - 0.3) / 0.6 * 0.5 };
quality.clamp(0.0, 1.0)
}
fn diversity_score(&self, batch: &[T]) -> f32 {
if batch.is_empty() {
return 0.0;
}
let embeddings: Vec<Vec<f32>> = batch.iter().map(|s| s.embedding().to_vec()).collect();
let variance = Self::embedding_variance(&embeddings);
(variance * 10.0).min(1.0)
}
}
#[cfg(test)]
#[path = "mixup_tests.rs"]
mod tests;