use candle::{DType, Device, Result, Tensor};
pub fn get_noise(
batch_size: usize,
channels: usize,
height: usize,
width: usize,
device: &Device,
) -> Result<Tensor> {
Tensor::randn(0f32, 1.0, (batch_size, channels, height, width), device)
}
pub fn get_schedule(num_steps: usize, mu: f64) -> Vec<f64> {
let timesteps: Vec<f64> = (0..=num_steps)
.map(|v| v as f64 / num_steps as f64)
.rev()
.collect();
timesteps
.into_iter()
.map(|t| {
if t <= 0.0 || t >= 1.0 {
t } else {
let e = mu.exp();
e / (e + (1.0 / t - 1.0))
}
})
.collect()
}
pub fn postprocess_image(image: &Tensor) -> Result<Tensor> {
let image = image.clamp(-1.0, 1.0)?;
let image = ((image + 1.0)? * 127.5)?;
image.to_dtype(DType::U8)
}
#[derive(Debug, Clone)]
pub struct CfgConfig {
pub guidance_scale: f64,
pub cfg_truncation: f64,
pub cfg_normalization: bool,
}
impl Default for CfgConfig {
fn default() -> Self {
Self {
guidance_scale: 5.0,
cfg_truncation: 1.0,
cfg_normalization: false,
}
}
}
pub fn apply_cfg(
pos_pred: &Tensor,
neg_pred: &Tensor,
cfg: &CfgConfig,
t_norm: f64,
) -> Result<Tensor> {
let current_scale = if t_norm > cfg.cfg_truncation {
0.0
} else {
cfg.guidance_scale
};
if current_scale <= 0.0 {
return Ok(pos_pred.clone());
}
let diff = (pos_pred - neg_pred)?;
let pred = (pos_pred + (diff * current_scale)?)?;
if cfg.cfg_normalization {
let ori_norm = pos_pred.sqr()?.sum_all()?.sqrt()?;
let new_norm = pred.sqr()?.sum_all()?.sqrt()?;
let ori_norm_val = ori_norm.to_scalar::<f32>()?;
let new_norm_val = new_norm.to_scalar::<f32>()?;
if new_norm_val > ori_norm_val {
let scale = ori_norm_val / new_norm_val;
return pred * scale as f64;
}
}
Ok(pred)
}
pub fn scale_noise(noise: &Tensor, sigma: f64) -> Result<Tensor> {
noise * sigma
}