use super::{interp, BetaSchedule, PredictionType};
use tch::{kind, IndexOp, Kind, Tensor};
#[derive(Debug, Clone)]
pub struct KDPM2AncestralDiscreteSchedulerConfig {
pub beta_start: f64,
pub beta_end: f64,
pub beta_schedule: BetaSchedule,
pub train_timesteps: usize,
pub prediction_type: PredictionType,
}
impl Default for KDPM2AncestralDiscreteSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085, beta_end: 0.012,
beta_schedule: BetaSchedule::ScaledLinear,
train_timesteps: 1000,
prediction_type: PredictionType::Epsilon,
}
}
}
pub struct KDPM2AncestralDiscreteScheduler {
timesteps: Vec<f64>,
sigmas: Vec<f64>,
sigmas_interpol: Vec<f64>,
sigmas_up: Vec<f64>,
sigmas_down: Vec<f64>,
init_noise_sigma: f64,
sample: Option<Tensor>,
pub config: KDPM2AncestralDiscreteSchedulerConfig,
}
impl KDPM2AncestralDiscreteScheduler {
pub fn new(inference_steps: usize, config: KDPM2AncestralDiscreteSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps as i64,
kind::FLOAT_CPU,
)
.square(),
BetaSchedule::Linear => Tensor::linspace(
config.beta_start,
config.beta_end,
config.train_timesteps as i64,
kind::FLOAT_CPU,
),
_ => unimplemented!(
"KDPM2AncestralDiscreteScheduler only implements linear and scaled_linear betas."
),
};
let alphas: Tensor = 1. - betas;
let alphas_cumprod = alphas.cumprod(0, Kind::Double);
let timesteps = Tensor::linspace(
(config.train_timesteps - 1) as f64,
0.,
inference_steps as i64,
kind::FLOAT_CPU,
);
let sigmas = ((1. - &alphas_cumprod) as Tensor / alphas_cumprod).sqrt();
let log_sigmas = sigmas.log();
let sigmas = interp(
×teps, Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU),
sigmas,
);
let sigmas = Tensor::concat(&[sigmas, [0.0].as_slice().into()], 0);
let sz = sigmas.size1().unwrap();
let sigmas_next = sigmas.roll([-1], [0]);
let sigmas_next = sigmas_next.index_fill(0, &[sz - 1].as_slice().into(), 0.0);
let sigmas_up = (sigmas_next.square() * (sigmas.square() - sigmas_next.square())
/ sigmas.square())
.sqrt();
let sigmas_down = (sigmas_next.square() - sigmas_up.square()).sqrt();
let sigmas_down = sigmas_down.index_fill(0, &[sz - 1].as_slice().into(), 0.0);
let sigmas_interpol = sigmas.log().lerp(&sigmas_down.log(), 0.5).exp();
let sigmas_interpol =
sigmas_interpol.index_fill(0, &[sz - 2, sz - 1].as_slice().into(), 0.0);
let timesteps_interpol = Self::sigma_to_t(&sigmas_interpol, log_sigmas);
let interleaved_timesteps = Tensor::stack(
&[
timesteps_interpol.slice(0, 0, -2, 1).unsqueeze(-1),
timesteps.i(1..).unsqueeze(-1),
],
-1,
)
.flatten(0, -1);
let sigmas = Tensor::cat(
&[
sigmas.i(..1),
sigmas.i(1..).repeat_interleave_self_int(2, 0, None),
sigmas.i(-1..0),
],
0,
);
let sigmas_interpol = Tensor::cat(
&[
sigmas_interpol.i(..1),
sigmas_interpol.i(1..).repeat_interleave_self_int(2, 0, None),
sigmas_interpol.i(-1..0),
],
0,
);
let sigmas_up = Tensor::cat(
&[
sigmas_up.i(..1),
sigmas_up.i(1..).repeat_interleave_self_int(2, 0, None),
sigmas_up.i(-1..0),
],
0,
);
let sigmas_down = Tensor::cat(
&[
sigmas_down.i(..1),
sigmas_down.i(1..).repeat_interleave_self_int(2, 0, None),
sigmas_down.i(-1..0),
],
0,
);
let timesteps = Tensor::cat(
&[
timesteps.i(..1),
interleaved_timesteps,
],
0,
);
let init_noise_sigma: f64 = sigmas.max().try_into().unwrap();
Self {
timesteps: timesteps.try_into().unwrap(),
sigmas: sigmas.try_into().unwrap(),
sigmas_interpol: sigmas_interpol.try_into().unwrap(),
sigmas_up: sigmas_up.try_into().unwrap(),
sigmas_down: sigmas_down.try_into().unwrap(),
init_noise_sigma,
sample: None,
config,
}
}
fn sigma_to_t(sigma: &Tensor, log_sigmas: Tensor) -> Tensor {
let log_sigma = sigma.log();
let dists = &log_sigma - log_sigmas.unsqueeze(-1);
let low_idx = dists
.ge(0)
.cumsum(0, Kind::Int64)
.argmax(0, false)
.clamp_max(log_sigmas.size1().unwrap() - 2);
let high_idx = &low_idx + 1;
let low = log_sigmas.index_select(0, &low_idx);
let high = log_sigmas.index_select(0, &high_idx);
let w = (&low - log_sigma) / (low - high);
let w = w.clamp(0., 1.);
let t: Tensor = (1 - &w) * low_idx + w * high_idx;
t.view(sigma.size().as_slice())
}
pub fn timesteps(&self) -> &[f64] {
self.timesteps.as_slice()
}
fn index_for_timestep(&self, timestep: f64) -> usize {
let indices = self
.timesteps
.iter()
.enumerate()
.filter_map(|(idx, &t)| (t == timestep).then_some(idx))
.collect::<Vec<_>>();
if self.state_in_first_order() {
*indices.last().unwrap()
} else {
indices[0]
}
}
pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor {
let step_index = self.index_for_timestep(timestep);
let step_index_minus_one =
if step_index == 0 { self.sigmas.len() - 1 } else { step_index - 1 };
let sigma = if self.state_in_first_order() {
self.sigmas[step_index]
} else {
self.sigmas_interpol[step_index_minus_one]
};
sample / (sigma.powi(2) + 1.).sqrt()
}
fn state_in_first_order(&self) -> bool {
self.sample.is_none()
}
pub fn step(&mut self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor {
let step_index = self.index_for_timestep(timestep);
let step_index_minus_one =
if step_index == 0 { self.sigmas.len() - 1 } else { step_index - 1 };
let (sigma, sigma_interpol, sigma_up, sigma_down) = if self.state_in_first_order() {
(
self.sigmas[step_index],
self.sigmas_interpol[step_index],
self.sigmas_up[step_index],
self.sigmas_down[step_index_minus_one],
)
} else {
(
self.sigmas[step_index_minus_one],
self.sigmas_interpol[step_index_minus_one],
self.sigmas_up[step_index_minus_one],
self.sigmas_down[step_index_minus_one],
)
};
let gamma = 0.0;
let sigma_hat = sigma * (gamma + 1.);
let noise = model_output.randn_like();
let sigma_input = if self.state_in_first_order() { sigma_hat } else { sigma_interpol };
let pred_original_sample = match self.config.prediction_type {
PredictionType::Epsilon => sample - sigma_input * model_output,
PredictionType::VPrediction => {
model_output * (-sigma_input / (sigma_input.powi(2) + 1.).sqrt())
+ (sample / (sigma_input.powi(2) + 1.))
}
_ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"),
};
let mut prev_sample;
if self.state_in_first_order() {
let derivative = (sample - pred_original_sample) / sigma_hat;
let dt = sigma_interpol - sigma_hat;
self.sample = Some(sample.shallow_clone());
prev_sample = sample + derivative * dt;
} else {
let derivative = (sample - pred_original_sample) / sigma_interpol;
let dt = sigma_down - sigma_hat;
let sample = self.sample.as_ref().unwrap().shallow_clone();
self.sample = None;
prev_sample = sample + derivative * dt;
prev_sample += noise * sigma_up;
}
prev_sample
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor {
let step_index = self.index_for_timestep(timestep);
let sigma = self.sigmas[step_index];
original_samples + noise * sigma
}
}