Skip to main content

datasynth_core/diffusion/
schedule.rs

1use super::backend::NoiseScheduleType;
2
3/// Precomputed noise schedule parameters.
4///
5/// For T timesteps, provides:
6/// - `betas`: noise level at each step
7/// - `alphas`: 1 - beta_t
8/// - `alpha_bars`: cumulative product of alphas (signal retention)
9#[derive(Debug, Clone)]
10pub struct NoiseSchedule {
11    pub betas: Vec<f64>,
12    pub alphas: Vec<f64>,
13    pub alpha_bars: Vec<f64>,
14    pub sqrt_alpha_bars: Vec<f64>,
15    pub sqrt_one_minus_alpha_bars: Vec<f64>,
16}
17
18impl NoiseSchedule {
19    /// Create a noise schedule of the given type and length.
20    pub fn new(schedule_type: &NoiseScheduleType, n_steps: usize) -> Self {
21        let betas = match schedule_type {
22            NoiseScheduleType::Linear => Self::linear_schedule(n_steps),
23            NoiseScheduleType::Cosine => Self::cosine_schedule(n_steps),
24            NoiseScheduleType::Sigmoid => Self::sigmoid_schedule(n_steps),
25        };
26        Self::from_betas(betas)
27    }
28
29    /// Build schedule from a vector of betas.
30    pub fn from_betas(betas: Vec<f64>) -> Self {
31        let alphas: Vec<f64> = betas.iter().map(|b| 1.0 - b).collect();
32
33        let mut alpha_bars = Vec::with_capacity(alphas.len());
34        let mut cumulative = 1.0;
35        for &a in &alphas {
36            cumulative *= a;
37            alpha_bars.push(cumulative);
38        }
39
40        let sqrt_alpha_bars: Vec<f64> = alpha_bars.iter().map(|a| a.sqrt()).collect();
41        let sqrt_one_minus_alpha_bars: Vec<f64> =
42            alpha_bars.iter().map(|a| (1.0 - a).sqrt()).collect();
43
44        Self {
45            betas,
46            alphas,
47            alpha_bars,
48            sqrt_alpha_bars,
49            sqrt_one_minus_alpha_bars,
50        }
51    }
52
53    /// Linear noise schedule: beta linearly interpolated from beta_start to beta_end.
54    fn linear_schedule(n_steps: usize) -> Vec<f64> {
55        let beta_start = 0.0001;
56        let beta_end = 0.02;
57        (0..n_steps)
58            .map(|i| {
59                beta_start + (beta_end - beta_start) * (i as f64) / ((n_steps - 1).max(1) as f64)
60            })
61            .collect()
62    }
63
64    /// Cosine noise schedule: alpha_bar_t = cos^2((t/T + s) / (1+s) * pi/2).
65    fn cosine_schedule(n_steps: usize) -> Vec<f64> {
66        let s = 0.008;
67        let mut alpha_bars = Vec::with_capacity(n_steps + 1);
68        for i in 0..=n_steps {
69            let t = i as f64 / n_steps as f64;
70            let val = ((t + s) / (1.0 + s) * std::f64::consts::FRAC_PI_2)
71                .cos()
72                .powi(2);
73            alpha_bars.push(val);
74        }
75
76        let mut betas = Vec::with_capacity(n_steps);
77        for i in 1..=n_steps {
78            let beta = 1.0 - alpha_bars[i] / alpha_bars[i - 1];
79            betas.push(beta.clamp(0.0001, 0.999));
80        }
81        betas
82    }
83
84    /// Sigmoid noise schedule: beta interpolated via sigmoid curve.
85    fn sigmoid_schedule(n_steps: usize) -> Vec<f64> {
86        let beta_start = 0.0001;
87        let beta_end = 0.02;
88        let range_start = -6.0;
89        let range_end = 6.0;
90
91        (0..n_steps)
92            .map(|i| {
93                let t = range_start
94                    + (range_end - range_start) * (i as f64) / ((n_steps - 1).max(1) as f64);
95                let sigmoid = 1.0 / (1.0 + (-t).exp());
96                beta_start + (beta_end - beta_start) * sigmoid
97            })
98            .collect()
99    }
100
101    /// Number of timesteps.
102    pub fn n_steps(&self) -> usize {
103        self.betas.len()
104    }
105}
106
107#[cfg(test)]
108#[allow(clippy::unwrap_used)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_linear_schedule_monotonic_betas() {
114        let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 100);
115        for i in 1..schedule.betas.len() {
116            assert!(
117                schedule.betas[i] >= schedule.betas[i - 1],
118                "Linear betas should be monotonically increasing"
119            );
120        }
121    }
122
123    #[test]
124    fn test_cosine_schedule_alpha_bar_decreasing() {
125        let schedule = NoiseSchedule::new(&NoiseScheduleType::Cosine, 100);
126        // Alpha bars should decrease (signal retention decreases over time)
127        assert!(
128            schedule.alpha_bars[0] > 0.9,
129            "First alpha_bar should be near 1.0"
130        );
131        assert!(
132            schedule.alpha_bars.last().copied().unwrap_or(1.0) < 0.1,
133            "Last alpha_bar should be near 0.0"
134        );
135        for i in 1..schedule.alpha_bars.len() {
136            assert!(
137                schedule.alpha_bars[i] <= schedule.alpha_bars[i - 1],
138                "Alpha bars should be monotonically decreasing"
139            );
140        }
141    }
142
143    #[test]
144    fn test_sigmoid_schedule_bounded() {
145        let schedule = NoiseSchedule::new(&NoiseScheduleType::Sigmoid, 100);
146        for &beta in &schedule.betas {
147            assert!(
148                (0.0001..=0.02).contains(&beta),
149                "Sigmoid betas should be within [0.0001, 0.02], got {}",
150                beta
151            );
152        }
153    }
154
155    #[test]
156    fn test_schedule_lengths() {
157        for n in [10, 100, 1000] {
158            let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, n);
159            assert_eq!(schedule.betas.len(), n);
160            assert_eq!(schedule.alphas.len(), n);
161            assert_eq!(schedule.alpha_bars.len(), n);
162            assert_eq!(schedule.sqrt_alpha_bars.len(), n);
163            assert_eq!(schedule.sqrt_one_minus_alpha_bars.len(), n);
164        }
165    }
166
167    #[test]
168    fn test_alpha_bar_product_correctness() {
169        let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 10);
170        // Verify alpha_bar[i] = product of alphas[0..=i]
171        let mut product = 1.0;
172        for i in 0..schedule.alphas.len() {
173            product *= schedule.alphas[i];
174            assert!(
175                (schedule.alpha_bars[i] - product).abs() < 1e-10,
176                "Alpha bar mismatch at step {}",
177                i
178            );
179        }
180    }
181
182    #[test]
183    fn test_sqrt_consistency() {
184        let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 50);
185        for i in 0..schedule.alpha_bars.len() {
186            assert!((schedule.sqrt_alpha_bars[i] - schedule.alpha_bars[i].sqrt()).abs() < 1e-10);
187            assert!(
188                (schedule.sqrt_one_minus_alpha_bars[i] - (1.0 - schedule.alpha_bars[i]).sqrt())
189                    .abs()
190                    < 1e-10
191            );
192        }
193    }
194}