use super::backend::NoiseScheduleType;
#[derive(Debug, Clone)]
pub struct NoiseSchedule {
pub betas: Vec<f64>,
pub alphas: Vec<f64>,
pub alpha_bars: Vec<f64>,
pub sqrt_alpha_bars: Vec<f64>,
pub sqrt_one_minus_alpha_bars: Vec<f64>,
}
impl NoiseSchedule {
pub fn new(schedule_type: &NoiseScheduleType, n_steps: usize) -> Self {
let betas = match schedule_type {
NoiseScheduleType::Linear => Self::linear_schedule(n_steps),
NoiseScheduleType::Cosine => Self::cosine_schedule(n_steps),
NoiseScheduleType::Sigmoid => Self::sigmoid_schedule(n_steps),
};
Self::from_betas(betas)
}
pub fn from_betas(betas: Vec<f64>) -> Self {
let alphas: Vec<f64> = betas.iter().map(|b| 1.0 - b).collect();
let mut alpha_bars = Vec::with_capacity(alphas.len());
let mut cumulative = 1.0;
for &a in &alphas {
cumulative *= a;
alpha_bars.push(cumulative);
}
let sqrt_alpha_bars: Vec<f64> = alpha_bars.iter().map(|a| a.sqrt()).collect();
let sqrt_one_minus_alpha_bars: Vec<f64> =
alpha_bars.iter().map(|a| (1.0 - a).sqrt()).collect();
Self {
betas,
alphas,
alpha_bars,
sqrt_alpha_bars,
sqrt_one_minus_alpha_bars,
}
}
fn linear_schedule(n_steps: usize) -> Vec<f64> {
let beta_start = 0.0001;
let beta_end = 0.02;
(0..n_steps)
.map(|i| {
beta_start + (beta_end - beta_start) * (i as f64) / ((n_steps - 1).max(1) as f64)
})
.collect()
}
fn cosine_schedule(n_steps: usize) -> Vec<f64> {
let s = 0.008;
let mut alpha_bars = Vec::with_capacity(n_steps + 1);
for i in 0..=n_steps {
let t = i as f64 / n_steps as f64;
let val = ((t + s) / (1.0 + s) * std::f64::consts::FRAC_PI_2)
.cos()
.powi(2);
alpha_bars.push(val);
}
let mut betas = Vec::with_capacity(n_steps);
for i in 1..=n_steps {
let beta = 1.0 - alpha_bars[i] / alpha_bars[i - 1];
betas.push(beta.clamp(0.0001, 0.999));
}
betas
}
fn sigmoid_schedule(n_steps: usize) -> Vec<f64> {
let beta_start = 0.0001;
let beta_end = 0.02;
let range_start = -6.0;
let range_end = 6.0;
(0..n_steps)
.map(|i| {
let t = range_start
+ (range_end - range_start) * (i as f64) / ((n_steps - 1).max(1) as f64);
let sigmoid = 1.0 / (1.0 + (-t).exp());
beta_start + (beta_end - beta_start) * sigmoid
})
.collect()
}
pub fn n_steps(&self) -> usize {
self.betas.len()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_linear_schedule_monotonic_betas() {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 100);
for i in 1..schedule.betas.len() {
assert!(
schedule.betas[i] >= schedule.betas[i - 1],
"Linear betas should be monotonically increasing"
);
}
}
#[test]
fn test_cosine_schedule_alpha_bar_decreasing() {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Cosine, 100);
assert!(
schedule.alpha_bars[0] > 0.9,
"First alpha_bar should be near 1.0"
);
assert!(
schedule.alpha_bars.last().copied().unwrap_or(1.0) < 0.1,
"Last alpha_bar should be near 0.0"
);
for i in 1..schedule.alpha_bars.len() {
assert!(
schedule.alpha_bars[i] <= schedule.alpha_bars[i - 1],
"Alpha bars should be monotonically decreasing"
);
}
}
#[test]
fn test_sigmoid_schedule_bounded() {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Sigmoid, 100);
for &beta in &schedule.betas {
assert!(
(0.0001..=0.02).contains(&beta),
"Sigmoid betas should be within [0.0001, 0.02], got {}",
beta
);
}
}
#[test]
fn test_schedule_lengths() {
for n in [10, 100, 1000] {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, n);
assert_eq!(schedule.betas.len(), n);
assert_eq!(schedule.alphas.len(), n);
assert_eq!(schedule.alpha_bars.len(), n);
assert_eq!(schedule.sqrt_alpha_bars.len(), n);
assert_eq!(schedule.sqrt_one_minus_alpha_bars.len(), n);
}
}
#[test]
fn test_alpha_bar_product_correctness() {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 10);
let mut product = 1.0;
for i in 0..schedule.alphas.len() {
product *= schedule.alphas[i];
assert!(
(schedule.alpha_bars[i] - product).abs() < 1e-10,
"Alpha bar mismatch at step {}",
i
);
}
}
#[test]
fn test_sqrt_consistency() {
let schedule = NoiseSchedule::new(&NoiseScheduleType::Linear, 50);
for i in 0..schedule.alpha_bars.len() {
assert!((schedule.sqrt_alpha_bars[i] - schedule.alpha_bars[i].sqrt()).abs() < 1e-10);
assert!(
(schedule.sqrt_one_minus_alpha_bars[i] - (1.0 - schedule.alpha_bars[i]).sqrt())
.abs()
< 1e-10
);
}
}
}