ferrotorch_diffusion/scheduler.rs
1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//! ^^^^^^^^^^^^^^^
10//! this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19//!
20//! ## REQ status (per `.design/ferrotorch-diffusion/scheduler.md`)
21//!
22//! | REQ | Status | Evidence |
23//! |---|---|---|
24//! | REQ-1 | SHIPPED | `DDIMScheduler::new` at `ferrotorch-diffusion/src/scheduler.rs:131..167`; consumer: `ferrotorch-diffusion/examples/sd_pipeline_dump.rs` builds the scheduler with `DDIMConfig::sd_v1_5()` |
25//! | REQ-2 | SHIPPED | `set_timesteps` at `ferrotorch-diffusion/src/scheduler.rs:191..248`; consumer: `ferrotorch-diffusion/src/pipeline.rs:194` calls it |
26//! | REQ-3 | SHIPPED | `step` at `ferrotorch-diffusion/src/scheduler.rs:295..361`; consumer: `ferrotorch-diffusion/src/pipeline.rs:212` calls it |
27//! | REQ-4 | SHIPPED | `init_noise_sigma` at `ferrotorch-diffusion/src/scheduler.rs:177..179` and `scale_model_input` at `scheduler.rs:265..271`; consumer: `ferrotorch-diffusion/src/pipeline.rs:199` and `pipeline.rs:132` |
28//! | REQ-5 | SHIPPED | `compute_betas` ScaledLinear arm at `ferrotorch-diffusion/src/scheduler.rs:383..392`; consumer: `DDIMScheduler::new` at `scheduler.rs:146` invokes it |
29//! | REQ-6 | SHIPPED | prediction-type guard at `ferrotorch-diffusion/src/scheduler.rs:206..212`; consumer: `ferrotorch-diffusion/src/pipeline.rs:194` surfaces this error |
30
31use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
32use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
33
34/// Beta-schedule recipe (subset matching SD-1.5).
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum BetaSchedule {
37 /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
38 ScaledLinear,
39 /// `betas = linspace(beta_start, beta_end, N)`.
40 Linear,
41}
42
43/// Discrete timestep spacing (subset matching SD-1.5).
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum TimestepSpacing {
46 /// `step_ratio = num_train_timesteps // num_inference_steps;
47 /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
48 /// SD default.
49 Leading,
50 /// `step_ratio = num_train_timesteps / num_inference_steps;
51 /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
52 Linspace,
53}
54
55/// `DDIMScheduler` configuration (the subset that affects forward math).
56#[derive(Debug, Clone)]
57pub struct DDIMConfig {
58 /// `num_train_timesteps` — SD-1.5: 1000.
59 pub num_train_timesteps: usize,
60 /// `beta_start` — SD-1.5: 0.00085.
61 pub beta_start: f64,
62 /// `beta_end` — SD-1.5: 0.012.
63 pub beta_end: f64,
64 /// `beta_schedule` — SD-1.5: ScaledLinear.
65 pub beta_schedule: BetaSchedule,
66 /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
67 pub clip_sample: bool,
68 /// `set_alpha_to_one` — SD-1.5: false. When false, the
69 /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
70 pub set_alpha_to_one: bool,
71 /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
72 /// Only `"epsilon"` is implemented; anything else returns an error
73 /// at `set_timesteps` time.
74 pub prediction_type: PredictionType,
75 /// `timestep_spacing` — SD-1.5: Leading.
76 pub timestep_spacing: TimestepSpacing,
77 /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
78 /// `leading` path so the first inference step is `step_ratio` rather
79 /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
80 pub steps_offset: usize,
81}
82
83/// Prediction parameterisation (subset matching SD-1.5).
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum PredictionType {
86 /// UNet predicts the noise ε. SD default.
87 Epsilon,
88}
89
90impl Default for DDIMConfig {
91 fn default() -> Self {
92 Self {
93 num_train_timesteps: 1000,
94 beta_start: 0.000_85,
95 beta_end: 0.012,
96 beta_schedule: BetaSchedule::ScaledLinear,
97 clip_sample: false,
98 set_alpha_to_one: false,
99 prediction_type: PredictionType::Epsilon,
100 timestep_spacing: TimestepSpacing::Leading,
101 steps_offset: 1,
102 }
103 }
104}
105
106impl DDIMConfig {
107 /// SD-1.5 defaults (alias for [`Default::default`]).
108 pub fn sd_v1_5() -> Self {
109 Self::default()
110 }
111}
112
113/// Deterministic DDIM scheduler (η=0, no noise injection).
114///
115/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
116/// training-time grid (`num_train_timesteps` entries). Inference picks
117/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
118/// them in reverse with [`DDIMScheduler::step`].
119#[derive(Debug, Clone)]
120pub struct DDIMScheduler {
121 config: DDIMConfig,
122 /// `alphas_cumprod` of length `num_train_timesteps`.
123 alphas_cumprod: Vec<f64>,
124 /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
125 /// (the very last denoising step). When `set_alpha_to_one=false` this
126 /// is `alphas_cumprod[0]`; when true it's 1.0.
127 final_alpha_cumprod: f64,
128 /// Timesteps the user requested via [`Self::set_timesteps`], in the
129 /// order they will be consumed (descending). Empty until
130 /// `set_timesteps` is called.
131 timesteps: Vec<usize>,
132}
133
134impl DDIMScheduler {
135 /// Build a scheduler with the given config; runs the one-shot
136 /// `betas → alphas → alphas_cumprod` precomputation.
137 ///
138 /// # Errors
139 ///
140 /// Returns [`FerrotorchError::InvalidArgument`] for malformed
141 /// configurations (zero training steps, non-positive beta).
142 pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
143 if config.num_train_timesteps == 0 {
144 return Err(FerrotorchError::InvalidArgument {
145 message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
146 });
147 }
148 if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
149 return Err(FerrotorchError::InvalidArgument {
150 message: format!(
151 "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
152 config.beta_start, config.beta_end
153 ),
154 });
155 }
156 let n = config.num_train_timesteps;
157 let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
158 // alphas[i] = 1 - betas[i].
159 // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
160 let mut alphas_cumprod = Vec::with_capacity(n);
161 let mut acc = 1.0_f64;
162 for &b in &betas {
163 let a = 1.0 - b;
164 acc *= a;
165 alphas_cumprod.push(acc);
166 }
167 let final_alpha_cumprod = if config.set_alpha_to_one {
168 1.0
169 } else {
170 alphas_cumprod[0]
171 };
172 Ok(Self {
173 config,
174 alphas_cumprod,
175 final_alpha_cumprod,
176 timesteps: Vec::new(),
177 })
178 }
179
180 /// Read-only access to the frozen configuration.
181 pub fn config(&self) -> &DDIMConfig {
182 &self.config
183 }
184
185 /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
186 /// noise tensor before the first denoising step. DDIM with the SD-1.5
187 /// defaults uses 1.0 (no scaling).
188 pub fn init_noise_sigma(&self) -> f64 {
189 1.0
190 }
191
192 /// Set the inference-time discrete timesteps and return them.
193 ///
194 /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
195 /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
196 ///
197 /// # Errors
198 ///
199 /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
200 /// is zero or exceeds `num_train_timesteps`, or if the configured
201 /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
202 pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
203 if num_inference_steps == 0 {
204 return Err(FerrotorchError::InvalidArgument {
205 message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
206 });
207 }
208 if num_inference_steps > self.config.num_train_timesteps {
209 return Err(FerrotorchError::InvalidArgument {
210 message: format!(
211 "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
212 must be <= num_train_timesteps {}",
213 self.config.num_train_timesteps
214 ),
215 });
216 }
217 if self.config.prediction_type != PredictionType::Epsilon {
218 return Err(FerrotorchError::InvalidArgument {
219 message:
220 "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
221 .into(),
222 });
223 }
224 let n_train = self.config.num_train_timesteps;
225 self.timesteps = match self.config.timestep_spacing {
226 TimestepSpacing::Leading => {
227 // step_ratio = num_train_timesteps // num_inference_steps
228 let step_ratio = n_train / num_inference_steps;
229 // ts = (arange(0, num_inference_steps) * step_ratio) reversed
230 // + steps_offset.
231 let mut ts: Vec<usize> = (0..num_inference_steps)
232 .rev()
233 .map(|i| i * step_ratio + self.config.steps_offset)
234 .collect();
235 // Clamp into [0, n_train - 1] (defensive — the standard
236 // SD-1.5 4-step recipe never hits this).
237 for t in &mut ts {
238 if *t >= n_train {
239 *t = n_train - 1;
240 }
241 }
242 ts
243 }
244 TimestepSpacing::Linspace => {
245 let step = (n_train as f64) / (num_inference_steps as f64);
246 let mut ts: Vec<usize> = (0..num_inference_steps)
247 .rev()
248 .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
249 .collect();
250 for t in &mut ts {
251 if *t >= n_train {
252 *t = n_train - 1;
253 }
254 }
255 ts
256 }
257 };
258 Ok(&self.timesteps)
259 }
260
261 /// Read-only access to the inference timesteps (empty until
262 /// `set_timesteps` has been called).
263 pub fn timesteps(&self) -> &[usize] {
264 &self.timesteps
265 }
266
267 /// Scale the model input. For DDIM with the SD-1.5 defaults this is
268 /// the identity (DDIM does not rescale the model input the way
269 /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
270 /// diffusers API surface.
271 ///
272 /// # Errors
273 ///
274 /// Currently infallible; returned for forward-compat with non-DDIM
275 /// schedulers we may add later.
276 pub fn scale_model_input<T: Float>(
277 &self,
278 sample: &Tensor<T>,
279 _timestep: usize,
280 ) -> FerrotorchResult<Tensor<T>> {
281 Ok(sample.clone())
282 }
283
284 /// One DDIM step: predict the previous-sample given the model's
285 /// noise prediction `model_output` at `timestep`, applied to `sample`.
286 ///
287 /// Math (η = 0, deterministic):
288 ///
289 /// ```text
290 /// prev_timestep = timestep - step_size
291 /// alpha_t = alphas_cumprod[timestep]
292 /// alpha_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0
293 /// else final_alpha_cumprod
294 /// beta_t = 1 - alpha_t
295 ///
296 /// pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
297 /// pred_dir = sqrt(1 - alpha_t_prev) * model_output
298 /// prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
299 /// ```
300 ///
301 /// # Errors
302 ///
303 /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
304 /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
305 /// called, and propagates any underlying tensor-arithmetic error.
306 pub fn step<T: Float>(
307 &self,
308 model_output: &Tensor<T>,
309 timestep: usize,
310 sample: &Tensor<T>,
311 ) -> FerrotorchResult<Tensor<T>> {
312 if self.timesteps.is_empty() {
313 return Err(FerrotorchError::InvalidArgument {
314 message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
315 });
316 }
317 if timestep >= self.config.num_train_timesteps {
318 return Err(FerrotorchError::InvalidArgument {
319 message: format!(
320 "DDIMScheduler::step: timestep {timestep} out of range \
321 [0, {})",
322 self.config.num_train_timesteps
323 ),
324 });
325 }
326 if model_output.shape() != sample.shape() {
327 return Err(FerrotorchError::ShapeMismatch {
328 message: format!(
329 "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
330 model_output.shape(),
331 sample.shape()
332 ),
333 });
334 }
335 // diffusers: prev_timestep = timestep - num_train_timesteps //
336 // num_inference_steps
337 let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
338 let prev_timestep_i = timestep as isize - step_ratio as isize;
339 let alpha_t = self.alphas_cumprod[timestep];
340 let alpha_t_prev = if prev_timestep_i >= 0 {
341 self.alphas_cumprod[prev_timestep_i as usize]
342 } else {
343 self.final_alpha_cumprod
344 };
345 let beta_t = 1.0 - alpha_t;
346
347 // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
348 let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
349 let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
350 let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
351 let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
352
353 // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
354 // = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
355 let scaled_noise = mul(model_output, &sqrt_beta_t)?;
356 let diff = sub(sample, &scaled_noise)?;
357 let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
358
359 // Optional clip_sample. SD-1.5 default is false; gate kept so this
360 // module is reusable with non-SD configs.
361 let pred_x0 = if self.config.clip_sample {
362 clip_to_one::<T>(&pred_x0)?
363 } else {
364 pred_x0
365 };
366
367 // pred_dir = sqrt(1 - alpha_t_prev) * model_output
368 let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
369 // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
370 let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
371 add(&x0_scaled, &pred_dir)
372 }
373}
374
375/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
376/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
377fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
378 let mut out = Vec::with_capacity(n);
379 if n == 0 {
380 return out;
381 }
382 if n == 1 {
383 out.push(beta_start);
384 return out;
385 }
386 let denom = (n - 1) as f64;
387 match schedule {
388 BetaSchedule::Linear => {
389 for i in 0..n {
390 let t = i as f64 / denom;
391 out.push(beta_start + t * (beta_end - beta_start));
392 }
393 }
394 BetaSchedule::ScaledLinear => {
395 // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
396 let a = beta_start.sqrt();
397 let b = beta_end.sqrt();
398 for i in 0..n {
399 let t = i as f64 / denom;
400 let lin = a + t * (b - a);
401 out.push(lin * lin);
402 }
403 }
404 }
405 out
406}
407
408/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
409fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
410 let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
411 message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
412 })?;
413 ferrotorch_core::scalar::<T>(v)
414}
415
416/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
417/// configures `clip_sample=true`; SD-1.5 default is false.
418fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
419 let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
420 message: "clip_to_one: cannot represent -1.0".into(),
421 })?;
422 let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
423 message: "clip_to_one: cannot represent 1.0".into(),
424 })?;
425 ferrotorch_core::clamp(t, lo, hi)
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
434 // Mirror diffusers's `betas = linspace(sqrt(0.00085),
435 // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
436 let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
437 assert_eq!(betas.len(), 1000);
438 // i=0 → beta_start exactly.
439 assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
440 // i=999 → beta_end exactly.
441 assert!(
442 (betas[999] - 0.012).abs() < 1e-12,
443 "betas[999]={}",
444 betas[999]
445 );
446 // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
447 let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
448 let want = mid_root * mid_root;
449 // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
450 let approx_idx = 999 / 2;
451 assert!(
452 (betas[approx_idx] - want).abs() < 5e-3,
453 "betas[{approx_idx}]={} vs midpoint {want}",
454 betas[approx_idx]
455 );
456 }
457
458 #[test]
459 fn alphas_cumprod_is_monotone_decreasing() {
460 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
461 let mut prev = 1.0_f64;
462 for &a in &sched.alphas_cumprod {
463 assert!(a < prev, "alphas_cumprod not strictly decreasing");
464 assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
465 prev = a;
466 }
467 // Final value is approximately 0.0047 for SD-1.5 (reference).
468 let last = *sched.alphas_cumprod.last().unwrap();
469 assert!(
470 (0.001..0.01).contains(&last),
471 "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
472 );
473 }
474
475 #[test]
476 fn timesteps_leading_4_steps_sd15() {
477 // step_ratio = 1000 // 4 = 250; offset = 1.
478 // diffusers: timesteps = (arange(4) * 250) reversed + 1
479 // = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
480 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
481 let ts = sched.set_timesteps(4).unwrap();
482 assert_eq!(ts, [751, 501, 251, 1]);
483 }
484
485 #[test]
486 fn timesteps_leading_50_steps_sd15_head() {
487 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
488 let ts = sched.set_timesteps(50).unwrap();
489 assert_eq!(ts.len(), 50);
490 // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
491 assert_eq!(ts[0], 981);
492 assert_eq!(ts[49], 1);
493 }
494
495 #[test]
496 fn init_noise_sigma_is_one() {
497 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
498 assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
499 }
500
501 #[test]
502 fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
503 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
504 assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
505 }
506
507 #[test]
508 fn step_recovers_zero_for_identity_noise() {
509 // If model_output is the zero tensor at timestep 0:
510 // pred_x0 = sample / sqrt(alpha_0)
511 // pred_dir = 0
512 // prev_sample = sqrt(alpha_prev) * pred_x0
513 // We can't really assert byte-level here without a reference, but
514 // we can assert non-finite values do not appear and shape is
515 // preserved.
516 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
517 sched.set_timesteps(4).unwrap();
518 let sample = Tensor::<f32>::from_storage(
519 ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
520 vec![1, 1, 2, 2],
521 false,
522 )
523 .unwrap();
524 let noise = Tensor::<f32>::from_storage(
525 ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
526 vec![1, 1, 2, 2],
527 false,
528 )
529 .unwrap();
530 let out = sched.step(&noise, 1, &sample).unwrap();
531 assert_eq!(out.shape(), &[1, 1, 2, 2]);
532 for v in out.data().unwrap() {
533 assert!(v.is_finite(), "step produced non-finite value: {v}");
534 }
535 }
536}