Skip to main content

oxigaf_diffusion/
config.rs

1//! Configuration for the multi-view diffusion pipeline.
2//!
3//! # Classifier-Free Guidance (CFG)
4//!
5//! The pipeline uses CFG to control the strength of IP-Adapter conditioning.
6//! CFG interpolates between conditional and unconditional predictions:
7//!
8//! ```text
9//! prediction = unconditional + guidance_scale * (conditional - unconditional)
10//! ```
11//!
12//! ## How CFG Works in GAF
13//!
14//! 1. **Conditional Pass**: U-Net forward pass WITH IP-Adapter tokens from
15//!    the reference image (CLIP embeddings)
16//! 2. **Unconditional Pass**: U-Net forward pass WITHOUT IP-Adapter tokens
17//!    (skips reference conditioning)
18//! 3. **Interpolation**: Combine predictions based on `guidance_scale`
19//!
20//! ## Guidance Scale Selection
21//!
22//! - **1.0**: Pure conditional (no guidance, equivalent to single forward pass)
23//! - **3.0-7.5**: Balanced (recommended for GAF, default: 3.0)
24//! - **>10.0**: Strong conditioning (may oversaturate or reduce diversity)
25//!
26//! # IP-Adapter Architecture
27//!
28//! IP-Adapter provides pixel-level identity preservation by conditioning on
29//! CLIP image embeddings. The architecture includes:
30//!
31//! - **CLIP Encoder**: ViT-H/14 encodes reference image to 257×1280 embeddings
32//! - **Projection**: Linear projection from 1280 → 1024 (cross_attention_dim)
33//! - **IP Cross-Attention**: Dedicated `attn_ip` layer in each transformer block
34//! - **Integration**: Each spatial position attends to image tokens
35//!
36//! This differs from text conditioning by providing direct visual features
37//! rather than semantic embeddings.
38
39use crate::upsampler::UpsamplerMode;
40
41/// Full configuration for the multi-view diffusion model.
42///
43/// Contains all hyperparameters for the diffusion pipeline, including U-Net
44/// architecture, attention settings, CFG parameters, and optional upsampling.
45///
46/// # Examples
47///
48/// ```rust
49/// use oxigaf_diffusion::DiffusionConfig;
50///
51/// // Use default configuration (256×256, guidance_scale=3.0)
52/// let config = DiffusionConfig::default();
53///
54/// // Customize guidance scale for stronger conditioning
55/// let mut config = DiffusionConfig::default();
56/// config.guidance_scale = 7.5;
57///
58/// // Enable upsampling for 512×512 output
59/// use oxigaf_diffusion::UpsamplerMode;
60/// config.upsampler_mode = Some(UpsamplerMode::SdX2);
61/// ```
62#[derive(Debug, Clone)]
63pub struct DiffusionConfig {
64    /// Number of views to generate simultaneously (default: 4).
65    pub num_views: usize,
66    /// Classifier-free guidance scale for IP-Adapter conditioning (default: 3.0).
67    ///
68    /// Controls the strength of reference image conditioning. Must be >= 1.0.
69    /// Higher values increase identity preservation but may reduce diversity.
70    ///
71    /// - **1.0**: No guidance (pure conditional)
72    /// - **3.0-7.5**: Balanced (recommended)
73    /// - **>10.0**: Strong conditioning (may oversaturate)
74    pub guidance_scale: f64,
75    /// Number of DDIM denoising steps (default: 50).
76    pub num_inference_steps: usize,
77    /// Number of latent upsampler denoising steps (default: 10).
78    pub upsampler_steps: usize,
79    /// Input/output image resolution before upscaling (default: 256).
80    pub image_size: usize,
81    /// Latent spatial size (image_size / 8).
82    pub latent_size: usize,
83    /// Number of latent channels produced by the VAE (default: 4).
84    pub latent_channels: usize,
85    /// U-Net input channels: latent_channels + normal-map latent channels (default: 8).
86    pub unet_in_channels: usize,
87    /// U-Net output channels (default: 4).
88    pub unet_out_channels: usize,
89    /// Cross-attention dimension (SD 2.1 = 1024).
90    pub cross_attention_dim: usize,
91    /// CLIP image embedding dimension (ViT-H/14 = 1280).
92    pub clip_embed_dim: usize,
93    /// Time embedding dimension (default: 1280).
94    pub time_embed_dim: usize,
95    /// Base channels for the U-Net (default: 320).
96    pub base_channels: usize,
97    /// Channel multipliers per U-Net stage.
98    pub channel_mult: Vec<usize>,
99    /// Layers per block in the U-Net.
100    pub layers_per_block: usize,
101    /// Number of attention heads per head-dim for each stage.
102    pub attention_head_dim: Vec<usize>,
103    /// Number of transformer blocks per attention stage.
104    pub transformer_layers_per_block: Vec<usize>,
105    /// Group-norm number of groups (default: 32).
106    pub norm_num_groups: usize,
107    /// Group-norm epsilon.
108    pub norm_eps: f64,
109    /// Camera pose input dimension (4×3 flattened = 12).
110    pub camera_pose_dim: usize,
111    /// Whether to use linear projection in spatial transformer.
112    pub use_linear_projection: bool,
113    /// VAE scaling factor for latent space.
114    pub vae_scale_factor: f64,
115    /// Whether to use flash attention for memory-efficient O(N) attention.
116    /// When enabled, uses block-wise computation with online softmax.
117    /// Falls back to standard O(N^2) attention when disabled.
118    /// Default: true (when feature is enabled).
119    pub use_flash_attention: bool,
120    /// Block size for flash attention tiled computation. Larger blocks use more
121    /// memory but may be faster due to better cache utilization. Default: 64.
122    pub flash_attention_block_size: usize,
123    /// Upsampler mode for latent upsampling (32×32 → 64×64).
124    /// - None: No upsampling, output is 256×256
125    /// - Some(SdX2): Use sd-x2-latent-upscaler, output is 512×512
126    /// - Some(BilinearVae): Use bilinear upsampling, output is 512×512
127    ///
128    /// Default: None (256×256 output).
129    pub upsampler_mode: Option<UpsamplerMode>,
130}
131
132impl Default for DiffusionConfig {
133    fn default() -> Self {
134        Self {
135            num_views: 4,
136            guidance_scale: 3.0,
137            num_inference_steps: 50,
138            upsampler_steps: 10,
139            image_size: 256,
140            latent_size: 32,
141            latent_channels: 4,
142            unet_in_channels: 8,
143            unet_out_channels: 4,
144            cross_attention_dim: 1024,
145            clip_embed_dim: 1280,
146            time_embed_dim: 1280,
147            base_channels: 320,
148            channel_mult: vec![1, 2, 4, 4],
149            layers_per_block: 2,
150            attention_head_dim: vec![5, 10, 20, 20],
151            transformer_layers_per_block: vec![1, 1, 1, 1],
152            norm_num_groups: 32,
153            norm_eps: 1e-5,
154            camera_pose_dim: 12,
155            use_linear_projection: true,
156            vae_scale_factor: 0.18215,
157            // Flash attention is enabled by default when the feature is available
158            #[cfg(feature = "flash_attention")]
159            use_flash_attention: true,
160            #[cfg(not(feature = "flash_attention"))]
161            use_flash_attention: false,
162            flash_attention_block_size: 64,
163            upsampler_mode: None,
164        }
165    }
166}
167
168impl DiffusionConfig {
169    /// Channel count for a given U-Net stage index.
170    pub fn stage_channels(&self, stage: usize) -> usize {
171        self.base_channels * self.channel_mult[stage]
172    }
173
174    /// Total number of U-Net stages.
175    pub fn num_stages(&self) -> usize {
176        self.channel_mult.len()
177    }
178}