datasynth-core 2.4.0

Core domain models, traits, and distributions for synthetic enterprise data generation
Documentation
//! Hybrid generator that blends rule-based and diffusion-generated data.
//!
//! Supports multiple blending strategies:
//! - **Interpolate**: weighted average of corresponding samples
//! - **Select**: randomly pick source per record based on weight
//! - **Ensemble**: use diffusion for specified columns, rules for others

use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Uniform};

/// Strategy for blending rule-based and diffusion-generated data.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BlendStrategy {
    /// Weighted average of corresponding samples: output = (1 - weight) * rule + weight * diffusion
    Interpolate,
    /// Randomly select source per record: each row comes entirely from one source
    Select,
    /// Column-level blending: specified columns use diffusion, others use rule-based
    Ensemble,
}

/// A hybrid generator that blends rule-based and diffusion-generated data.
///
/// The `weight` parameter controls the balance:
/// - `0.0` = pure rule-based output
/// - `1.0` = pure diffusion output
/// - Values between blend according to the chosen strategy
#[derive(Debug, Clone)]
pub struct HybridGenerator {
    /// Blending weight: 0.0 (pure rule-based) to 1.0 (pure diffusion).
    weight: f64,
}

impl HybridGenerator {
    /// Create a new hybrid generator with the given weight.
    ///
    /// The weight is clamped to [0.0, 1.0].
    pub fn new(weight: f64) -> Self {
        Self {
            weight: weight.clamp(0.0, 1.0),
        }
    }

    /// Get the current blending weight.
    pub fn weight(&self) -> f64 {
        self.weight
    }

    /// Blend rule-based and diffusion-generated data using the specified strategy.
    ///
    /// Both input slices must have the same number of rows and columns.
    /// For `Ensemble` strategy, all columns will use the weight for interpolation;
    /// use `blend_ensemble` for column-level control.
    ///
    /// # Arguments
    /// * `rule_based` - Data generated by rule-based methods
    /// * `diffusion` - Data generated by diffusion methods
    /// * `strategy` - How to blend the two sources
    /// * `seed` - Random seed for deterministic blending
    pub fn blend(
        &self,
        rule_based: &[Vec<f64>],
        diffusion: &[Vec<f64>],
        strategy: BlendStrategy,
        seed: u64,
    ) -> Vec<Vec<f64>> {
        let n_rows = rule_based.len().min(diffusion.len());
        if n_rows == 0 {
            return vec![];
        }

        match strategy {
            BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
            BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
            BlendStrategy::Ensemble => {
                // Without specific column indices, fall back to interpolation
                self.blend_interpolate(rule_based, diffusion, n_rows)
            }
        }
    }

    /// Blend using column-level ensemble: specified columns use diffusion data,
    /// remaining columns use rule-based data.
    ///
    /// # Arguments
    /// * `rule_based` - Data generated by rule-based methods
    /// * `diffusion` - Data generated by diffusion methods
    /// * `diffusion_columns` - Column indices that should use diffusion output
    pub fn blend_ensemble(
        &self,
        rule_based: &[Vec<f64>],
        diffusion: &[Vec<f64>],
        diffusion_columns: &[usize],
    ) -> Vec<Vec<f64>> {
        let n_rows = rule_based.len().min(diffusion.len());
        if n_rows == 0 {
            return vec![];
        }

        (0..n_rows)
            .map(|i| {
                let rule_row = &rule_based[i];
                let diff_row = &diffusion[i];
                let n_cols = rule_row.len().min(diff_row.len());

                (0..n_cols)
                    .map(|j| {
                        if diffusion_columns.contains(&j) {
                            diff_row[j]
                        } else {
                            rule_row[j]
                        }
                    })
                    .collect()
            })
            .collect()
    }

    /// Interpolate: weighted average of each element.
    fn blend_interpolate(
        &self,
        rule_based: &[Vec<f64>],
        diffusion: &[Vec<f64>],
        n_rows: usize,
    ) -> Vec<Vec<f64>> {
        let w = self.weight;
        (0..n_rows)
            .map(|i| {
                let rule_row = &rule_based[i];
                let diff_row = &diffusion[i];
                let n_cols = rule_row.len().min(diff_row.len());
                (0..n_cols)
                    .map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
                    .collect()
            })
            .collect()
    }

    /// Select: randomly choose source per row.
    fn blend_select(
        &self,
        rule_based: &[Vec<f64>],
        diffusion: &[Vec<f64>],
        n_rows: usize,
        seed: u64,
    ) -> Vec<Vec<f64>> {
        let mut rng = ChaCha8Rng::seed_from_u64(seed);
        let uniform = Uniform::new(0.0_f64, 1.0).expect("valid uniform params");

        (0..n_rows)
            .map(|i| {
                let roll: f64 = uniform.sample(&mut rng);
                if roll < self.weight {
                    diffusion[i].clone()
                } else {
                    rule_based[i].clone()
                }
            })
            .collect()
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use super::*;

    #[test]
    fn test_interpolation_produces_blended_output() {
        let gen = HybridGenerator::new(0.5);
        let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
        let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];

        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
        assert_eq!(blended.len(), 2);
        // 0.5 * 10 + 0.5 * 20 = 15
        assert!((blended[0][0] - 15.0).abs() < 1e-10);
        assert!((blended[0][1] - 30.0).abs() < 1e-10);
        assert!((blended[1][0] - 40.0).abs() < 1e-10);
        assert!((blended[1][1] - 50.0).abs() < 1e-10);
    }

    #[test]
    fn test_select_picks_from_both_sources() {
        let gen = HybridGenerator::new(0.5);
        let rules = vec![vec![0.0]; 1000];
        let diffusion = vec![vec![1.0]; 1000];

        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
        assert_eq!(blended.len(), 1000);

        let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
        let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();

        // Both sources should be represented
        assert!(
            count_diffusion > 100,
            "Expected diffusion picks, got {}",
            count_diffusion
        );
        assert!(
            count_rule > 100,
            "Expected rule-based picks, got {}",
            count_rule
        );
    }

    #[test]
    fn test_ensemble_uses_correct_columns() {
        let gen = HybridGenerator::new(0.5);
        let rules = vec![vec![1.0, 2.0, 3.0]];
        let diffusion = vec![vec![10.0, 20.0, 30.0]];
        let diffusion_cols = vec![1]; // Only column 1 from diffusion

        let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
        assert_eq!(blended.len(), 1);
        assert!(
            (blended[0][0] - 1.0).abs() < 1e-10,
            "Column 0 should be rule-based"
        );
        assert!(
            (blended[0][1] - 20.0).abs() < 1e-10,
            "Column 1 should be diffusion"
        );
        assert!(
            (blended[0][2] - 3.0).abs() < 1e-10,
            "Column 2 should be rule-based"
        );
    }

    #[test]
    fn test_weight_zero_returns_rule_based() {
        let gen = HybridGenerator::new(0.0);
        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];

        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
        for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
            for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
                assert!(
                    (r - b).abs() < 1e-10,
                    "weight=0 should return rule-based: {} vs {}",
                    r,
                    b
                );
            }
        }
    }

    #[test]
    fn test_weight_one_returns_diffusion() {
        let gen = HybridGenerator::new(1.0);
        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];

        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
        for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
            for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
                assert!(
                    (d - b).abs() < 1e-10,
                    "weight=1 should return diffusion: {} vs {}",
                    d,
                    b
                );
            }
        }
    }

    #[test]
    fn test_empty_inputs() {
        let gen = HybridGenerator::new(0.5);
        let empty: Vec<Vec<f64>> = vec![];

        let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
        assert!(result.is_empty());

        let result = gen.blend_ensemble(&empty, &empty, &[0]);
        assert!(result.is_empty());
    }

    #[test]
    fn test_weight_clamping() {
        let gen_low = HybridGenerator::new(-0.5);
        assert!((gen_low.weight() - 0.0).abs() < 1e-10);

        let gen_high = HybridGenerator::new(1.5);
        assert!((gen_high.weight() - 1.0).abs() < 1e-10);
    }
}