use super::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use tch::{kind, Kind, Tensor};
#[derive(Debug, Clone)]
pub struct PNDMSchedulerConfig {
pub beta_start: f64,
pub beta_end: f64,
pub beta_schedule: BetaSchedule,
pub set_alpha_to_one: bool,
pub prediction_type: PredictionType,
pub steps_offset: usize,
pub train_timesteps: usize,
}
impl Default for PNDMSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085,
beta_end: 0.012,
beta_schedule: BetaSchedule::ScaledLinear,
set_alpha_to_one: false,
prediction_type: PredictionType::Epsilon,
steps_offset: 1,
train_timesteps: 1000,
}
}
}
pub struct PNDMScheduler {
alphas_cumprod: Vec<f64>,
final_alpha_cumprod: f64,
step_ratio: usize,
init_noise_sigma: f64,
counter: usize,
cur_sample: Option<Tensor>,
ets: Vec<Tensor>,
timesteps: Vec<usize>,
pub config: PNDMSchedulerConfig,
}
impl PNDMScheduler {
pub fn new(inference_steps: usize, config: PNDMSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps as i64,
kind::FLOAT_CPU,
)
.square(),
BetaSchedule::Linear => Tensor::linspace(
config.beta_start,
config.beta_end,
config.train_timesteps as i64,
kind::FLOAT_CPU,
),
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999),
};
let alphas: Tensor = 1. - betas;
let alphas_cumprod = Vec::<f64>::try_from(alphas.cumprod(0, Kind::Double)).unwrap();
let final_alpha_cumprod = if config.set_alpha_to_one { 1.0 } else { alphas_cumprod[0] };
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> =
(0..(inference_steps)).map(|s| s * step_ratio + config.steps_offset).collect();
let n_ts = timesteps.len();
let plms_timesteps =
[×teps[..n_ts - 2], &[timesteps[n_ts - 2]], ×teps[n_ts - 2..]]
.concat()
.into_iter()
.rev()
.collect();
Self {
alphas_cumprod,
final_alpha_cumprod,
step_ratio,
init_noise_sigma: 1.0,
counter: 0,
cur_sample: None,
ets: vec![],
timesteps: plms_timesteps,
config,
}
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
sample
}
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
self.step_plms(model_output, timestep, sample)
}
fn step_plms(&mut self, model_output: &Tensor, mut timestep: usize, sample: &Tensor) -> Tensor {
let mut prev_timestep = timestep as isize - self.step_ratio as isize;
if self.counter != 1 {
if self.ets.len() > 3 {
self.ets.drain(0..self.ets.len() - 3);
}
self.ets.push(model_output.shallow_clone());
} else {
prev_timestep = timestep as isize;
timestep += self.step_ratio;
}
let (ets_last, n_ets) = (self.ets.last().unwrap(), self.ets.len());
let (mut model_output, mut sample) = (model_output.shallow_clone(), sample.shallow_clone());
if n_ets == 1 && self.counter == 0 {
self.cur_sample = Some(sample.shallow_clone());
} else if n_ets == 1 && self.counter == 1 {
sample = self.cur_sample.as_ref().unwrap().shallow_clone();
self.cur_sample = None;
model_output = (model_output + ets_last) / 2.;
} else if n_ets == 2 {
model_output = (3. * ets_last - &self.ets[n_ets - 2]) / 2.;
} else if n_ets == 3 {
model_output =
(23. * ets_last - 16. * &self.ets[n_ets - 2] + 5. * &self.ets[n_ets - 3]) / 12.;
} else {
model_output = (1. / 24.)
* (55. * ets_last - 59. * &self.ets[n_ets - 2] + 37. * &self.ets[n_ets - 3]
- 9. * &self.ets[n_ets - 4]);
}
let prev_sample = self.get_prev_sample(sample, timestep, prev_timestep, model_output);
self.counter += 1;
prev_sample
}
fn get_prev_sample(
&self,
sample: Tensor,
timestep: usize,
prev_timestep: isize,
model_output: Tensor,
) -> Tensor {
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev = if prev_timestep >= 0 {
self.alphas_cumprod[prev_timestep as usize]
} else {
self.final_alpha_cumprod
};
let beta_prod_t = 1. - alpha_prod_t;
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
let model_output = match self.config.prediction_type {
PredictionType::VPrediction => {
alpha_prod_t.sqrt() * model_output + beta_prod_t.sqrt() * &sample
}
PredictionType::Epsilon => model_output.shallow_clone(),
_ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction"),
};
let sample_coeff = (alpha_prod_t_prev / alpha_prod_t).sqrt();
let model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev.sqrt()
+ (alpha_prod_t * beta_prod_t * alpha_prod_t_prev).sqrt();
sample_coeff * sample
- (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { timestep };
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}