Skip to main content

oxigaf_diffusion/
clip.rs

1//! CLIP image encoder for reference-image conditioning.
2//!
3//! Implements the ViT-H/14 CLIP image encoder used to extract per-image
4//! embeddings that condition the multi-view diffusion model via IP-adapter
5//! cross-attention.
6
7use candle_core::{Result, Tensor, D};
8use candle_nn as nn;
9use candle_nn::Module;
10
11use crate::config::DiffusionConfig;
12
13// ---------------------------------------------------------------------------
14// Vision Transformer components
15// ---------------------------------------------------------------------------
16
17/// Multi-head self-attention for the CLIP ViT.
18#[derive(Debug)]
19struct ClipAttention {
20    q_proj: nn::Linear,
21    k_proj: nn::Linear,
22    v_proj: nn::Linear,
23    out_proj: nn::Linear,
24    num_heads: usize,
25    head_dim: usize,
26    scale: f64,
27}
28
29impl ClipAttention {
30    fn new(vs: nn::VarBuilder, embed_dim: usize, num_heads: usize) -> Result<Self> {
31        let head_dim = embed_dim / num_heads;
32        let scale = 1.0 / (head_dim as f64).sqrt();
33        let q_proj = nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
34        let k_proj = nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
35        let v_proj = nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
36        let out_proj = nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
37        Ok(Self {
38            q_proj,
39            k_proj,
40            v_proj,
41            out_proj,
42            num_heads,
43            head_dim,
44            scale,
45        })
46    }
47
48    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
49        let (b, seq_len, _) = xs.dims3()?;
50        let q = self.q_proj.forward(xs)?;
51        let k = self.k_proj.forward(xs)?;
52        let v = self.v_proj.forward(xs)?;
53
54        let reshape = |t: Tensor| -> Result<Tensor> {
55            t.reshape((b, seq_len, self.num_heads, self.head_dim))?
56                .transpose(1, 2)
57        };
58
59        let q = reshape(q)?;
60        let k = reshape(k)?;
61        let v = reshape(v)?;
62
63        let attn = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? * self.scale)?;
64        let attn = nn::ops::softmax_last_dim(&attn)?;
65        let out = attn.matmul(&v)?;
66
67        let out = out.transpose(1, 2)?.reshape((b, seq_len, ()))?;
68        self.out_proj.forward(&out)
69    }
70}
71
72/// A single CLIP ViT encoder layer (pre-norm style).
73#[derive(Debug)]
74struct ClipEncoderLayer {
75    layer_norm1: nn::LayerNorm,
76    self_attn: ClipAttention,
77    layer_norm2: nn::LayerNorm,
78    fc1: nn::Linear,
79    fc2: nn::Linear,
80}
81
82impl ClipEncoderLayer {
83    fn new(
84        vs: nn::VarBuilder,
85        embed_dim: usize,
86        num_heads: usize,
87        intermediate_size: usize,
88    ) -> Result<Self> {
89        let layer_norm1 = nn::layer_norm(embed_dim, 1e-5, vs.pp("layer_norm1"))?;
90        let self_attn = ClipAttention::new(vs.pp("self_attn"), embed_dim, num_heads)?;
91        let layer_norm2 = nn::layer_norm(embed_dim, 1e-5, vs.pp("layer_norm2"))?;
92        let fc1 = nn::linear(embed_dim, intermediate_size, vs.pp("mlp.fc1"))?;
93        let fc2 = nn::linear(intermediate_size, embed_dim, vs.pp("mlp.fc2"))?;
94        Ok(Self {
95            layer_norm1,
96            self_attn,
97            layer_norm2,
98            fc1,
99            fc2,
100        })
101    }
102}
103
104impl Module for ClipEncoderLayer {
105    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
106        let residual = xs;
107        let xs = self.layer_norm1.forward(xs)?;
108        let xs = self.self_attn.forward(&xs)?;
109        let xs = (xs + residual)?;
110
111        let residual = &xs;
112        let h = self.layer_norm2.forward(&xs)?;
113        let h = self.fc1.forward(&h)?.gelu()?;
114        let h = self.fc2.forward(&h)?;
115        h + residual
116    }
117}
118
119// ---------------------------------------------------------------------------
120// CLIP Vision Model
121// ---------------------------------------------------------------------------
122
123/// CLIP vision model configuration (ViT-H/14 defaults).
124#[derive(Debug, Clone)]
125pub struct ClipVisionConfig {
126    pub embed_dim: usize,
127    pub num_heads: usize,
128    pub num_layers: usize,
129    pub intermediate_size: usize,
130    pub image_size: usize,
131    pub patch_size: usize,
132}
133
134impl Default for ClipVisionConfig {
135    fn default() -> Self {
136        Self {
137            embed_dim: 1280,
138            num_heads: 16,
139            num_layers: 32,
140            intermediate_size: 5120,
141            image_size: 224,
142            patch_size: 14,
143        }
144    }
145}
146
147impl ClipVisionConfig {
148    pub fn num_patches(&self) -> usize {
149        (self.image_size / self.patch_size).pow(2)
150    }
151}
152
153/// CLIP ViT image encoder.
154///
155/// Produces per-patch embeddings suitable for IP-adapter cross-attention.
156#[derive(Debug)]
157pub struct ClipImageEncoder {
158    patch_embedding: nn::Conv2d,
159    position_embedding: nn::Embedding,
160    class_embedding: Tensor,
161    pre_layernorm: nn::LayerNorm,
162    encoder_layers: Vec<ClipEncoderLayer>,
163    post_layernorm: nn::LayerNorm,
164    /// Optional projection to map to cross-attention dimension.
165    ip_projection: Option<nn::Linear>,
166    config: ClipVisionConfig,
167}
168
169impl ClipImageEncoder {
170    /// Build a CLIP image encoder from a VarBuilder.
171    pub fn new(
172        vs: nn::VarBuilder,
173        clip_config: &ClipVisionConfig,
174        project_to: Option<usize>,
175    ) -> Result<Self> {
176        let embed_dim = clip_config.embed_dim;
177
178        let patch_embedding = nn::conv2d(
179            3,
180            embed_dim,
181            clip_config.patch_size,
182            nn::Conv2dConfig {
183                stride: clip_config.patch_size,
184                ..Default::default()
185            },
186            vs.pp("embeddings.patch_embedding"),
187        )?;
188
189        let num_positions = clip_config.num_patches() + 1; // +1 for CLS token
190        let position_embedding = nn::embedding(
191            num_positions,
192            embed_dim,
193            vs.pp("embeddings.position_embedding"),
194        )?;
195
196        let class_embedding = vs.get((1, 1, embed_dim), "embeddings.class_embedding")?;
197
198        let pre_layernorm = nn::layer_norm(embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
199
200        let vs_layers = vs.pp("encoder.layers");
201        let mut encoder_layers = Vec::with_capacity(clip_config.num_layers);
202        for i in 0..clip_config.num_layers {
203            encoder_layers.push(ClipEncoderLayer::new(
204                vs_layers.pp(i.to_string()),
205                embed_dim,
206                clip_config.num_heads,
207                clip_config.intermediate_size,
208            )?);
209        }
210
211        let post_layernorm = nn::layer_norm(embed_dim, 1e-5, vs.pp("post_layernorm"))?;
212
213        let ip_projection = if let Some(target_dim) = project_to {
214            Some(nn::linear(embed_dim, target_dim, vs.pp("ip_projection"))?)
215        } else {
216            None
217        };
218
219        Ok(Self {
220            patch_embedding,
221            position_embedding,
222            class_embedding,
223            pre_layernorm,
224            encoder_layers,
225            post_layernorm,
226            ip_projection,
227            config: clip_config.clone(),
228        })
229    }
230
231    /// Encode an image batch into patch-level embeddings.
232    ///
233    /// - `pixel_values`: `(B, 3, H, W)` normalised image tensor.
234    ///
235    /// Returns `(B, num_patches + 1, embed_dim)` or projected dimension.
236    pub fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
237        let batch_size = pixel_values.dim(0)?;
238        let device = pixel_values.device();
239        let dtype = pixel_values.dtype();
240
241        // Patch embedding: (B, 3, H, W) -> (B, embed_dim, H/P, W/P)
242        let patches = self.patch_embedding.forward(pixel_values)?;
243        let (_, _, h, w) = patches.dims4()?;
244        let num_patches = h * w;
245
246        // Flatten spatial: (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
247        let patches = patches.flatten(2, 3)?.transpose(1, 2)?;
248
249        // Prepend CLS token
250        let cls = self
251            .class_embedding
252            .broadcast_as((batch_size, 1, self.config.embed_dim))?;
253        let embeddings = Tensor::cat(&[cls.to_dtype(dtype)?, patches], 1)?;
254
255        // Add position embeddings
256        let position_ids = Tensor::arange(0u32, (num_patches + 1) as u32, device)?;
257        let pos_embeds = self.position_embedding.forward(&position_ids)?;
258        let embeddings = (embeddings + pos_embeds.unsqueeze(0)?)?;
259
260        // Pre-layernorm
261        let mut hidden = self.pre_layernorm.forward(&embeddings)?;
262
263        // Encoder layers
264        for layer in &self.encoder_layers {
265            hidden = layer.forward(&hidden)?;
266        }
267
268        // Post-layernorm
269        hidden = self.post_layernorm.forward(&hidden)?;
270
271        // Optional IP projection
272        if let Some(ref proj) = self.ip_projection {
273            hidden = proj.forward(&hidden)?;
274        }
275
276        Ok(hidden)
277    }
278}
279
280/// Build a CLIP encoder from a DiffusionConfig with default ViT-H/14 settings.
281pub fn build_clip_encoder(
282    vs: nn::VarBuilder,
283    config: &DiffusionConfig,
284) -> Result<ClipImageEncoder> {
285    let clip_config = ClipVisionConfig::default();
286    ClipImageEncoder::new(vs, &clip_config, Some(config.cross_attention_dim))
287}