Skip to main content

ferrotorch_diffusion/
scheduler.rs

1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//!                                            ^^^^^^^^^^^^^^^
10//!                                            this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19//!
20//! ## REQ status (per `.design/ferrotorch-diffusion/scheduler.md`)
21//!
22//! | REQ | Status | Evidence |
23//! |---|---|---|
24//! | REQ-1 | SHIPPED | `DDIMScheduler::new` at `ferrotorch-diffusion/src/scheduler.rs:131..167`; consumer: `ferrotorch-diffusion/examples/sd_pipeline_dump.rs` builds the scheduler with `DDIMConfig::sd_v1_5()` |
25//! | REQ-2 | SHIPPED | `set_timesteps` at `ferrotorch-diffusion/src/scheduler.rs:191..248`; consumer: `ferrotorch-diffusion/src/pipeline.rs:194` calls it |
26//! | REQ-3 | SHIPPED | `step` at `ferrotorch-diffusion/src/scheduler.rs:295..361`; consumer: `ferrotorch-diffusion/src/pipeline.rs:212` calls it |
27//! | REQ-4 | SHIPPED | `init_noise_sigma` at `ferrotorch-diffusion/src/scheduler.rs:177..179` and `scale_model_input` at `scheduler.rs:265..271`; consumer: `ferrotorch-diffusion/src/pipeline.rs:199` and `pipeline.rs:132` |
28//! | REQ-5 | SHIPPED | `compute_betas` ScaledLinear arm at `ferrotorch-diffusion/src/scheduler.rs:383..392`; consumer: `DDIMScheduler::new` at `scheduler.rs:146` invokes it |
29//! | REQ-6 | SHIPPED | prediction-type guard at `ferrotorch-diffusion/src/scheduler.rs:206..212`; consumer: `ferrotorch-diffusion/src/pipeline.rs:194` surfaces this error |
30
31use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
32use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
33
34/// Beta-schedule recipe (subset matching SD-1.5).
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum BetaSchedule {
37    /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
38    ScaledLinear,
39    /// `betas = linspace(beta_start, beta_end, N)`.
40    Linear,
41}
42
43/// Discrete timestep spacing (subset matching SD-1.5).
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum TimestepSpacing {
46    /// `step_ratio = num_train_timesteps // num_inference_steps;
47    /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
48    /// SD default.
49    Leading,
50    /// `step_ratio = num_train_timesteps / num_inference_steps;
51    /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
52    Linspace,
53}
54
55/// `DDIMScheduler` configuration (the subset that affects forward math).
56#[derive(Debug, Clone)]
57pub struct DDIMConfig {
58    /// `num_train_timesteps` — SD-1.5: 1000.
59    pub num_train_timesteps: usize,
60    /// `beta_start` — SD-1.5: 0.00085.
61    pub beta_start: f64,
62    /// `beta_end` — SD-1.5: 0.012.
63    pub beta_end: f64,
64    /// `beta_schedule` — SD-1.5: ScaledLinear.
65    pub beta_schedule: BetaSchedule,
66    /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
67    pub clip_sample: bool,
68    /// `set_alpha_to_one` — SD-1.5: false. When false, the
69    /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
70    pub set_alpha_to_one: bool,
71    /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
72    /// Only `"epsilon"` is implemented; anything else returns an error
73    /// at `set_timesteps` time.
74    pub prediction_type: PredictionType,
75    /// `timestep_spacing` — SD-1.5: Leading.
76    pub timestep_spacing: TimestepSpacing,
77    /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
78    /// `leading` path so the first inference step is `step_ratio` rather
79    /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
80    pub steps_offset: usize,
81}
82
83/// Prediction parameterisation (subset matching SD-1.5).
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum PredictionType {
86    /// UNet predicts the noise ε. SD default.
87    Epsilon,
88}
89
90impl Default for DDIMConfig {
91    fn default() -> Self {
92        Self {
93            num_train_timesteps: 1000,
94            beta_start: 0.000_85,
95            beta_end: 0.012,
96            beta_schedule: BetaSchedule::ScaledLinear,
97            clip_sample: false,
98            set_alpha_to_one: false,
99            prediction_type: PredictionType::Epsilon,
100            timestep_spacing: TimestepSpacing::Leading,
101            steps_offset: 1,
102        }
103    }
104}
105
106impl DDIMConfig {
107    /// SD-1.5 defaults (alias for [`Default::default`]).
108    pub fn sd_v1_5() -> Self {
109        Self::default()
110    }
111}
112
113/// Deterministic DDIM scheduler (η=0, no noise injection).
114///
115/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
116/// training-time grid (`num_train_timesteps` entries). Inference picks
117/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
118/// them in reverse with [`DDIMScheduler::step`].
119#[derive(Debug, Clone)]
120pub struct DDIMScheduler {
121    config: DDIMConfig,
122    /// `alphas_cumprod` of length `num_train_timesteps`.
123    alphas_cumprod: Vec<f64>,
124    /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
125    /// (the very last denoising step). When `set_alpha_to_one=false` this
126    /// is `alphas_cumprod[0]`; when true it's 1.0.
127    final_alpha_cumprod: f64,
128    /// Timesteps the user requested via [`Self::set_timesteps`], in the
129    /// order they will be consumed (descending). Empty until
130    /// `set_timesteps` is called.
131    timesteps: Vec<usize>,
132}
133
134impl DDIMScheduler {
135    /// Build a scheduler with the given config; runs the one-shot
136    /// `betas → alphas → alphas_cumprod` precomputation.
137    ///
138    /// # Errors
139    ///
140    /// Returns [`FerrotorchError::InvalidArgument`] for malformed
141    /// configurations (zero training steps, non-positive beta).
142    pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
143        if config.num_train_timesteps == 0 {
144            return Err(FerrotorchError::InvalidArgument {
145                message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
146            });
147        }
148        if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
149            return Err(FerrotorchError::InvalidArgument {
150                message: format!(
151                    "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
152                    config.beta_start, config.beta_end
153                ),
154            });
155        }
156        let n = config.num_train_timesteps;
157        let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
158        // alphas[i] = 1 - betas[i].
159        // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
160        let mut alphas_cumprod = Vec::with_capacity(n);
161        let mut acc = 1.0_f64;
162        for &b in &betas {
163            let a = 1.0 - b;
164            acc *= a;
165            alphas_cumprod.push(acc);
166        }
167        let final_alpha_cumprod = if config.set_alpha_to_one {
168            1.0
169        } else {
170            alphas_cumprod[0]
171        };
172        Ok(Self {
173            config,
174            alphas_cumprod,
175            final_alpha_cumprod,
176            timesteps: Vec::new(),
177        })
178    }
179
180    /// Read-only access to the frozen configuration.
181    pub fn config(&self) -> &DDIMConfig {
182        &self.config
183    }
184
185    /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
186    /// noise tensor before the first denoising step. DDIM with the SD-1.5
187    /// defaults uses 1.0 (no scaling).
188    pub fn init_noise_sigma(&self) -> f64 {
189        1.0
190    }
191
192    /// Set the inference-time discrete timesteps and return them.
193    ///
194    /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
195    /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
196    ///
197    /// # Errors
198    ///
199    /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
200    /// is zero or exceeds `num_train_timesteps`, or if the configured
201    /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
202    pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
203        if num_inference_steps == 0 {
204            return Err(FerrotorchError::InvalidArgument {
205                message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
206            });
207        }
208        if num_inference_steps > self.config.num_train_timesteps {
209            return Err(FerrotorchError::InvalidArgument {
210                message: format!(
211                    "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
212                     must be <= num_train_timesteps {}",
213                    self.config.num_train_timesteps
214                ),
215            });
216        }
217        if self.config.prediction_type != PredictionType::Epsilon {
218            return Err(FerrotorchError::InvalidArgument {
219                message:
220                    "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
221                        .into(),
222            });
223        }
224        let n_train = self.config.num_train_timesteps;
225        self.timesteps = match self.config.timestep_spacing {
226            TimestepSpacing::Leading => {
227                // step_ratio = num_train_timesteps // num_inference_steps
228                let step_ratio = n_train / num_inference_steps;
229                // ts = (arange(0, num_inference_steps) * step_ratio) reversed
230                //      + steps_offset.
231                let mut ts: Vec<usize> = (0..num_inference_steps)
232                    .rev()
233                    .map(|i| i * step_ratio + self.config.steps_offset)
234                    .collect();
235                // Clamp into [0, n_train - 1] (defensive — the standard
236                // SD-1.5 4-step recipe never hits this).
237                for t in &mut ts {
238                    if *t >= n_train {
239                        *t = n_train - 1;
240                    }
241                }
242                ts
243            }
244            TimestepSpacing::Linspace => {
245                let step = (n_train as f64) / (num_inference_steps as f64);
246                let mut ts: Vec<usize> = (0..num_inference_steps)
247                    .rev()
248                    .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
249                    .collect();
250                for t in &mut ts {
251                    if *t >= n_train {
252                        *t = n_train - 1;
253                    }
254                }
255                ts
256            }
257        };
258        Ok(&self.timesteps)
259    }
260
261    /// Read-only access to the inference timesteps (empty until
262    /// `set_timesteps` has been called).
263    pub fn timesteps(&self) -> &[usize] {
264        &self.timesteps
265    }
266
267    /// Scale the model input. For DDIM with the SD-1.5 defaults this is
268    /// the identity (DDIM does not rescale the model input the way
269    /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
270    /// diffusers API surface.
271    ///
272    /// # Errors
273    ///
274    /// Currently infallible; returned for forward-compat with non-DDIM
275    /// schedulers we may add later.
276    pub fn scale_model_input<T: Float>(
277        &self,
278        sample: &Tensor<T>,
279        _timestep: usize,
280    ) -> FerrotorchResult<Tensor<T>> {
281        Ok(sample.clone())
282    }
283
284    /// One DDIM step: predict the previous-sample given the model's
285    /// noise prediction `model_output` at `timestep`, applied to `sample`.
286    ///
287    /// Math (η = 0, deterministic):
288    ///
289    /// ```text
290    /// prev_timestep = timestep - step_size
291    /// alpha_t       = alphas_cumprod[timestep]
292    /// alpha_t_prev  = alphas_cumprod[prev_timestep] if prev_timestep >= 0
293    ///                else final_alpha_cumprod
294    /// beta_t        = 1 - alpha_t
295    ///
296    /// pred_x0       = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
297    /// pred_dir      = sqrt(1 - alpha_t_prev) * model_output
298    /// prev_sample   = sqrt(alpha_t_prev) * pred_x0 + pred_dir
299    /// ```
300    ///
301    /// # Errors
302    ///
303    /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
304    /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
305    /// called, and propagates any underlying tensor-arithmetic error.
306    pub fn step<T: Float>(
307        &self,
308        model_output: &Tensor<T>,
309        timestep: usize,
310        sample: &Tensor<T>,
311    ) -> FerrotorchResult<Tensor<T>> {
312        if self.timesteps.is_empty() {
313            return Err(FerrotorchError::InvalidArgument {
314                message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
315            });
316        }
317        if timestep >= self.config.num_train_timesteps {
318            return Err(FerrotorchError::InvalidArgument {
319                message: format!(
320                    "DDIMScheduler::step: timestep {timestep} out of range \
321                     [0, {})",
322                    self.config.num_train_timesteps
323                ),
324            });
325        }
326        if model_output.shape() != sample.shape() {
327            return Err(FerrotorchError::ShapeMismatch {
328                message: format!(
329                    "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
330                    model_output.shape(),
331                    sample.shape()
332                ),
333            });
334        }
335        // diffusers: prev_timestep = timestep - num_train_timesteps //
336        //                              num_inference_steps
337        let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
338        let prev_timestep_i = timestep as isize - step_ratio as isize;
339        let alpha_t = self.alphas_cumprod[timestep];
340        let alpha_t_prev = if prev_timestep_i >= 0 {
341            self.alphas_cumprod[prev_timestep_i as usize]
342        } else {
343            self.final_alpha_cumprod
344        };
345        let beta_t = 1.0 - alpha_t;
346
347        // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
348        let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
349        let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
350        let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
351        let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
352
353        // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
354        //         = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
355        let scaled_noise = mul(model_output, &sqrt_beta_t)?;
356        let diff = sub(sample, &scaled_noise)?;
357        let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
358
359        // Optional clip_sample. SD-1.5 default is false; gate kept so this
360        // module is reusable with non-SD configs.
361        let pred_x0 = if self.config.clip_sample {
362            clip_to_one::<T>(&pred_x0)?
363        } else {
364            pred_x0
365        };
366
367        // pred_dir = sqrt(1 - alpha_t_prev) * model_output
368        let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
369        // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
370        let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
371        add(&x0_scaled, &pred_dir)
372    }
373}
374
375/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
376/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
377fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
378    let mut out = Vec::with_capacity(n);
379    if n == 0 {
380        return out;
381    }
382    if n == 1 {
383        out.push(beta_start);
384        return out;
385    }
386    let denom = (n - 1) as f64;
387    match schedule {
388        BetaSchedule::Linear => {
389            for i in 0..n {
390                let t = i as f64 / denom;
391                out.push(beta_start + t * (beta_end - beta_start));
392            }
393        }
394        BetaSchedule::ScaledLinear => {
395            // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
396            let a = beta_start.sqrt();
397            let b = beta_end.sqrt();
398            for i in 0..n {
399                let t = i as f64 / denom;
400                let lin = a + t * (b - a);
401                out.push(lin * lin);
402            }
403        }
404    }
405    out
406}
407
408/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
409fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
410    let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
411        message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
412    })?;
413    ferrotorch_core::scalar::<T>(v)
414}
415
416/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
417/// configures `clip_sample=true`; SD-1.5 default is false.
418fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
419    let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
420        message: "clip_to_one: cannot represent -1.0".into(),
421    })?;
422    let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
423        message: "clip_to_one: cannot represent 1.0".into(),
424    })?;
425    ferrotorch_core::clamp(t, lo, hi)
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
434        // Mirror diffusers's `betas = linspace(sqrt(0.00085),
435        // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
436        let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
437        assert_eq!(betas.len(), 1000);
438        // i=0 → beta_start exactly.
439        assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
440        // i=999 → beta_end exactly.
441        assert!(
442            (betas[999] - 0.012).abs() < 1e-12,
443            "betas[999]={}",
444            betas[999]
445        );
446        // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
447        let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
448        let want = mid_root * mid_root;
449        // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
450        let approx_idx = 999 / 2;
451        assert!(
452            (betas[approx_idx] - want).abs() < 5e-3,
453            "betas[{approx_idx}]={} vs midpoint {want}",
454            betas[approx_idx]
455        );
456    }
457
458    #[test]
459    fn alphas_cumprod_is_monotone_decreasing() {
460        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
461        let mut prev = 1.0_f64;
462        for &a in &sched.alphas_cumprod {
463            assert!(a < prev, "alphas_cumprod not strictly decreasing");
464            assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
465            prev = a;
466        }
467        // Final value is approximately 0.0047 for SD-1.5 (reference).
468        let last = *sched.alphas_cumprod.last().unwrap();
469        assert!(
470            (0.001..0.01).contains(&last),
471            "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
472        );
473    }
474
475    #[test]
476    fn timesteps_leading_4_steps_sd15() {
477        // step_ratio = 1000 // 4 = 250; offset = 1.
478        // diffusers: timesteps = (arange(4) * 250) reversed + 1
479        //           = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
480        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
481        let ts = sched.set_timesteps(4).unwrap();
482        assert_eq!(ts, [751, 501, 251, 1]);
483    }
484
485    #[test]
486    fn timesteps_leading_50_steps_sd15_head() {
487        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
488        let ts = sched.set_timesteps(50).unwrap();
489        assert_eq!(ts.len(), 50);
490        // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
491        assert_eq!(ts[0], 981);
492        assert_eq!(ts[49], 1);
493    }
494
495    #[test]
496    fn init_noise_sigma_is_one() {
497        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
498        assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
499    }
500
501    #[test]
502    fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
503        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
504        assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
505    }
506
507    #[test]
508    fn step_recovers_zero_for_identity_noise() {
509        // If model_output is the zero tensor at timestep 0:
510        //   pred_x0 = sample / sqrt(alpha_0)
511        //   pred_dir = 0
512        //   prev_sample = sqrt(alpha_prev) * pred_x0
513        // We can't really assert byte-level here without a reference, but
514        // we can assert non-finite values do not appear and shape is
515        // preserved.
516        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
517        sched.set_timesteps(4).unwrap();
518        let sample = Tensor::<f32>::from_storage(
519            ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
520            vec![1, 1, 2, 2],
521            false,
522        )
523        .unwrap();
524        let noise = Tensor::<f32>::from_storage(
525            ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
526            vec![1, 1, 2, 2],
527            false,
528        )
529        .unwrap();
530        let out = sched.step(&noise, 1, &sample).unwrap();
531        assert_eq!(out.shape(), &[1, 1, 2, 2]);
532        for v in out.data().unwrap() {
533            assert!(v.is_finite(), "step produced non-finite value: {v}");
534        }
535    }
536}