Skip to main content

datasynth_core/diffusion/
hybrid.rs

1//! Hybrid generator that blends rule-based and diffusion-generated data.
2//!
3//! Supports multiple blending strategies:
4//! - **Interpolate**: weighted average of corresponding samples
5//! - **Select**: randomly pick source per record based on weight
6//! - **Ensemble**: use diffusion for specified columns, rules for others
7
8use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10use rand_distr::{Distribution, Uniform};
11
12/// Strategy for blending rule-based and diffusion-generated data.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum BlendStrategy {
15    /// Weighted average of corresponding samples: output = (1 - weight) * rule + weight * diffusion
16    Interpolate,
17    /// Randomly select source per record: each row comes entirely from one source
18    Select,
19    /// Column-level blending: specified columns use diffusion, others use rule-based
20    Ensemble,
21}
22
23/// A hybrid generator that blends rule-based and diffusion-generated data.
24///
25/// The `weight` parameter controls the balance:
26/// - `0.0` = pure rule-based output
27/// - `1.0` = pure diffusion output
28/// - Values between blend according to the chosen strategy
29#[derive(Debug, Clone)]
30pub struct HybridGenerator {
31    /// Blending weight: 0.0 (pure rule-based) to 1.0 (pure diffusion).
32    weight: f64,
33}
34
35impl HybridGenerator {
36    /// Create a new hybrid generator with the given weight.
37    ///
38    /// The weight is clamped to [0.0, 1.0].
39    pub fn new(weight: f64) -> Self {
40        Self {
41            weight: weight.clamp(0.0, 1.0),
42        }
43    }
44
45    /// Get the current blending weight.
46    pub fn weight(&self) -> f64 {
47        self.weight
48    }
49
50    /// Blend rule-based and diffusion-generated data using the specified strategy.
51    ///
52    /// Both input slices must have the same number of rows and columns.
53    /// For `Ensemble` strategy, all columns will use the weight for interpolation;
54    /// use `blend_ensemble` for column-level control.
55    ///
56    /// # Arguments
57    /// * `rule_based` - Data generated by rule-based methods
58    /// * `diffusion` - Data generated by diffusion methods
59    /// * `strategy` - How to blend the two sources
60    /// * `seed` - Random seed for deterministic blending
61    pub fn blend(
62        &self,
63        rule_based: &[Vec<f64>],
64        diffusion: &[Vec<f64>],
65        strategy: BlendStrategy,
66        seed: u64,
67    ) -> Vec<Vec<f64>> {
68        let n_rows = rule_based.len().min(diffusion.len());
69        if n_rows == 0 {
70            return vec![];
71        }
72
73        match strategy {
74            BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
75            BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
76            BlendStrategy::Ensemble => {
77                // Without specific column indices, fall back to interpolation
78                self.blend_interpolate(rule_based, diffusion, n_rows)
79            }
80        }
81    }
82
83    /// Blend using column-level ensemble: specified columns use diffusion data,
84    /// remaining columns use rule-based data.
85    ///
86    /// # Arguments
87    /// * `rule_based` - Data generated by rule-based methods
88    /// * `diffusion` - Data generated by diffusion methods
89    /// * `diffusion_columns` - Column indices that should use diffusion output
90    pub fn blend_ensemble(
91        &self,
92        rule_based: &[Vec<f64>],
93        diffusion: &[Vec<f64>],
94        diffusion_columns: &[usize],
95    ) -> Vec<Vec<f64>> {
96        let n_rows = rule_based.len().min(diffusion.len());
97        if n_rows == 0 {
98            return vec![];
99        }
100
101        (0..n_rows)
102            .map(|i| {
103                let rule_row = &rule_based[i];
104                let diff_row = &diffusion[i];
105                let n_cols = rule_row.len().min(diff_row.len());
106
107                (0..n_cols)
108                    .map(|j| {
109                        if diffusion_columns.contains(&j) {
110                            diff_row[j]
111                        } else {
112                            rule_row[j]
113                        }
114                    })
115                    .collect()
116            })
117            .collect()
118    }
119
120    /// Interpolate: weighted average of each element.
121    fn blend_interpolate(
122        &self,
123        rule_based: &[Vec<f64>],
124        diffusion: &[Vec<f64>],
125        n_rows: usize,
126    ) -> Vec<Vec<f64>> {
127        let w = self.weight;
128        (0..n_rows)
129            .map(|i| {
130                let rule_row = &rule_based[i];
131                let diff_row = &diffusion[i];
132                let n_cols = rule_row.len().min(diff_row.len());
133                (0..n_cols)
134                    .map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
135                    .collect()
136            })
137            .collect()
138    }
139
140    /// Select: randomly choose source per row.
141    fn blend_select(
142        &self,
143        rule_based: &[Vec<f64>],
144        diffusion: &[Vec<f64>],
145        n_rows: usize,
146        seed: u64,
147    ) -> Vec<Vec<f64>> {
148        let mut rng = ChaCha8Rng::seed_from_u64(seed);
149        let uniform = Uniform::new(0.0_f64, 1.0);
150
151        (0..n_rows)
152            .map(|i| {
153                let roll: f64 = uniform.sample(&mut rng);
154                if roll < self.weight {
155                    diffusion[i].clone()
156                } else {
157                    rule_based[i].clone()
158                }
159            })
160            .collect()
161    }
162}
163
164#[cfg(test)]
165#[allow(clippy::unwrap_used)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_interpolation_produces_blended_output() {
171        let gen = HybridGenerator::new(0.5);
172        let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
173        let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];
174
175        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
176        assert_eq!(blended.len(), 2);
177        // 0.5 * 10 + 0.5 * 20 = 15
178        assert!((blended[0][0] - 15.0).abs() < 1e-10);
179        assert!((blended[0][1] - 30.0).abs() < 1e-10);
180        assert!((blended[1][0] - 40.0).abs() < 1e-10);
181        assert!((blended[1][1] - 50.0).abs() < 1e-10);
182    }
183
184    #[test]
185    fn test_select_picks_from_both_sources() {
186        let gen = HybridGenerator::new(0.5);
187        let rules = vec![vec![0.0]; 1000];
188        let diffusion = vec![vec![1.0]; 1000];
189
190        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
191        assert_eq!(blended.len(), 1000);
192
193        let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
194        let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();
195
196        // Both sources should be represented
197        assert!(
198            count_diffusion > 100,
199            "Expected diffusion picks, got {}",
200            count_diffusion
201        );
202        assert!(
203            count_rule > 100,
204            "Expected rule-based picks, got {}",
205            count_rule
206        );
207    }
208
209    #[test]
210    fn test_ensemble_uses_correct_columns() {
211        let gen = HybridGenerator::new(0.5);
212        let rules = vec![vec![1.0, 2.0, 3.0]];
213        let diffusion = vec![vec![10.0, 20.0, 30.0]];
214        let diffusion_cols = vec![1]; // Only column 1 from diffusion
215
216        let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
217        assert_eq!(blended.len(), 1);
218        assert!(
219            (blended[0][0] - 1.0).abs() < 1e-10,
220            "Column 0 should be rule-based"
221        );
222        assert!(
223            (blended[0][1] - 20.0).abs() < 1e-10,
224            "Column 1 should be diffusion"
225        );
226        assert!(
227            (blended[0][2] - 3.0).abs() < 1e-10,
228            "Column 2 should be rule-based"
229        );
230    }
231
232    #[test]
233    fn test_weight_zero_returns_rule_based() {
234        let gen = HybridGenerator::new(0.0);
235        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
236        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
237
238        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
239        for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
240            for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
241                assert!(
242                    (r - b).abs() < 1e-10,
243                    "weight=0 should return rule-based: {} vs {}",
244                    r,
245                    b
246                );
247            }
248        }
249    }
250
251    #[test]
252    fn test_weight_one_returns_diffusion() {
253        let gen = HybridGenerator::new(1.0);
254        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
255        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
256
257        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
258        for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
259            for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
260                assert!(
261                    (d - b).abs() < 1e-10,
262                    "weight=1 should return diffusion: {} vs {}",
263                    d,
264                    b
265                );
266            }
267        }
268    }
269
270    #[test]
271    fn test_empty_inputs() {
272        let gen = HybridGenerator::new(0.5);
273        let empty: Vec<Vec<f64>> = vec![];
274
275        let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
276        assert!(result.is_empty());
277
278        let result = gen.blend_ensemble(&empty, &empty, &[0]);
279        assert!(result.is_empty());
280    }
281
282    #[test]
283    fn test_weight_clamping() {
284        let gen_low = HybridGenerator::new(-0.5);
285        assert!((gen_low.weight() - 0.0).abs() < 1e-10);
286
287        let gen_high = HybridGenerator::new(1.5);
288        assert!((gen_high.weight() - 1.0).abs() < 1e-10);
289    }
290}