Skip to main content

datasynth_core/diffusion/
hybrid.rs

1//! Hybrid generator that blends rule-based and diffusion-generated data.
2//!
3//! Supports multiple blending strategies:
4//! - **Interpolate**: weighted average of corresponding samples
5//! - **Select**: randomly pick source per record based on weight
6//! - **Ensemble**: use diffusion for specified columns, rules for others
7
8use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10use rand_distr::{Distribution, Uniform};
11
12/// Strategy for blending rule-based and diffusion-generated data.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum BlendStrategy {
15    /// Weighted average of corresponding samples: output = (1 - weight) * rule + weight * diffusion
16    Interpolate,
17    /// Randomly select source per record: each row comes entirely from one source
18    Select,
19    /// Column-level blending: specified columns use diffusion, others use rule-based
20    Ensemble,
21}
22
23/// A hybrid generator that blends rule-based and diffusion-generated data.
24///
25/// The `weight` parameter controls the balance:
26/// - `0.0` = pure rule-based output
27/// - `1.0` = pure diffusion output
28/// - Values between blend according to the chosen strategy
29#[derive(Debug, Clone)]
30pub struct HybridGenerator {
31    /// Blending weight: 0.0 (pure rule-based) to 1.0 (pure diffusion).
32    weight: f64,
33}
34
35impl HybridGenerator {
36    /// Create a new hybrid generator with the given weight.
37    ///
38    /// The weight is clamped to [0.0, 1.0].
39    pub fn new(weight: f64) -> Self {
40        Self {
41            weight: weight.clamp(0.0, 1.0),
42        }
43    }
44
45    /// Get the current blending weight.
46    pub fn weight(&self) -> f64 {
47        self.weight
48    }
49
50    /// Blend rule-based and diffusion-generated data using the specified strategy.
51    ///
52    /// Both input slices must have the same number of rows and columns.
53    /// For `Ensemble` strategy, all columns will use the weight for interpolation;
54    /// use `blend_ensemble` for column-level control.
55    ///
56    /// # Arguments
57    /// * `rule_based` - Data generated by rule-based methods
58    /// * `diffusion` - Data generated by diffusion methods
59    /// * `strategy` - How to blend the two sources
60    /// * `seed` - Random seed for deterministic blending
61    pub fn blend(
62        &self,
63        rule_based: &[Vec<f64>],
64        diffusion: &[Vec<f64>],
65        strategy: BlendStrategy,
66        seed: u64,
67    ) -> Vec<Vec<f64>> {
68        let n_rows = rule_based.len().min(diffusion.len());
69        if n_rows == 0 {
70            return vec![];
71        }
72
73        match strategy {
74            BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
75            BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
76            BlendStrategy::Ensemble => {
77                // Without specific column indices, fall back to interpolation
78                self.blend_interpolate(rule_based, diffusion, n_rows)
79            }
80        }
81    }
82
83    /// Blend using column-level ensemble: specified columns use diffusion data,
84    /// remaining columns use rule-based data.
85    ///
86    /// # Arguments
87    /// * `rule_based` - Data generated by rule-based methods
88    /// * `diffusion` - Data generated by diffusion methods
89    /// * `diffusion_columns` - Column indices that should use diffusion output
90    pub fn blend_ensemble(
91        &self,
92        rule_based: &[Vec<f64>],
93        diffusion: &[Vec<f64>],
94        diffusion_columns: &[usize],
95    ) -> Vec<Vec<f64>> {
96        let n_rows = rule_based.len().min(diffusion.len());
97        if n_rows == 0 {
98            return vec![];
99        }
100
101        (0..n_rows)
102            .map(|i| {
103                let rule_row = &rule_based[i];
104                let diff_row = &diffusion[i];
105                let n_cols = rule_row.len().min(diff_row.len());
106
107                (0..n_cols)
108                    .map(|j| {
109                        if diffusion_columns.contains(&j) {
110                            diff_row[j]
111                        } else {
112                            rule_row[j]
113                        }
114                    })
115                    .collect()
116            })
117            .collect()
118    }
119
120    /// Interpolate: weighted average of each element.
121    fn blend_interpolate(
122        &self,
123        rule_based: &[Vec<f64>],
124        diffusion: &[Vec<f64>],
125        n_rows: usize,
126    ) -> Vec<Vec<f64>> {
127        let w = self.weight;
128        (0..n_rows)
129            .map(|i| {
130                let rule_row = &rule_based[i];
131                let diff_row = &diffusion[i];
132                let n_cols = rule_row.len().min(diff_row.len());
133                (0..n_cols)
134                    .map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
135                    .collect()
136            })
137            .collect()
138    }
139
140    /// Select: randomly choose source per row.
141    fn blend_select(
142        &self,
143        rule_based: &[Vec<f64>],
144        diffusion: &[Vec<f64>],
145        n_rows: usize,
146        seed: u64,
147    ) -> Vec<Vec<f64>> {
148        let mut rng = ChaCha8Rng::seed_from_u64(seed);
149        let uniform = Uniform::new(0.0_f64, 1.0).expect("valid uniform params");
150
151        (0..n_rows)
152            .map(|i| {
153                let roll: f64 = uniform.sample(&mut rng);
154                if roll < self.weight {
155                    diffusion[i].clone()
156                } else {
157                    rule_based[i].clone()
158                }
159            })
160            .collect()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_interpolation_produces_blended_output() {
170        let gen = HybridGenerator::new(0.5);
171        let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
172        let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];
173
174        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
175        assert_eq!(blended.len(), 2);
176        // 0.5 * 10 + 0.5 * 20 = 15
177        assert!((blended[0][0] - 15.0).abs() < 1e-10);
178        assert!((blended[0][1] - 30.0).abs() < 1e-10);
179        assert!((blended[1][0] - 40.0).abs() < 1e-10);
180        assert!((blended[1][1] - 50.0).abs() < 1e-10);
181    }
182
183    #[test]
184    fn test_select_picks_from_both_sources() {
185        let gen = HybridGenerator::new(0.5);
186        let rules = vec![vec![0.0]; 1000];
187        let diffusion = vec![vec![1.0]; 1000];
188
189        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
190        assert_eq!(blended.len(), 1000);
191
192        let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
193        let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();
194
195        // Both sources should be represented
196        assert!(
197            count_diffusion > 100,
198            "Expected diffusion picks, got {}",
199            count_diffusion
200        );
201        assert!(
202            count_rule > 100,
203            "Expected rule-based picks, got {}",
204            count_rule
205        );
206    }
207
208    #[test]
209    fn test_ensemble_uses_correct_columns() {
210        let gen = HybridGenerator::new(0.5);
211        let rules = vec![vec![1.0, 2.0, 3.0]];
212        let diffusion = vec![vec![10.0, 20.0, 30.0]];
213        let diffusion_cols = vec![1]; // Only column 1 from diffusion
214
215        let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
216        assert_eq!(blended.len(), 1);
217        assert!(
218            (blended[0][0] - 1.0).abs() < 1e-10,
219            "Column 0 should be rule-based"
220        );
221        assert!(
222            (blended[0][1] - 20.0).abs() < 1e-10,
223            "Column 1 should be diffusion"
224        );
225        assert!(
226            (blended[0][2] - 3.0).abs() < 1e-10,
227            "Column 2 should be rule-based"
228        );
229    }
230
231    #[test]
232    fn test_weight_zero_returns_rule_based() {
233        let gen = HybridGenerator::new(0.0);
234        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
235        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
236
237        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
238        for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
239            for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
240                assert!(
241                    (r - b).abs() < 1e-10,
242                    "weight=0 should return rule-based: {} vs {}",
243                    r,
244                    b
245                );
246            }
247        }
248    }
249
250    #[test]
251    fn test_weight_one_returns_diffusion() {
252        let gen = HybridGenerator::new(1.0);
253        let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
254        let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
255
256        let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
257        for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
258            for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
259                assert!(
260                    (d - b).abs() < 1e-10,
261                    "weight=1 should return diffusion: {} vs {}",
262                    d,
263                    b
264                );
265            }
266        }
267    }
268
269    #[test]
270    fn test_empty_inputs() {
271        let gen = HybridGenerator::new(0.5);
272        let empty: Vec<Vec<f64>> = vec![];
273
274        let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
275        assert!(result.is_empty());
276
277        let result = gen.blend_ensemble(&empty, &empty, &[0]);
278        assert!(result.is_empty());
279    }
280
281    #[test]
282    fn test_weight_clamping() {
283        let gen_low = HybridGenerator::new(-0.5);
284        assert!((gen_low.weight() - 0.0).abs() < 1e-10);
285
286        let gen_high = HybridGenerator::new(1.5);
287        assert!((gen_high.weight() - 1.0).abs() < 1e-10);
288    }
289}