Skip to main content

ferrotorch_diffusion/
scheduler.rs

1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//!                                            ^^^^^^^^^^^^^^^
10//!                                            this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19
20use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22
23/// Beta-schedule recipe (subset matching SD-1.5).
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum BetaSchedule {
26    /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
27    ScaledLinear,
28    /// `betas = linspace(beta_start, beta_end, N)`.
29    Linear,
30}
31
32/// Discrete timestep spacing (subset matching SD-1.5).
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TimestepSpacing {
35    /// `step_ratio = num_train_timesteps // num_inference_steps;
36    /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
37    /// SD default.
38    Leading,
39    /// `step_ratio = num_train_timesteps / num_inference_steps;
40    /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
41    Linspace,
42}
43
44/// `DDIMScheduler` configuration (the subset that affects forward math).
45#[derive(Debug, Clone)]
46pub struct DDIMConfig {
47    /// `num_train_timesteps` — SD-1.5: 1000.
48    pub num_train_timesteps: usize,
49    /// `beta_start` — SD-1.5: 0.00085.
50    pub beta_start: f64,
51    /// `beta_end` — SD-1.5: 0.012.
52    pub beta_end: f64,
53    /// `beta_schedule` — SD-1.5: ScaledLinear.
54    pub beta_schedule: BetaSchedule,
55    /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
56    pub clip_sample: bool,
57    /// `set_alpha_to_one` — SD-1.5: false. When false, the
58    /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
59    pub set_alpha_to_one: bool,
60    /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
61    /// Only `"epsilon"` is implemented; anything else returns an error
62    /// at `set_timesteps` time.
63    pub prediction_type: PredictionType,
64    /// `timestep_spacing` — SD-1.5: Leading.
65    pub timestep_spacing: TimestepSpacing,
66    /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
67    /// `leading` path so the first inference step is `step_ratio` rather
68    /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
69    pub steps_offset: usize,
70}
71
72/// Prediction parameterisation (subset matching SD-1.5).
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum PredictionType {
75    /// UNet predicts the noise ε. SD default.
76    Epsilon,
77}
78
79impl Default for DDIMConfig {
80    fn default() -> Self {
81        Self {
82            num_train_timesteps: 1000,
83            beta_start: 0.000_85,
84            beta_end: 0.012,
85            beta_schedule: BetaSchedule::ScaledLinear,
86            clip_sample: false,
87            set_alpha_to_one: false,
88            prediction_type: PredictionType::Epsilon,
89            timestep_spacing: TimestepSpacing::Leading,
90            steps_offset: 1,
91        }
92    }
93}
94
95impl DDIMConfig {
96    /// SD-1.5 defaults (alias for [`Default::default`]).
97    pub fn sd_v1_5() -> Self {
98        Self::default()
99    }
100}
101
102/// Deterministic DDIM scheduler (η=0, no noise injection).
103///
104/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
105/// training-time grid (`num_train_timesteps` entries). Inference picks
106/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
107/// them in reverse with [`DDIMScheduler::step`].
108#[derive(Debug, Clone)]
109pub struct DDIMScheduler {
110    config: DDIMConfig,
111    /// `alphas_cumprod` of length `num_train_timesteps`.
112    alphas_cumprod: Vec<f64>,
113    /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
114    /// (the very last denoising step). When `set_alpha_to_one=false` this
115    /// is `alphas_cumprod[0]`; when true it's 1.0.
116    final_alpha_cumprod: f64,
117    /// Timesteps the user requested via [`Self::set_timesteps`], in the
118    /// order they will be consumed (descending). Empty until
119    /// `set_timesteps` is called.
120    timesteps: Vec<usize>,
121}
122
123impl DDIMScheduler {
124    /// Build a scheduler with the given config; runs the one-shot
125    /// `betas → alphas → alphas_cumprod` precomputation.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`FerrotorchError::InvalidArgument`] for malformed
130    /// configurations (zero training steps, non-positive beta).
131    pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
132        if config.num_train_timesteps == 0 {
133            return Err(FerrotorchError::InvalidArgument {
134                message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
135            });
136        }
137        if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
138            return Err(FerrotorchError::InvalidArgument {
139                message: format!(
140                    "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
141                    config.beta_start, config.beta_end
142                ),
143            });
144        }
145        let n = config.num_train_timesteps;
146        let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
147        // alphas[i] = 1 - betas[i].
148        // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
149        let mut alphas_cumprod = Vec::with_capacity(n);
150        let mut acc = 1.0_f64;
151        for &b in &betas {
152            let a = 1.0 - b;
153            acc *= a;
154            alphas_cumprod.push(acc);
155        }
156        let final_alpha_cumprod = if config.set_alpha_to_one {
157            1.0
158        } else {
159            alphas_cumprod[0]
160        };
161        Ok(Self {
162            config,
163            alphas_cumprod,
164            final_alpha_cumprod,
165            timesteps: Vec::new(),
166        })
167    }
168
169    /// Read-only access to the frozen configuration.
170    pub fn config(&self) -> &DDIMConfig {
171        &self.config
172    }
173
174    /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
175    /// noise tensor before the first denoising step. DDIM with the SD-1.5
176    /// defaults uses 1.0 (no scaling).
177    pub fn init_noise_sigma(&self) -> f64 {
178        1.0
179    }
180
181    /// Set the inference-time discrete timesteps and return them.
182    ///
183    /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
184    /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
185    ///
186    /// # Errors
187    ///
188    /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
189    /// is zero or exceeds `num_train_timesteps`, or if the configured
190    /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
191    pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
192        if num_inference_steps == 0 {
193            return Err(FerrotorchError::InvalidArgument {
194                message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
195            });
196        }
197        if num_inference_steps > self.config.num_train_timesteps {
198            return Err(FerrotorchError::InvalidArgument {
199                message: format!(
200                    "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
201                     must be <= num_train_timesteps {}",
202                    self.config.num_train_timesteps
203                ),
204            });
205        }
206        if self.config.prediction_type != PredictionType::Epsilon {
207            return Err(FerrotorchError::InvalidArgument {
208                message:
209                    "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
210                        .into(),
211            });
212        }
213        let n_train = self.config.num_train_timesteps;
214        self.timesteps = match self.config.timestep_spacing {
215            TimestepSpacing::Leading => {
216                // step_ratio = num_train_timesteps // num_inference_steps
217                let step_ratio = n_train / num_inference_steps;
218                // ts = (arange(0, num_inference_steps) * step_ratio) reversed
219                //      + steps_offset.
220                let mut ts: Vec<usize> = (0..num_inference_steps)
221                    .rev()
222                    .map(|i| i * step_ratio + self.config.steps_offset)
223                    .collect();
224                // Clamp into [0, n_train - 1] (defensive — the standard
225                // SD-1.5 4-step recipe never hits this).
226                for t in &mut ts {
227                    if *t >= n_train {
228                        *t = n_train - 1;
229                    }
230                }
231                ts
232            }
233            TimestepSpacing::Linspace => {
234                let step = (n_train as f64) / (num_inference_steps as f64);
235                let mut ts: Vec<usize> = (0..num_inference_steps)
236                    .rev()
237                    .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
238                    .collect();
239                for t in &mut ts {
240                    if *t >= n_train {
241                        *t = n_train - 1;
242                    }
243                }
244                ts
245            }
246        };
247        Ok(&self.timesteps)
248    }
249
250    /// Read-only access to the inference timesteps (empty until
251    /// `set_timesteps` has been called).
252    pub fn timesteps(&self) -> &[usize] {
253        &self.timesteps
254    }
255
256    /// Scale the model input. For DDIM with the SD-1.5 defaults this is
257    /// the identity (DDIM does not rescale the model input the way
258    /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
259    /// diffusers API surface.
260    ///
261    /// # Errors
262    ///
263    /// Currently infallible; returned for forward-compat with non-DDIM
264    /// schedulers we may add later.
265    pub fn scale_model_input<T: Float>(
266        &self,
267        sample: &Tensor<T>,
268        _timestep: usize,
269    ) -> FerrotorchResult<Tensor<T>> {
270        Ok(sample.clone())
271    }
272
273    /// One DDIM step: predict the previous-sample given the model's
274    /// noise prediction `model_output` at `timestep`, applied to `sample`.
275    ///
276    /// Math (η = 0, deterministic):
277    ///
278    /// ```text
279    /// prev_timestep = timestep - step_size
280    /// alpha_t       = alphas_cumprod[timestep]
281    /// alpha_t_prev  = alphas_cumprod[prev_timestep] if prev_timestep >= 0
282    ///                else final_alpha_cumprod
283    /// beta_t        = 1 - alpha_t
284    ///
285    /// pred_x0       = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
286    /// pred_dir      = sqrt(1 - alpha_t_prev) * model_output
287    /// prev_sample   = sqrt(alpha_t_prev) * pred_x0 + pred_dir
288    /// ```
289    ///
290    /// # Errors
291    ///
292    /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
293    /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
294    /// called, and propagates any underlying tensor-arithmetic error.
295    pub fn step<T: Float>(
296        &self,
297        model_output: &Tensor<T>,
298        timestep: usize,
299        sample: &Tensor<T>,
300    ) -> FerrotorchResult<Tensor<T>> {
301        if self.timesteps.is_empty() {
302            return Err(FerrotorchError::InvalidArgument {
303                message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
304            });
305        }
306        if timestep >= self.config.num_train_timesteps {
307            return Err(FerrotorchError::InvalidArgument {
308                message: format!(
309                    "DDIMScheduler::step: timestep {timestep} out of range \
310                     [0, {})",
311                    self.config.num_train_timesteps
312                ),
313            });
314        }
315        if model_output.shape() != sample.shape() {
316            return Err(FerrotorchError::ShapeMismatch {
317                message: format!(
318                    "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
319                    model_output.shape(),
320                    sample.shape()
321                ),
322            });
323        }
324        // diffusers: prev_timestep = timestep - num_train_timesteps //
325        //                              num_inference_steps
326        let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
327        let prev_timestep_i = timestep as isize - step_ratio as isize;
328        let alpha_t = self.alphas_cumprod[timestep];
329        let alpha_t_prev = if prev_timestep_i >= 0 {
330            self.alphas_cumprod[prev_timestep_i as usize]
331        } else {
332            self.final_alpha_cumprod
333        };
334        let beta_t = 1.0 - alpha_t;
335
336        // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
337        let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
338        let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
339        let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
340        let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
341
342        // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
343        //         = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
344        let scaled_noise = mul(model_output, &sqrt_beta_t)?;
345        let diff = sub(sample, &scaled_noise)?;
346        let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
347
348        // Optional clip_sample. SD-1.5 default is false; gate kept so this
349        // module is reusable with non-SD configs.
350        let pred_x0 = if self.config.clip_sample {
351            clip_to_one::<T>(&pred_x0)?
352        } else {
353            pred_x0
354        };
355
356        // pred_dir = sqrt(1 - alpha_t_prev) * model_output
357        let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
358        // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
359        let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
360        add(&x0_scaled, &pred_dir)
361    }
362}
363
364/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
365/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
366fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
367    let mut out = Vec::with_capacity(n);
368    if n == 0 {
369        return out;
370    }
371    if n == 1 {
372        out.push(beta_start);
373        return out;
374    }
375    let denom = (n - 1) as f64;
376    match schedule {
377        BetaSchedule::Linear => {
378            for i in 0..n {
379                let t = i as f64 / denom;
380                out.push(beta_start + t * (beta_end - beta_start));
381            }
382        }
383        BetaSchedule::ScaledLinear => {
384            // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
385            let a = beta_start.sqrt();
386            let b = beta_end.sqrt();
387            for i in 0..n {
388                let t = i as f64 / denom;
389                let lin = a + t * (b - a);
390                out.push(lin * lin);
391            }
392        }
393    }
394    out
395}
396
397/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
398fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
399    let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
400        message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
401    })?;
402    ferrotorch_core::scalar::<T>(v)
403}
404
405/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
406/// configures `clip_sample=true`; SD-1.5 default is false.
407fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
408    let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
409        message: "clip_to_one: cannot represent -1.0".into(),
410    })?;
411    let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
412        message: "clip_to_one: cannot represent 1.0".into(),
413    })?;
414    ferrotorch_core::clamp(t, lo, hi)
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
423        // Mirror diffusers's `betas = linspace(sqrt(0.00085),
424        // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
425        let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
426        assert_eq!(betas.len(), 1000);
427        // i=0 → beta_start exactly.
428        assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
429        // i=999 → beta_end exactly.
430        assert!(
431            (betas[999] - 0.012).abs() < 1e-12,
432            "betas[999]={}",
433            betas[999]
434        );
435        // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
436        let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
437        let want = mid_root * mid_root;
438        // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
439        let approx_idx = 999 / 2;
440        assert!(
441            (betas[approx_idx] - want).abs() < 5e-3,
442            "betas[{approx_idx}]={} vs midpoint {want}",
443            betas[approx_idx]
444        );
445    }
446
447    #[test]
448    fn alphas_cumprod_is_monotone_decreasing() {
449        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
450        let mut prev = 1.0_f64;
451        for &a in &sched.alphas_cumprod {
452            assert!(a < prev, "alphas_cumprod not strictly decreasing");
453            assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
454            prev = a;
455        }
456        // Final value is approximately 0.0047 for SD-1.5 (reference).
457        let last = *sched.alphas_cumprod.last().unwrap();
458        assert!(
459            (0.001..0.01).contains(&last),
460            "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
461        );
462    }
463
464    #[test]
465    fn timesteps_leading_4_steps_sd15() {
466        // step_ratio = 1000 // 4 = 250; offset = 1.
467        // diffusers: timesteps = (arange(4) * 250) reversed + 1
468        //           = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
469        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
470        let ts = sched.set_timesteps(4).unwrap();
471        assert_eq!(ts, [751, 501, 251, 1]);
472    }
473
474    #[test]
475    fn timesteps_leading_50_steps_sd15_head() {
476        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
477        let ts = sched.set_timesteps(50).unwrap();
478        assert_eq!(ts.len(), 50);
479        // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
480        assert_eq!(ts[0], 981);
481        assert_eq!(ts[49], 1);
482    }
483
484    #[test]
485    fn init_noise_sigma_is_one() {
486        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
487        assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
488    }
489
490    #[test]
491    fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
492        let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
493        assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
494    }
495
496    #[test]
497    fn step_recovers_zero_for_identity_noise() {
498        // If model_output is the zero tensor at timestep 0:
499        //   pred_x0 = sample / sqrt(alpha_0)
500        //   pred_dir = 0
501        //   prev_sample = sqrt(alpha_prev) * pred_x0
502        // We can't really assert byte-level here without a reference, but
503        // we can assert non-finite values do not appear and shape is
504        // preserved.
505        let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
506        sched.set_timesteps(4).unwrap();
507        let sample = Tensor::<f32>::from_storage(
508            ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
509            vec![1, 1, 2, 2],
510            false,
511        )
512        .unwrap();
513        let noise = Tensor::<f32>::from_storage(
514            ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
515            vec![1, 1, 2, 2],
516            false,
517        )
518        .unwrap();
519        let out = sched.step(&noise, 1, &sample).unwrap();
520        assert_eq!(out.shape(), &[1, 1, 2, 2]);
521        for v in out.data().unwrap() {
522            assert!(v.is_finite(), "step produced non-finite value: {v}");
523        }
524    }
525}