Skip to main content

datasynth_core/diffusion/
utils.rs

1use rand::SeedableRng;
2use rand_chacha::ChaCha8Rng;
3use rand_distr::{Distribution, Normal};
4
5/// Add Gaussian noise to each element of x with given variance.
6pub fn add_gaussian_noise(x: &[f64], variance: f64, rng: &mut ChaCha8Rng) -> Vec<f64> {
7    let std_dev = variance.sqrt();
8    if let Ok(normal) = Normal::new(0.0, std_dev) {
9        x.iter().map(|&v| v + normal.sample(rng)).collect()
10    } else {
11        x.to_vec()
12    }
13}
14
15/// Normalize features to zero mean and unit variance.
16/// Returns (normalized_data, means, stds).
17pub fn normalize_features(data: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>) {
18    if data.is_empty() {
19        return (vec![], vec![], vec![]);
20    }
21
22    let n_features = data[0].len();
23    let n_samples = data.len() as f64;
24
25    // Compute means
26    let mut means = vec![0.0; n_features];
27    for row in data {
28        for (j, &val) in row.iter().enumerate() {
29            if j < n_features {
30                means[j] += val;
31            }
32        }
33    }
34    for m in &mut means {
35        *m /= n_samples;
36    }
37
38    // Compute standard deviations
39    let mut stds = vec![0.0; n_features];
40    for row in data {
41        for (j, &val) in row.iter().enumerate() {
42            if j < n_features {
43                stds[j] += (val - means[j]).powi(2);
44            }
45        }
46    }
47    for s in &mut stds {
48        *s = (*s / n_samples).sqrt().max(1e-8); // Avoid division by zero
49    }
50
51    // Normalize
52    let normalized: Vec<Vec<f64>> = data
53        .iter()
54        .map(|row| {
55            row.iter()
56                .enumerate()
57                .map(|(j, &val)| {
58                    if j < n_features {
59                        (val - means[j]) / stds[j]
60                    } else {
61                        val
62                    }
63                })
64                .collect()
65        })
66        .collect();
67
68    (normalized, means, stds)
69}
70
71/// Denormalize features back to original scale.
72pub fn denormalize_features(data: &[Vec<f64>], means: &[f64], stds: &[f64]) -> Vec<Vec<f64>> {
73    data.iter()
74        .map(|row| {
75            row.iter()
76                .enumerate()
77                .map(|(j, &val)| {
78                    if j < means.len() && j < stds.len() {
79                        val * stds[j] + means[j]
80                    } else {
81                        val
82                    }
83                })
84                .collect()
85        })
86        .collect()
87}
88
89/// Clip values to [min, max] range.
90pub fn clip_values(data: &mut [Vec<f64>], min: f64, max: f64) {
91    for row in data.iter_mut() {
92        for val in row.iter_mut() {
93            *val = val.clamp(min, max);
94        }
95    }
96}
97
98/// Generate standard normal noise matrix.
99pub fn generate_noise(n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>> {
100    let mut rng = ChaCha8Rng::seed_from_u64(seed);
101    if let Ok(normal) = Normal::new(0.0, 1.0) {
102        (0..n_samples)
103            .map(|_| (0..n_features).map(|_| normal.sample(&mut rng)).collect())
104            .collect()
105    } else {
106        vec![vec![0.0; n_features]; n_samples]
107    }
108}
109
110#[cfg(test)]
111#[allow(clippy::unwrap_used)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_add_gaussian_noise() {
117        let mut rng = ChaCha8Rng::seed_from_u64(42);
118        let x = vec![1.0, 2.0, 3.0];
119        let noised = add_gaussian_noise(&x, 0.01, &mut rng);
120        assert_eq!(noised.len(), 3);
121        // Noise should be small
122        for (orig, noised) in x.iter().zip(noised.iter()) {
123            assert!((orig - noised).abs() < 1.0);
124        }
125    }
126
127    #[test]
128    fn test_normalize_denormalize_roundtrip() {
129        let data = vec![vec![10.0, 20.0], vec![12.0, 22.0], vec![14.0, 24.0]];
130        let (normalized, means, stds) = normalize_features(&data);
131        let recovered = denormalize_features(&normalized, &means, &stds);
132
133        for (orig, rec) in data.iter().zip(recovered.iter()) {
134            for (o, r) in orig.iter().zip(rec.iter()) {
135                assert!((o - r).abs() < 1e-10, "Roundtrip failed: {} vs {}", o, r);
136            }
137        }
138    }
139
140    #[test]
141    fn test_normalize_zero_mean() {
142        let data = vec![vec![10.0, 20.0], vec![20.0, 40.0]];
143        let (normalized, _, _) = normalize_features(&data);
144        let mean: f64 = normalized.iter().map(|r| r[0]).sum::<f64>() / normalized.len() as f64;
145        assert!(
146            mean.abs() < 1e-10,
147            "Normalized mean should be ~0, got {}",
148            mean
149        );
150    }
151
152    #[test]
153    fn test_clip_values() {
154        let mut data = vec![vec![-5.0, 10.0, 0.5]];
155        clip_values(&mut data, 0.0, 1.0);
156        assert_eq!(data[0], vec![0.0, 1.0, 0.5]);
157    }
158
159    #[test]
160    fn test_generate_noise_shape() {
161        let noise = generate_noise(100, 5, 42);
162        assert_eq!(noise.len(), 100);
163        assert_eq!(noise[0].len(), 5);
164    }
165
166    #[test]
167    fn test_normalize_empty() {
168        let (data, means, stds) = normalize_features(&[]);
169        assert!(data.is_empty());
170        assert!(means.is_empty());
171        assert!(stds.is_empty());
172    }
173}