datasynth_core/diffusion/
schedule.rs1use super::backend::NoiseScheduleType;
2
3#[derive(Debug, Clone)]
10pub struct NoiseSchedule {
11 pub betas: Vec<f64>,
12 pub alphas: Vec<f64>,
13 pub alpha_bars: Vec<f64>,
14 pub sqrt_alpha_bars: Vec<f64>,
15 pub sqrt_one_minus_alpha_bars: Vec<f64>,
16}
17
18impl NoiseSchedule {
19 pub fn new(schedule_type: &NoiseScheduleType, n_steps: usize) -> Self {
21 let betas = match schedule_type {
22 NoiseScheduleType::Linear => Self::linear_schedule(n_steps),
23 NoiseScheduleType::Cosine => Self::cosine_schedule(n_steps),
24 NoiseScheduleType::Sigmoid => Self::sigmoid_schedule(n_steps),
25 };
26 Self::from_betas(betas)
27 }
28
29 pub fn from_betas(betas: Vec<f64>) -> Self {
31 let alphas: Vec<f64> = betas.iter().map(|b| 1.0 - b).collect();
32
33 let mut alpha_bars = Vec::with_capacity(alphas.len());
34 let mut cumulative = 1.0;
35 for &a in &alphas {
36 cumulative *= a;
37 alpha_bars.push(cumulative);
38 }
39
40 let sqrt_alpha_bars: Vec<f64> = alpha_bars.iter().map(|a| a.sqrt()).collect();
41 let sqrt_one_minus_alpha_bars: Vec<f64> =
42 alpha_bars.iter().map(|a| (1.0 - a).sqrt()).collect();
43
44 Self {
45 betas,
46 alphas,
47 alpha_bars,
48 sqrt_alpha_bars,
49 sqrt_one_minus_alpha_bars,
50 }
51 }
52
53 fn linear_schedule(n_steps: usize) -> Vec<f64> {
55 let beta_start = 0.0001;
56 let beta_end = 0.02;
57 (0..n_steps)
58 .map(|i| {
59 beta_start + (beta_end - beta_start) * (i as f64) / ((n_steps - 1).max(1) as f64)
60 })
61 .collect()
62 }
63
64 fn cosine_schedule(n_steps: usize) -> Vec<f64> {
66 let s = 0.008;
67 let mut alpha_bars = Vec::with_capacity(n_steps + 1);
68 for i in 0..=n_steps {
69 let t = i as f64 / n_steps as f64;
70 let val = ((t + s) / (1.0 + s) * std::f64::consts::FRAC_PI_2)
71 .cos()
72 .powi(2);
73 alpha_bars.push(val);
74 }
75
76 let mut betas = Vec::with_capacity(n_steps);
77 for i in 1..=n_steps {
78 let beta = 1.0 - alpha_bars[i] / alpha_bars[i - 1];
79 betas.push(beta.clamp(0.0001, 0.999));
80 }
81 betas
82 }
83
84 fn sigmoid_schedule(n_steps: usize) -> Vec<f64> {
86 let beta_start = 0.0001;
87 let beta_end = 0.02;
88 let range_start = -6.0;
89 let range_end = 6.0;
90
91 (0..n_steps)
92 .map(|i| {
93 let t = range_start
94 + (range_end - range_start) * (i as f64) / ((n_steps - 1).max(1) as f64);
95 let sigmoid = 1.0 / (1.0 + (-t).exp());
96 beta_start + (beta_end - beta_start) * sigmoid
97 })
98 .collect()
99 }
100
101 pub fn n_steps(&self) -> usize {
103 self.betas.len()
104 }
105}
106
107#[cfg(test)]
108#[allow(clippy::unwrap_used)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_linear_schedule_monotonic_betas() {
114 let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 100);
115 for i in 1..schedule.betas.len() {
116 assert!(
117 schedule.betas[i] >= schedule.betas[i - 1],
118 "Linear betas should be monotonically increasing"
119 );
120 }
121 }
122
123 #[test]
124 fn test_cosine_schedule_alpha_bar_decreasing() {
125 let schedule = NoiseSchedule::new(&NoiseScheduleType::Cosine, 100);
126 assert!(
128 schedule.alpha_bars[0] > 0.9,
129 "First alpha_bar should be near 1.0"
130 );
131 assert!(
132 schedule.alpha_bars.last().copied().unwrap_or(1.0) < 0.1,
133 "Last alpha_bar should be near 0.0"
134 );
135 for i in 1..schedule.alpha_bars.len() {
136 assert!(
137 schedule.alpha_bars[i] <= schedule.alpha_bars[i - 1],
138 "Alpha bars should be monotonically decreasing"
139 );
140 }
141 }
142
143 #[test]
144 fn test_sigmoid_schedule_bounded() {
145 let schedule = NoiseSchedule::new(&NoiseScheduleType::Sigmoid, 100);
146 for &beta in &schedule.betas {
147 assert!(
148 (0.0001..=0.02).contains(&beta),
149 "Sigmoid betas should be within [0.0001, 0.02], got {}",
150 beta
151 );
152 }
153 }
154
155 #[test]
156 fn test_schedule_lengths() {
157 for n in [10, 100, 1000] {
158 let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, n);
159 assert_eq!(schedule.betas.len(), n);
160 assert_eq!(schedule.alphas.len(), n);
161 assert_eq!(schedule.alpha_bars.len(), n);
162 assert_eq!(schedule.sqrt_alpha_bars.len(), n);
163 assert_eq!(schedule.sqrt_one_minus_alpha_bars.len(), n);
164 }
165 }
166
167 #[test]
168 fn test_alpha_bar_product_correctness() {
169 let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 10);
170 let mut product = 1.0;
172 for i in 0..schedule.alphas.len() {
173 product *= schedule.alphas[i];
174 assert!(
175 (schedule.alpha_bars[i] - product).abs() < 1e-10,
176 "Alpha bar mismatch at step {}",
177 i
178 );
179 }
180 }
181
182 #[test]
183 fn test_sqrt_consistency() {
184 let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 50);
185 for i in 0..schedule.alpha_bars.len() {
186 assert!((schedule.sqrt_alpha_bars[i] - schedule.alpha_bars[i].sqrt()).abs() < 1e-10);
187 assert!(
188 (schedule.sqrt_one_minus_alpha_bars[i] - (1.0 - schedule.alpha_bars[i]).sqrt())
189 .abs()
190 < 1e-10
191 );
192 }
193 }
194}