datasynth_core/diffusion/
statistical.rs1use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, StandardNormal};
10
11use super::backend::{DiffusionBackend, DiffusionConfig};
12use super::schedule::NoiseSchedule;
13
14#[derive(Debug, Clone)]
20pub struct StatisticalDiffusionBackend {
21 means: Vec<f64>,
23 stds: Vec<f64>,
25 correlations: Option<Vec<Vec<f64>>>,
27 config: DiffusionConfig,
29 schedule: NoiseSchedule,
31}
32
33impl StatisticalDiffusionBackend {
34 pub fn new(means: Vec<f64>, stds: Vec<f64>, config: DiffusionConfig) -> Self {
41 let schedule = config.build_schedule();
42 Self {
43 means,
44 stds,
45 correlations: None,
46 config,
47 schedule,
48 }
49 }
50
51 pub fn with_correlations(mut self, corr_matrix: Vec<Vec<f64>>) -> Self {
57 self.correlations = Some(corr_matrix);
58 self
59 }
60
61 fn cholesky_decomposition(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
66 let n = matrix.len();
67 if n == 0 {
68 return Some(vec![]);
69 }
70
71 let mut l = vec![vec![0.0; n]; n];
72
73 for i in 0..n {
74 for j in 0..=i {
75 let sum: f64 = l[i]
76 .iter()
77 .zip(l[j].iter())
78 .take(j)
79 .map(|(a, b)| a * b)
80 .sum();
81
82 if i == j {
83 let diag = matrix[i][i] - sum;
84 if diag <= 0.0 {
85 return None;
86 }
87 l[i][j] = diag.sqrt();
88 } else {
89 if l[j][j].abs() < 1e-15 {
90 return None;
91 }
92 l[i][j] = (matrix[i][j] - sum) / l[j][j];
93 }
94 }
95 }
96
97 Some(l)
98 }
99
100 fn apply_correlation(samples: &mut [Vec<f64>], cholesky_l: &[Vec<f64>]) {
105 let n_features = cholesky_l.len();
106 for row in samples.iter_mut() {
107 let original: Vec<f64> = row.iter().copied().take(n_features).collect();
108 for i in 0..n_features.min(row.len()) {
109 let mut val = 0.0;
110 for j in 0..=i {
111 if j < original.len() {
112 val += cholesky_l[i][j] * original[j];
113 }
114 }
115 row[i] = val;
116 }
117 }
118 }
119}
120
121impl DiffusionBackend for StatisticalDiffusionBackend {
122 fn name(&self) -> &str {
123 "statistical"
124 }
125
126 fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
130 let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
131 let sqrt_alpha_bar = self.schedule.sqrt_alpha_bars[t_clamped];
132 let sqrt_one_minus_alpha_bar = self.schedule.sqrt_one_minus_alpha_bars[t_clamped];
133
134 let n_features = x.first().map_or(0, |row| row.len());
135 let noise =
136 super::generate_noise(x.len(), n_features, self.config.seed.wrapping_add(t as u64));
137
138 x.iter()
139 .zip(noise.iter())
140 .map(|(row, noise_row)| {
141 row.iter()
142 .zip(noise_row.iter())
143 .map(|(&xi, &ni)| sqrt_alpha_bar * xi + sqrt_one_minus_alpha_bar * ni)
144 .collect()
145 })
146 .collect()
147 }
148
149 fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
153 let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
154 let beta_t = self.schedule.betas[t_clamped];
155
156 let step_size = beta_t;
158 let noise_scale = if t_clamped > 0 { beta_t.sqrt() } else { 0.0 };
160
161 let n_features = x_t.first().map_or(0, |row| row.len());
162 let noise = super::generate_noise(
163 x_t.len(),
164 n_features,
165 self.config
166 .seed
167 .wrapping_add(t as u64)
168 .wrapping_add(1_000_000),
169 );
170
171 x_t.iter()
172 .zip(noise.iter())
173 .map(|(row, noise_row)| {
174 row.iter()
175 .enumerate()
176 .map(|(j, &x_val)| {
177 let mu = if j < self.means.len() {
178 self.means[j]
179 } else {
180 0.0
181 };
182 let sigma = if j < self.stds.len() {
183 self.stds[j].max(1e-8)
184 } else {
185 1.0
186 };
187 let n = if j < noise_row.len() {
188 noise_row[j]
189 } else {
190 0.0
191 };
192
193 let drift = step_size * (x_val - mu) / (sigma * sigma);
195 x_val - drift + noise_scale * n
196 })
197 .collect()
198 })
199 .collect()
200 }
201
202 fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>> {
211 if n_samples == 0 || n_features == 0 {
212 return vec![];
213 }
214
215 let mut rng = ChaCha8Rng::seed_from_u64(seed);
216
217 let normal = StandardNormal;
219 let mut samples: Vec<Vec<f64>> = (0..n_samples)
220 .map(|_| (0..n_features).map(|_| normal.sample(&mut rng)).collect())
221 .collect();
222
223 let n_steps = self.schedule.n_steps();
228 for t in (0..n_steps).rev() {
229 let beta_t = self.schedule.betas[t];
230 let alpha_t = self.schedule.alphas[t];
231 let alpha_bar_t = self.schedule.alpha_bars[t];
232
233 let alpha_bar_prev = if t > 0 {
235 self.schedule.alpha_bars[t - 1]
236 } else {
237 1.0
238 };
239
240 let blend = (alpha_bar_prev - alpha_bar_t).max(0.0) / (1.0 - alpha_bar_t).max(1e-12);
243 let blend = blend.clamp(0.0, 1.0);
244
245 let noise_scale = if t > 0 { beta_t.sqrt() * 0.5 } else { 0.0 };
246
247 for row in samples.iter_mut() {
248 for (j, x_val) in row.iter_mut().enumerate().take(n_features) {
249 let mu = if j < self.means.len() {
250 self.means[j]
251 } else {
252 0.0
253 };
254 let sigma = if j < self.stds.len() {
255 self.stds[j].max(1e-8)
256 } else {
257 1.0
258 };
259
260 let z: f64 = normal.sample(&mut rng);
262 let target_val = mu + sigma * z;
263
264 let denoised = (1.0 - blend) * *x_val + blend * target_val;
266
267 let noise_val: f64 = if t > 0 { normal.sample(&mut rng) } else { 0.0 };
269
270 let correction = beta_t / (2.0 * alpha_t.max(1e-12)) * (*x_val - mu)
272 / (sigma * sigma).max(1e-12);
273 *x_val = denoised - correction + noise_scale * noise_val;
274 }
275 }
276 }
277
278 if let Some(ref corr_matrix) = self.correlations {
280 if let Some(cholesky_l) = Self::cholesky_decomposition(corr_matrix) {
281 let mut standardized: Vec<Vec<f64>> = samples
283 .iter()
284 .map(|row| {
285 row.iter()
286 .enumerate()
287 .map(|(j, &val)| {
288 let mu = if j < self.means.len() {
289 self.means[j]
290 } else {
291 0.0
292 };
293 let sigma = if j < self.stds.len() {
294 self.stds[j].max(1e-8)
295 } else {
296 1.0
297 };
298 (val - mu) / sigma
299 })
300 .collect()
301 })
302 .collect();
303
304 Self::apply_correlation(&mut standardized, &cholesky_l);
306
307 samples = standardized
309 .iter()
310 .map(|row| {
311 row.iter()
312 .enumerate()
313 .map(|(j, &val)| {
314 let mu = if j < self.means.len() {
315 self.means[j]
316 } else {
317 0.0
318 };
319 let sigma = if j < self.stds.len() {
320 self.stds[j].max(1e-8)
321 } else {
322 1.0
323 };
324 val * sigma + mu
325 })
326 .collect()
327 })
328 .collect();
329 }
330 }
331
332 for row in samples.iter_mut() {
334 for (j, val) in row.iter_mut().enumerate() {
335 let mu = if j < self.means.len() {
336 self.means[j]
337 } else {
338 0.0
339 };
340 let sigma = if j < self.stds.len() {
341 self.stds[j]
342 } else {
343 1.0
344 };
345 let lo = mu - 4.0 * sigma;
346 let hi = mu + 4.0 * sigma;
347 *val = val.clamp(lo, hi);
348 }
349 }
350
351 samples
352 }
353}
354
355#[cfg(test)]
356#[allow(clippy::unwrap_used)]
357mod tests {
358 use super::*;
359
360 fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
361 DiffusionConfig {
362 n_steps,
363 schedule: super::super::NoiseScheduleType::Linear,
364 seed,
365 }
366 }
367
368 #[test]
369 fn test_output_dimensions() {
370 let means = vec![100.0, 200.0, 300.0];
371 let stds = vec![10.0, 20.0, 30.0];
372 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
373
374 let samples = backend.generate(500, 3, 42);
375 assert_eq!(samples.len(), 500);
376 for row in &samples {
377 assert_eq!(row.len(), 3);
378 }
379 }
380
381 #[test]
382 fn test_deterministic_with_same_seed() {
383 let means = vec![50.0, 100.0];
384 let stds = vec![5.0, 10.0];
385 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 99));
386
387 let samples1 = backend.generate(100, 2, 123);
388 let samples2 = backend.generate(100, 2, 123);
389
390 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
391 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
392 assert!(
393 (v1 - v2).abs() < 1e-12,
394 "Determinism failed: {} vs {}",
395 v1,
396 v2
397 );
398 }
399 }
400 }
401
402 #[test]
403 fn test_mean_within_tolerance() {
404 let target_means = vec![100.0, 0.0, -50.0];
405 let target_stds = vec![10.0, 5.0, 20.0];
406 let backend = StatisticalDiffusionBackend::new(
407 target_means.clone(),
408 target_stds.clone(),
409 make_config(100, 42),
410 );
411
412 let samples = backend.generate(5000, 3, 42);
413
414 for feat in 0..3 {
416 let sample_mean: f64 =
417 samples.iter().map(|r| r[feat]).sum::<f64>() / samples.len() as f64;
418 let tolerance = target_stds[feat]; assert!(
420 (sample_mean - target_means[feat]).abs() < tolerance,
421 "Feature {} mean {} is more than 1 std ({}) from target {}",
422 feat,
423 sample_mean,
424 tolerance,
425 target_means[feat]
426 );
427 }
428 }
429
430 #[test]
431 fn test_forward_adds_noise() {
432 let means = vec![100.0, 200.0];
433 let stds = vec![10.0, 20.0];
434 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42));
435
436 let original = vec![vec![100.0, 200.0]; 100];
437
438 let noised_early = backend.forward(&original, 5);
440 let dist_early: f64 = noised_early
441 .iter()
442 .zip(original.iter())
443 .map(|(n, o)| {
444 n.iter()
445 .zip(o.iter())
446 .map(|(a, b)| (a - b).powi(2))
447 .sum::<f64>()
448 })
449 .sum::<f64>()
450 .sqrt();
451
452 let noised_late = backend.forward(&original, 90);
454 let dist_late: f64 = noised_late
455 .iter()
456 .zip(original.iter())
457 .map(|(n, o)| {
458 n.iter()
459 .zip(o.iter())
460 .map(|(a, b)| (a - b).powi(2))
461 .sum::<f64>()
462 })
463 .sum::<f64>()
464 .sqrt();
465
466 assert!(
467 dist_late > dist_early,
468 "Later timestep should add more noise: early={}, late={}",
469 dist_early,
470 dist_late
471 );
472 }
473
474 #[test]
475 fn test_correlation_structure_preserved() {
476 let means = vec![0.0, 0.0];
477 let stds = vec![1.0, 1.0];
478 let corr = vec![vec![1.0, 0.9], vec![0.9, 1.0]];
480
481 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42))
482 .with_correlations(corr);
483
484 let samples = backend.generate(5000, 2, 42);
485
486 let n = samples.len() as f64;
488 let mean0: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
489 let mean1: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
490 let std0: f64 = (samples.iter().map(|r| (r[0] - mean0).powi(2)).sum::<f64>() / n).sqrt();
491 let std1: f64 = (samples.iter().map(|r| (r[1] - mean1).powi(2)).sum::<f64>() / n).sqrt();
492 let cov01: f64 = samples
493 .iter()
494 .map(|r| (r[0] - mean0) * (r[1] - mean1))
495 .sum::<f64>()
496 / n;
497
498 let sample_corr = if std0 > 1e-8 && std1 > 1e-8 {
499 cov01 / (std0 * std1)
500 } else {
501 0.0
502 };
503
504 assert!(
506 sample_corr > 0.5,
507 "Expected positive correlation (target 0.9), got {}",
508 sample_corr
509 );
510 }
511
512 #[test]
513 fn test_cholesky_identity() {
514 let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
515 let l = StatisticalDiffusionBackend::cholesky_decomposition(&identity);
516 assert!(l.is_some());
517 let l = l.unwrap();
518 assert!((l[0][0] - 1.0).abs() < 1e-10);
519 assert!((l[1][1] - 1.0).abs() < 1e-10);
520 assert!(l[0][1].abs() < 1e-10);
521 assert!(l[1][0].abs() < 1e-10);
522 }
523
524 #[test]
525 fn test_cholesky_non_positive_definite() {
526 let matrix = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
528 let l = StatisticalDiffusionBackend::cholesky_decomposition(&matrix);
529 assert!(l.is_none());
530 }
531
532 #[test]
533 fn test_generate_empty() {
534 let backend = StatisticalDiffusionBackend::new(vec![], vec![], make_config(10, 0));
535 let samples = backend.generate(0, 0, 0);
536 assert!(samples.is_empty());
537 }
538
539 #[test]
540 fn test_values_clipped_to_range() {
541 let means = vec![0.0];
542 let stds = vec![1.0];
543 let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
544
545 let samples = backend.generate(1000, 1, 42);
546 for row in &samples {
547 assert!(
548 row[0] >= -4.0 && row[0] <= 4.0,
549 "Value {} out of clipping range [-4, 4]",
550 row[0]
551 );
552 }
553 }
554}