use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BetaSchedule {
ScaledLinear,
Linear,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TimestepSpacing {
Leading,
Linspace,
}
#[derive(Debug, Clone)]
pub struct DDIMConfig {
pub num_train_timesteps: usize,
pub beta_start: f64,
pub beta_end: f64,
pub beta_schedule: BetaSchedule,
pub clip_sample: bool,
pub set_alpha_to_one: bool,
pub prediction_type: PredictionType,
pub timestep_spacing: TimestepSpacing,
pub steps_offset: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PredictionType {
Epsilon,
}
impl Default for DDIMConfig {
fn default() -> Self {
Self {
num_train_timesteps: 1000,
beta_start: 0.000_85,
beta_end: 0.012,
beta_schedule: BetaSchedule::ScaledLinear,
clip_sample: false,
set_alpha_to_one: false,
prediction_type: PredictionType::Epsilon,
timestep_spacing: TimestepSpacing::Leading,
steps_offset: 1,
}
}
}
impl DDIMConfig {
pub fn sd_v1_5() -> Self {
Self::default()
}
}
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
config: DDIMConfig,
alphas_cumprod: Vec<f64>,
final_alpha_cumprod: f64,
timesteps: Vec<usize>,
}
impl DDIMScheduler {
pub fn new(config: DDIMConfig) -> FerrotorchResult<Self> {
if config.num_train_timesteps == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "DDIMScheduler::new: num_train_timesteps must be > 0".into(),
});
}
if !config.beta_start.is_finite() || !config.beta_end.is_finite() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"DDIMScheduler::new: non-finite betas (beta_start={}, beta_end={})",
config.beta_start, config.beta_end
),
});
}
let n = config.num_train_timesteps;
let betas = compute_betas(config.beta_schedule, config.beta_start, config.beta_end, n);
let mut alphas_cumprod = Vec::with_capacity(n);
let mut acc = 1.0_f64;
for &b in &betas {
let a = 1.0 - b;
acc *= a;
alphas_cumprod.push(acc);
}
let final_alpha_cumprod = if config.set_alpha_to_one {
1.0
} else {
alphas_cumprod[0]
};
Ok(Self {
config,
alphas_cumprod,
final_alpha_cumprod,
timesteps: Vec::new(),
})
}
pub fn config(&self) -> &DDIMConfig {
&self.config
}
pub fn init_noise_sigma(&self) -> f64 {
1.0
}
pub fn set_timesteps(&mut self, num_inference_steps: usize) -> FerrotorchResult<&[usize]> {
if num_inference_steps == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "DDIMScheduler::set_timesteps: num_inference_steps must be > 0".into(),
});
}
if num_inference_steps > self.config.num_train_timesteps {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"DDIMScheduler::set_timesteps: num_inference_steps {num_inference_steps} \
must be <= num_train_timesteps {}",
self.config.num_train_timesteps
),
});
}
if self.config.prediction_type != PredictionType::Epsilon {
return Err(FerrotorchError::InvalidArgument {
message:
"DDIMScheduler::set_timesteps: only PredictionType::Epsilon is implemented"
.into(),
});
}
let n_train = self.config.num_train_timesteps;
self.timesteps = match self.config.timestep_spacing {
TimestepSpacing::Leading => {
let step_ratio = n_train / num_inference_steps;
let mut ts: Vec<usize> = (0..num_inference_steps)
.rev()
.map(|i| i * step_ratio + self.config.steps_offset)
.collect();
for t in &mut ts {
if *t >= n_train {
*t = n_train - 1;
}
}
ts
}
TimestepSpacing::Linspace => {
let step = (n_train as f64) / (num_inference_steps as f64);
let mut ts: Vec<usize> = (0..num_inference_steps)
.rev()
.map(|i| ((i as f64 * step).round() as usize).min(n_train - 1))
.collect();
for t in &mut ts {
if *t >= n_train {
*t = n_train - 1;
}
}
ts
}
};
Ok(&self.timesteps)
}
pub fn timesteps(&self) -> &[usize] {
&self.timesteps
}
pub fn scale_model_input<T: Float>(
&self,
sample: &Tensor<T>,
_timestep: usize,
) -> FerrotorchResult<Tensor<T>> {
Ok(sample.clone())
}
pub fn step<T: Float>(
&self,
model_output: &Tensor<T>,
timestep: usize,
sample: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if self.timesteps.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "DDIMScheduler::step: set_timesteps must be called before step".into(),
});
}
if timestep >= self.config.num_train_timesteps {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"DDIMScheduler::step: timestep {timestep} out of range \
[0, {})",
self.config.num_train_timesteps
),
});
}
if model_output.shape() != sample.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"DDIMScheduler::step: model_output shape {:?} != sample shape {:?}",
model_output.shape(),
sample.shape()
),
});
}
let step_ratio = self.config.num_train_timesteps / self.timesteps.len();
let prev_timestep_i = timestep as isize - step_ratio as isize;
let alpha_t = self.alphas_cumprod[timestep];
let alpha_t_prev = if prev_timestep_i >= 0 {
self.alphas_cumprod[prev_timestep_i as usize]
} else {
self.final_alpha_cumprod
};
let beta_t = 1.0 - alpha_t;
let sqrt_beta_t = scalar_f64::<T>(beta_t.sqrt())?;
let inv_sqrt_alpha_t = scalar_f64::<T>(1.0 / alpha_t.sqrt())?;
let sqrt_one_minus_alpha_t_prev = scalar_f64::<T>((1.0 - alpha_t_prev).sqrt())?;
let sqrt_alpha_t_prev = scalar_f64::<T>(alpha_t_prev.sqrt())?;
let scaled_noise = mul(model_output, &sqrt_beta_t)?;
let diff = sub(sample, &scaled_noise)?;
let pred_x0 = mul(&diff, &inv_sqrt_alpha_t)?;
let pred_x0 = if self.config.clip_sample {
clip_to_one::<T>(&pred_x0)?
} else {
pred_x0
};
let pred_dir = mul(model_output, &sqrt_one_minus_alpha_t_prev)?;
let x0_scaled = mul(&pred_x0, &sqrt_alpha_t_prev)?;
add(&x0_scaled, &pred_dir)
}
}
fn compute_betas(schedule: BetaSchedule, beta_start: f64, beta_end: f64, n: usize) -> Vec<f64> {
let mut out = Vec::with_capacity(n);
if n == 0 {
return out;
}
if n == 1 {
out.push(beta_start);
return out;
}
let denom = (n - 1) as f64;
match schedule {
BetaSchedule::Linear => {
for i in 0..n {
let t = i as f64 / denom;
out.push(beta_start + t * (beta_end - beta_start));
}
}
BetaSchedule::ScaledLinear => {
let a = beta_start.sqrt();
let b = beta_end.sqrt();
for i in 0..n {
let t = i as f64 / denom;
let lin = a + t * (b - a);
out.push(lin * lin);
}
}
}
out
}
fn scalar_f64<T: Float>(value: f64) -> FerrotorchResult<Tensor<T>> {
let v = T::from(value).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("DDIMScheduler: cannot represent f64 {value} as the requested Float"),
})?;
ferrotorch_core::scalar::<T>(v)
}
fn clip_to_one<T: Float>(t: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let lo = T::from(-1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
message: "clip_to_one: cannot represent -1.0".into(),
})?;
let hi = T::from(1.0).ok_or_else(|| FerrotorchError::InvalidArgument {
message: "clip_to_one: cannot represent 1.0".into(),
})?;
ferrotorch_core::clamp(t, lo, hi)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn beta_schedule_scaled_linear_matches_diffusers_sd15() {
let betas = compute_betas(BetaSchedule::ScaledLinear, 0.000_85, 0.012, 1000);
assert_eq!(betas.len(), 1000);
assert!((betas[0] - 0.000_85).abs() < 1e-12, "betas[0]={}", betas[0]);
assert!(
(betas[999] - 0.012).abs() < 1e-12,
"betas[999]={}",
betas[999]
);
let mid_root = (0.000_85_f64.sqrt() + 0.012_f64.sqrt()) * 0.5;
let want = mid_root * mid_root;
let approx_idx = 999 / 2;
assert!(
(betas[approx_idx] - want).abs() < 5e-3,
"betas[{approx_idx}]={} vs midpoint {want}",
betas[approx_idx]
);
}
#[test]
fn alphas_cumprod_is_monotone_decreasing() {
let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
let mut prev = 1.0_f64;
for &a in &sched.alphas_cumprod {
assert!(a < prev, "alphas_cumprod not strictly decreasing");
assert!(a > 0.0, "alphas_cumprod must be > 0, got {a}");
prev = a;
}
let last = *sched.alphas_cumprod.last().unwrap();
assert!(
(0.001..0.01).contains(&last),
"SD-1.5 alphas_cumprod[-1] should be ~0.0047, got {last}"
);
}
#[test]
fn timesteps_leading_4_steps_sd15() {
let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
let ts = sched.set_timesteps(4).unwrap();
assert_eq!(ts, [751, 501, 251, 1]);
}
#[test]
fn timesteps_leading_50_steps_sd15_head() {
let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
let ts = sched.set_timesteps(50).unwrap();
assert_eq!(ts.len(), 50);
assert_eq!(ts[0], 981);
assert_eq!(ts[49], 1);
}
#[test]
fn init_noise_sigma_is_one() {
let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
assert!((sched.init_noise_sigma() - 1.0).abs() < 1e-12);
}
#[test]
fn final_alpha_cumprod_is_alphas_cumprod_zero_when_set_alpha_to_one_false() {
let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
assert!((sched.final_alpha_cumprod - sched.alphas_cumprod[0]).abs() < 1e-12);
}
#[test]
fn step_recovers_zero_for_identity_noise() {
let mut sched = DDIMScheduler::new(DDIMConfig::sd_v1_5()).unwrap();
sched.set_timesteps(4).unwrap();
let sample = Tensor::<f32>::from_storage(
ferrotorch_core::TensorStorage::cpu(vec![0.5_f32; 4]),
vec![1, 1, 2, 2],
false,
)
.unwrap();
let noise = Tensor::<f32>::from_storage(
ferrotorch_core::TensorStorage::cpu(vec![0.0_f32; 4]),
vec![1, 1, 2, 2],
false,
)
.unwrap();
let out = sched.step(&noise, 1, &sample).unwrap();
assert_eq!(out.shape(), &[1, 1, 2, 2]);
for v in out.data().unwrap() {
assert!(v.is_finite(), "step produced non-finite value: {v}");
}
}
}