ferrotorch_diffusion/scheduler.rs
1//! Deterministic DDIM scheduler matching `diffusers.schedulers.DDIMScheduler`
2//! for the Stable-Diffusion-1.5 sampling defaults.
3//!
4//! Phase F of real-artifact-driven development (#1163). The scheduler is
5//! the missing fourth component of the SD generation pipeline:
6//!
7//! ```text
8//! CLIP text encoder + UNet noise predictor + DDIM scheduler + VAE decoder
9//! ^^^^^^^^^^^^^^^
10//! this module
11//! ```
12//!
13//! Matches `diffusers.schedulers.DDIMScheduler` for the SD-1.5 defaults
14//! (`scaled_linear` beta schedule, `epsilon` prediction, `leading`
15//! timestep spacing, `clip_sample=false`, `set_alpha_to_one=false`,
16//! `init_noise_sigma=1.0` — i.e. η=0 deterministic sampling). Values
17//! mirrored byte-for-byte from the upstream defaults in
18//! `diffusers/schedulers/scheduling_ddim.py` as of `diffusers==0.38.0`.
19
20use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22
23/// Beta-schedule recipe (subset matching SD-1.5).
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum BetaSchedule {
26 /// `betas = linspace(sqrt(beta_start), sqrt(beta_end), N)^2`. SD default.
27 ScaledLinear,
28 /// `betas = linspace(beta_start, beta_end, N)`.
29 Linear,
30}
31
32/// Discrete timestep spacing (subset matching SD-1.5).
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TimestepSpacing {
35 /// `step_ratio = num_train_timesteps // num_inference_steps;
36 /// timesteps = arange(0, num_inference_steps) * step_ratio` then reversed.
37 /// SD default.
38 Leading,
39 /// `step_ratio = num_train_timesteps / num_inference_steps;
40 /// timesteps = arange(num_inference_steps) * step_ratio` (rounded) then reversed.
41 Linspace,
42}
43
44/// `DDIMScheduler` configuration (the subset that affects forward math).
45#[derive(Debug, Clone)]
46pub struct DDIMConfig {
47 /// `num_train_timesteps` — SD-1.5: 1000.
48 pub num_train_timesteps: usize,
49 /// `beta_start` — SD-1.5: 0.00085.
50 pub beta_start: f64,
51 /// `beta_end` — SD-1.5: 0.012.
52 pub beta_end: f64,
53 /// `beta_schedule` — SD-1.5: ScaledLinear.
54 pub beta_schedule: BetaSchedule,
55 /// `clip_sample` — SD-1.5: false (no clipping of predicted x0).
56 pub clip_sample: bool,
57 /// `set_alpha_to_one` — SD-1.5: false. When false, the
58 /// final-step `alpha_prev` is `alphas_cumprod[0]` (matches diffusers).
59 pub set_alpha_to_one: bool,
60 /// `prediction_type` — SD-1.5: "epsilon" (the UNet predicts the noise).
61 /// Only `"epsilon"` is implemented; anything else returns an error
62 /// at `set_timesteps` time.
63 pub prediction_type: PredictionType,
64 /// `timestep_spacing` — SD-1.5: Leading.
65 pub timestep_spacing: TimestepSpacing,
66 /// `steps_offset` — SD-1.5: 1 (diffusers adds this offset on the
67 /// `leading` path so the first inference step is `step_ratio` rather
68 /// than 0; consumed inside [`DDIMScheduler::set_timesteps`]).
69 pub steps_offset: usize,
70}
71
72/// Prediction parameterisation (subset matching SD-1.5).
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum PredictionType {
75 /// UNet predicts the noise ε. SD default.
76 Epsilon,
77}
78
79impl Default for DDIMConfig {
80 fn default() -> Self {
81 Self {
82 num_train_timesteps: 1000,
83 beta_start: 0.000_85,
84 beta_end: 0.012,
85 beta_schedule: BetaSchedule::ScaledLinear,
86 clip_sample: false,
87 set_alpha_to_one: false,
88 prediction_type: PredictionType::Epsilon,
89 timestep_spacing: TimestepSpacing::Leading,
90 steps_offset: 1,
91 }
92 }
93}
94
95impl DDIMConfig {
96 /// SD-1.5 defaults (alias for [`Default::default`]).
97 pub fn sd_v1_5() -> Self {
98 Self::default()
99 }
100}
101
102/// Deterministic DDIM scheduler (η=0, no noise injection).
103///
104/// Pre-computes `betas`, `alphas`, and `alphas_cumprod` over the full
105/// training-time grid (`num_train_timesteps` entries). Inference picks
106/// a subset of timesteps via [`DDIMScheduler::set_timesteps`] and walks
107/// them in reverse with [`DDIMScheduler::step`].
108#[derive(Debug, Clone)]
109pub struct DDIMScheduler {
110 config: DDIMConfig,
111 /// `alphas_cumprod` of length `num_train_timesteps`.
112 alphas_cumprod: Vec<f64>,
113 /// `final_alpha_cumprod` — used when stepping into prev_timestep < 0
114 /// (the very last denoising step). When `set_alpha_to_one=false` this
115 /// is `alphas_cumprod[0]`; when true it's 1.0.
116 final_alpha_cumprod: f64,
117 /// Timesteps the user requested via [`Self::set_timesteps`], in the
118 /// order they will be consumed (descending). Empty until
119 /// `set_timesteps` is called.
120 timesteps: Vec<usize>,
121}
122
123impl DDIMScheduler {
124 /// Build a scheduler with the given config; runs the one-shot
125 /// `betas → alphas → alphas_cumprod` precomputation.
126 ///
127 /// # Errors
128 ///
129 /// Returns [`FerrotorchError::InvalidArgument`] for malformed
130 /// configurations (zero training steps, non-positive beta).
131 pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
132 if config.num_train_timesteps == 0 {
133 return Err(FerrotorchError::InvalidArgument {
134 message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
135 });
136 }
137 if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
138 return Err(FerrotorchError::InvalidArgument {
139 message: format!(
140 "DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
141 config.beta_start, config.beta_end
142 ),
143 });
144 }
145 let n = config.num_train_timesteps;
146 let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
147 // alphas[i] = 1 - betas[i].
148 // alphas_cumprod[i] = prod_{j=0..=i} alphas[j].
149 let mut alphas_cumprod = Vec::with_capacity(n);
150 let mut acc = 1.0_f64;
151 for &b in &betas {
152 let a = 1.0 - b;
153 acc *= a;
154 alphas_cumprod.push(acc);
155 }
156 let final_alpha_cumprod = if config.set_alpha_to_one {
157 1.0
158 } else {
159 alphas_cumprod[0]
160 };
161 Ok(Self {
162 config,
163 alphas_cumprod,
164 final_alpha_cumprod,
165 timesteps: Vec::new(),
166 })
167 }
168
169 /// Read-only access to the frozen configuration.
170 pub fn config(&self) -> &DDIMConfig {
171 &self.config
172 }
173
174 /// `init_noise_sigma` — the multiplier applied to the initial Gaussian
175 /// noise tensor before the first denoising step. DDIM with the SD-1.5
176 /// defaults uses 1.0 (no scaling).
177 pub fn init_noise_sigma(&self) -> f64 {
178 1.0
179 }
180
181 /// Set the inference-time discrete timesteps and return them.
182 ///
183 /// Matches `diffusers.schedulers.DDIMScheduler.set_timesteps` exactly
184 /// for the SD-1.5 defaults (Leading spacing + `steps_offset=1`).
185 ///
186 /// # Errors
187 ///
188 /// Returns [`FerrotorchError::InvalidArgument`] if `num_inference_steps`
189 /// is zero or exceeds `num_train_timesteps`, or if the configured
190 /// `prediction_type` is anything other than [`PredictionType::Epsilon`].
191 pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
192 if num_inference_steps == 0 {
193 return Err(FerrotorchError::InvalidArgument {
194 message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
195 });
196 }
197 if num_inference_steps > self.config.num_train_timesteps {
198 return Err(FerrotorchError::InvalidArgument {
199 message: format!(
200 "DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
201 must be <= num_train_timesteps {}",
202 self.config.num_train_timesteps
203 ),
204 });
205 }
206 if self.config.prediction_type != PredictionType::Epsilon {
207 return Err(FerrotorchError::InvalidArgument {
208 message:
209 "DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
210 .into(),
211 });
212 }
213 let n_train = self.config.num_train_timesteps;
214 self.timesteps = match self.config.timestep_spacing {
215 TimestepSpacing::Leading => {
216 // step_ratio = num_train_timesteps // num_inference_steps
217 let step_ratio = n_train / num_inference_steps;
218 // ts = (arange(0, num_inference_steps) * step_ratio) reversed
219 // + steps_offset.
220 let mut ts: Vec<usize> = (0..num_inference_steps)
221 .rev()
222 .map(|i| i * step_ratio + self.config.steps_offset)
223 .collect();
224 // Clamp into [0, n_train - 1] (defensive — the standard
225 // SD-1.5 4-step recipe never hits this).
226 for t in &mut ts {
227 if *t >= n_train {
228 *t = n_train - 1;
229 }
230 }
231 ts
232 }
233 TimestepSpacing::Linspace => {
234 let step = (n_train as f64) / (num_inference_steps as f64);
235 let mut ts: Vec<usize> = (0..num_inference_steps)
236 .rev()
237 .map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
238 .collect();
239 for t in &mut ts {
240 if *t >= n_train {
241 *t = n_train - 1;
242 }
243 }
244 ts
245 }
246 };
247 Ok(&self.timesteps)
248 }
249
250 /// Read-only access to the inference timesteps (empty until
251 /// `set_timesteps` has been called).
252 pub fn timesteps(&self) -> &[usize] {
253 &self.timesteps
254 }
255
256 /// Scale the model input. For DDIM with the SD-1.5 defaults this is
257 /// the identity (DDIM does not rescale the model input the way
258 /// `LMSDiscreteScheduler` does). Kept for pipeline-parity with the
259 /// diffusers API surface.
260 ///
261 /// # Errors
262 ///
263 /// Currently infallible; returned for forward-compat with non-DDIM
264 /// schedulers we may add later.
265 pub fn scale_model_input<T: Float>(
266 &self,
267 sample: &Tensor<T>,
268 _timestep: usize,
269 ) -> FerrotorchResult<Tensor<T>> {
270 Ok(sample.clone())
271 }
272
273 /// One DDIM step: predict the previous-sample given the model's
274 /// noise prediction `model_output` at `timestep`, applied to `sample`.
275 ///
276 /// Math (η = 0, deterministic):
277 ///
278 /// ```text
279 /// prev_timestep = timestep - step_size
280 /// alpha_t = alphas_cumprod[timestep]
281 /// alpha_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0
282 /// else final_alpha_cumprod
283 /// beta_t = 1 - alpha_t
284 ///
285 /// pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
286 /// pred_dir = sqrt(1 - alpha_t_prev) * model_output
287 /// prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
288 /// ```
289 ///
290 /// # Errors
291 ///
292 /// Returns [`FerrotorchError::InvalidArgument`] if `timestep` is
293 /// outside `[0, num_train_timesteps)` or `set_timesteps` has not been
294 /// called, and propagates any underlying tensor-arithmetic error.
295 pub fn step<T: Float>(
296 &self,
297 model_output: &Tensor<T>,
298 timestep: usize,
299 sample: &Tensor<T>,
300 ) -> FerrotorchResult<Tensor<T>> {
301 if self.timesteps.is_empty() {
302 return Err(FerrotorchError::InvalidArgument {
303 message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
304 });
305 }
306 if timestep >= self.config.num_train_timesteps {
307 return Err(FerrotorchError::InvalidArgument {
308 message: format!(
309 "DDIMScheduler::step: timestep {timestep} out of range \
310 [0, {})",
311 self.config.num_train_timesteps
312 ),
313 });
314 }
315 if model_output.shape() != sample.shape() {
316 return Err(FerrotorchError::ShapeMismatch {
317 message: format!(
318 "DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
319 model_output.shape(),
320 sample.shape()
321 ),
322 });
323 }
324 // diffusers: prev_timestep = timestep - num_train_timesteps //
325 // num_inference_steps
326 let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
327 let prev_timestep_i = timestep as isize - step_ratio as isize;
328 let alpha_t = self.alphas_cumprod[timestep];
329 let alpha_t_prev = if prev_timestep_i >= 0 {
330 self.alphas_cumprod[prev_timestep_i as usize]
331 } else {
332 self.final_alpha_cumprod
333 };
334 let beta_t = 1.0 - alpha_t;
335
336 // Compute three scalar tensors (broadcastable against the [B,C,H,W] sample).
337 let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
338 let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
339 let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
340 let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
341
342 // pred_x0 = (sample - sqrt(beta_t) * model_output) / sqrt(alpha_t)
343 // = (sample - sqrt(beta_t) * model_output) * (1 / sqrt(alpha_t))
344 let scaled_noise = mul(model_output, &sqrt_beta_t)?;
345 let diff = sub(sample, &scaled_noise)?;
346 let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
347
348 // Optional clip_sample. SD-1.5 default is false; gate kept so this
349 // module is reusable with non-SD configs.
350 let pred_x0 = if self.config.clip_sample {
351 clip_to_one::<T>(&pred_x0)?
352 } else {
353 pred_x0
354 };
355
356 // pred_dir = sqrt(1 - alpha_t_prev) * model_output
357 let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
358 // prev_sample = sqrt(alpha_t_prev) * pred_x0 + pred_dir
359 let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
360 add(&x0_scaled, &pred_dir)
361 }
362}
363
364/// Compute `betas` according to the chosen schedule. Mirrors diffusers's
365/// `betas_for_alpha_bar` only for the two schedules SD-1.5 ever uses.
366fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
367 let mut out = Vec::with_capacity(n);
368 if n == 0 {
369 return out;
370 }
371 if n == 1 {
372 out.push(beta_start);
373 return out;
374 }
375 let denom = (n - 1) as f64;
376 match schedule {
377 BetaSchedule::Linear => {
378 for i in 0..n {
379 let t = i as f64 / denom;
380 out.push(beta_start + t * (beta_end - beta_start));
381 }
382 }
383 BetaSchedule::ScaledLinear => {
384 // linspace(sqrt(beta_start), sqrt(beta_end), N) ^ 2
385 let a = beta_start.sqrt();
386 let b = beta_end.sqrt();
387 for i in 0..n {
388 let t = i as f64 / denom;
389 let lin = a + t * (b - a);
390 out.push(lin * lin);
391 }
392 }
393 }
394 out
395}
396
397/// Build a 1-element scalar tensor carrying `value` cast into the target Float.
398fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
399 let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
400 message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
401 })?;
402 ferrotorch_core::scalar::<T>(v)
403}
404
405/// Clamp every entry of `t` into `[-1, 1]`. Used only when the user
406/// configures `clip_sample=true`; SD-1.5 default is false.
407fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
408 let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
409 message: "clip_to_one: cannot represent -1.0".into(),
410 })?;
411 let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
412 message: "clip_to_one: cannot represent 1.0".into(),
413 })?;
414 ferrotorch_core::clamp(t, lo, hi)
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
423 // Mirror diffusers's `betas = linspace(sqrt(0.00085),
424 // sqrt(0.012), 1000) ** 2`. Spot-check a handful of indices.
425 let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
426 assert_eq!(betas.len(), 1000);
427 // i=0 → beta_start exactly.
428 assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
429 // i=999 → beta_end exactly.
430 assert!(
431 (betas[999] - 0.012).abs() < 1e-12,
432 "betas[999]={}",
433 betas[999]
434 );
435 // i=500 → midpoint: ((sqrt(0.00085)+sqrt(0.012))/2)^2
436 let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
437 let want = mid_root * mid_root;
438 // i=500 with linspace of 1000 points is offset 500/999 ≠ 0.5; allow a small slack.
439 let approx_idx = 999 / 2;
440 assert!(
441 (betas[approx_idx] - want).abs() < 5e-3,
442 "betas[{approx_idx}]={} vs midpoint {want}",
443 betas[approx_idx]
444 );
445 }
446
447 #[test]
448 fn alphas_cumprod_is_monotone_decreasing() {
449 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
450 let mut prev = 1.0_f64;
451 for &a in &sched.alphas_cumprod {
452 assert!(a < prev, "alphas_cumprod not strictly decreasing");
453 assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
454 prev = a;
455 }
456 // Final value is approximately 0.0047 for SD-1.5 (reference).
457 let last = *sched.alphas_cumprod.last().unwrap();
458 assert!(
459 (0.001..0.01).contains(&last),
460 "SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
461 );
462 }
463
464 #[test]
465 fn timesteps_leading_4_steps_sd15() {
466 // step_ratio = 1000 // 4 = 250; offset = 1.
467 // diffusers: timesteps = (arange(4) * 250) reversed + 1
468 // = [750, 500, 250, 0] + 1 = [751, 501, 251, 1].
469 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
470 let ts = sched.set_timesteps(4).unwrap();
471 assert_eq!(ts, [751, 501, 251, 1]);
472 }
473
474 #[test]
475 fn timesteps_leading_50_steps_sd15_head() {
476 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
477 let ts = sched.set_timesteps(50).unwrap();
478 assert_eq!(ts.len(), 50);
479 // step_ratio = 1000 // 50 = 20. first = 49*20+1 = 981; last = 1.
480 assert_eq!(ts[0], 981);
481 assert_eq!(ts[49], 1);
482 }
483
484 #[test]
485 fn init_noise_sigma_is_one() {
486 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
487 assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
488 }
489
490 #[test]
491 fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
492 let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
493 assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
494 }
495
496 #[test]
497 fn step_recovers_zero_for_identity_noise() {
498 // If model_output is the zero tensor at timestep 0:
499 // pred_x0 = sample / sqrt(alpha_0)
500 // pred_dir = 0
501 // prev_sample = sqrt(alpha_prev) * pred_x0
502 // We can't really assert byte-level here without a reference, but
503 // we can assert non-finite values do not appear and shape is
504 // preserved.
505 let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
506 sched.set_timesteps(4).unwrap();
507 let sample = Tensor::<f32>::from_storage(
508 ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
509 vec![1, 1, 2, 2],
510 false,
511 )
512 .unwrap();
513 let noise = Tensor::<f32>::from_storage(
514 ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
515 vec![1, 1, 2, 2],
516 false,
517 )
518 .unwrap();
519 let out = sched.step(&noise, 1, &sample).unwrap();
520 assert_eq!(out.shape(), &[1, 1, 2, 2]);
521 for v in out.data().unwrap() {
522 assert!(v.is_finite(), "step produced non-finite value: {v}");
523 }
524 }
525}