Skip to main content

ferrotorch_diffusion/
scheduler.rs

1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//!                                            ^^^^^^^^^^^^^^^
10//!                                            this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19
20use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22
23/// Beta-schedule recipe (subset matching SD-1.5).
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum BetaSchedule {
26    /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
27    ScaledLinear,
28    /// `betas = linspace(beta_start, beta_end, N)`.
29    Linear,
30}
31
32/// Discrete timestep spacing (subset matching SD-1.5).
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TimestepSpacing {
35    /// `step_ratio = num_train_timesteps // num_inference_steps;
36    /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
37    /// SD default.
38    Leading,
39    /// `step_ratio = num_train_timesteps / num_inference_steps;
40    /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
41    Linspace,
42}
43
44/// `DDIMScheduler` configuration (the subset that affects forward math).
45#[derive(Debug, Clone)]
46pub struct DDIMConfig {
47    /// `num_train_timesteps` — SD-1.5: 1000.
48    pub num_train_timesteps: usize,
49    /// `beta_start` — SD-1.5: 0.00085.
50    pub beta_start: f64,
51    /// `beta_end` — SD-1.5: 0.012.
52    pub beta_end: f64,
53    /// `beta_schedule` — SD-1.5: ScaledLinear.
54    pub beta_schedule: BetaSchedule,
55    /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
56    pub clip_sample: bool,
57    /// `set_alpha_to_one` — SD-1.5: false. When false, the
58    /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
59    pub set_alpha_to_one: bool,
60    /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
61    /// Only `"epsilon"` is implemented; anything else returns an error
62    /// at `set_timesteps` time.
63    pub prediction_type: PredictionType,
64    /// `timestep_spacing` — SD-1.5: Leading.
65    pub timestep_spacing: TimestepSpacing,
66    /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
67    /// `leading` path so the first inference step is `step_ratio` rather
68    /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
69    pub steps_offset: usize,
70}
71
72/// Prediction parameterisation (subset matching SD-1.5).
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum PredictionType {
75    /// UNet predicts the noise ε. SD default.
76    Epsilon,
77}
78
79impl Default for DDIMConfig {
80    fn default() -> Self {
81        Self {
82            num_train_timesteps: 1000,
83            beta_start: 0.000_85,
84            beta_end: 0.012,
85            beta_schedule: BetaSchedule::ScaledLinear,
86            clip_sample: false,
87            set_alpha_to_one: false,
88            prediction_type: PredictionType::Epsilon,
89            timestep_spacing: TimestepSpacing::Leading,
90            steps_offset: 1,
91        }
92    }
93}
94
95impl DDIMConfig {
96    /// SD-1.5 defaults (alias for [`Default::default`]).
97    pub fn sd_v1_5() -> Self {
98        Self::default()
99    }
100}
101
102/// Deterministic DDIM scheduler (η=0, no noise injection).
103///
104/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
105/// training-time grid (`num_train_timesteps` entries). Inference picks
106/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
107/// them in reverse with [`DDIMScheduler::step`].
108#[derive(Debug, Clone)]
109pub struct DDIMScheduler {
110    config: DDIMConfig,
111    /// `alphas_cumprod` of length `num_train_timesteps`.
112    alphas_cumprod: Vec<f64>,
113    /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
114    /// (the very last denoising step). When `set_alpha_to_one=false` this
115    /// is `alphas_cumprod[0]`; when true it's 1.0.
116    final_alpha_cumprod: f64,
117    /// Timesteps the user requested via [`Self::set_timesteps`], in the
118    /// order they will be consumed (descending). Empty until
119    /// `set_timesteps` is called.
120    timesteps: Vec<usize>,
121}
122
123impl DDIMScheduler {
124    /// Build a scheduler with the given config; runs the one-shot
125    /// `betas → alphas → alphas_cumprod` precomputation.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`FerrotorchError::InvalidArgument`] for malformed
130    /// configurations (zero training steps, non-positive beta).
131    pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
132        if config.num_train_timesteps == 0 {
133            return Err(FerrotorchError::InvalidArgument {
134                message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
135            });
136        }
137        if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
138            return Err(FerrotorchError::InvalidArgument {
139                message: format!(
140                    "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
141                    config.beta_start, config.beta_end
142                ),
143            });
144        }
145        let n = config.num_train_timesteps;
146        let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
147        // alphas[i] = 1 - betas[i].
148        // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
149        let mut alphas_cumprod = Vec::with_capacity(n);
150        let mut acc = 1.0_f64;
151        for &b in &betas {
152            let a = 1.0 - b;
153            acc *= a;
154            alphas_cumprod.push(acc);
155        }
156        let final_alpha_cumprod = if config.set_alpha_to_one {
157            1.0
158        } else {
159            alphas_cumprod[0]
160        };
161        Ok(Self {
162            config,
163            alphas_cumprod,
164            final_alpha_cumprod,
165            timesteps: Vec::new(),
166        })
167    }
168
169    /// Read-only access to the frozen configuration.
170    pub fn config(&self) -> &DDIMConfig {
171        &self.config
172    }
173
174    /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
175    /// noise tensor before the first denoising step. DDIM with the SD-1.5
176    /// defaults uses 1.0 (no scaling).
177    pub fn init_noise_sigma(&self) -> f64 {
178        1.0
179    }
180
181    /// Set the inference-time discrete timesteps and return them.
182    ///
183    /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
184    /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
185    ///
186    /// # Errors
187    ///
188    /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
189    /// is zero or exceeds `num_train_timesteps`, or if the configured
190    /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
191    pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
192        if num_inference_steps == 0 {
193            return Err(FerrotorchError::InvalidArgument {
194                message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
195            });
196        }
197        if num_inference_steps > self.config.num_train_timesteps {
198            return Err(FerrotorchError::InvalidArgument {
199                message: format!(
200                    "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
201                     must be <= num_train_timesteps {}",
202                    self.config.num_train_timesteps
203                ),
204            });
205        }
206        if self.config.prediction_type != PredictionType::Epsilon {
207            return Err(FerrotorchError::InvalidArgument {
208                message: "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
209                    .into(),
210            });
211        }
212        let n_train = self.config.num_train_timesteps;
213        self.timesteps = match self.config.timestep_spacing {
214            TimestepSpacing::Leading => {
215                // step_ratio = num_train_timesteps // num_inference_steps
216                let step_ratio = n_train / num_inference_steps;
217                // ts = (arange(0, num_inference_steps) * step_ratio) reversed
218                //      + steps_offset.
219                let mut ts: Vec<usize> = (0..num_inference_steps)
220                    .rev()
221                    .map(|i| i * step_ratio + self.config.steps_offset)
222                    .collect();
223                // Clamp into [0, n_train - 1] (defensive — the standard
224                // SD-1.5 4-step recipe never hits this).
225                for t in &mut ts {
226                    if *t >= n_train {
227                        *t = n_train - 1;
228                    }
229                }
230                ts
231            }
232            TimestepSpacing::Linspace => {
233                let step = (n_train as f64) / (num_inference_steps as f64);
234                let mut ts: Vec<usize> = (0..num_inference_steps)
235                    .rev()
236                    .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
237                    .collect();
238                for t in &mut ts {
239                    if *t >= n_train {
240                        *t = n_train - 1;
241                    }
242                }
243                ts
244            }
245        };
246        Ok(&self.timesteps)
247    }
248
249    /// Read-only access to the inference timesteps (empty until
250    /// `set_timesteps` has been called).
251    pub fn timesteps(&self) -> &[usize] {
252        &self.timesteps
253    }
254
255    /// Scale the model input. For DDIM with the SD-1.5 defaults this is
256    /// the identity (DDIM does not rescale the model input the way
257    /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
258    /// diffusers API surface.
259    ///
260    /// # Errors
261    ///
262    /// Currently infallible; returned for forward-compat with non-DDIM
263    /// schedulers we may add later.
264    pub fn scale_model_input<T: Float>(
265        &self,
266        sample: &Tensor<T>,
267        _timestep: usize,
268    ) -> FerrotorchResult<Tensor<T>> {
269        Ok(sample.clone())
270    }
271
272    /// One DDIM step: predict the previous-sample given the model's
273    /// noise prediction `model_output` at `timestep`, applied to `sample`.
274    ///
275    /// Math (η = 0, deterministic):
276    ///
277    /// ```text
278    /// prev_timestep = timestep - step_size
279    /// alpha_t       = alphas_cumprod[timestep]
280    /// alpha_t_prev  = alphas_cumprod[prev_timestep] if prev_timestep >= 0
281    ///                else final_alpha_cumprod
282    /// beta_t        = 1 - alpha_t
283    ///
284    /// pred_x0       = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
285    /// pred_dir      = sqrt(1 - alpha_t_prev) * model_output
286    /// prev_sample   = sqrt(alpha_t_prev) * pred_x0 + pred_dir
287    /// ```
288    ///
289    /// # Errors
290    ///
291    /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
292    /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
293    /// called, and propagates any underlying tensor-arithmetic error.
294    pub fn step<T: Float>(
295        &self,
296        model_output: &Tensor<T>,
297        timestep: usize,
298        sample: &Tensor<T>,
299    ) -> FerrotorchResult<Tensor<T>> {
300        if self.timesteps.is_empty() {
301            return Err(FerrotorchError::InvalidArgument {
302                message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
303            });
304        }
305        if timestep >= self.config.num_train_timesteps {
306            return Err(FerrotorchError::InvalidArgument {
307                message: format!(
308                    "DDIMScheduler::step: timestep {timestep} out of range \
309                     [0, {})",
310                    self.config.num_train_timesteps
311                ),
312            });
313        }
314        if model_output.shape() != sample.shape() {
315            return Err(FerrotorchError::ShapeMismatch {
316                message: format!(
317                    "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
318                    model_output.shape(),
319                    sample.shape()
320                ),
321            });
322        }
323        // diffusers: prev_timestep = timestep - num_train_timesteps //
324        //                              num_inference_steps
325        let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
326        let prev_timestep_i = timestep as isize - step_ratio as isize;
327        let alpha_t = self.alphas_cumprod[timestep];
328        let alpha_t_prev = if prev_timestep_i >= 0 {
329            self.alphas_cumprod[prev_timestep_i as usize]
330        } else {
331            self.final_alpha_cumprod
332        };
333        let beta_t = 1.0 - alpha_t;
334
335        // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
336        let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
337        let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
338        let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
339        let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
340
341        // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
342        //         = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
343        let scaled_noise = mul(model_output, &sqrt_beta_t)?;
344        let diff = sub(sample, &scaled_noise)?;
345        let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
346
347        // Optional clip_sample. SD-1.5 default is false; gate kept so this
348        // module is reusable with non-SD configs.
349        let pred_x0 = if self.config.clip_sample {
350            clip_to_one::<T>(&pred_x0)?
351        } else {
352            pred_x0
353        };
354
355        // pred_dir = sqrt(1 - alpha_t_prev) * model_output
356        let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
357        // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
358        let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
359        add(&x0_scaled, &pred_dir)
360    }
361}
362
363/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
364/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
365fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
366    let mut out = Vec::with_capacity(n);
367    if n == 0 {
368        return out;
369    }
370    if n == 1 {
371        out.push(beta_start);
372        return out;
373    }
374    let denom = (n - 1) as f64;
375    match schedule {
376        BetaSchedule::Linear => {
377            for i in 0..n {
378                let t = i as f64 / denom;
379                out.push(beta_start + t * (beta_end - beta_start));
380            }
381        }
382        BetaSchedule::ScaledLinear => {
383            // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
384            let a = beta_start.sqrt();
385            let b = beta_end.sqrt();
386            for i in 0..n {
387                let t = i as f64 / denom;
388                let lin = a + t * (b - a);
389                out.push(lin * lin);
390            }
391        }
392    }
393    out
394}
395
396/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
397fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
398    let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
399        message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
400    })?;
401    ferrotorch_core::scalar::<T>(v)
402}
403
404/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
405/// configures `clip_sample=true`; SD-1.5 default is false.
406fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
407    let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
408        message: "clip_to_one: cannot represent -1.0".into(),
409    })?;
410    let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
411        message: "clip_to_one: cannot represent 1.0".into(),
412    })?;
413    ferrotorch_core::clamp(t, lo, hi)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
422        // Mirror diffusers's `betas = linspace(sqrt(0.00085),
423        // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
424        let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
425        assert_eq!(betas.len(), 1000);
426        // i=0 → beta_start exactly.
427        assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
428        // i=999 → beta_end exactly.
429        assert!(
430            (betas[999] - 0.012).abs() < 1e-12,
431            "betas[999]={}",
432            betas[999]
433        );
434        // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
435        let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
436        let want = mid_root * mid_root;
437        // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
438        let approx_idx = 999 / 2;
439        assert!(
440            (betas[approx_idx] - want).abs() < 5e-3,
441            "betas[{approx_idx}]={} vs midpoint {want}",
442            betas[approx_idx]
443        );
444    }
445
446    #[test]
447    fn alphas_cumprod_is_monotone_decreasing() {
448        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
449        let mut prev = 1.0_f64;
450        for &a in &sched.alphas_cumprod {
451            assert!(a < prev, "alphas_cumprod not strictly decreasing");
452            assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
453            prev = a;
454        }
455        // Final value is approximately 0.0047 for SD-1.5 (reference).
456        let last = *sched.alphas_cumprod.last().unwrap();
457        assert!(
458            (0.001..0.01).contains(&last),
459            "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
460        );
461    }
462
463    #[test]
464    fn timesteps_leading_4_steps_sd15() {
465        // step_ratio = 1000 // 4 = 250; offset = 1.
466        // diffusers: timesteps = (arange(4) * 250) reversed + 1
467        //           = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
468        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
469        let ts = sched.set_timesteps(4).unwrap();
470        assert_eq!(ts, [751, 501, 251, 1]);
471    }
472
473    #[test]
474    fn timesteps_leading_50_steps_sd15_head() {
475        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
476        let ts = sched.set_timesteps(50).unwrap();
477        assert_eq!(ts.len(), 50);
478        // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
479        assert_eq!(ts[0], 981);
480        assert_eq!(ts[49], 1);
481    }
482
483    #[test]
484    fn init_noise_sigma_is_one() {
485        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
486        assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
487    }
488
489    #[test]
490    fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
491        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
492        assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
493    }
494
495    #[test]
496    fn step_recovers_zero_for_identity_noise() {
497        // If model_output is the zero tensor at timestep 0:
498        //   pred_x0 = sample / sqrt(alpha_0)
499        //   pred_dir = 0
500        //   prev_sample = sqrt(alpha_prev) * pred_x0
501        // We can't really assert byte-level here without a reference, but
502        // we can assert non-finite values do not appear and shape is
503        // preserved.
504        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
505        sched.set_timesteps(4).unwrap();
506        let sample = Tensor::<f32>::from_storage(
507            ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
508            vec![1, 1, 2, 2],
509            false,
510        )
511        .unwrap();
512        let noise = Tensor::<f32>::from_storage(
513            ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
514            vec![1, 1, 2, 2],
515            false,
516        )
517        .unwrap();
518        let out = sched.step(&noise, 1, &sample).unwrap();
519        assert_eq!(out.shape(), &[1, 1, 2, 2]);
520        for v in out.data().unwrap() {
521            assert!(v.is_finite(), "step produced non-finite value: {v}");
522        }
523    }
524}