use anyhow::Result;
use candle_core::Tensor;
use candle_transformers::models::stable_diffusion::ddim::DDIMSchedulerConfig;
use candle_transformers::models::stable_diffusion::schedulers::{BetaSchedule, TimestepSpacing};
#[cfg(test)]
const DEFAULT_TRAIN_TIMESTEPS: usize = 1000;
pub(crate) struct DdimAlphaSchedule {
alphas_cumprod: Vec<f64>,
step_ratio: usize,
}
impl DdimAlphaSchedule {
pub(crate) fn from_default(inference_steps: usize) -> Self {
let cfg = DDIMSchedulerConfig::default();
Self::from_config(inference_steps, &cfg)
}
pub(crate) fn from_config(inference_steps: usize, cfg: &DDIMSchedulerConfig) -> Self {
let train_timesteps = cfg.train_timesteps;
let step_ratio = train_timesteps / inference_steps.max(1);
let _ = TimestepSpacing::Leading; let betas: Vec<f64> = match cfg.beta_schedule {
BetaSchedule::ScaledLinear => {
linspace(cfg.beta_start.sqrt(), cfg.beta_end.sqrt(), train_timesteps)
.into_iter()
.map(|x| x * x)
.collect()
}
BetaSchedule::Linear => linspace(cfg.beta_start, cfg.beta_end, train_timesteps),
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(train_timesteps, 0.999),
};
let mut alphas_cumprod = Vec::with_capacity(betas.len());
for &beta in &betas {
let alpha = 1.0 - beta;
let last = *alphas_cumprod.last().unwrap_or(&1.0);
alphas_cumprod.push(alpha * last);
}
Self {
alphas_cumprod,
step_ratio,
}
}
pub(crate) fn alphas_for_step(&self, timestep: usize) -> (f64, f64) {
let t = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
let prev_t = t.saturating_sub(self.step_ratio);
(self.alphas_cumprod[t], self.alphas_cumprod[prev_t])
}
pub(crate) fn cfg_plus_step(
&self,
x_t: &Tensor,
eps_guided: &Tensor,
eps_uncond: &Tensor,
timestep: usize,
) -> Result<Tensor> {
let (alpha_t, alpha_t_prev) = self.alphas_for_step(timestep);
let beta_t = 1.0 - alpha_t;
let beta_t_prev = 1.0 - alpha_t_prev;
let x0 = ((x_t - (eps_guided * beta_t.sqrt())?)? * (1.0 / alpha_t.sqrt()))?;
let prev = ((x0 * alpha_t_prev.sqrt())? + (eps_uncond * beta_t_prev.sqrt())?)?;
Ok(prev)
}
}
fn linspace(start: f64, end: f64, n: usize) -> Vec<f64> {
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![start];
}
let step = (end - start) / (n - 1) as f64;
(0..n).map(|i| start + step * i as f64).collect()
}
fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Vec<f64> {
let alpha_bar =
|t: usize| f64::cos((t as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2);
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
for i in 0..num_diffusion_timesteps {
let t1 = i / num_diffusion_timesteps;
let t2 = (i + 1) / num_diffusion_timesteps;
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
}
betas
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn upstream_ddim_defaults_match_baked_in_constants() {
let cfg = DDIMSchedulerConfig::default();
assert_eq!(cfg.train_timesteps, DEFAULT_TRAIN_TIMESTEPS);
assert_eq!(cfg.beta_start, 0.00085);
assert_eq!(cfg.beta_end, 0.012);
assert!(matches!(cfg.beta_schedule, BetaSchedule::ScaledLinear));
assert_eq!(cfg.eta, 0.0);
}
#[test]
fn alphas_cumprod_monotone_decreasing() {
let sched = DdimAlphaSchedule::from_default(50);
assert_eq!(sched.alphas_cumprod.len(), DEFAULT_TRAIN_TIMESTEPS);
let mut prev = sched.alphas_cumprod[0];
assert!(
prev < 1.0 && prev > 0.999,
"alphas[0] should be ~ 1 - beta_start, got {prev}"
);
for &a in &sched.alphas_cumprod[1..] {
assert!(a < prev, "alphas_cumprod must be strictly decreasing");
assert!(a > 0.0, "alphas_cumprod must stay positive");
prev = a;
}
assert!(
prev < 0.01,
"alphas_cumprod[final] should be ≈ 0, got {prev}"
);
}
#[test]
fn step_ratio_floor_div() {
assert_eq!(DdimAlphaSchedule::from_default(50).step_ratio, 20);
assert_eq!(DdimAlphaSchedule::from_default(28).step_ratio, 35); assert_eq!(DdimAlphaSchedule::from_default(1).step_ratio, 1000);
}
#[test]
fn alphas_for_step_clamps_at_boundaries() {
let sched = DdimAlphaSchedule::from_default(50);
let (a, ap) = sched.alphas_for_step(999);
assert_eq!(a, sched.alphas_cumprod[999]);
assert_eq!(ap, sched.alphas_cumprod[979]);
let (a, ap) = sched.alphas_for_step(19);
assert_eq!(a, sched.alphas_cumprod[19]);
assert_eq!(ap, sched.alphas_cumprod[0]);
let (a, _) = sched.alphas_for_step(DEFAULT_TRAIN_TIMESTEPS);
assert_eq!(a, sched.alphas_cumprod[DEFAULT_TRAIN_TIMESTEPS - 1]);
}
#[test]
fn cfg_plus_step_collapses_to_standard_when_eps_uncond_eq_eps_guided() {
let sched = DdimAlphaSchedule::from_default(50);
let dev = Device::Cpu;
let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], (1, 4), &dev).unwrap();
let eps = Tensor::from_slice(&[0.5f32, -0.3, 0.1, 0.7], (1, 4), &dev).unwrap();
let timestep = 999;
let cfg_plus = sched.cfg_plus_step(&x, &eps, &eps, timestep).unwrap();
let (alpha_t, alpha_t_prev) = sched.alphas_for_step(timestep);
let beta_t = 1.0 - alpha_t;
let beta_t_prev = 1.0 - alpha_t_prev;
let x0 =
((&x - (&eps * beta_t.sqrt()).unwrap()).unwrap() * (1.0 / alpha_t.sqrt())).unwrap();
let standard =
((x0 * alpha_t_prev.sqrt()).unwrap() + (&eps * beta_t_prev.sqrt()).unwrap()).unwrap();
let cfg_vec: Vec<f32> = cfg_plus.flatten_all().unwrap().to_vec1().unwrap();
let std_vec: Vec<f32> = standard.flatten_all().unwrap().to_vec1().unwrap();
for (a, b) in cfg_vec.iter().zip(std_vec.iter()) {
assert!(
(a - b).abs() < 1e-5,
"cfg++ ≠ standard at degenerate eps_uncond=eps_guided"
);
}
}
#[test]
fn cfg_plus_step_diverges_from_standard_under_high_cfg() {
let sched = DdimAlphaSchedule::from_default(28);
let dev = Device::Cpu;
let x = Tensor::from_slice(&[0.5f32; 8], (1, 8), &dev).unwrap();
let eps_uncond = Tensor::from_slice(&[0.1f32; 8], (1, 8), &dev).unwrap();
let eps_cond = Tensor::from_slice(&[0.4f32; 8], (1, 8), &dev).unwrap();
let s = 7.5_f64;
let eps_guided =
(&eps_uncond + (((&eps_cond - &eps_uncond).unwrap() * s).unwrap())).unwrap();
let timestep = 999;
let cfg_plus = sched
.cfg_plus_step(&x, &eps_guided, &eps_uncond, timestep)
.unwrap();
let standard = sched
.cfg_plus_step(&x, &eps_guided, &eps_guided, timestep)
.unwrap();
let cfg_vec: Vec<f32> = cfg_plus.flatten_all().unwrap().to_vec1().unwrap();
let std_vec: Vec<f32> = standard.flatten_all().unwrap().to_vec1().unwrap();
let max_diff = cfg_vec
.iter()
.zip(std_vec.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff > 0.05,
"cfg++ must diverge from standard at cfg=7.5, max_diff={max_diff}"
);
}
#[test]
fn cfg_plus_step_finite_at_boundary_timesteps() {
let sched = DdimAlphaSchedule::from_default(50);
let dev = Device::Cpu;
let x = Tensor::from_slice(&[0.5f32; 4], (1, 4), &dev).unwrap();
let eps_g = Tensor::from_slice(&[0.3f32; 4], (1, 4), &dev).unwrap();
let eps_u = Tensor::from_slice(&[0.1f32; 4], (1, 4), &dev).unwrap();
for &ts in &[999_usize, 19] {
let out = sched.cfg_plus_step(&x, &eps_g, &eps_u, ts).unwrap();
let v: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
for x in &v {
assert!(
x.is_finite(),
"cfg_plus_step produced non-finite output at timestep {ts}"
);
}
}
}
}