Skip to main content

oxigaf_diffusion/
scheduler.rs

1//! DDIM scheduler with v-prediction parameterisation.
2//!
3//! Implements the deterministic DDIM sampling loop used by Stable Diffusion 2.1
4//! and the GAF multi-view diffusion model.
5
6use candle_core::{DType, Device, Result, Tensor};
7
8/// DDIM scheduler state.
9#[derive(Debug)]
10pub struct DdimScheduler {
11    /// Cumulative product of (1 - beta_t).
12    alphas_cumprod: Vec<f64>,
13    /// Total training timesteps.
14    num_train_timesteps: usize,
15    /// Inference timesteps (reversed, evenly spaced).
16    timesteps: Vec<usize>,
17    /// Whether the model predicts v (velocity) rather than noise.
18    prediction_type: PredictionType,
19}
20
21/// What the model predicts.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum PredictionType {
24    /// Model predicts the noise ε.
25    Epsilon,
26    /// Model predicts v = α_t · ε − σ_t · x_0.
27    VPrediction,
28}
29
30impl DdimScheduler {
31    /// Create a new DDIM scheduler.
32    ///
33    /// Uses a "scaled linear" beta schedule matching SD 2.1 defaults:
34    /// `beta_start=0.00085`, `beta_end=0.012`, 1000 training steps.
35    pub fn new(num_train_timesteps: usize, prediction_type: PredictionType) -> Self {
36        let beta_start: f64 = 0.00085_f64.sqrt();
37        let beta_end: f64 = 0.012_f64.sqrt();
38
39        let mut alphas_cumprod = Vec::with_capacity(num_train_timesteps);
40        let mut cumprod = 1.0_f64;
41        for i in 0..num_train_timesteps {
42            let beta = beta_start
43                + (beta_end - beta_start) * (i as f64) / ((num_train_timesteps - 1) as f64);
44            let beta = beta * beta; // scaled-linear
45            let alpha = 1.0 - beta;
46            cumprod *= alpha;
47            alphas_cumprod.push(cumprod);
48        }
49
50        Self {
51            alphas_cumprod,
52            num_train_timesteps,
53            timesteps: Vec::new(),
54            prediction_type,
55        }
56    }
57
58    /// Configure evenly-spaced timesteps for a given number of inference steps.
59    pub fn set_timesteps(&mut self, num_inference_steps: usize) {
60        let step = self.num_train_timesteps / num_inference_steps;
61        self.timesteps = (0..num_inference_steps).rev().map(|i| i * step).collect();
62    }
63
64    /// Return the current list of timesteps (descending).
65    pub fn timesteps(&self) -> &[usize] {
66        &self.timesteps
67    }
68
69    /// Perform one DDIM step (deterministic, η=0).
70    ///
71    /// - `model_output`: the raw network prediction at timestep `t`.
72    /// - `t`: current timestep index.
73    /// - `sample`: the current noisy latent x_t.
74    ///
75    /// Returns the denoised latent x_{t-1}.
76    pub fn step(&self, model_output: &Tensor, t: usize, sample: &Tensor) -> Result<Tensor> {
77        let alpha_prod_t = self.alphas_cumprod[t];
78        let alpha_prod_t_prev = if t > 0 {
79            // find previous timestep
80            let step = self.num_train_timesteps / self.timesteps.len();
81            if t >= step {
82                self.alphas_cumprod[t - step]
83            } else {
84                1.0
85            }
86        } else {
87            1.0
88        };
89
90        let sqrt_alpha_prod = alpha_prod_t.sqrt();
91        let sqrt_one_minus_alpha_prod = (1.0 - alpha_prod_t).sqrt();
92
93        // Recover x_0 prediction depending on parameterisation
94        let pred_x0 = match self.prediction_type {
95            PredictionType::Epsilon => {
96                // x_0 = (x_t - sqrt(1-α) * ε) / sqrt(α)
97                ((sample - (model_output * sqrt_one_minus_alpha_prod)?)? * (1.0 / sqrt_alpha_prod))?
98            }
99            PredictionType::VPrediction => {
100                // x_0 = sqrt(α) * x_t - sqrt(1-α) * v
101                ((sample * sqrt_alpha_prod)? - (model_output * sqrt_one_minus_alpha_prod)?)?
102            }
103        };
104
105        // Predict noise direction
106        let pred_epsilon = match self.prediction_type {
107            PredictionType::Epsilon => model_output.clone(),
108            PredictionType::VPrediction => {
109                ((model_output * sqrt_alpha_prod)? + (sample * sqrt_one_minus_alpha_prod)?)?
110            }
111        };
112
113        // DDIM deterministic step (η = 0)
114        let sqrt_alpha_prod_prev = alpha_prod_t_prev.sqrt();
115        let sqrt_one_minus_alpha_prod_prev = (1.0 - alpha_prod_t_prev).sqrt();
116
117        // x_{t-1} = sqrt(α_{t-1}) · x_0 + sqrt(1-α_{t-1}) · ε
118        (&pred_x0 * sqrt_alpha_prod_prev)? + (&pred_epsilon * sqrt_one_minus_alpha_prod_prev)?
119    }
120
121    /// Add noise to latents for a given timestep (forward diffusion process).
122    ///
123    /// x_t = sqrt(α_t) · x_0 + sqrt(1-α_t) · noise
124    pub fn add_noise(&self, original: &Tensor, noise: &Tensor, timestep: usize) -> Result<Tensor> {
125        let alpha = self.alphas_cumprod[timestep];
126        let sqrt_alpha = alpha.sqrt();
127        let sqrt_one_minus_alpha = (1.0 - alpha).sqrt();
128        (original * sqrt_alpha)? + (noise * sqrt_one_minus_alpha)?
129    }
130
131    /// Create a tensor of timestep values on the given device.
132    pub fn timestep_tensor(&self, t: usize, batch_size: usize, device: &Device) -> Result<Tensor> {
133        Tensor::full(t as f32, (batch_size,), device)?.to_dtype(DType::F32)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_alphas_cumprod_decreasing() {
143        let sched = DdimScheduler::new(1000, PredictionType::VPrediction);
144        assert!(sched.alphas_cumprod[0] > sched.alphas_cumprod[999]);
145        // First alpha should be close to 1
146        assert!(sched.alphas_cumprod[0] > 0.99);
147        // Last alpha should be small
148        assert!(sched.alphas_cumprod[999] < 0.01);
149    }
150
151    #[test]
152    fn test_set_timesteps() {
153        let mut sched = DdimScheduler::new(1000, PredictionType::Epsilon);
154        sched.set_timesteps(50);
155        assert_eq!(sched.timesteps().len(), 50);
156        // Should be descending
157        assert!(sched.timesteps()[0] > sched.timesteps()[49]);
158    }
159}