Skip to main content

datasynth_core/diffusion/
statistical.rs

1//! Statistical diffusion backend that generates data matching target distributions.
2//!
3//! Uses a Langevin-inspired reverse process to denoise samples toward
4//! target means and standard deviations, with optional correlation structure
5//! applied via Cholesky decomposition.
6
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, StandardNormal};
10
11use super::backend::{DiffusionBackend, DiffusionConfig};
12use super::schedule::NoiseSchedule;
13
14/// A diffusion backend that generates samples matching target statistical properties.
15///
16/// The forward process adds Gaussian noise according to the noise schedule.
17/// The reverse process uses Langevin-inspired updates to guide samples toward
18/// the target distribution (means, standard deviations, correlations).
19#[derive(Debug, Clone)]
20pub struct StatisticalDiffusionBackend {
21    /// Target means for each feature.
22    means: Vec<f64>,
23    /// Target standard deviations for each feature.
24    stds: Vec<f64>,
25    /// Optional correlation matrix (n_features x n_features).
26    correlations: Option<Vec<Vec<f64>>>,
27    /// Diffusion configuration.
28    config: DiffusionConfig,
29    /// Precomputed noise schedule.
30    schedule: NoiseSchedule,
31}
32
33impl StatisticalDiffusionBackend {
34    /// Create a new statistical diffusion backend.
35    ///
36    /// # Arguments
37    /// * `means` - Target means for each feature dimension
38    /// * `stds` - Target standard deviations for each feature dimension
39    /// * `config` - Diffusion configuration (steps, schedule type, seed)
40    pub fn new(means: Vec<f64>, stds: Vec<f64>, config: DiffusionConfig) -> Self {
41        let schedule = config.build_schedule();
42        Self {
43            means,
44            stds,
45            correlations: None,
46            config,
47            schedule,
48        }
49    }
50
51    /// Set the correlation matrix for multi-dimensional generation.
52    ///
53    /// The matrix should be symmetric positive-definite with ones on the diagonal.
54    /// After denoising, Cholesky decomposition is used to impose this correlation
55    /// structure on the generated samples.
56    pub fn with_correlations(mut self, corr_matrix: Vec<Vec<f64>>) -> Self {
57        self.correlations = Some(corr_matrix);
58        self
59    }
60
61    /// Perform Cholesky decomposition of a symmetric positive-definite matrix.
62    ///
63    /// Returns the lower-triangular matrix L such that A = L * L^T.
64    /// Returns `None` if the matrix is not positive-definite.
65    fn cholesky_decomposition(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
66        let n = matrix.len();
67        if n == 0 {
68            return Some(vec![]);
69        }
70
71        let mut l = vec![vec![0.0; n]; n];
72
73        for i in 0..n {
74            for j in 0..=i {
75                let sum: f64 = l[i]
76                    .iter()
77                    .zip(l[j].iter())
78                    .take(j)
79                    .map(|(a, b)| a * b)
80                    .sum();
81
82                if i == j {
83                    let diag = matrix[i][i] - sum;
84                    if diag <= 0.0 {
85                        return None;
86                    }
87                    l[i][j] = diag.sqrt();
88                } else {
89                    if l[j][j].abs() < 1e-15 {
90                        return None;
91                    }
92                    l[i][j] = (matrix[i][j] - sum) / l[j][j];
93                }
94            }
95        }
96
97        Some(l)
98    }
99
100    /// Apply correlation structure to independent samples using Cholesky decomposition.
101    ///
102    /// Given independent standard normal samples, multiply by the Cholesky factor
103    /// to produce correlated samples.
104    fn apply_correlation(samples: &mut [Vec<f64>], cholesky_l: &[Vec<f64>]) {
105        let n_features = cholesky_l.len();
106        for row in samples.iter_mut() {
107            let original: Vec<f64> = row.iter().copied().take(n_features).collect();
108            for i in 0..n_features.min(row.len()) {
109                let mut val = 0.0;
110                for j in 0..=i {
111                    if j < original.len() {
112                        val += cholesky_l[i][j] * original[j];
113                    }
114                }
115                row[i] = val;
116            }
117        }
118    }
119}
120
121impl DiffusionBackend for StatisticalDiffusionBackend {
122    fn name(&self) -> &str {
123        "statistical"
124    }
125
126    /// Forward process: add noise at timestep t.
127    ///
128    /// x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
129    fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
130        let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
131        let sqrt_alpha_bar = self.schedule.sqrt_alpha_bars[t_clamped];
132        let sqrt_one_minus_alpha_bar = self.schedule.sqrt_one_minus_alpha_bars[t_clamped];
133
134        let n_features = x.first().map_or(0, |row| row.len());
135        let noise =
136            super::generate_noise(x.len(), n_features, self.config.seed.wrapping_add(t as u64));
137
138        x.iter()
139            .zip(noise.iter())
140            .map(|(row, noise_row)| {
141                row.iter()
142                    .zip(noise_row.iter())
143                    .map(|(&xi, &ni)| sqrt_alpha_bar * xi + sqrt_one_minus_alpha_bar * ni)
144                    .collect()
145            })
146            .collect()
147    }
148
149    /// Reverse process: denoise at timestep t using Langevin-inspired updates.
150    ///
151    /// x_{t-1} = x_t - step_size * (x_t - mu) / sigma^2 + noise_scale * noise
152    fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>> {
153        let t_clamped = t.min(self.schedule.n_steps().saturating_sub(1));
154        let beta_t = self.schedule.betas[t_clamped];
155
156        // Step size derived from the noise schedule beta
157        let step_size = beta_t;
158        // Noise scale decreases as we approach t=0
159        let noise_scale = if t_clamped > 0 { beta_t.sqrt() } else { 0.0 };
160
161        let n_features = x_t.first().map_or(0, |row| row.len());
162        let noise = super::generate_noise(
163            x_t.len(),
164            n_features,
165            self.config
166                .seed
167                .wrapping_add(t as u64)
168                .wrapping_add(1_000_000),
169        );
170
171        x_t.iter()
172            .zip(noise.iter())
173            .map(|(row, noise_row)| {
174                row.iter()
175                    .enumerate()
176                    .map(|(j, &x_val)| {
177                        let mu = if j < self.means.len() {
178                            self.means[j]
179                        } else {
180                            0.0
181                        };
182                        let sigma = if j < self.stds.len() {
183                            self.stds[j].max(1e-8)
184                        } else {
185                            1.0
186                        };
187                        let n = if j < noise_row.len() {
188                            noise_row[j]
189                        } else {
190                            0.0
191                        };
192
193                        // Langevin-inspired drift toward target distribution
194                        let drift = step_size * (x_val - mu) / (sigma * sigma);
195                        x_val - drift + noise_scale * n
196                    })
197                    .collect()
198            })
199            .collect()
200    }
201
202    /// Generate n_samples with n_features by starting from pure noise and
203    /// iteratively denoising using the reverse process.
204    ///
205    /// At each reverse step t, the sample is updated via:
206    ///   x_{t-1} = (1 - blend) * x_t + blend * (mu + sigma * z) + noise
207    /// where blend is derived from the schedule's signal-to-noise progression,
208    /// z is standard normal for stochastic variation, and noise decreases to
209    /// zero at t=0.
210    fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>> {
211        if n_samples == 0 || n_features == 0 {
212            return vec![];
213        }
214
215        let mut rng = ChaCha8Rng::seed_from_u64(seed);
216
217        // Start from pure standard normal noise
218        let normal = StandardNormal;
219        let mut samples: Vec<Vec<f64>> = (0..n_samples)
220            .map(|_| (0..n_features).map(|_| normal.sample(&mut rng)).collect())
221            .collect();
222
223        // Reverse process: denoise from t = T-1 down to 0
224        // We use the schedule's alpha_bar to progressively blend toward the
225        // target distribution. At each step, the blend factor increases as
226        // more signal is recovered.
227        let n_steps = self.schedule.n_steps();
228        for t in (0..n_steps).rev() {
229            let beta_t = self.schedule.betas[t];
230            let alpha_t = self.schedule.alphas[t];
231            let alpha_bar_t = self.schedule.alpha_bars[t];
232
233            // Previous alpha_bar (at t-1); for t=0 this is 1.0 (fully denoised)
234            let alpha_bar_prev = if t > 0 {
235                self.schedule.alpha_bars[t - 1]
236            } else {
237                1.0
238            };
239
240            // Blend factor: how much to move toward the target at this step
241            // As t decreases, alpha_bar increases, so we progressively reveal signal
242            let blend = (alpha_bar_prev - alpha_bar_t).max(0.0) / (1.0 - alpha_bar_t).max(1e-12);
243            let blend = blend.clamp(0.0, 1.0);
244
245            let noise_scale = if t > 0 { beta_t.sqrt() * 0.5 } else { 0.0 };
246
247            for row in samples.iter_mut() {
248                for (j, x_val) in row.iter_mut().enumerate().take(n_features) {
249                    let mu = if j < self.means.len() {
250                        self.means[j]
251                    } else {
252                        0.0
253                    };
254                    let sigma = if j < self.stds.len() {
255                        self.stds[j].max(1e-8)
256                    } else {
257                        1.0
258                    };
259
260                    // Target sample: draw from target distribution
261                    let z: f64 = normal.sample(&mut rng);
262                    let target_val = mu + sigma * z;
263
264                    // Blend current noisy sample toward target
265                    let denoised = (1.0 - blend) * *x_val + blend * target_val;
266
267                    // Add small stochastic noise (diminishes to zero at t=0)
268                    let noise_val: f64 = if t > 0 { normal.sample(&mut rng) } else { 0.0 };
269
270                    // Update using DDPM-style posterior with Langevin correction
271                    let correction = beta_t / (2.0 * alpha_t.max(1e-12)) * (*x_val - mu)
272                        / (sigma * sigma).max(1e-12);
273                    *x_val = denoised - correction + noise_scale * noise_val;
274                }
275            }
276        }
277
278        // Apply correlation structure via Cholesky decomposition if provided
279        if let Some(ref corr_matrix) = self.correlations {
280            if let Some(cholesky_l) = Self::cholesky_decomposition(corr_matrix) {
281                // First standardize the samples (subtract mean, divide by std)
282                let mut standardized: Vec<Vec<f64>> = samples
283                    .iter()
284                    .map(|row| {
285                        row.iter()
286                            .enumerate()
287                            .map(|(j, &val)| {
288                                let mu = if j < self.means.len() {
289                                    self.means[j]
290                                } else {
291                                    0.0
292                                };
293                                let sigma = if j < self.stds.len() {
294                                    self.stds[j].max(1e-8)
295                                } else {
296                                    1.0
297                                };
298                                (val - mu) / sigma
299                            })
300                            .collect()
301                    })
302                    .collect();
303
304                // Apply correlation
305                Self::apply_correlation(&mut standardized, &cholesky_l);
306
307                // Denormalize back to target scale
308                samples = standardized
309                    .iter()
310                    .map(|row| {
311                        row.iter()
312                            .enumerate()
313                            .map(|(j, &val)| {
314                                let mu = if j < self.means.len() {
315                                    self.means[j]
316                                } else {
317                                    0.0
318                                };
319                                let sigma = if j < self.stds.len() {
320                                    self.stds[j].max(1e-8)
321                                } else {
322                                    1.0
323                                };
324                                val * sigma + mu
325                            })
326                            .collect()
327                    })
328                    .collect();
329            }
330        }
331
332        // Clip to reasonable ranges: mean +/- 4 * std
333        for row in samples.iter_mut() {
334            for (j, val) in row.iter_mut().enumerate() {
335                let mu = if j < self.means.len() {
336                    self.means[j]
337                } else {
338                    0.0
339                };
340                let sigma = if j < self.stds.len() {
341                    self.stds[j]
342                } else {
343                    1.0
344                };
345                let lo = mu - 4.0 * sigma;
346                let hi = mu + 4.0 * sigma;
347                *val = val.clamp(lo, hi);
348            }
349        }
350
351        samples
352    }
353}
354
355#[cfg(test)]
356#[allow(clippy::unwrap_used)]
357mod tests {
358    use super::*;
359
360    fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
361        DiffusionConfig {
362            n_steps,
363            schedule: super::super::NoiseScheduleType::Linear,
364            seed,
365        }
366    }
367
368    #[test]
369    fn test_output_dimensions() {
370        let means = vec![100.0, 200.0, 300.0];
371        let stds = vec![10.0, 20.0, 30.0];
372        let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
373
374        let samples = backend.generate(500, 3, 42);
375        assert_eq!(samples.len(), 500);
376        for row in &samples {
377            assert_eq!(row.len(), 3);
378        }
379    }
380
381    #[test]
382    fn test_deterministic_with_same_seed() {
383        let means = vec![50.0, 100.0];
384        let stds = vec![5.0, 10.0];
385        let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 99));
386
387        let samples1 = backend.generate(100, 2, 123);
388        let samples2 = backend.generate(100, 2, 123);
389
390        for (row1, row2) in samples1.iter().zip(samples2.iter()) {
391            for (&v1, &v2) in row1.iter().zip(row2.iter()) {
392                assert!(
393                    (v1 - v2).abs() < 1e-12,
394                    "Determinism failed: {} vs {}",
395                    v1,
396                    v2
397                );
398            }
399        }
400    }
401
402    #[test]
403    fn test_mean_within_tolerance() {
404        let target_means = vec![100.0, 0.0, -50.0];
405        let target_stds = vec![10.0, 5.0, 20.0];
406        let backend = StatisticalDiffusionBackend::new(
407            target_means.clone(),
408            target_stds.clone(),
409            make_config(100, 42),
410        );
411
412        let samples = backend.generate(5000, 3, 42);
413
414        // Compute sample means
415        for feat in 0..3 {
416            let sample_mean: f64 =
417                samples.iter().map(|r| r[feat]).sum::<f64>() / samples.len() as f64;
418            let tolerance = target_stds[feat]; // within 1 std of target
419            assert!(
420                (sample_mean - target_means[feat]).abs() < tolerance,
421                "Feature {} mean {} is more than 1 std ({}) from target {}",
422                feat,
423                sample_mean,
424                tolerance,
425                target_means[feat]
426            );
427        }
428    }
429
430    #[test]
431    fn test_forward_adds_noise() {
432        let means = vec![100.0, 200.0];
433        let stds = vec![10.0, 20.0];
434        let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42));
435
436        let original = vec![vec![100.0, 200.0]; 100];
437
438        // At early timestep, noise is small
439        let noised_early = backend.forward(&original, 5);
440        let dist_early: f64 = noised_early
441            .iter()
442            .zip(original.iter())
443            .map(|(n, o)| {
444                n.iter()
445                    .zip(o.iter())
446                    .map(|(a, b)| (a - b).powi(2))
447                    .sum::<f64>()
448            })
449            .sum::<f64>()
450            .sqrt();
451
452        // At late timestep, noise is large
453        let noised_late = backend.forward(&original, 90);
454        let dist_late: f64 = noised_late
455            .iter()
456            .zip(original.iter())
457            .map(|(n, o)| {
458                n.iter()
459                    .zip(o.iter())
460                    .map(|(a, b)| (a - b).powi(2))
461                    .sum::<f64>()
462            })
463            .sum::<f64>()
464            .sqrt();
465
466        assert!(
467            dist_late > dist_early,
468            "Later timestep should add more noise: early={}, late={}",
469            dist_early,
470            dist_late
471        );
472    }
473
474    #[test]
475    fn test_correlation_structure_preserved() {
476        let means = vec![0.0, 0.0];
477        let stds = vec![1.0, 1.0];
478        // Strong positive correlation
479        let corr = vec![vec![1.0, 0.9], vec![0.9, 1.0]];
480
481        let backend = StatisticalDiffusionBackend::new(means, stds, make_config(100, 42))
482            .with_correlations(corr);
483
484        let samples = backend.generate(5000, 2, 42);
485
486        // Compute sample correlation
487        let n = samples.len() as f64;
488        let mean0: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
489        let mean1: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
490        let std0: f64 = (samples.iter().map(|r| (r[0] - mean0).powi(2)).sum::<f64>() / n).sqrt();
491        let std1: f64 = (samples.iter().map(|r| (r[1] - mean1).powi(2)).sum::<f64>() / n).sqrt();
492        let cov01: f64 = samples
493            .iter()
494            .map(|r| (r[0] - mean0) * (r[1] - mean1))
495            .sum::<f64>()
496            / n;
497
498        let sample_corr = if std0 > 1e-8 && std1 > 1e-8 {
499            cov01 / (std0 * std1)
500        } else {
501            0.0
502        };
503
504        // Correlation should be positive and reasonably close to 0.9
505        assert!(
506            sample_corr > 0.5,
507            "Expected positive correlation (target 0.9), got {}",
508            sample_corr
509        );
510    }
511
512    #[test]
513    fn test_cholesky_identity() {
514        let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
515        let l = StatisticalDiffusionBackend::cholesky_decomposition(&identity);
516        assert!(l.is_some());
517        let l = l.unwrap();
518        assert!((l[0][0] - 1.0).abs() < 1e-10);
519        assert!((l[1][1] - 1.0).abs() < 1e-10);
520        assert!(l[0][1].abs() < 1e-10);
521        assert!(l[1][0].abs() < 1e-10);
522    }
523
524    #[test]
525    fn test_cholesky_non_positive_definite() {
526        // Not positive definite
527        let matrix = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
528        let l = StatisticalDiffusionBackend::cholesky_decomposition(&matrix);
529        assert!(l.is_none());
530    }
531
532    #[test]
533    fn test_generate_empty() {
534        let backend = StatisticalDiffusionBackend::new(vec![], vec![], make_config(10, 0));
535        let samples = backend.generate(0, 0, 0);
536        assert!(samples.is_empty());
537    }
538
539    #[test]
540    fn test_values_clipped_to_range() {
541        let means = vec![0.0];
542        let stds = vec![1.0];
543        let backend = StatisticalDiffusionBackend::new(means, stds, make_config(50, 42));
544
545        let samples = backend.generate(1000, 1, 42);
546        for row in &samples {
547            assert!(
548                row[0] >= -4.0 && row[0] <= 4.0,
549                "Value {} out of clipping range [-4, 4]",
550                row[0]
551            );
552        }
553    }
554}