oxigaf_diffusion/config.rs
1//! Configuration for the multi-view diffusion pipeline.
2//!
3//! # Classifier-Free Guidance (CFG)
4//!
5//! The pipeline uses CFG to control the strength of IP-Adapter conditioning.
6//! CFG interpolates between conditional and unconditional predictions:
7//!
8//! ```text
9//! prediction = unconditional + guidance_scale * (conditional - unconditional)
10//! ```
11//!
12//! ## How CFG Works in GAF
13//!
14//! 1. **Conditional Pass**: U-Net forward pass WITH IP-Adapter tokens from
15//! the reference image (CLIP embeddings)
16//! 2. **Unconditional Pass**: U-Net forward pass WITHOUT IP-Adapter tokens
17//! (skips reference conditioning)
18//! 3. **Interpolation**: Combine predictions based on `guidance_scale`
19//!
20//! ## Guidance Scale Selection
21//!
22//! - **1.0**: Pure conditional (no guidance, equivalent to single forward pass)
23//! - **3.0-7.5**: Balanced (recommended for GAF, default: 3.0)
24//! - **>10.0**: Strong conditioning (may oversaturate or reduce diversity)
25//!
26//! # IP-Adapter Architecture
27//!
28//! IP-Adapter provides pixel-level identity preservation by conditioning on
29//! CLIP image embeddings. The architecture includes:
30//!
31//! - **CLIP Encoder**: ViT-H/14 encodes reference image to 257×1280 embeddings
32//! - **Projection**: Linear projection from 1280 → 1024 (cross_attention_dim)
33//! - **IP Cross-Attention**: Dedicated `attn_ip` layer in each transformer block
34//! - **Integration**: Each spatial position attends to image tokens
35//!
36//! This differs from text conditioning by providing direct visual features
37//! rather than semantic embeddings.
38
39use crate::upsampler::UpsamplerMode;
40
41/// Full configuration for the multi-view diffusion model.
42///
43/// Contains all hyperparameters for the diffusion pipeline, including U-Net
44/// architecture, attention settings, CFG parameters, and optional upsampling.
45///
46/// # Examples
47///
48/// ```rust
49/// use oxigaf_diffusion::DiffusionConfig;
50///
51/// // Use default configuration (256×256, guidance_scale=3.0)
52/// let config = DiffusionConfig::default();
53///
54/// // Customize guidance scale for stronger conditioning
55/// let mut config = DiffusionConfig::default();
56/// config.guidance_scale = 7.5;
57///
58/// // Enable upsampling for 512×512 output
59/// use oxigaf_diffusion::UpsamplerMode;
60/// config.upsampler_mode = Some(UpsamplerMode::SdX2);
61/// ```
62#[derive(Debug, Clone)]
63pub struct DiffusionConfig {
64 /// Number of views to generate simultaneously (default: 4).
65 pub num_views: usize,
66 /// Classifier-free guidance scale for IP-Adapter conditioning (default: 3.0).
67 ///
68 /// Controls the strength of reference image conditioning. Must be >= 1.0.
69 /// Higher values increase identity preservation but may reduce diversity.
70 ///
71 /// - **1.0**: No guidance (pure conditional)
72 /// - **3.0-7.5**: Balanced (recommended)
73 /// - **>10.0**: Strong conditioning (may oversaturate)
74 pub guidance_scale: f64,
75 /// Number of DDIM denoising steps (default: 50).
76 pub num_inference_steps: usize,
77 /// Number of latent upsampler denoising steps (default: 10).
78 pub upsampler_steps: usize,
79 /// Input/output image resolution before upscaling (default: 256).
80 pub image_size: usize,
81 /// Latent spatial size (image_size / 8).
82 pub latent_size: usize,
83 /// Number of latent channels produced by the VAE (default: 4).
84 pub latent_channels: usize,
85 /// U-Net input channels: latent_channels + normal-map latent channels (default: 8).
86 pub unet_in_channels: usize,
87 /// U-Net output channels (default: 4).
88 pub unet_out_channels: usize,
89 /// Cross-attention dimension (SD 2.1 = 1024).
90 pub cross_attention_dim: usize,
91 /// CLIP image embedding dimension (ViT-H/14 = 1280).
92 pub clip_embed_dim: usize,
93 /// Time embedding dimension (default: 1280).
94 pub time_embed_dim: usize,
95 /// Base channels for the U-Net (default: 320).
96 pub base_channels: usize,
97 /// Channel multipliers per U-Net stage.
98 pub channel_mult: Vec<usize>,
99 /// Layers per block in the U-Net.
100 pub layers_per_block: usize,
101 /// Number of attention heads per head-dim for each stage.
102 pub attention_head_dim: Vec<usize>,
103 /// Number of transformer blocks per attention stage.
104 pub transformer_layers_per_block: Vec<usize>,
105 /// Group-norm number of groups (default: 32).
106 pub norm_num_groups: usize,
107 /// Group-norm epsilon.
108 pub norm_eps: f64,
109 /// Camera pose input dimension (4×3 flattened = 12).
110 pub camera_pose_dim: usize,
111 /// Whether to use linear projection in spatial transformer.
112 pub use_linear_projection: bool,
113 /// VAE scaling factor for latent space.
114 pub vae_scale_factor: f64,
115 /// Whether to use flash attention for memory-efficient O(N) attention.
116 /// When enabled, uses block-wise computation with online softmax.
117 /// Falls back to standard O(N^2) attention when disabled.
118 /// Default: true (when feature is enabled).
119 pub use_flash_attention: bool,
120 /// Block size for flash attention tiled computation. Larger blocks use more
121 /// memory but may be faster due to better cache utilization. Default: 64.
122 pub flash_attention_block_size: usize,
123 /// Upsampler mode for latent upsampling (32×32 → 64×64).
124 /// - None: No upsampling, output is 256×256
125 /// - Some(SdX2): Use sd-x2-latent-upscaler, output is 512×512
126 /// - Some(BilinearVae): Use bilinear upsampling, output is 512×512
127 ///
128 /// Default: None (256×256 output).
129 pub upsampler_mode: Option<UpsamplerMode>,
130}
131
132impl Default for DiffusionConfig {
133 fn default() -> Self {
134 Self {
135 num_views: 4,
136 guidance_scale: 3.0,
137 num_inference_steps: 50,
138 upsampler_steps: 10,
139 image_size: 256,
140 latent_size: 32,
141 latent_channels: 4,
142 unet_in_channels: 8,
143 unet_out_channels: 4,
144 cross_attention_dim: 1024,
145 clip_embed_dim: 1280,
146 time_embed_dim: 1280,
147 base_channels: 320,
148 channel_mult: vec![1, 2, 4, 4],
149 layers_per_block: 2,
150 attention_head_dim: vec![5, 10, 20, 20],
151 transformer_layers_per_block: vec![1, 1, 1, 1],
152 norm_num_groups: 32,
153 norm_eps: 1e-5,
154 camera_pose_dim: 12,
155 use_linear_projection: true,
156 vae_scale_factor: 0.18215,
157 // Flash attention is enabled by default when the feature is available
158 #[cfg(feature = "flash_attention")]
159 use_flash_attention: true,
160 #[cfg(not(feature = "flash_attention"))]
161 use_flash_attention: false,
162 flash_attention_block_size: 64,
163 upsampler_mode: None,
164 }
165 }
166}
167
168impl DiffusionConfig {
169 /// Channel count for a given U-Net stage index.
170 pub fn stage_channels(&self, stage: usize) -> usize {
171 self.base_channels * self.channel_mult[stage]
172 }
173
174 /// Total number of U-Net stages.
175 pub fn num_stages(&self) -> usize {
176 self.channel_mult.len()
177 }
178}