use super::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use tch::{kind, Kind, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DDPMVarianceType {
FixedSmall,
FixedSmallLog,
FixedLarge,
FixedLargeLog,
Learned,
}
impl Default for DDPMVarianceType {
fn default() -> Self {
Self::FixedSmall
}
}
#[derive(Debug, Clone)]
pub struct DDPMSchedulerConfig {
pub beta_start: f64,
pub beta_end: f64,
pub beta_schedule: BetaSchedule,
pub clip_sample: bool,
pub variance_type: DDPMVarianceType,
pub prediction_type: PredictionType,
pub train_timesteps: usize,
}
impl Default for DDPMSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085,
beta_end: 0.012,
beta_schedule: BetaSchedule::ScaledLinear,
clip_sample: false,
variance_type: DDPMVarianceType::FixedSmall,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
}
}
}
pub struct DDPMScheduler {
alphas_cumprod: Vec<f64>,
init_noise_sigma: f64,
timesteps: Vec<usize>,
step_ratio: usize,
pub config: DDPMSchedulerConfig,
}
impl DDPMScheduler {
pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps as i64,
kind::FLOAT_CPU,
)
.square(),
BetaSchedule::Linear => Tensor::linspace(
config.beta_start,
config.beta_end,
config.train_timesteps as i64,
kind::FLOAT_CPU,
),
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999),
};
let alphas: Tensor = 1. - betas;
let alphas_cumprod = Vec::<f64>::try_from(alphas.cumprod(0, Kind::Double)).unwrap();
let inference_steps = inference_steps.min(config.train_timesteps);
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
Self { alphas_cumprod, init_noise_sigma: 1.0, timesteps, step_ratio, config }
}
fn get_variance(&self, timestep: usize) -> f64 {
let prev_t = timestep as isize - self.step_ratio as isize;
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev =
if prev_t >= 0 { self.alphas_cumprod[prev_t as usize] } else { 1.0 };
let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
match self.config.variance_type {
DDPMVarianceType::FixedSmall => variance.max(1e-20),
DDPMVarianceType::FixedSmallLog => {
let variance = variance.max(1e-20).ln();
(variance * 0.5).exp()
}
DDPMVarianceType::FixedLarge => current_beta_t,
DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
DDPMVarianceType::Learned => variance,
}
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
sample
}
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
let prev_t = timestep as isize - self.step_ratio as isize;
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev =
if prev_t >= 0 { self.alphas_cumprod[prev_t as usize] } else { 1.0 };
let beta_prod_t = 1. - alpha_prod_t;
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
let current_beta_t = 1. - current_alpha_t;
let mut pred_original_sample = match self.config.prediction_type {
PredictionType::Epsilon => {
(sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
}
PredictionType::Sample => model_output.shallow_clone(),
PredictionType::VPrediction => {
alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
}
};
if self.config.clip_sample {
pred_original_sample = pred_original_sample.clamp(-1., 1.);
}
let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
let pred_prev_sample =
pred_original_sample_coeff * &pred_original_sample + current_sample_coeff * sample;
let mut variance = model_output.zeros_like();
if timestep > 0 {
let variance_noise = model_output.randn_like();
if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
variance = self.get_variance(timestep) * variance_noise;
} else {
variance = self.get_variance(timestep).sqrt() * variance_noise;
}
}
&pred_prev_sample + variance
}
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
self.alphas_cumprod[timestep].sqrt() * original_samples
+ (1. - self.alphas_cumprod[timestep]).sqrt() * noise
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}