datasynth_core/diffusion/
hybrid.rs1use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10use rand_distr::{Distribution, Uniform};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum BlendStrategy {
15 Interpolate,
17 Select,
19 Ensemble,
21}
22
23#[derive(Debug, Clone)]
30pub struct HybridGenerator {
31 weight: f64,
33}
34
35impl HybridGenerator {
36 pub fn new(weight: f64) -> Self {
40 Self {
41 weight: weight.clamp(0.0, 1.0),
42 }
43 }
44
45 pub fn weight(&self) -> f64 {
47 self.weight
48 }
49
50 pub fn blend(
62 &self,
63 rule_based: &[Vec<f64>],
64 diffusion: &[Vec<f64>],
65 strategy: BlendStrategy,
66 seed: u64,
67 ) -> Vec<Vec<f64>> {
68 let n_rows = rule_based.len().min(diffusion.len());
69 if n_rows == 0 {
70 return vec![];
71 }
72
73 match strategy {
74 BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
75 BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
76 BlendStrategy::Ensemble => {
77 self.blend_interpolate(rule_based, diffusion, n_rows)
79 }
80 }
81 }
82
83 pub fn blend_ensemble(
91 &self,
92 rule_based: &[Vec<f64>],
93 diffusion: &[Vec<f64>],
94 diffusion_columns: &[usize],
95 ) -> Vec<Vec<f64>> {
96 let n_rows = rule_based.len().min(diffusion.len());
97 if n_rows == 0 {
98 return vec![];
99 }
100
101 (0..n_rows)
102 .map(|i| {
103 let rule_row = &rule_based[i];
104 let diff_row = &diffusion[i];
105 let n_cols = rule_row.len().min(diff_row.len());
106
107 (0..n_cols)
108 .map(|j| {
109 if diffusion_columns.contains(&j) {
110 diff_row[j]
111 } else {
112 rule_row[j]
113 }
114 })
115 .collect()
116 })
117 .collect()
118 }
119
120 fn blend_interpolate(
122 &self,
123 rule_based: &[Vec<f64>],
124 diffusion: &[Vec<f64>],
125 n_rows: usize,
126 ) -> Vec<Vec<f64>> {
127 let w = self.weight;
128 (0..n_rows)
129 .map(|i| {
130 let rule_row = &rule_based[i];
131 let diff_row = &diffusion[i];
132 let n_cols = rule_row.len().min(diff_row.len());
133 (0..n_cols)
134 .map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
135 .collect()
136 })
137 .collect()
138 }
139
140 fn blend_select(
142 &self,
143 rule_based: &[Vec<f64>],
144 diffusion: &[Vec<f64>],
145 n_rows: usize,
146 seed: u64,
147 ) -> Vec<Vec<f64>> {
148 let mut rng = ChaCha8Rng::seed_from_u64(seed);
149 let uniform = Uniform::new(0.0_f64, 1.0);
150
151 (0..n_rows)
152 .map(|i| {
153 let roll: f64 = uniform.sample(&mut rng);
154 if roll < self.weight {
155 diffusion[i].clone()
156 } else {
157 rule_based[i].clone()
158 }
159 })
160 .collect()
161 }
162}
163
164#[cfg(test)]
165#[allow(clippy::unwrap_used)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_interpolation_produces_blended_output() {
171 let gen = HybridGenerator::new(0.5);
172 let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
173 let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];
174
175 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
176 assert_eq!(blended.len(), 2);
177 assert!((blended[0][0] - 15.0).abs() < 1e-10);
179 assert!((blended[0][1] - 30.0).abs() < 1e-10);
180 assert!((blended[1][0] - 40.0).abs() < 1e-10);
181 assert!((blended[1][1] - 50.0).abs() < 1e-10);
182 }
183
184 #[test]
185 fn test_select_picks_from_both_sources() {
186 let gen = HybridGenerator::new(0.5);
187 let rules = vec![vec![0.0]; 1000];
188 let diffusion = vec![vec![1.0]; 1000];
189
190 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
191 assert_eq!(blended.len(), 1000);
192
193 let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
194 let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();
195
196 assert!(
198 count_diffusion > 100,
199 "Expected diffusion picks, got {}",
200 count_diffusion
201 );
202 assert!(
203 count_rule > 100,
204 "Expected rule-based picks, got {}",
205 count_rule
206 );
207 }
208
209 #[test]
210 fn test_ensemble_uses_correct_columns() {
211 let gen = HybridGenerator::new(0.5);
212 let rules = vec![vec![1.0, 2.0, 3.0]];
213 let diffusion = vec![vec![10.0, 20.0, 30.0]];
214 let diffusion_cols = vec![1]; let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
217 assert_eq!(blended.len(), 1);
218 assert!(
219 (blended[0][0] - 1.0).abs() < 1e-10,
220 "Column 0 should be rule-based"
221 );
222 assert!(
223 (blended[0][1] - 20.0).abs() < 1e-10,
224 "Column 1 should be diffusion"
225 );
226 assert!(
227 (blended[0][2] - 3.0).abs() < 1e-10,
228 "Column 2 should be rule-based"
229 );
230 }
231
232 #[test]
233 fn test_weight_zero_returns_rule_based() {
234 let gen = HybridGenerator::new(0.0);
235 let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
236 let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
237
238 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
239 for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
240 for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
241 assert!(
242 (r - b).abs() < 1e-10,
243 "weight=0 should return rule-based: {} vs {}",
244 r,
245 b
246 );
247 }
248 }
249 }
250
251 #[test]
252 fn test_weight_one_returns_diffusion() {
253 let gen = HybridGenerator::new(1.0);
254 let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
255 let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
256
257 let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
258 for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
259 for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
260 assert!(
261 (d - b).abs() < 1e-10,
262 "weight=1 should return diffusion: {} vs {}",
263 d,
264 b
265 );
266 }
267 }
268 }
269
270 #[test]
271 fn test_empty_inputs() {
272 let gen = HybridGenerator::new(0.5);
273 let empty: Vec<Vec<f64>> = vec![];
274
275 let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
276 assert!(result.is_empty());
277
278 let result = gen.blend_ensemble(&empty, &empty, &[0]);
279 assert!(result.is_empty());
280 }
281
282 #[test]
283 fn test_weight_clamping() {
284 let gen_low = HybridGenerator::new(-0.5);
285 assert!((gen_low.weight() - 0.0).abs() < 1e-10);
286
287 let gen_high = HybridGenerator::new(1.5);
288 assert!((gen_high.weight() - 1.0).abs() < 1e-10);
289 }
290}