ferrotorch_diffusion/scheduler.rs
1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//! ^^^^^^^^^^^^^^^
10//! this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19
20use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22
23/// Beta-schedule recipe (subset matching SD-1.5).
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum BetaSchedule {
26 /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
27 ScaledLinear,
28 /// `betas = linspace(beta_start, beta_end, N)`.
29 Linear,
30}
31
32/// Discrete timestep spacing (subset matching SD-1.5).
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TimestepSpacing {
35 /// `step_ratio = num_train_timesteps // num_inference_steps;
36 /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
37 /// SD default.
38 Leading,
39 /// `step_ratio = num_train_timesteps / num_inference_steps;
40 /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
41 Linspace,
42}
43
44/// `DDIMScheduler` configuration (the subset that affects forward math).
45#[derive(Debug, Clone)]
46pub struct DDIMConfig {
47 /// `num_train_timesteps` — SD-1.5: 1000.
48 pub num_train_timesteps: usize,
49 /// `beta_start` — SD-1.5: 0.00085.
50 pub beta_start: f64,
51 /// `beta_end` — SD-1.5: 0.012.
52 pub beta_end: f64,
53 /// `beta_schedule` — SD-1.5: ScaledLinear.
54 pub beta_schedule: BetaSchedule,
55 /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
56 pub clip_sample: bool,
57 /// `set_alpha_to_one` — SD-1.5: false. When false, the
58 /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
59 pub set_alpha_to_one: bool,
60 /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
61 /// Only `"epsilon"` is implemented; anything else returns an error
62 /// at `set_timesteps` time.
63 pub prediction_type: PredictionType,
64 /// `timestep_spacing` — SD-1.5: Leading.
65 pub timestep_spacing: TimestepSpacing,
66 /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
67 /// `leading` path so the first inference step is `step_ratio` rather
68 /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
69 pub steps_offset: usize,
70}
71
72/// Prediction parameterisation (subset matching SD-1.5).
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum PredictionType {
75 /// UNet predicts the noise ε. SD default.
76 Epsilon,
77}
78
79impl Default for DDIMConfig {
80 fn default() -> Self {
81 Self {
82 num_train_timesteps: 1000,
83 beta_start: 0.000_85,
84 beta_end: 0.012,
85 beta_schedule: BetaSchedule::ScaledLinear,
86 clip_sample: false,
87 set_alpha_to_one: false,
88 prediction_type: PredictionType::Epsilon,
89 timestep_spacing: TimestepSpacing::Leading,
90 steps_offset: 1,
91 }
92 }
93}
94
95impl DDIMConfig {
96 /// SD-1.5 defaults (alias for [`Default::default`]).
97 pub fn sd_v1_5() -> Self {
98 Self::default()
99 }
100}
101
102/// Deterministic DDIM scheduler (η=0, no noise injection).
103///
104/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
105/// training-time grid (`num_train_timesteps` entries). Inference picks
106/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
107/// them in reverse with [`DDIMScheduler::step`].
108#[derive(Debug, Clone)]
109pub struct DDIMScheduler {
110 config: DDIMConfig,
111 /// `alphas_cumprod` of length `num_train_timesteps`.
112 alphas_cumprod: Vec<f64>,
113 /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
114 /// (the very last denoising step). When `set_alpha_to_one=false` this
115 /// is `alphas_cumprod[0]`; when true it's 1.0.
116 final_alpha_cumprod: f64,
117 /// Timesteps the user requested via [`Self::set_timesteps`], in the
118 /// order they will be consumed (descending). Empty until
119 /// `set_timesteps` is called.
120 timesteps: Vec<usize>,
121}
122
123impl DDIMScheduler {
124 /// Build a scheduler with the given config; runs the one-shot
125 /// `betas → alphas → alphas_cumprod` precomputation.
126 ///
127 /// # Errors
128 ///
129 /// Returns [`FerrotorchError::InvalidArgument`] for malformed
130 /// configurations (zero training steps, non-positive beta).
131 pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
132 if config.num_train_timesteps == 0 {
133 return Err(FerrotorchError::InvalidArgument {
134 message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
135 });
136 }
137 if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
138 return Err(FerrotorchError::InvalidArgument {
139 message: format!(
140 "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
141 config.beta_start, config.beta_end
142 ),
143 });
144 }
145 let n = config.num_train_timesteps;
146 let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
147 // alphas[i] = 1 - betas[i].
148 // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
149 let mut alphas_cumprod = Vec::with_capacity(n);
150 let mut acc = 1.0_f64;
151 for &b in &betas {
152 let a = 1.0 - b;
153 acc *= a;
154 alphas_cumprod.push(acc);
155 }
156 let final_alpha_cumprod = if config.set_alpha_to_one {
157 1.0
158 } else {
159 alphas_cumprod[0]
160 };
161 Ok(Self {
162 config,
163 alphas_cumprod,
164 final_alpha_cumprod,
165 timesteps: Vec::new(),
166 })
167 }
168
169 /// Read-only access to the frozen configuration.
170 pub fn config(&self) -> &DDIMConfig {
171 &self.config
172 }
173
174 /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
175 /// noise tensor before the first denoising step. DDIM with the SD-1.5
176 /// defaults uses 1.0 (no scaling).
177 pub fn init_noise_sigma(&self) -> f64 {
178 1.0
179 }
180
181 /// Set the inference-time discrete timesteps and return them.
182 ///
183 /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
184 /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
185 ///
186 /// # Errors
187 ///
188 /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
189 /// is zero or exceeds `num_train_timesteps`, or if the configured
190 /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
191 pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
192 if num_inference_steps == 0 {
193 return Err(FerrotorchError::InvalidArgument {
194 message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
195 });
196 }
197 if num_inference_steps > self.config.num_train_timesteps {
198 return Err(FerrotorchError::InvalidArgument {
199 message: format!(
200 "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
201 must be <= num_train_timesteps {}",
202 self.config.num_train_timesteps
203 ),
204 });
205 }
206 if self.config.prediction_type != PredictionType::Epsilon {
207 return Err(FerrotorchError::InvalidArgument {
208 message: "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
209 .into(),
210 });
211 }
212 let n_train = self.config.num_train_timesteps;
213 self.timesteps = match self.config.timestep_spacing {
214 TimestepSpacing::Leading => {
215 // step_ratio = num_train_timesteps // num_inference_steps
216 let step_ratio = n_train / num_inference_steps;
217 // ts = (arange(0, num_inference_steps) * step_ratio) reversed
218 // + steps_offset.
219 let mut ts: Vec<usize> = (0..num_inference_steps)
220 .rev()
221 .map(|i| i * step_ratio + self.config.steps_offset)
222 .collect();
223 // Clamp into [0, n_train - 1] (defensive — the standard
224 // SD-1.5 4-step recipe never hits this).
225 for t in &mut ts {
226 if *t >= n_train {
227 *t = n_train - 1;
228 }
229 }
230 ts
231 }
232 TimestepSpacing::Linspace => {
233 let step = (n_train as f64) / (num_inference_steps as f64);
234 let mut ts: Vec<usize> = (0..num_inference_steps)
235 .rev()
236 .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
237 .collect();
238 for t in &mut ts {
239 if *t >= n_train {
240 *t = n_train - 1;
241 }
242 }
243 ts
244 }
245 };
246 Ok(&self.timesteps)
247 }
248
249 /// Read-only access to the inference timesteps (empty until
250 /// `set_timesteps` has been called).
251 pub fn timesteps(&self) -> &[usize] {
252 &self.timesteps
253 }
254
255 /// Scale the model input. For DDIM with the SD-1.5 defaults this is
256 /// the identity (DDIM does not rescale the model input the way
257 /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
258 /// diffusers API surface.
259 ///
260 /// # Errors
261 ///
262 /// Currently infallible; returned for forward-compat with non-DDIM
263 /// schedulers we may add later.
264 pub fn scale_model_input<T: Float>(
265 &self,
266 sample: &Tensor<T>,
267 _timestep: usize,
268 ) -> FerrotorchResult<Tensor<T>> {
269 Ok(sample.clone())
270 }
271
272 /// One DDIM step: predict the previous-sample given the model's
273 /// noise prediction `model_output` at `timestep`, applied to `sample`.
274 ///
275 /// Math (η = 0, deterministic):
276 ///
277 /// ```text
278 /// prev_timestep = timestep - step_size
279 /// alpha_t = alphas_cumprod[timestep]
280 /// alpha_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0
281 /// else final_alpha_cumprod
282 /// beta_t = 1 - alpha_t
283 ///
284 /// pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
285 /// pred_dir = sqrt(1 - alpha_t_prev) * model_output
286 /// prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
287 /// ```
288 ///
289 /// # Errors
290 ///
291 /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
292 /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
293 /// called, and propagates any underlying tensor-arithmetic error.
294 pub fn step<T: Float>(
295 &self,
296 model_output: &Tensor<T>,
297 timestep: usize,
298 sample: &Tensor<T>,
299 ) -> FerrotorchResult<Tensor<T>> {
300 if self.timesteps.is_empty() {
301 return Err(FerrotorchError::InvalidArgument {
302 message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
303 });
304 }
305 if timestep >= self.config.num_train_timesteps {
306 return Err(FerrotorchError::InvalidArgument {
307 message: format!(
308 "DDIMScheduler::step: timestep {timestep} out of range \
309 [0, {})",
310 self.config.num_train_timesteps
311 ),
312 });
313 }
314 if model_output.shape() != sample.shape() {
315 return Err(FerrotorchError::ShapeMismatch {
316 message: format!(
317 "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
318 model_output.shape(),
319 sample.shape()
320 ),
321 });
322 }
323 // diffusers: prev_timestep = timestep - num_train_timesteps //
324 // num_inference_steps
325 let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
326 let prev_timestep_i = timestep as isize - step_ratio as isize;
327 let alpha_t = self.alphas_cumprod[timestep];
328 let alpha_t_prev = if prev_timestep_i >= 0 {
329 self.alphas_cumprod[prev_timestep_i as usize]
330 } else {
331 self.final_alpha_cumprod
332 };
333 let beta_t = 1.0 - alpha_t;
334
335 // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
336 let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
337 let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
338 let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
339 let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
340
341 // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
342 // = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
343 let scaled_noise = mul(model_output, &sqrt_beta_t)?;
344 let diff = sub(sample, &scaled_noise)?;
345 let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
346
347 // Optional clip_sample. SD-1.5 default is false; gate kept so this
348 // module is reusable with non-SD configs.
349 let pred_x0 = if self.config.clip_sample {
350 clip_to_one::<T>(&pred_x0)?
351 } else {
352 pred_x0
353 };
354
355 // pred_dir = sqrt(1 - alpha_t_prev) * model_output
356 let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
357 // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
358 let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
359 add(&x0_scaled, &pred_dir)
360 }
361}
362
363/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
364/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
365fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
366 let mut out = Vec::with_capacity(n);
367 if n == 0 {
368 return out;
369 }
370 if n == 1 {
371 out.push(beta_start);
372 return out;
373 }
374 let denom = (n - 1) as f64;
375 match schedule {
376 BetaSchedule::Linear => {
377 for i in 0..n {
378 let t = i as f64 / denom;
379 out.push(beta_start + t * (beta_end - beta_start));
380 }
381 }
382 BetaSchedule::ScaledLinear => {
383 // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
384 let a = beta_start.sqrt();
385 let b = beta_end.sqrt();
386 for i in 0..n {
387 let t = i as f64 / denom;
388 let lin = a + t * (b - a);
389 out.push(lin * lin);
390 }
391 }
392 }
393 out
394}
395
396/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
397fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
398 let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
399 message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
400 })?;
401 ferrotorch_core::scalar::<T>(v)
402}
403
404/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
405/// configures `clip_sample=true`; SD-1.5 default is false.
406fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
407 let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
408 message: "clip_to_one: cannot represent -1.0".into(),
409 })?;
410 let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
411 message: "clip_to_one: cannot represent 1.0".into(),
412 })?;
413 ferrotorch_core::clamp(t, lo, hi)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
422 // Mirror diffusers's `betas = linspace(sqrt(0.00085),
423 // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
424 let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
425 assert_eq!(betas.len(), 1000);
426 // i=0 → beta_start exactly.
427 assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
428 // i=999 → beta_end exactly.
429 assert!(
430 (betas[999] - 0.012).abs() < 1e-12,
431 "betas[999]={}",
432 betas[999]
433 );
434 // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
435 let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
436 let want = mid_root * mid_root;
437 // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
438 let approx_idx = 999 / 2;
439 assert!(
440 (betas[approx_idx] - want).abs() < 5e-3,
441 "betas[{approx_idx}]={} vs midpoint {want}",
442 betas[approx_idx]
443 );
444 }
445
446 #[test]
447 fn alphas_cumprod_is_monotone_decreasing() {
448 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
449 let mut prev = 1.0_f64;
450 for &a in &sched.alphas_cumprod {
451 assert!(a < prev, "alphas_cumprod not strictly decreasing");
452 assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
453 prev = a;
454 }
455 // Final value is approximately 0.0047 for SD-1.5 (reference).
456 let last = *sched.alphas_cumprod.last().unwrap();
457 assert!(
458 (0.001..0.01).contains(&last),
459 "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
460 );
461 }
462
463 #[test]
464 fn timesteps_leading_4_steps_sd15() {
465 // step_ratio = 1000 // 4 = 250; offset = 1.
466 // diffusers: timesteps = (arange(4) * 250) reversed + 1
467 // = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
468 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
469 let ts = sched.set_timesteps(4).unwrap();
470 assert_eq!(ts, [751, 501, 251, 1]);
471 }
472
473 #[test]
474 fn timesteps_leading_50_steps_sd15_head() {
475 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
476 let ts = sched.set_timesteps(50).unwrap();
477 assert_eq!(ts.len(), 50);
478 // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
479 assert_eq!(ts[0], 981);
480 assert_eq!(ts[49], 1);
481 }
482
483 #[test]
484 fn init_noise_sigma_is_one() {
485 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
486 assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
487 }
488
489 #[test]
490 fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
491 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
492 assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
493 }
494
495 #[test]
496 fn step_recovers_zero_for_identity_noise() {
497 // If model_output is the zero tensor at timestep 0:
498 // pred_x0 = sample / sqrt(alpha_0)
499 // pred_dir = 0
500 // prev_sample = sqrt(alpha_prev) * pred_x0
501 // We can't really assert byte-level here without a reference, but
502 // we can assert non-finite values do not appear and shape is
503 // preserved.
504 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
505 sched.set_timesteps(4).unwrap();
506 let sample = Tensor::<f32>::from_storage(
507 ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
508 vec![1, 1, 2, 2],
509 false,
510 )
511 .unwrap();
512 let noise = Tensor::<f32>::from_storage(
513 ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
514 vec![1, 1, 2, 2],
515 false,
516 )
517 .unwrap();
518 let out = sched.step(&noise, 1, &sample).unwrap();
519 assert_eq!(out.shape(), &[1, 1, 2, 2]);
520 for v in out.data().unwrap() {
521 assert!(v.is_finite(), "step produced non-finite value: {v}");
522 }
523 }
524}