datasynth_core/diffusion/
hybrid.rs1use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10use rand_distr::{Distribution, Uniform};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum BlendStrategy {
15 Interpolate,
17 Select,
19 Ensemble,
21}
22
23#[derive(Debug, Clone)]
30pub struct HybridGenerator {
31 weight: f64,
33}
34
35impl HybridGenerator {
36 pub fn new(weight: f64) -> Self {
40 Self {
41 weight: weight.clamp(0.0, 1.0),
42 }
43 }
44
45 pub fn weight(&self) -> f64 {
47 self.weight
48 }
49
50 pub fn blend(
62 &self,
63 rule_based: &[Vec<f64>],
64 diffusion: &[Vec<f64>],
65 strategy: BlendStrategy,
66 seed: u64,
67 ) -> Vec<Vec<f64>> {
68 let n_rows = rule_based.len().min(diffusion.len());
69 if n_rows == 0 {
70 return vec![];
71 }
72
73 match strategy {
74 BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
75 BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
76 BlendStrategy::Ensemble => {
77 self.blend_interpolate(rule_based, diffusion, n_rows)
79 }
80 }
81 }
82
83 pub fn blend_ensemble(
91 &self,
92 rule_based: &[Vec<f64>],
93 diffusion: &[Vec<f64>],
94 diffusion_columns: &[usize],
95 ) -> Vec<Vec<f64>> {
96 let n_rows = rule_based.len().min(diffusion.len());
97 if n_rows == 0 {
98 return vec![];
99 }
100
101 (0..n_rows)
102 .map(|i| {
103 let rule_row = &rule_based[i];
104 let diff_row = &diffusion[i];
105 let n_cols = rule_row.len().min(diff_row.len());
106
107 (0..n_cols)
108 .map(|j| {
109 if diffusion_columns.contains(&j) {
110 diff_row[j]
111 } else {
112 rule_row[j]
113 }
114 })
115 .collect()
116 })
117 .collect()
118 }
119
120 fn blend_interpolate(
122 &self,
123 rule_based: &[Vec<f64>],
124 diffusion: &[Vec<f64>],
125 n_rows: usize,
126 ) -> Vec<Vec<f64>> {
127 let w = self.weight;
128 (0..n_rows)
129 .map(|i| {
130 let rule_row = &rule_based[i];
131 let diff_row = &diffusion[i];
132 let n_cols = rule_row.len().min(diff_row.len());
133 (0..n_cols)
134 .map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
135 .collect()
136 })
137 .collect()
138 }
139
140 fn blend_select(
142 &self,
143 rule_based: &[Vec<f64>],
144 diffusion: &[Vec<f64>],
145 n_rows: usize,
146 seed: u64,
147 ) -> Vec<Vec<f64>> {
148 let mut rng = ChaCha8Rng::seed_from_u64(seed);
149 let uniform = Uniform::new(0.0_f64, 1.0).expect("valid uniform params");
150
151 (0..n_rows)
152 .map(|i| {
153 let roll: f64 = uniform.sample(&mut rng);
154 if roll < self.weight {
155 diffusion[i].clone()
156 } else {
157 rule_based[i].clone()
158 }
159 })
160 .collect()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_interpolation_produces_blended_output() {
170 let gen = HybridGenerator::new(0.5);
171 let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
172 let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];
173
174 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
175 assert_eq!(blended.len(), 2);
176 assert!((blended[0][0] - 15.0).abs() < 1e-10);
178 assert!((blended[0][1] - 30.0).abs() < 1e-10);
179 assert!((blended[1][0] - 40.0).abs() < 1e-10);
180 assert!((blended[1][1] - 50.0).abs() < 1e-10);
181 }
182
183 #[test]
184 fn test_select_picks_from_both_sources() {
185 let gen = HybridGenerator::new(0.5);
186 let rules = vec![vec![0.0]; 1000];
187 let diffusion = vec![vec![1.0]; 1000];
188
189 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
190 assert_eq!(blended.len(), 1000);
191
192 let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
193 let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();
194
195 assert!(
197 count_diffusion > 100,
198 "Expected diffusion picks, got {}",
199 count_diffusion
200 );
201 assert!(
202 count_rule > 100,
203 "Expected rule-based picks, got {}",
204 count_rule
205 );
206 }
207
208 #[test]
209 fn test_ensemble_uses_correct_columns() {
210 let gen = HybridGenerator::new(0.5);
211 let rules = vec![vec![1.0, 2.0, 3.0]];
212 let diffusion = vec![vec![10.0, 20.0, 30.0]];
213 let diffusion_cols = vec![1]; let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
216 assert_eq!(blended.len(), 1);
217 assert!(
218 (blended[0][0] - 1.0).abs() < 1e-10,
219 "Column 0 should be rule-based"
220 );
221 assert!(
222 (blended[0][1] - 20.0).abs() < 1e-10,
223 "Column 1 should be diffusion"
224 );
225 assert!(
226 (blended[0][2] - 3.0).abs() < 1e-10,
227 "Column 2 should be rule-based"
228 );
229 }
230
231 #[test]
232 fn test_weight_zero_returns_rule_based() {
233 let gen = HybridGenerator::new(0.0);
234 let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
235 let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
236
237 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
238 for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
239 for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
240 assert!(
241 (r - b).abs() < 1e-10,
242 "weight=0 should return rule-based: {} vs {}",
243 r,
244 b
245 );
246 }
247 }
248 }
249
250 #[test]
251 fn test_weight_one_returns_diffusion() {
252 let gen = HybridGenerator::new(1.0);
253 let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
254 let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
255
256 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
257 for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
258 for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
259 assert!(
260 (d - b).abs() < 1e-10,
261 "weight=1 should return diffusion: {} vs {}",
262 d,
263 b
264 );
265 }
266 }
267 }
268
269 #[test]
270 fn test_empty_inputs() {
271 let gen = HybridGenerator::new(0.5);
272 let empty: Vec<Vec<f64>> = vec![];
273
274 let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
275 assert!(result.is_empty());
276
277 let result = gen.blend_ensemble(&empty, &empty, &[0]);
278 assert!(result.is_empty());
279 }
280
281 #[test]
282 fn test_weight_clamping() {
283 let gen_low = HybridGenerator::new(-0.5);
284 assert!((gen_low.weight() - 0.0).abs() < 1e-10);
285
286 let gen_high = HybridGenerator::new(1.5);
287 assert!((gen_high.weight() - 1.0).abs() < 1e-10);
288 }
289}