1use ghostflow_core::Tensor;
11use crate::conv::Conv2d;
12use crate::norm::GroupNorm;
13use crate::linear::Linear;
14use crate::Module;
15
16#[derive(Debug, Clone)]
18pub struct DiffusionConfig {
19 pub image_size: usize,
21 pub in_channels: usize,
23 pub model_channels: usize,
25 pub num_res_blocks: usize,
27 pub channel_mult: Vec<usize>,
29 pub num_heads: usize,
31 pub num_timesteps: usize,
33 pub beta_schedule: BetaSchedule,
35}
36
37impl Default for DiffusionConfig {
38 fn default() -> Self {
39 DiffusionConfig {
40 image_size: 256,
41 in_channels: 3,
42 model_channels: 128,
43 num_res_blocks: 2,
44 channel_mult: vec![1, 2, 4, 8],
45 num_heads: 8,
46 num_timesteps: 1000,
47 beta_schedule: BetaSchedule::Linear,
48 }
49 }
50}
51
52impl DiffusionConfig {
53 pub fn stable_diffusion() -> Self {
55 DiffusionConfig {
56 image_size: 512,
57 in_channels: 4, model_channels: 320,
59 num_res_blocks: 2,
60 channel_mult: vec![1, 2, 4, 4],
61 num_heads: 8,
62 num_timesteps: 1000,
63 beta_schedule: BetaSchedule::ScaledLinear,
64 }
65 }
66
67 pub fn ddpm() -> Self {
69 DiffusionConfig {
70 image_size: 256,
71 in_channels: 3,
72 model_channels: 128,
73 num_res_blocks: 2,
74 channel_mult: vec![1, 1, 2, 2, 4, 4],
75 num_heads: 4,
76 num_timesteps: 1000,
77 beta_schedule: BetaSchedule::Linear,
78 }
79 }
80
81 pub fn tiny() -> Self {
83 DiffusionConfig {
84 image_size: 32,
85 in_channels: 3,
86 model_channels: 64,
87 num_res_blocks: 1,
88 channel_mult: vec![1, 2],
89 num_heads: 2,
90 num_timesteps: 100,
91 beta_schedule: BetaSchedule::Linear,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq)]
98pub enum BetaSchedule {
99 Linear,
100 Cosine,
101 ScaledLinear,
102}
103
104pub struct NoiseScheduler {
106 betas: Vec<f32>,
108 alphas: Vec<f32>,
110 alphas_cumprod: Vec<f32>,
112 num_timesteps: usize,
114}
115
116impl NoiseScheduler {
117 pub fn new(num_timesteps: usize, schedule: BetaSchedule) -> Self {
119 let betas = Self::compute_betas(num_timesteps, schedule);
120 let alphas: Vec<f32> = betas.iter().map(|b| 1.0 - b).collect();
121
122 let mut alphas_cumprod = Vec::with_capacity(num_timesteps);
124 let mut prod = 1.0;
125 for &alpha in &alphas {
126 prod *= alpha;
127 alphas_cumprod.push(prod);
128 }
129
130 NoiseScheduler {
131 betas,
132 alphas,
133 alphas_cumprod,
134 num_timesteps,
135 }
136 }
137
138 fn compute_betas(num_timesteps: usize, schedule: BetaSchedule) -> Vec<f32> {
140 match schedule {
141 BetaSchedule::Linear => {
142 let beta_start = 0.0001;
143 let beta_end = 0.02;
144 (0..num_timesteps)
145 .map(|t| {
146 let frac = t as f32 / (num_timesteps - 1) as f32;
147 beta_start + (beta_end - beta_start) * frac
148 })
149 .collect()
150 }
151 BetaSchedule::Cosine => {
152 let s = 0.008;
153 (0..num_timesteps)
154 .map(|t| {
155 let t_frac = t as f32 / num_timesteps as f32;
156 let alpha_bar = ((t_frac + s) / (1.0 + s) * std::f32::consts::PI / 2.0).cos().powi(2);
157 let alpha_bar_prev = if t > 0 {
158 (((t - 1) as f32 / num_timesteps as f32 + s) / (1.0 + s) * std::f32::consts::PI / 2.0).cos().powi(2)
159 } else {
160 1.0
161 };
162 (1.0 - alpha_bar / alpha_bar_prev).min(0.999)
163 })
164 .collect()
165 }
166 BetaSchedule::ScaledLinear => {
167 let beta_start = 0.00085_f32.sqrt();
168 let beta_end = 0.012_f32.sqrt();
169 (0..num_timesteps)
170 .map(|t| {
171 let frac = t as f32 / (num_timesteps - 1) as f32;
172 let beta = beta_start + (beta_end - beta_start) * frac;
173 beta * beta
174 })
175 .collect()
176 }
177 }
178 }
179
180 pub fn add_noise(&self, x0: &Tensor, noise: &Tensor, timestep: usize) -> Result<Tensor, String> {
182 if timestep >= self.num_timesteps {
183 return Err(format!("Timestep {} out of range", timestep));
184 }
185
186 let alpha_cumprod = self.alphas_cumprod[timestep];
187 let sqrt_alpha = alpha_cumprod.sqrt();
188 let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
189
190 let x0_data = x0.data_f32();
191 let noise_data = noise.data_f32();
192
193 let noisy: Vec<f32> = x0_data.iter()
194 .zip(noise_data.iter())
195 .map(|(&x, &n)| sqrt_alpha * x + sqrt_one_minus_alpha * n)
196 .collect();
197
198 Tensor::from_slice(&noisy, x0.dims())
199 .map_err(|e| format!("Failed to create noisy tensor: {:?}", e))
200 }
201
202 pub fn predict_x0(&self, xt: &Tensor, noise_pred: &Tensor, timestep: usize) -> Result<Tensor, String> {
204 if timestep >= self.num_timesteps {
205 return Err(format!("Timestep {} out of range", timestep));
206 }
207
208 let alpha_cumprod = self.alphas_cumprod[timestep];
209 let sqrt_alpha = alpha_cumprod.sqrt();
210 let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
211
212 let xt_data = xt.data_f32();
213 let noise_data = noise_pred.data_f32();
214
215 let x0: Vec<f32> = xt_data.iter()
216 .zip(noise_data.iter())
217 .map(|(&x, &n)| (x - sqrt_one_minus_alpha * n) / sqrt_alpha)
218 .collect();
219
220 Tensor::from_slice(&x0, xt.dims())
221 .map_err(|e| format!("Failed to create x0 tensor: {:?}", e))
222 }
223}
224
225pub struct ResidualBlock {
227 conv1: Conv2d,
228 conv2: Conv2d,
229 norm1: GroupNorm,
230 norm2: GroupNorm,
231 time_emb_proj: Linear,
232 skip_conv: Option<Conv2d>,
233}
234
235impl ResidualBlock {
236 pub fn new(in_channels: usize, out_channels: usize, time_emb_dim: usize) -> Self {
238 let conv1 = Conv2d::new(in_channels, out_channels, 3, 1, 1);
239 let conv2 = Conv2d::new(out_channels, out_channels, 3, 1, 1);
240 let norm1 = GroupNorm::new(32, in_channels);
241 let norm2 = GroupNorm::new(32, out_channels);
242 let time_emb_proj = Linear::new(time_emb_dim, out_channels);
243
244 let skip_conv = if in_channels != out_channels {
245 Some(Conv2d::new(in_channels, out_channels, 1, 1, 0))
246 } else {
247 None
248 };
249
250 ResidualBlock {
251 conv1,
252 conv2,
253 norm1,
254 norm2,
255 time_emb_proj,
256 skip_conv,
257 }
258 }
259
260 pub fn forward(&self, x: &Tensor, time_emb: &Tensor) -> Tensor {
262 let h = self.norm1.forward(x);
263 let h = h.relu();
264 let h = self.conv1.forward(&h);
265
266 let time_proj = self.time_emb_proj.forward(time_emb);
268 let h = h.add(&time_proj).unwrap_or(h);
270
271 let h = self.norm2.forward(&h);
272 let h = h.relu();
273 let h = self.conv2.forward(&h);
274
275 let skip = if let Some(ref conv) = self.skip_conv {
277 conv.forward(x)
278 } else {
279 x.clone()
280 };
281
282 h.add(&skip).unwrap_or(h)
283 }
284}
285
286pub struct UNet {
288 time_emb_dim: usize,
290 time_mlp: Vec<Linear>,
292 encoder_blocks: Vec<ResidualBlock>,
294 decoder_blocks: Vec<ResidualBlock>,
296 out_conv: Conv2d,
298}
299
300impl UNet {
301 pub fn new(config: &DiffusionConfig) -> Self {
303 let time_emb_dim = config.model_channels * 4;
304
305 let time_mlp = vec![
307 Linear::new(config.model_channels, time_emb_dim),
308 Linear::new(time_emb_dim, time_emb_dim),
309 ];
310
311 let mut encoder_blocks = Vec::new();
313 let mut in_ch = config.in_channels;
314
315 for &mult in &config.channel_mult {
316 let out_ch = config.model_channels * mult;
317 for _ in 0..config.num_res_blocks {
318 encoder_blocks.push(ResidualBlock::new(in_ch, out_ch, time_emb_dim));
319 in_ch = out_ch;
320 }
321 }
322
323 let mut decoder_blocks = Vec::new();
325 for &mult in config.channel_mult.iter().rev() {
326 let out_ch = config.model_channels * mult;
327 for _ in 0..config.num_res_blocks {
328 decoder_blocks.push(ResidualBlock::new(in_ch, out_ch, time_emb_dim));
329 in_ch = out_ch;
330 }
331 }
332
333 let out_conv = Conv2d::new(in_ch, config.in_channels, 3, 1, 1);
334
335 UNet {
336 time_emb_dim,
337 time_mlp,
338 encoder_blocks,
339 decoder_blocks,
340 out_conv,
341 }
342 }
343
344 fn get_time_embedding(&self, timestep: usize) -> Tensor {
346 let half_dim = self.time_emb_dim / 2;
348 let emb: Vec<f32> = (0..half_dim)
349 .flat_map(|i| {
350 let freq = (timestep as f32 * (-10000_f32.ln() * i as f32 / half_dim as f32).exp()).sin();
351 let freq_cos = (timestep as f32 * (-10000_f32.ln() * i as f32 / half_dim as f32).exp()).cos();
352 vec![freq, freq_cos]
353 })
354 .collect();
355
356 Tensor::from_slice(&emb, &[self.time_emb_dim]).unwrap()
357 }
358
359 pub fn forward(&self, x: &Tensor, timestep: usize) -> Tensor {
361 let mut time_emb = self.get_time_embedding(timestep);
363
364 for layer in &self.time_mlp {
366 time_emb = layer.forward(&time_emb);
367 time_emb = time_emb.relu();
368 }
369
370 let mut h = x.clone();
372 for block in &self.encoder_blocks {
373 h = block.forward(&h, &time_emb);
374 }
375
376 for block in &self.decoder_blocks {
378 h = block.forward(&h, &time_emb);
379 }
380
381 self.out_conv.forward(&h)
383 }
384}
385
386pub struct DDPM {
388 config: DiffusionConfig,
390 unet: UNet,
392 scheduler: NoiseScheduler,
394}
395
396impl DDPM {
397 pub fn new(config: DiffusionConfig) -> Self {
399 let unet = UNet::new(&config);
400 let scheduler = NoiseScheduler::new(config.num_timesteps, config.beta_schedule);
401
402 DDPM {
403 config,
404 unet,
405 scheduler,
406 }
407 }
408
409 pub fn forward(&self, x0: &Tensor, timestep: usize) -> Result<(Tensor, Tensor), String> {
411 let noise = Tensor::randn(x0.dims());
413
414 let xt = self.scheduler.add_noise(x0, &noise, timestep)?;
416
417 let noise_pred = self.unet.forward(&xt, timestep);
419
420 Ok((noise_pred, noise))
421 }
422
423 pub fn sample(&self, batch_size: usize) -> Result<Tensor, String> {
425 let mut xt = Tensor::randn(&[
427 batch_size,
428 self.config.in_channels,
429 self.config.image_size,
430 self.config.image_size,
431 ]);
432
433 for t in (0..self.config.num_timesteps).rev() {
435 let noise_pred = self.unet.forward(&xt, t);
437
438 xt = self.denoise_step(&xt, &noise_pred, t)?;
440 }
441
442 Ok(xt)
443 }
444
445 fn denoise_step(&self, xt: &Tensor, noise_pred: &Tensor, timestep: usize) -> Result<Tensor, String> {
447 let x0_pred = self.scheduler.predict_x0(xt, noise_pred, timestep)?;
449
450 if timestep == 0 {
451 return Ok(x0_pred);
452 }
453
454 let alpha_t = self.scheduler.alphas_cumprod[timestep];
456 let alpha_t_prev = if timestep > 0 {
457 self.scheduler.alphas_cumprod[timestep - 1]
458 } else {
459 1.0
460 };
461
462 let xt_data = xt.data_f32();
463 let x0_data = x0_pred.data_f32();
464
465 let xt_prev: Vec<f32> = xt_data.iter()
466 .zip(x0_data.iter())
467 .map(|(&x, &x0)| {
468 let coef1 = alpha_t_prev.sqrt();
469 let coef2 = (1.0 - alpha_t_prev).sqrt();
470 coef1 * x0 + coef2 * ((x - alpha_t.sqrt() * x0) / (1.0 - alpha_t).sqrt())
471 })
472 .collect();
473
474 Tensor::from_slice(&xt_prev, xt.dims())
475 .map_err(|e| format!("Failed to create denoised tensor: {:?}", e))
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_noise_scheduler() {
485 let scheduler = NoiseScheduler::new(100, BetaSchedule::Linear);
486 assert_eq!(scheduler.betas.len(), 100);
487 assert_eq!(scheduler.alphas.len(), 100);
488 assert_eq!(scheduler.alphas_cumprod.len(), 100);
489
490 for i in 1..100 {
492 assert!(scheduler.alphas_cumprod[i] < scheduler.alphas_cumprod[i-1]);
493 }
494 }
495
496 #[test]
497 fn test_add_noise() {
498 let scheduler = NoiseScheduler::new(100, BetaSchedule::Linear);
499 let x0 = Tensor::randn(&[2, 3, 32, 32]);
500 let noise = Tensor::randn(&[2, 3, 32, 32]);
501
502 let xt = scheduler.add_noise(&x0, &noise, 50).unwrap();
503 assert_eq!(xt.dims(), &[2, 3, 32, 32]);
504 }
505
506 #[test]
507 fn test_ddpm_config() {
508 let config = DiffusionConfig::tiny();
509 assert_eq!(config.image_size, 32);
510 assert_eq!(config.num_timesteps, 100);
511
512 let config = DiffusionConfig::stable_diffusion();
513 assert_eq!(config.image_size, 512);
514 assert_eq!(config.in_channels, 4);
515 }
516
517 #[test]
518 fn test_ddpm_forward() {
519 let config = DiffusionConfig::tiny();
520 let ddpm = DDPM::new(config);
521
522 let x0 = Tensor::randn(&[2, 3, 32, 32]);
523 let (noise_pred, noise) = ddpm.forward(&x0, 50).unwrap();
524
525 assert_eq!(noise_pred.dims(), noise.dims());
526 }
527}