use candle::{Result, Tensor};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct SchedulerConfig {
#[serde(default = "default_num_train_timesteps")]
pub num_train_timesteps: usize,
#[serde(default = "default_shift")]
pub shift: f64,
#[serde(default)]
pub use_dynamic_shifting: bool,
}
fn default_num_train_timesteps() -> usize {
1000
}
fn default_shift() -> f64 {
3.0
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
num_train_timesteps: default_num_train_timesteps(),
shift: default_shift(),
use_dynamic_shifting: false,
}
}
}
impl SchedulerConfig {
pub fn z_image_turbo() -> Self {
Self {
num_train_timesteps: 1000,
shift: 3.0,
use_dynamic_shifting: false,
}
}
}
#[derive(Debug, Clone)]
pub struct FlowMatchEulerDiscreteScheduler {
pub config: SchedulerConfig,
pub timesteps: Vec<f64>,
pub sigmas: Vec<f64>,
pub sigma_min: f64,
pub sigma_max: f64,
step_index: usize,
}
impl FlowMatchEulerDiscreteScheduler {
pub fn new(config: SchedulerConfig) -> Self {
let num_train_timesteps = config.num_train_timesteps;
let shift = config.shift;
let timesteps: Vec<f64> = (1..=num_train_timesteps).rev().map(|t| t as f64).collect();
let sigmas: Vec<f64> = timesteps
.iter()
.map(|&t| t / num_train_timesteps as f64)
.collect();
let sigmas: Vec<f64> = if !config.use_dynamic_shifting {
sigmas
.iter()
.map(|&s| shift * s / (1.0 + (shift - 1.0) * s))
.collect()
} else {
sigmas
};
let timesteps: Vec<f64> = sigmas
.iter()
.map(|&s| s * num_train_timesteps as f64)
.collect();
let sigma_max = sigmas[0];
let sigma_min = *sigmas.last().unwrap_or(&0.0);
Self {
config,
timesteps,
sigmas,
sigma_min,
sigma_max,
step_index: 0,
}
}
pub fn set_timesteps(&mut self, num_inference_steps: usize, mu: Option<f64>) {
let sigma_max = self.sigmas[0];
let sigma_min = *self.sigmas.last().unwrap_or(&0.0);
let timesteps: Vec<f64> = (0..num_inference_steps)
.map(|i| {
let t = i as f64 / num_inference_steps as f64;
sigma_max * (1.0 - t) + sigma_min * t
})
.map(|s| s * self.config.num_train_timesteps as f64)
.collect();
let mut sigmas: Vec<f64> = timesteps
.iter()
.map(|&t| t / self.config.num_train_timesteps as f64)
.collect();
if let Some(mu) = mu {
if self.config.use_dynamic_shifting {
sigmas = sigmas
.iter()
.map(|&t| {
if t <= 0.0 {
0.0
} else {
let e_mu = mu.exp();
e_mu / (e_mu + (1.0 / t - 1.0))
}
})
.collect();
}
} else if !self.config.use_dynamic_shifting {
let shift = self.config.shift;
sigmas = sigmas
.iter()
.map(|&s| shift * s / (1.0 + (shift - 1.0) * s))
.collect();
}
sigmas.push(0.0);
self.timesteps = timesteps;
self.sigmas = sigmas;
self.step_index = 0;
}
pub fn current_sigma(&self) -> f64 {
self.sigmas[self.step_index]
}
pub fn current_timestep_normalized(&self) -> f64 {
let t = self.timesteps.get(self.step_index).copied().unwrap_or(0.0);
(1000.0 - t) / 1000.0
}
pub fn step(&mut self, model_output: &Tensor, sample: &Tensor) -> Result<Tensor> {
let sigma = self.sigmas[self.step_index];
let sigma_next = self.sigmas[self.step_index + 1];
let dt = sigma_next - sigma;
let prev_sample = (sample + (model_output * dt)?)?;
self.step_index += 1;
Ok(prev_sample)
}
pub fn reset(&mut self) {
self.step_index = 0;
}
pub fn num_inference_steps(&self) -> usize {
self.timesteps.len()
}
pub fn step_index(&self) -> usize {
self.step_index
}
pub fn is_complete(&self) -> bool {
self.step_index >= self.timesteps.len()
}
}
pub fn calculate_shift(
image_seq_len: usize,
base_seq_len: usize,
max_seq_len: usize,
base_shift: f64,
max_shift: f64,
) -> f64 {
let m = (max_shift - base_shift) / (max_seq_len - base_seq_len) as f64;
let b = base_shift - m * base_seq_len as f64;
image_seq_len as f64 * m + b
}
pub const BASE_IMAGE_SEQ_LEN: usize = 256;
pub const MAX_IMAGE_SEQ_LEN: usize = 4096;
pub const BASE_SHIFT: f64 = 0.5;
pub const MAX_SHIFT: f64 = 1.15;