ghostflow_nn/
diffusion.rs

1//! Diffusion Models (Stable Diffusion, DDPM, DDIM)
2//!
3//! Implements diffusion models for image generation:
4//! - DDPM (Denoising Diffusion Probabilistic Models)
5//! - DDIM (Denoising Diffusion Implicit Models)
6//! - Stable Diffusion architecture
7//! - Noise scheduling
8//! - Sampling algorithms
9
10use ghostflow_core::Tensor;
11use crate::conv::Conv2d;
12use crate::norm::GroupNorm;
13use crate::linear::Linear;
14use crate::Module;
15
16/// Diffusion model configuration
17#[derive(Debug, Clone)]
18pub struct DiffusionConfig {
19    /// Image size
20    pub image_size: usize,
21    /// Number of channels
22    pub in_channels: usize,
23    /// Model channels
24    pub model_channels: usize,
25    /// Number of residual blocks per resolution
26    pub num_res_blocks: usize,
27    /// Channel multipliers for each resolution
28    pub channel_mult: Vec<usize>,
29    /// Number of attention heads
30    pub num_heads: usize,
31    /// Number of diffusion timesteps
32    pub num_timesteps: usize,
33    /// Beta schedule type
34    pub beta_schedule: BetaSchedule,
35}
36
37impl Default for DiffusionConfig {
38    fn default() -> Self {
39        DiffusionConfig {
40            image_size: 256,
41            in_channels: 3,
42            model_channels: 128,
43            num_res_blocks: 2,
44            channel_mult: vec![1, 2, 4, 8],
45            num_heads: 8,
46            num_timesteps: 1000,
47            beta_schedule: BetaSchedule::Linear,
48        }
49    }
50}
51
52impl DiffusionConfig {
53    /// Stable Diffusion configuration
54    pub fn stable_diffusion() -> Self {
55        DiffusionConfig {
56            image_size: 512,
57            in_channels: 4, // Latent space
58            model_channels: 320,
59            num_res_blocks: 2,
60            channel_mult: vec![1, 2, 4, 4],
61            num_heads: 8,
62            num_timesteps: 1000,
63            beta_schedule: BetaSchedule::ScaledLinear,
64        }
65    }
66    
67    /// DDPM configuration (original paper)
68    pub fn ddpm() -> Self {
69        DiffusionConfig {
70            image_size: 256,
71            in_channels: 3,
72            model_channels: 128,
73            num_res_blocks: 2,
74            channel_mult: vec![1, 1, 2, 2, 4, 4],
75            num_heads: 4,
76            num_timesteps: 1000,
77            beta_schedule: BetaSchedule::Linear,
78        }
79    }
80    
81    /// Small model for testing
82    pub fn tiny() -> Self {
83        DiffusionConfig {
84            image_size: 32,
85            in_channels: 3,
86            model_channels: 64,
87            num_res_blocks: 1,
88            channel_mult: vec![1, 2],
89            num_heads: 2,
90            num_timesteps: 100,
91            beta_schedule: BetaSchedule::Linear,
92        }
93    }
94}
95
96/// Beta schedule for noise
97#[derive(Debug, Clone, Copy, PartialEq)]
98pub enum BetaSchedule {
99    Linear,
100    Cosine,
101    ScaledLinear,
102}
103
104/// Noise scheduler
105pub struct NoiseScheduler {
106    /// Beta values (variance schedule)
107    betas: Vec<f32>,
108    /// Alpha values (1 - beta)
109    alphas: Vec<f32>,
110    /// Cumulative product of alphas
111    alphas_cumprod: Vec<f32>,
112    /// Number of timesteps
113    num_timesteps: usize,
114}
115
116impl NoiseScheduler {
117    /// Create new noise scheduler
118    pub fn new(num_timesteps: usize, schedule: BetaSchedule) -> Self {
119        let betas = Self::compute_betas(num_timesteps, schedule);
120        let alphas: Vec<f32> = betas.iter().map(|b| 1.0 - b).collect();
121        
122        // Compute cumulative product
123        let mut alphas_cumprod = Vec::with_capacity(num_timesteps);
124        let mut prod = 1.0;
125        for &alpha in &alphas {
126            prod *= alpha;
127            alphas_cumprod.push(prod);
128        }
129        
130        NoiseScheduler {
131            betas,
132            alphas,
133            alphas_cumprod,
134            num_timesteps,
135        }
136    }
137    
138    /// Compute beta schedule
139    fn compute_betas(num_timesteps: usize, schedule: BetaSchedule) -> Vec<f32> {
140        match schedule {
141            BetaSchedule::Linear => {
142                let beta_start = 0.0001;
143                let beta_end = 0.02;
144                (0..num_timesteps)
145                    .map(|t| {
146                        let frac = t as f32 / (num_timesteps - 1) as f32;
147                        beta_start + (beta_end - beta_start) * frac
148                    })
149                    .collect()
150            }
151            BetaSchedule::Cosine => {
152                let s = 0.008;
153                (0..num_timesteps)
154                    .map(|t| {
155                        let t_frac = t as f32 / num_timesteps as f32;
156                        let alpha_bar = ((t_frac + s) / (1.0 + s) * std::f32::consts::PI / 2.0).cos().powi(2);
157                        let alpha_bar_prev = if t > 0 {
158                            (((t - 1) as f32 / num_timesteps as f32 + s) / (1.0 + s) * std::f32::consts::PI / 2.0).cos().powi(2)
159                        } else {
160                            1.0
161                        };
162                        (1.0 - alpha_bar / alpha_bar_prev).min(0.999)
163                    })
164                    .collect()
165            }
166            BetaSchedule::ScaledLinear => {
167                let beta_start = 0.00085_f32.sqrt();
168                let beta_end = 0.012_f32.sqrt();
169                (0..num_timesteps)
170                    .map(|t| {
171                        let frac = t as f32 / (num_timesteps - 1) as f32;
172                        let beta = beta_start + (beta_end - beta_start) * frac;
173                        beta * beta
174                    })
175                    .collect()
176            }
177        }
178    }
179    
180    /// Add noise to image (forward diffusion)
181    pub fn add_noise(&self, x0: &Tensor, noise: &Tensor, timestep: usize) -> Result<Tensor, String> {
182        if timestep >= self.num_timesteps {
183            return Err(format!("Timestep {} out of range", timestep));
184        }
185        
186        let alpha_cumprod = self.alphas_cumprod[timestep];
187        let sqrt_alpha = alpha_cumprod.sqrt();
188        let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
189        
190        let x0_data = x0.data_f32();
191        let noise_data = noise.data_f32();
192        
193        let noisy: Vec<f32> = x0_data.iter()
194            .zip(noise_data.iter())
195            .map(|(&x, &n)| sqrt_alpha * x + sqrt_one_minus_alpha * n)
196            .collect();
197        
198        Tensor::from_slice(&noisy, x0.dims())
199            .map_err(|e| format!("Failed to create noisy tensor: {:?}", e))
200    }
201    
202    /// Predict original image from noisy image and predicted noise
203    pub fn predict_x0(&self, xt: &Tensor, noise_pred: &Tensor, timestep: usize) -> Result<Tensor, String> {
204        if timestep >= self.num_timesteps {
205            return Err(format!("Timestep {} out of range", timestep));
206        }
207        
208        let alpha_cumprod = self.alphas_cumprod[timestep];
209        let sqrt_alpha = alpha_cumprod.sqrt();
210        let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
211        
212        let xt_data = xt.data_f32();
213        let noise_data = noise_pred.data_f32();
214        
215        let x0: Vec<f32> = xt_data.iter()
216            .zip(noise_data.iter())
217            .map(|(&x, &n)| (x - sqrt_one_minus_alpha * n) / sqrt_alpha)
218            .collect();
219        
220        Tensor::from_slice(&x0, xt.dims())
221            .map_err(|e| format!("Failed to create x0 tensor: {:?}", e))
222    }
223}
224
225/// Residual block for U-Net
226pub struct ResidualBlock {
227    conv1: Conv2d,
228    conv2: Conv2d,
229    norm1: GroupNorm,
230    norm2: GroupNorm,
231    time_emb_proj: Linear,
232    skip_conv: Option<Conv2d>,
233}
234
235impl ResidualBlock {
236    /// Create new residual block
237    pub fn new(in_channels: usize, out_channels: usize, time_emb_dim: usize) -> Self {
238        let conv1 = Conv2d::new(in_channels, out_channels, 3, 1, 1);
239        let conv2 = Conv2d::new(out_channels, out_channels, 3, 1, 1);
240        let norm1 = GroupNorm::new(32, in_channels);
241        let norm2 = GroupNorm::new(32, out_channels);
242        let time_emb_proj = Linear::new(time_emb_dim, out_channels);
243        
244        let skip_conv = if in_channels != out_channels {
245            Some(Conv2d::new(in_channels, out_channels, 1, 1, 0))
246        } else {
247            None
248        };
249        
250        ResidualBlock {
251            conv1,
252            conv2,
253            norm1,
254            norm2,
255            time_emb_proj,
256            skip_conv,
257        }
258    }
259    
260    /// Forward pass
261    pub fn forward(&self, x: &Tensor, time_emb: &Tensor) -> Tensor {
262        let h = self.norm1.forward(x);
263        let h = h.relu();
264        let h = self.conv1.forward(&h);
265        
266        // Add time embedding
267        let time_proj = self.time_emb_proj.forward(time_emb);
268        // Broadcast time embedding to spatial dimensions (simplified)
269        let h = h.add(&time_proj).unwrap_or(h);
270        
271        let h = self.norm2.forward(&h);
272        let h = h.relu();
273        let h = self.conv2.forward(&h);
274        
275        // Skip connection
276        let skip = if let Some(ref conv) = self.skip_conv {
277            conv.forward(x)
278        } else {
279            x.clone()
280        };
281        
282        h.add(&skip).unwrap_or(h)
283    }
284}
285
286/// U-Net for diffusion models
287pub struct UNet {
288    /// Time embedding dimension
289    time_emb_dim: usize,
290    /// Time embedding MLP
291    time_mlp: Vec<Linear>,
292    /// Encoder blocks
293    encoder_blocks: Vec<ResidualBlock>,
294    /// Decoder blocks
295    decoder_blocks: Vec<ResidualBlock>,
296    /// Output convolution
297    out_conv: Conv2d,
298}
299
300impl UNet {
301    /// Create new U-Net
302    pub fn new(config: &DiffusionConfig) -> Self {
303        let time_emb_dim = config.model_channels * 4;
304        
305        // Time embedding MLP
306        let time_mlp = vec![
307            Linear::new(config.model_channels, time_emb_dim),
308            Linear::new(time_emb_dim, time_emb_dim),
309        ];
310        
311        // Simplified encoder/decoder (real implementation would be more complex)
312        let mut encoder_blocks = Vec::new();
313        let mut in_ch = config.in_channels;
314        
315        for &mult in &config.channel_mult {
316            let out_ch = config.model_channels * mult;
317            for _ in 0..config.num_res_blocks {
318                encoder_blocks.push(ResidualBlock::new(in_ch, out_ch, time_emb_dim));
319                in_ch = out_ch;
320            }
321        }
322        
323        // Decoder (mirror of encoder)
324        let mut decoder_blocks = Vec::new();
325        for &mult in config.channel_mult.iter().rev() {
326            let out_ch = config.model_channels * mult;
327            for _ in 0..config.num_res_blocks {
328                decoder_blocks.push(ResidualBlock::new(in_ch, out_ch, time_emb_dim));
329                in_ch = out_ch;
330            }
331        }
332        
333        let out_conv = Conv2d::new(in_ch, config.in_channels, 3, 1, 1);
334        
335        UNet {
336            time_emb_dim,
337            time_mlp,
338            encoder_blocks,
339            decoder_blocks,
340            out_conv,
341        }
342    }
343    
344    /// Compute time embeddings
345    fn get_time_embedding(&self, timestep: usize) -> Tensor {
346        // Sinusoidal position embedding
347        let half_dim = self.time_emb_dim / 2;
348        let emb: Vec<f32> = (0..half_dim)
349            .flat_map(|i| {
350                let freq = (timestep as f32 * (-10000_f32.ln() * i as f32 / half_dim as f32).exp()).sin();
351                let freq_cos = (timestep as f32 * (-10000_f32.ln() * i as f32 / half_dim as f32).exp()).cos();
352                vec![freq, freq_cos]
353            })
354            .collect();
355        
356        Tensor::from_slice(&emb, &[self.time_emb_dim]).unwrap()
357    }
358    
359    /// Forward pass
360    pub fn forward(&self, x: &Tensor, timestep: usize) -> Tensor {
361        // Get time embedding
362        let mut time_emb = self.get_time_embedding(timestep);
363        
364        // Time MLP
365        for layer in &self.time_mlp {
366            time_emb = layer.forward(&time_emb);
367            time_emb = time_emb.relu();
368        }
369        
370        // Encoder
371        let mut h = x.clone();
372        for block in &self.encoder_blocks {
373            h = block.forward(&h, &time_emb);
374        }
375        
376        // Decoder
377        for block in &self.decoder_blocks {
378            h = block.forward(&h, &time_emb);
379        }
380        
381        // Output
382        self.out_conv.forward(&h)
383    }
384}
385
386/// DDPM (Denoising Diffusion Probabilistic Model)
387pub struct DDPM {
388    /// Configuration
389    config: DiffusionConfig,
390    /// U-Net model
391    unet: UNet,
392    /// Noise scheduler
393    scheduler: NoiseScheduler,
394}
395
396impl DDPM {
397    /// Create new DDPM
398    pub fn new(config: DiffusionConfig) -> Self {
399        let unet = UNet::new(&config);
400        let scheduler = NoiseScheduler::new(config.num_timesteps, config.beta_schedule);
401        
402        DDPM {
403            config,
404            unet,
405            scheduler,
406        }
407    }
408    
409    /// Training forward pass
410    pub fn forward(&self, x0: &Tensor, timestep: usize) -> Result<(Tensor, Tensor), String> {
411        // Sample noise
412        let noise = Tensor::randn(x0.dims());
413        
414        // Add noise to image
415        let xt = self.scheduler.add_noise(x0, &noise, timestep)?;
416        
417        // Predict noise
418        let noise_pred = self.unet.forward(&xt, timestep);
419        
420        Ok((noise_pred, noise))
421    }
422    
423    /// Sample (generate) images
424    pub fn sample(&self, batch_size: usize) -> Result<Tensor, String> {
425        // Start from pure noise
426        let mut xt = Tensor::randn(&[
427            batch_size,
428            self.config.in_channels,
429            self.config.image_size,
430            self.config.image_size,
431        ]);
432        
433        // Reverse diffusion process
434        for t in (0..self.config.num_timesteps).rev() {
435            // Predict noise
436            let noise_pred = self.unet.forward(&xt, t);
437            
438            // Denoise one step
439            xt = self.denoise_step(&xt, &noise_pred, t)?;
440        }
441        
442        Ok(xt)
443    }
444    
445    /// Single denoising step
446    fn denoise_step(&self, xt: &Tensor, noise_pred: &Tensor, timestep: usize) -> Result<Tensor, String> {
447        // Predict x0
448        let x0_pred = self.scheduler.predict_x0(xt, noise_pred, timestep)?;
449        
450        if timestep == 0 {
451            return Ok(x0_pred);
452        }
453        
454        // Add noise for next step (simplified DDPM sampling)
455        let alpha_t = self.scheduler.alphas_cumprod[timestep];
456        let alpha_t_prev = if timestep > 0 {
457            self.scheduler.alphas_cumprod[timestep - 1]
458        } else {
459            1.0
460        };
461        
462        let xt_data = xt.data_f32();
463        let x0_data = x0_pred.data_f32();
464        
465        let xt_prev: Vec<f32> = xt_data.iter()
466            .zip(x0_data.iter())
467            .map(|(&x, &x0)| {
468                let coef1 = alpha_t_prev.sqrt();
469                let coef2 = (1.0 - alpha_t_prev).sqrt();
470                coef1 * x0 + coef2 * ((x - alpha_t.sqrt() * x0) / (1.0 - alpha_t).sqrt())
471            })
472            .collect();
473        
474        Tensor::from_slice(&xt_prev, xt.dims())
475            .map_err(|e| format!("Failed to create denoised tensor: {:?}", e))
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    
483    #[test]
484    fn test_noise_scheduler() {
485        let scheduler = NoiseScheduler::new(100, BetaSchedule::Linear);
486        assert_eq!(scheduler.betas.len(), 100);
487        assert_eq!(scheduler.alphas.len(), 100);
488        assert_eq!(scheduler.alphas_cumprod.len(), 100);
489        
490        // Check that alphas_cumprod is decreasing
491        for i in 1..100 {
492            assert!(scheduler.alphas_cumprod[i] < scheduler.alphas_cumprod[i-1]);
493        }
494    }
495    
496    #[test]
497    fn test_add_noise() {
498        let scheduler = NoiseScheduler::new(100, BetaSchedule::Linear);
499        let x0 = Tensor::randn(&[2, 3, 32, 32]);
500        let noise = Tensor::randn(&[2, 3, 32, 32]);
501        
502        let xt = scheduler.add_noise(&x0, &noise, 50).unwrap();
503        assert_eq!(xt.dims(), &[2, 3, 32, 32]);
504    }
505    
506    #[test]
507    fn test_ddpm_config() {
508        let config = DiffusionConfig::tiny();
509        assert_eq!(config.image_size, 32);
510        assert_eq!(config.num_timesteps, 100);
511        
512        let config = DiffusionConfig::stable_diffusion();
513        assert_eq!(config.image_size, 512);
514        assert_eq!(config.in_channels, 4);
515    }
516    
517    #[test]
518    fn test_ddpm_forward() {
519        let config = DiffusionConfig::tiny();
520        let ddpm = DDPM::new(config);
521        
522        let x0 = Tensor::randn(&[2, 3, 32, 32]);
523        let (noise_pred, noise) = ddpm.forward(&x0, 50).unwrap();
524        
525        assert_eq!(noise_pred.dims(), noise.dims());
526    }
527}