use anyhow::Context;
use ndarray::{s, Array1, Array4, ArrayView4};
use rand::Rng;
use super::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{SchedulerOptimizedDefaults, SchedulerPredictionType};
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum DPMSolverAlgorithmType {
#[default]
DPMSolverPlusPlus,
DPMSolver
}
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum DPMSolverType {
#[default]
Midpoint,
Heun
}
#[derive(Debug, Clone)]
pub struct DPMSolverMultistepSchedulerConfig {
pub solver_order: usize,
pub thresholding: bool,
pub dynamic_thresholding_ratio: f32,
pub sample_max_value: f32,
pub algorithm_type: DPMSolverAlgorithmType,
pub solver_type: DPMSolverType,
pub lower_order_final: bool
}
impl Default for DPMSolverMultistepSchedulerConfig {
fn default() -> Self {
Self {
solver_order: 2,
thresholding: false,
dynamic_thresholding_ratio: 0.995,
sample_max_value: 1.0,
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
solver_type: DPMSolverType::Midpoint,
lower_order_final: true
}
}
}
#[derive(Clone)]
pub struct DPMSolverMultistepScheduler {
alphas_cumprod: Array1<f32>,
alpha_t: Array1<f32>,
sigma_t: Array1<f32>,
lambda_t: Array1<f32>,
init_noise_sigma: f32,
timesteps: Array1<usize>,
num_train_timesteps: usize,
num_inference_steps: Option<usize>,
config: DPMSolverMultistepSchedulerConfig,
prediction_type: SchedulerPredictionType,
model_outputs: Vec<Option<Array4<f32>>>,
lower_order_nums: usize
}
impl Default for DPMSolverMultistepScheduler {
fn default() -> Self {
Self::new(1000, 0.0001, 0.02, &BetaSchedule::Linear, &SchedulerPredictionType::Epsilon, None).unwrap()
}
}
impl DPMSolverMultistepScheduler {
pub fn new(
num_train_timesteps: usize,
beta_start: f32,
beta_end: f32,
beta_schedule: &BetaSchedule,
prediction_type: &SchedulerPredictionType,
config: Option<DPMSolverMultistepSchedulerConfig>
) -> anyhow::Result<Self> {
if num_train_timesteps == 0 {
anyhow::bail!("num_train_timesteps ({num_train_timesteps}) must be >0");
}
if !beta_start.is_normal() || !beta_end.is_normal() {
anyhow::bail!("beta_start ({beta_start}) and beta_end ({beta_end}) must be normal (not zero, infinite, NaN, or subnormal)");
}
if beta_start >= beta_end {
anyhow::bail!("beta_start must be < beta_end");
}
let config = config.unwrap_or_default();
let betas = match beta_schedule {
BetaSchedule::TrainedBetas(betas) => betas.clone(),
BetaSchedule::Linear => Array1::linspace(beta_start, beta_end, num_train_timesteps),
BetaSchedule::ScaledLinear => {
let mut betas = Array1::linspace(beta_start.sqrt(), beta_end.sqrt(), num_train_timesteps);
betas.par_map_inplace(|f| *f = f.powi(2));
betas
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(num_train_timesteps, 0.999),
_ => anyhow::bail!("{beta_schedule:?} not implemented for DDIMScheduler")
};
let alphas = 1.0 - betas;
let alphas_cumprod = alphas
.view()
.into_iter()
.scan(1.0, |prod, alpha| {
*prod *= *alpha;
Some(*prod)
})
.collect::<Array1<_>>();
let alpha_t = alphas_cumprod.map(|f| f.sqrt());
let sigma_t = alphas_cumprod.map(|f| (1.0 - f).sqrt());
let lambda_t = alpha_t.map(|f| f.log(std::f32::consts::E)) - sigma_t.map(|f| f.log(std::f32::consts::E));
let timesteps = Array1::linspace(0.0, num_train_timesteps as f32 - 1.0, num_train_timesteps)
.slice(s![..;-1])
.map(|f| *f as usize)
.to_owned();
let init_noise_sigma = 1.0;
Ok(Self {
alphas_cumprod,
alpha_t,
sigma_t,
lambda_t,
init_noise_sigma,
timesteps,
num_inference_steps: None,
num_train_timesteps,
prediction_type: *prediction_type,
config: config.clone(),
lower_order_nums: 0,
model_outputs: vec![None; config.solver_order]
})
}
fn convert_model_output(&self, model_output: Array4<f32>, timestep: usize, sample: ArrayView4<f32>) -> Array4<f32> {
match self.config.algorithm_type {
DPMSolverAlgorithmType::DPMSolverPlusPlus => {
let x0_pred = match self.prediction_type {
SchedulerPredictionType::Epsilon => {
let alpha_t = self.alpha_t[timestep];
let sigma_t = self.sigma_t[timestep];
(sample.to_owned() - sigma_t * model_output) / alpha_t
}
SchedulerPredictionType::Sample => model_output,
SchedulerPredictionType::VPrediction => {
let alpha_t = self.alpha_t[timestep];
let sigma_t = self.sigma_t[timestep];
alpha_t * sample.to_owned() - sigma_t * model_output
}
};
if self.config.thresholding {
todo!("thresholding not yet implemented for DPMSolverMultistepScheduler, please open an issue");
}
x0_pred
}
DPMSolverAlgorithmType::DPMSolver => match self.prediction_type {
SchedulerPredictionType::Epsilon => model_output,
SchedulerPredictionType::Sample => {
let alpha_t = self.alpha_t[timestep];
let sigma_t = self.sigma_t[timestep];
(sample.to_owned() - alpha_t * model_output) / sigma_t
}
SchedulerPredictionType::VPrediction => {
let alpha_t = self.alpha_t[timestep];
let sigma_t = self.sigma_t[timestep];
alpha_t * model_output + sigma_t * sample.to_owned()
}
}
}
}
fn dpm_solver_first_order_update(&self, model_output: Array4<f32>, timestep: usize, prev_timestep: usize, sample: ArrayView4<f32>) -> Array4<f32> {
let (lambda_t, lambda_s) = (self.lambda_t[prev_timestep], self.lambda_t[timestep]);
let (alpha_t, alpha_s) = (self.alpha_t[prev_timestep], self.alpha_t[timestep]);
let (sigma_t, sigma_s) = (self.sigma_t[prev_timestep], self.sigma_t[timestep]);
let h = lambda_t - lambda_s;
match self.config.algorithm_type {
DPMSolverAlgorithmType::DPMSolverPlusPlus => {
(sigma_t / sigma_s) * sample.to_owned() - (alpha_t * (std::f32::consts::E.powf(-h) - 1.0)) * model_output
}
DPMSolverAlgorithmType::DPMSolver => (alpha_t / alpha_s) * sample.to_owned() - (sigma_t * (std::f32::consts::E.powf(h) - 1.0)) * model_output
}
}
fn multistep_dpm_solver_second_order_update(
&self,
model_output_list: &Vec<Option<Array4<f32>>>,
timestep_list: [usize; 2],
prev_timestep: usize,
sample: ArrayView4<f32>
) -> Array4<f32> {
assert_eq!(timestep_list.len(), model_output_list.len());
let (t, s0, s1) = (prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2]);
let (m0, m1) = (model_output_list[model_output_list.len() - 1].as_ref().unwrap(), model_output_list[model_output_list.len() - 2].as_ref().unwrap());
let (lambda_t, lambda_s0, lambda_s1) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]);
let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]);
let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]);
let (h, h_0) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1);
let r0 = h_0 / h;
let (d0, d1) = (m0, (1.0 / r0) * (m0 - m1));
match self.config.algorithm_type {
DPMSolverAlgorithmType::DPMSolverPlusPlus => match self.config.solver_type {
DPMSolverType::Midpoint => {
((sigma_t / sigma_s0) * sample.to_owned())
- (alpha_t * (std::f32::consts::E.powf(-h) - 1.0)) * d0
- 0.5 * (alpha_t * (std::f32::consts::E.powf(-h) - 1.0)) * d1
}
DPMSolverType::Heun => {
((sigma_t / sigma_s0) * sample.to_owned()) - (alpha_t * (std::f32::consts::E.powf(-h) - 1.0)) * d0
+ (alpha_t * ((std::f32::consts::E.powf(-h) - 1.0) / h + 1.0)) * d1
}
},
DPMSolverAlgorithmType::DPMSolver => match self.config.solver_type {
DPMSolverType::Midpoint => {
(alpha_t / alpha_s0) * sample.to_owned()
- (sigma_t * (std::f32::consts::E.powf(h) - 1.0)) * d0
- 0.5 * (sigma_t * (std::f32::consts::E.powf(h) - 1.0)) * d1
}
DPMSolverType::Heun => {
(alpha_t / alpha_s0) * sample.to_owned()
- (sigma_t * (std::f32::consts::E.powf(h) - 1.0)) * d0
- (sigma_t * ((std::f32::consts::E.powf(h) - 1.0) / h - 1.0)) * d1
}
}
}
}
fn multistep_dpm_solver_third_order_update(
&self,
model_output_list: &Vec<Option<Array4<f32>>>,
timestep_list: [usize; 3],
prev_timestep: usize,
sample: ArrayView4<f32>
) -> Array4<f32> {
assert_eq!(timestep_list.len(), model_output_list.len());
let (t, s0, s1, s2) =
(prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2], timestep_list[timestep_list.len() - 3]);
let (m0, m1, m2) = (
model_output_list[model_output_list.len() - 1].as_ref().unwrap(),
model_output_list[model_output_list.len() - 2].as_ref().unwrap(),
model_output_list[model_output_list.len() - 3].as_ref().unwrap()
);
let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]);
let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]);
let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]);
let (h, h_0, h_1) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2);
let (r0, r1) = (h_0 / h, h_1 / h);
let d0 = m0;
let (d1_0, d1_1) = ((1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2));
let d1 = d1_0.clone() + (r0 / (r0 + r1)) * (&d1_0 - &d1_1);
let d2 = (1.0 / (r0 + r1)) * (d1_0 - d1_1);
match self.config.algorithm_type {
DPMSolverAlgorithmType::DPMSolverPlusPlus => {
(sigma_t / sigma_s0) * sample.to_owned() - (alpha_t * (std::f32::consts::E.powf(-h) - 1.0)) * d0
+ (alpha_t * ((std::f32::consts::E.powf(-h) - 1.0) / h + 1.0)) * d1
- (alpha_t * ((std::f32::consts::E.powf(-h) - 1.0 + h) / h.powi(2) - 0.5)) * d2
}
DPMSolverAlgorithmType::DPMSolver => {
(alpha_t / alpha_s0) * sample.to_owned()
- (sigma_t * (std::f32::consts::E.powf(h) - 1.0)) * d0
- (sigma_t * ((std::f32::consts::E.powf(h) - 1.0) / h - 1.0)) * d1
- (sigma_t * ((std::f32::consts::E.powf(h) - 1.0 - h) / h.powi(2) - 0.5)) * d2
}
}
}
}
impl DiffusionScheduler for DPMSolverMultistepScheduler {
type TimestepType = usize;
fn order() -> usize {
1
}
fn scale_model_input(&mut self, sample: ArrayView4<'_, f32>, _: usize) -> Array4<f32> {
sample.to_owned()
}
fn set_timesteps(&mut self, num_inference_steps: usize) {
self.num_inference_steps = Some(num_inference_steps);
let timesteps = Array1::linspace(0.0, (self.num_train_timesteps - 1) as f32, num_inference_steps + 1)
.slice(s![..;-1])
.slice(s![..num_inference_steps])
.map(|f| f.round() as usize)
.to_owned();
self.timesteps = timesteps;
self.model_outputs = vec![None; self.config.solver_order as _];
self.lower_order_nums = 0;
}
fn step<R: Rng + ?Sized>(&mut self, model_output: ArrayView4<'_, f32>, timestep: usize, sample: ArrayView4<'_, f32>, _: &mut R) -> SchedulerStepOutput {
let step_index = self
.timesteps
.iter()
.position(|&p| p == timestep)
.with_context(|| format!("timestep out of this schedulers bounds: {timestep}"))
.unwrap();
let prev_timestep = if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] };
let lower_order_final = (step_index == self.timesteps.len() - 1) && self.config.lower_order_final && self.timesteps.len() < 15;
let lower_order_second = (step_index == self.timesteps.len() - 2) && self.config.lower_order_final && self.timesteps.len() < 15;
let model_output = self.convert_model_output(model_output.to_owned(), timestep, sample);
for i in 0..self.config.solver_order - 1 {
self.model_outputs[i] = self.model_outputs[i + 1].clone();
}
let m_len = self.model_outputs.len();
self.model_outputs[m_len - 1] = Some(model_output.clone());
let prev_sample = if self.config.solver_order == 1 || self.lower_order_nums < 1 || lower_order_final {
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
let timestep_list = [self.timesteps[step_index - 1], timestep];
self.multistep_dpm_solver_second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
} else {
let timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
self.multistep_dpm_solver_third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
};
SchedulerStepOutput { prev_sample, ..Default::default() }
}
fn add_noise(&mut self, original_samples: ArrayView4<'_, f32>, noise: ArrayView4<'_, f32>, timestep: usize) -> Array4<f32> {
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise.to_owned()
}
fn timesteps(&self) -> ndarray::ArrayView1<'_, usize> {
self.timesteps.view()
}
fn init_noise_sigma(&self) -> f32 {
self.init_noise_sigma
}
fn len(&self) -> usize {
self.num_train_timesteps
}
}
impl SchedulerOptimizedDefaults for DPMSolverMultistepScheduler {
fn stable_diffusion_v1_optimized_default() -> anyhow::Result<Self>
where
Self: Sized
{
Self::new(
1000,
0.00085,
0.012,
&BetaSchedule::ScaledLinear,
&SchedulerPredictionType::Epsilon,
Some(DPMSolverMultistepSchedulerConfig {
solver_order: 2,
thresholding: false,
dynamic_thresholding_ratio: 0.995,
sample_max_value: 1.0,
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
solver_type: DPMSolverType::Midpoint,
lower_order_final: true
})
)
}
}