datasynth_core/diffusion/
statistical.rs1use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, StandardNormal};
10
11use super::backend::{DiffusionBackend, DiffusionConfig};
12use super::schedule::NoiseSchedule;
13
14#[derive(Debug, Clone)]
20pub struct StatisticalDiffusionBackend {
21 means: Vec<f64>,
23 stds: Vec<f64>,
25 correlations: Option<Vec<Vec<f64>>>,
27 config: DiffusionConfig,
29 schedule: NoiseSchedule,
31}
32
33impl StatisticalDiffusionBackend {
34 pub fn new(means: Vec<f64>, stds: Vec<f64>, config: DiffusionConfig) -> Self {
41 let schedule = config.build_schedule();
42 Self {
43 means,
44 stds,
45 correlations: None,
46 config,
47 schedule,
48 }
49 }
50
51 pub fn with_correlations(mut self, corr_matrix: Vec<Vec<f64>>) -> Self {
57 self.correlations = Some(corr_matrix);
58 self
59 }
60
61 fn cholesky_decomposition(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
66 let n = matrix.len();
67 if n == 0 {
68 return Some(vec![]);
69 }
70
71 let mut l = vec![vec![0.0; n]; n];
72
73 for i in 0..n {
74 for j in 0..=i {
75 let sum: f64 = l[i]
76 .iter()
77 .zip(l[j].iter())
78 .take(j)
79 .map(|(a, b)| a * b)
80 .sum();
81
82 if i == j {
83 let diag = matrix[i][i] - sum;
84 if diag <= 0.0 {
85 return None;
86 }
87 l[i][j] = diag.sqrt();
88 } else {
89 if l[j][j].abs() < 1e-15 {
90 return None;
91 }
92 l[i][j] = (matrix[i][j] - sum) / l[j][j];
93 }
94 }
95 }
96
97 Some(l)
98 }
99
100 fn apply_correlation(samples: &mut [Vec<f64>], cholesky_l: &[Vec<f64>]) {
105 let n_features = cholesky_l.len();
106 for row in samples.iter_mut() {
107 let original: Vec<f64> = row.iter().copied().take(n_features).collect();
108 for i in 0..n_features.min(row.len()) {
109 let mut val = 0.0;
110 for j in 0..=i {
111 if j < original.len() {
112 val += cholesky_l[i][j] * original[j];
113 }
114 }
115 row[i] = val;
116 }
117 }
118 }
119}
120
121impl DiffusionBackend for StatisticalDiffusionBackend {
122 fn name(&self) -> &str {
123 "statistical"
124 }
125
126 fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
130 let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
131 let sqrt_alpha_bar = self.schedule.sqrt_alpha_bars[t_clamped];
132 let sqrt_one_minus_alpha_bar = self.schedule.sqrt_one_minus_alpha_bars[t_clamped];
133
134 let n_features = x.first().map_or(0, std::vec::Vec::len);
135 let noise =
136 super::generate_noise(x.len(), n_features, self.config.seed.wrapping_add(t as u64));
137
138 x.iter()
139 .zip(noise.iter())
140 .map(|(row, noise_row)| {
141 row.iter()
142 .zip(noise_row.iter())
143 .map(|(&xi, &ni)| sqrt_alpha_bar * xi + sqrt_one_minus_alpha_bar * ni)
144 .collect()
145 })
146 .collect()
147 }
148
149 fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
153 let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
154 let beta_t = self.schedule.betas[t_clamped];
155
156 let step_size = beta_t;
158 let noise_scale = if t_clamped > 0 { beta_t.sqrt() } else { 0.0 };
160
161 let n_features = x_t.first().map_or(0, std::vec::Vec::len);
162 let noise = super::generate_noise(
163 x_t.len(),
164 n_features,
165 self.config
166 .seed
167 .wrapping_add(t as u64)
168 .wrapping_add(1_000_000),
169 );
170
171 x_t.iter()
172 .zip(noise.iter())
173 .map(|(row, noise_row)| {
174 row.iter()
175 .enumerate()
176 .map(|(j, &x_val)| {
177 let mu = if j < self.means.len() {
178 self.means[j]
179 } else {
180 0.0
181 };
182 let sigma = if j < self.stds.len() {
183 self.stds[j].max(1e-8)
184 } else {
185 1.0
186 };
187 let n = if j < noise_row.len() {
188 noise_row[j]
189 } else {
190 0.0
191 };
192
193 let drift = step_size * (x_val - mu) / (sigma * sigma);
195 x_val - drift + noise_scale * n
196 })
197 .collect()
198 })
199 .collect()
200 }
201
202 fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>> {
211 if n_samples == 0 || n_features == 0 {
212 return vec![];
213 }
214
215 let mut rng = ChaCha8Rng::seed_from_u64(seed);
216
217 let normal = StandardNormal;
219 let mut samples: Vec<Vec<f64>> = (0..n_samples)
220 .map(|_| (0..n_features).map(|_| normal.sample(&mut rng)).collect())
221 .collect();
222
223 let n_steps = self.schedule.n_steps();
228 for t in (0..n_steps).rev() {
229 let beta_t = self.schedule.betas[t];
230 let alpha_t = self.schedule.alphas[t];
231 let alpha_bar_t = self.schedule.alpha_bars[t];
232
233 let alpha_bar_prev = if t > 0 {
235 self.schedule.alpha_bars[t - 1]
236 } else {
237 1.0
238 };
239
240 let blend = (alpha_bar_prev - alpha_bar_t).max(0.0) / (1.0 - alpha_bar_t).max(1e-12);
243 let blend = blend.clamp(0.0, 1.0);
244
245 let noise_scale = if t > 0 { beta_t.sqrt() * 0.5 } else { 0.0 };
246
247 for row in samples.iter_mut() {
248 for (j, x_val) in row.iter_mut().enumerate().take(n_features) {
249 let mu = if j < self.means.len() {
250 self.means[j]
251 } else {
252 0.0
253 };
254 let sigma = if j < self.stds.len() {
255 self.stds[j].max(1e-8)
256 } else {
257 1.0
258 };
259
260 let z: f64 = normal.sample(&mut rng);
262 let target_val = mu + sigma * z;
263
264 let denoised = (1.0 - blend) * *x_val + blend * target_val;
266
267 let noise_val: f64 = if t > 0 { normal.sample(&mut rng) } else { 0.0 };
269
270 let correction = beta_t / (2.0 * alpha_t.max(1e-12)) * (*x_val - mu)
272 / (sigma * sigma).max(1e-12);
273 *x_val = denoised - correction + noise_scale * noise_val;
274 }
275 }
276 }
277
278 if let Some(ref corr_matrix) = self.correlations {
280 if let Some(cholesky_l) = Self::cholesky_decomposition(corr_matrix) {
281 let mut standardized: Vec<Vec<f64>> = samples
283 .iter()
284 .map(|row| {
285 row.iter()
286 .enumerate()
287 .map(|(j, &val)| {
288 let mu = if j < self.means.len() {
289 self.means[j]
290 } else {
291 0.0
292 };
293 let sigma = if j < self.stds.len() {
294 self.stds[j].max(1e-8)
295 } else {
296 1.0
297 };
298 (val - mu) / sigma
299 })
300 .collect()
301 })
302 .collect();
303
304 Self::apply_correlation(&mut standardized, &cholesky_l);
306
307 samples = standardized
309 .iter()
310 .map(|row| {
311 row.iter()
312 .enumerate()
313 .map(|(j, &val)| {
314 let mu = if j < self.means.len() {
315 self.means[j]
316 } else {
317 0.0
318 };
319 let sigma = if j < self.stds.len() {
320 self.stds[j].max(1e-8)
321 } else {
322 1.0
323 };
324 val * sigma + mu
325 })
326 .collect()
327 })
328 .collect();
329 }
330 }
331
332 for row in samples.iter_mut() {
334 for (j, val) in row.iter_mut().enumerate() {
335 let mu = if j < self.means.len() {
336 self.means[j]
337 } else {
338 0.0
339 };
340 let sigma = if j < self.stds.len() {
341 self.stds[j]
342 } else {
343 1.0
344 };
345 let lo = mu - 4.0 * sigma;
346 let hi = mu + 4.0 * sigma;
347 *val = val.clamp(lo, hi);
348 }
349 }
350
351 samples
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
360 DiffusionConfig {
361 n_steps,
362 schedule: super::super::NoiseScheduleType::Linear,
363 seed,
364 }
365 }
366
367 #[test]
368 fn test_output_dimensions() {
369 let means = vec![100.0, 200.0, 300.0];
370 let stds = vec![10.0, 20.0, 30.0];
371 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
372
373 let samples = backend.generate(500, 3, 42);
374 assert_eq!(samples.len(), 500);
375 for row in &samples {
376 assert_eq!(row.len(), 3);
377 }
378 }
379
380 #[test]
381 fn test_deterministic_with_same_seed() {
382 let means = vec![50.0, 100.0];
383 let stds = vec![5.0, 10.0];
384 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 99));
385
386 let samples1 = backend.generate(100, 2, 123);
387 let samples2 = backend.generate(100, 2, 123);
388
389 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
390 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
391 assert!(
392 (v1 - v2).abs() < 1e-12,
393 "Determinism failed: {} vs {}",
394 v1,
395 v2
396 );
397 }
398 }
399 }
400
401 #[test]
402 fn test_mean_within_tolerance() {
403 let target_means = vec![100.0, 0.0, -50.0];
404 let target_stds = vec![10.0, 5.0, 20.0];
405 let backend = StatisticalDiffusionBackend::new(
406 target_means.clone(),
407 target_stds.clone(),
408 make_config(100, 42),
409 );
410
411 let samples = backend.generate(5000, 3, 42);
412
413 for feat in 0..3 {
415 let sample_mean: f64 =
416 samples.iter().map(|r| r[feat]).sum::<f64>() / samples.len() as f64;
417 let tolerance = target_stds[feat]; assert!(
419 (sample_mean - target_means[feat]).abs() < tolerance,
420 "Feature {} mean {} is more than 1 std ({}) from target {}",
421 feat,
422 sample_mean,
423 tolerance,
424 target_means[feat]
425 );
426 }
427 }
428
429 #[test]
430 fn test_forward_adds_noise() {
431 let means = vec![100.0, 200.0];
432 let stds = vec![10.0, 20.0];
433 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42));
434
435 let original = vec![vec![100.0, 200.0]; 100];
436
437 let noised_early = backend.forward(&original, 5);
439 let dist_early: f64 = noised_early
440 .iter()
441 .zip(original.iter())
442 .map(|(n, o)| {
443 n.iter()
444 .zip(o.iter())
445 .map(|(a, b)| (a - b).powi(2))
446 .sum::<f64>()
447 })
448 .sum::<f64>()
449 .sqrt();
450
451 let noised_late = backend.forward(&original, 90);
453 let dist_late: f64 = noised_late
454 .iter()
455 .zip(original.iter())
456 .map(|(n, o)| {
457 n.iter()
458 .zip(o.iter())
459 .map(|(a, b)| (a - b).powi(2))
460 .sum::<f64>()
461 })
462 .sum::<f64>()
463 .sqrt();
464
465 assert!(
466 dist_late > dist_early,
467 "Later timestep should add more noise: early={}, late={}",
468 dist_early,
469 dist_late
470 );
471 }
472
473 #[test]
474 fn test_correlation_structure_preserved() {
475 let means = vec![0.0, 0.0];
476 let stds = vec![1.0, 1.0];
477 let corr = vec![vec![1.0, 0.9], vec![0.9, 1.0]];
479
480 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42))
481 .with_correlations(corr);
482
483 let samples = backend.generate(5000, 2, 42);
484
485 let n = samples.len() as f64;
487 let mean0: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
488 let mean1: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
489 let std0: f64 = (samples.iter().map(|r| (r[0] - mean0).powi(2)).sum::<f64>() / n).sqrt();
490 let std1: f64 = (samples.iter().map(|r| (r[1] - mean1).powi(2)).sum::<f64>() / n).sqrt();
491 let cov01: f64 = samples
492 .iter()
493 .map(|r| (r[0] - mean0) * (r[1] - mean1))
494 .sum::<f64>()
495 / n;
496
497 let sample_corr = if std0 > 1e-8 && std1 > 1e-8 {
498 cov01 / (std0 * std1)
499 } else {
500 0.0
501 };
502
503 assert!(
505 sample_corr > 0.5,
506 "Expected positive correlation (target 0.9), got {}",
507 sample_corr
508 );
509 }
510
511 #[test]
512 fn test_cholesky_identity() {
513 let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
514 let l = StatisticalDiffusionBackend::cholesky_decomposition(&identity);
515 assert!(l.is_some());
516 let l = l.unwrap();
517 assert!((l[0][0] - 1.0).abs() < 1e-10);
518 assert!((l[1][1] - 1.0).abs() < 1e-10);
519 assert!(l[0][1].abs() < 1e-10);
520 assert!(l[1][0].abs() < 1e-10);
521 }
522
523 #[test]
524 fn test_cholesky_non_positive_definite() {
525 let matrix = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
527 let l = StatisticalDiffusionBackend::cholesky_decomposition(&matrix);
528 assert!(l.is_none());
529 }
530
531 #[test]
532 fn test_generate_empty() {
533 let backend = StatisticalDiffusionBackend::new(vec![], vec![], make_config(10, 0));
534 let samples = backend.generate(0, 0, 0);
535 assert!(samples.is_empty());
536 }
537
538 #[test]
539 fn test_values_clipped_to_range() {
540 let means = vec![0.0];
541 let stds = vec![1.0];
542 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
543
544 let samples = backend.generate(1000, 1, 42);
545 for row in &samples {
546 assert!(
547 row[0] >= -4.0 && row[0] <= 4.0,
548 "Value {} out of clipping range [-4, 4]",
549 row[0]
550 );
551 }
552 }
553}