oxigaf_diffusion/
scheduler.rs1use candle_core::{DType, Device, Result, Tensor};
7
8#[derive(Debug)]
10pub struct DdimScheduler {
11 alphas_cumprod: Vec<f64>,
13 num_train_timesteps: usize,
15 timesteps: Vec<usize>,
17 prediction_type: PredictionType,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum PredictionType {
24 Epsilon,
26 VPrediction,
28}
29
30impl DdimScheduler {
31 pub fn new(num_train_timesteps: usize, prediction_type: PredictionType) -> Self {
36 let beta_start: f64 = 0.00085_f64.sqrt();
37 let beta_end: f64 = 0.012_f64.sqrt();
38
39 let mut alphas_cumprod = Vec::with_capacity(num_train_timesteps);
40 let mut cumprod = 1.0_f64;
41 for i in 0..num_train_timesteps {
42 let beta = beta_start
43 + (beta_end - beta_start) * (i as f64) / ((num_train_timesteps - 1) as f64);
44 let beta = beta * beta; let alpha = 1.0 - beta;
46 cumprod *= alpha;
47 alphas_cumprod.push(cumprod);
48 }
49
50 Self {
51 alphas_cumprod,
52 num_train_timesteps,
53 timesteps: Vec::new(),
54 prediction_type,
55 }
56 }
57
58 pub fn set_timesteps(&mut self, num_inference_steps: usize) {
60 let step = self.num_train_timesteps / num_inference_steps;
61 self.timesteps = (0..num_inference_steps).rev().map(|i| i * step).collect();
62 }
63
64 pub fn timesteps(&self) -> &[usize] {
66 &self.timesteps
67 }
68
69 pub fn step(&self, model_output: &Tensor, t: usize, sample: &Tensor) -> Result<Tensor> {
77 let alpha_prod_t = self.alphas_cumprod[t];
78 let alpha_prod_t_prev = if t > 0 {
79 let step = self.num_train_timesteps / self.timesteps.len();
81 if t >= step {
82 self.alphas_cumprod[t - step]
83 } else {
84 1.0
85 }
86 } else {
87 1.0
88 };
89
90 let sqrt_alpha_prod = alpha_prod_t.sqrt();
91 let sqrt_one_minus_alpha_prod = (1.0 - alpha_prod_t).sqrt();
92
93 let pred_x0 = match self.prediction_type {
95 PredictionType::Epsilon => {
96 ((sample - (model_output * sqrt_one_minus_alpha_prod)?)? * (1.0 / sqrt_alpha_prod))?
98 }
99 PredictionType::VPrediction => {
100 ((sample * sqrt_alpha_prod)? - (model_output * sqrt_one_minus_alpha_prod)?)?
102 }
103 };
104
105 let pred_epsilon = match self.prediction_type {
107 PredictionType::Epsilon => model_output.clone(),
108 PredictionType::VPrediction => {
109 ((model_output * sqrt_alpha_prod)? + (sample * sqrt_one_minus_alpha_prod)?)?
110 }
111 };
112
113 let sqrt_alpha_prod_prev = alpha_prod_t_prev.sqrt();
115 let sqrt_one_minus_alpha_prod_prev = (1.0 - alpha_prod_t_prev).sqrt();
116
117 (&pred_x0 * sqrt_alpha_prod_prev)? + (&pred_epsilon * sqrt_one_minus_alpha_prod_prev)?
119 }
120
121 pub fn add_noise(&self, original: &Tensor, noise: &Tensor, timestep: usize) -> Result<Tensor> {
125 let alpha = self.alphas_cumprod[timestep];
126 let sqrt_alpha = alpha.sqrt();
127 let sqrt_one_minus_alpha = (1.0 - alpha).sqrt();
128 (original * sqrt_alpha)? + (noise * sqrt_one_minus_alpha)?
129 }
130
131 pub fn timestep_tensor(&self, t: usize, batch_size: usize, device: &Device) -> Result<Tensor> {
133 Tensor::full(t as f32, (batch_size,), device)?.to_dtype(DType::F32)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_alphas_cumprod_decreasing() {
143 let sched = DdimScheduler::new(1000, PredictionType::VPrediction);
144 assert!(sched.alphas_cumprod[0] > sched.alphas_cumprod[999]);
145 assert!(sched.alphas_cumprod[0] > 0.99);
147 assert!(sched.alphas_cumprod[999] < 0.01);
149 }
150
151 #[test]
152 fn test_set_timesteps() {
153 let mut sched = DdimScheduler::new(1000, PredictionType::Epsilon);
154 sched.set_timesteps(50);
155 assert_eq!(sched.timesteps().len(), 50);
156 assert!(sched.timesteps()[0] > sched.timesteps()[49]);
158 }
159}