optical_embeddings/vision/clip.rs
1//! CLIP vision encoder used by DeepSeek-OCR.
2//!
3//! This module implements the CLIP (Contrastive Language-Image Pre-training) vision encoder
4//! component of the DeepSeek-OCR pipeline. CLIP provides global attention processing for
5//! semantic understanding of visual features.
6//!
7//! # Architecture
8//!
9//! The CLIP encoder consists of:
10//! - **Vision embeddings**: Patch embedding with class token and position embeddings
11//! - **Pre-LayerNorm**: Applied before transformer blocks
12//! - **Transformer blocks**: Stack of self-attention layers with Quick GELU activation
13//!
14//! # Configuration
15//!
16//! Supports multiple model sizes:
17//! - **CLIP-Base**: 768-dim, 12 layers, 12 heads
18//! - **CLIP-Large**: 1024-dim, 24 layers, 16 heads (default)
19//!
20//! # Example
21//!
22//! ```ignore
23//! use optical_embeddings::vision::ClipEncoder;
24//! use optical_embeddings::config::ClipConfig;
25//!
26//! let config = ClipConfig::large();
27//! let encoder = ClipEncoder::new(&config, &device);
28//! let features = encoder.forward(image_tensor, None);
29//! ```
30
31use burn::nn::{
32 conv::{Conv2d, Conv2dConfig},
33 Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
34};
35use burn::prelude::*;
36use burn::tensor::{backend::Backend, Distribution, Tensor};
37use log::{debug, info, trace};
38
39use super::attention::Attention;
40use crate::config::ClipConfig;
41
42/// CLIP vision embeddings layer.
43///
44/// Converts input images to patch embeddings with class token and position encodings.
45/// Implements the embedding strategy used in Vision Transformers (ViT).
46///
47/// # Components
48///
49/// - `class_embedding`: Learnable class token prepended to patch sequence
50/// - `patch_embedding`: Convolutional layer that splits image into patches
51/// - `position_embedding`: Learnable positional encodings for each patch + class token
52#[derive(Module, Debug)]
53struct ClipVisionEmbeddings<B: Backend> {
54 class_embedding: Tensor<B, 1>,
55 patch_embedding: Conv2d<B>,
56 position_embedding: Embedding<B>,
57 embed_dim: usize,
58}
59
60impl<B: Backend> ClipVisionEmbeddings<B> {
61 /// Creates a new CLIP vision embeddings layer.
62 ///
63 /// # Arguments
64 ///
65 /// * `cfg` - CLIP configuration specifying embedding dimensions and patch size
66 /// * `device` - Device to initialize tensors on (CPU/GPU)
67 ///
68 /// # Returns
69 ///
70 /// Initialized embeddings layer ready for forward pass
71 fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
72 info!(
73 "Creating CLIP vision embeddings: hidden_size={}, image_size={}, patch_size={}",
74 cfg.hidden_size, cfg.image_size, cfg.patch_size
75 );
76
77 let embed_dim = cfg.hidden_size;
78 let cls = Tensor::random([embed_dim], Distribution::Normal(0.0, 0.02), device);
79 let pe = Conv2dConfig::new([3, embed_dim], [cfg.patch_size, cfg.patch_size])
80 .with_stride([cfg.patch_size, cfg.patch_size])
81 .with_bias(false)
82 .init(device);
83 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
84
85 debug!(
86 "CLIP embeddings: num_patches={}, embed_dim={}",
87 num_patches, embed_dim
88 );
89
90 let pos = EmbeddingConfig::new(num_patches + 1, embed_dim).init(device);
91 Self {
92 class_embedding: cls,
93 patch_embedding: pe,
94 position_embedding: pos,
95 embed_dim,
96 }
97 }
98
99 /// Forward pass through embeddings layer.
100 ///
101 /// # Arguments
102 ///
103 /// * `pixels` - Input image tensor of shape `[B, 3, H, W]`
104 /// * `patch_embeds` - Optional pre-computed patch embeddings (currently unused)
105 ///
106 /// # Returns
107 ///
108 /// Embedded sequence of shape `[B, num_patches+1, embed_dim]` where the first
109 /// token is the class token and remaining tokens are patch embeddings with
110 /// position encodings added.
111 fn forward(&self, pixels: Tensor<B, 4>, patch_embeds: Option<Tensor<B, 4>>) -> Tensor<B, 3> {
112 trace!("CLIP embeddings forward: input shape={:?}", pixels.dims());
113
114 let _ = pixels.dims()[0];
115
116 // Get patch embeddings: either provided or compute from pixels
117 let patches = if let Some(_) = patch_embeds {
118 debug!("Using provided patch embeddings");
119 // If SAM features provided, they need to be resized/projected to match embed_dim
120 // For now, we'll compute fresh patch embeddings
121 self.patch_embedding.forward(pixels)
122 } else {
123 trace!("Computing fresh patch embeddings");
124 self.patch_embedding.forward(pixels)
125 };
126
127 // patches: [B, embed_dim, H_patches, W_patches]
128 let shp = patches.dims();
129 let batch = shp[0];
130 let embed_dim = shp[1];
131 let h_patches = shp[2];
132 let w_patches = shp[3];
133 let num_patches = h_patches * w_patches;
134
135 debug!("Patch embedding output: batch={}, embed_dim={}, h_patches={}, w_patches={}, total_patches={}",
136 batch, embed_dim, h_patches, w_patches, num_patches);
137
138 // Reshape to [B, embed_dim, num_patches] then transpose to [B, num_patches, embed_dim]
139 let patches = patches
140 .reshape([batch, embed_dim, num_patches])
141 .swap_dims(1, 2);
142
143 // Prepare class token: [embed_dim] -> [1, 1, embed_dim] -> [B, 1, embed_dim]
144 let cls = self
145 .class_embedding
146 .clone()
147 .reshape([1, 1, self.embed_dim])
148 .repeat(&[batch]);
149
150 // Concatenate class token with patches: [B, num_patches+1, embed_dim]
151 let embs = Tensor::cat(vec![cls, patches], 1);
152 let seq_len = embs.dims()[1];
153
154 debug!(
155 "After concatenation: seq_len={} (1 cls + {} patches)",
156 seq_len, num_patches
157 );
158
159 // Create position IDs: [0, 1, 2, ..., seq_len-1]
160 let pos_ids = Tensor::arange(0..(seq_len as i64), &embs.device())
161 .reshape([1, seq_len])
162 .repeat(&[batch]);
163
164 // Add position embeddings
165 let result = embs + self.position_embedding.forward(pos_ids);
166 trace!("CLIP embeddings output shape: {:?}", result.dims());
167 result
168 }
169}
170
171/// CLIP transformer block implementing pre-LayerNorm architecture.
172///
173/// Each block consists of:
174/// 1. LayerNorm → Multi-head self-attention → Residual connection
175/// 2. LayerNorm → Feed-forward network (with Quick GELU) → Residual connection
176///
177/// # Quick GELU
178///
179/// Uses the approximation: `x * sigmoid(1.702 * x)` instead of standard GELU
180/// for computational efficiency as specified in CLIP.
181#[derive(Module, Debug)]
182struct ClipTransformerBlock<B: Backend> {
183 ln1: LayerNorm<B>,
184 attn: Attention<B>,
185 ln2: LayerNorm<B>,
186 ff1: burn::nn::Linear<B>,
187 ff2: burn::nn::Linear<B>,
188}
189
190impl<B: Backend> ClipTransformerBlock<B> {
191 /// Creates a new CLIP transformer block.
192 ///
193 /// # Arguments
194 ///
195 /// * `cfg` - CLIP configuration with layer dimensions and hyperparameters
196 /// * `device` - Device for tensor initialization
197 fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
198 trace!(
199 "Creating CLIP transformer block: hidden_size={}, num_heads={}, ffn_hidden={}",
200 cfg.hidden_size,
201 cfg.num_attention_heads,
202 cfg.ffn_hidden_size
203 );
204
205 let ln1 = LayerNormConfig::new(cfg.hidden_size)
206 .with_epsilon(cfg.layernorm_epsilon as f64)
207 .init(device);
208 let ln2 = LayerNormConfig::new(cfg.hidden_size)
209 .with_epsilon(cfg.layernorm_epsilon as f64)
210 .init(device);
211 let attn = Attention::new(cfg.hidden_size, cfg.num_attention_heads, device);
212 let ff1 = burn::nn::LinearConfig::new(cfg.hidden_size, cfg.ffn_hidden_size)
213 .with_bias(true)
214 .init(device);
215 let ff2 = burn::nn::LinearConfig::new(cfg.ffn_hidden_size, cfg.hidden_size)
216 .with_bias(true)
217 .init(device);
218 Self {
219 ln1,
220 attn,
221 ln2,
222 ff1,
223 ff2,
224 }
225 }
226
227 /// Quick GELU activation function.
228 ///
229 /// Approximation of GELU using: `x * sigmoid(1.702 * x)`
230 ///
231 /// # Arguments
232 ///
233 /// * `x` - Input tensor
234 ///
235 /// # Returns
236 ///
237 /// Activated tensor with same shape as input
238 fn quick_gelu(x: Tensor<B, 3>) -> Tensor<B, 3> {
239 let s = burn::tensor::activation::sigmoid(x.clone() * 1.702);
240 x * s
241 }
242
243 /// Forward pass through transformer block.
244 ///
245 /// Applies pre-LayerNorm architecture:
246 /// 1. x = x + Attention(LayerNorm(x))
247 /// 2. x = x + FFN(LayerNorm(x))
248 ///
249 /// # Arguments
250 ///
251 /// * `x` - Input tensor of shape `[B, seq_len, hidden_size]`
252 ///
253 /// # Returns
254 ///
255 /// Output tensor of same shape as input
256 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
257 trace!("CLIP transformer block forward: input shape={:?}", x.dims());
258 let h = x.clone() + self.attn.forward(self.ln1.forward(x.clone()));
259 let y = self.ff2.forward(Self::quick_gelu(
260 self.ff1.forward(self.ln2.forward(h.clone())),
261 ));
262 h + y
263 }
264}
265
266/// CLIP vision encoder for DeepSeek-OCR.
267///
268/// Implements the CLIP-Large architecture (or configurable variants) for global
269/// semantic understanding of visual features. Processes patch embeddings through
270/// a stack of transformer blocks with pre-LayerNorm and Quick GELU activation.
271///
272/// # Architecture Details
273///
274/// - **Input**: RGB images `[B, 3, H, W]`
275/// - **Patch embedding**: Splits image into non-overlapping patches
276/// - **Class token**: Learnable token prepended to sequence
277/// - **Position encoding**: Added to all tokens
278/// - **Transformer**: Stack of attention + FFN blocks
279/// - **Output**: Sequence of embeddings `[B, num_patches+1, hidden_size]`
280///
281/// # Default Configuration (CLIP-Large)
282///
283/// - 1024-dimensional embeddings
284/// - 24 transformer layers
285/// - 16 attention heads per layer
286/// - 4096-dimensional feed-forward hidden layer
287/// - 14×14 patch size on 224×224 images (256 patches)
288///
289/// # Example
290///
291/// ```ignore
292/// let config = ClipConfig::large();
293/// let encoder = ClipEncoder::new(&config, &device);
294///
295/// let image = Tensor::zeros(, &device);[1]
296/// let features = encoder.forward(image, None);
297/// // features shape: (256 patches + 1 class token)[1]
298/// ```
299#[derive(Module, Debug)]
300pub struct ClipEncoder<B: Backend> {
301 emb: ClipVisionEmbeddings<B>,
302 pre_ln: LayerNorm<B>,
303 blocks: Vec<ClipTransformerBlock<B>>,
304}
305
306impl<B: Backend> ClipEncoder<B> {
307 /// Creates a new CLIP encoder.
308 ///
309 /// Initializes all layers including embeddings, pre-LayerNorm, and
310 /// transformer blocks according to the provided configuration.
311 ///
312 /// # Arguments
313 ///
314 /// * `cfg` - CLIP configuration specifying model architecture
315 /// * `device` - Device for tensor allocation (CPU/GPU)
316 ///
317 /// # Returns
318 ///
319 /// Fully initialized CLIP encoder ready for inference
320 pub fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
321 info!("Creating CLIP encoder with {} layers", cfg.num_layers);
322
323 let emb = ClipVisionEmbeddings::new(cfg, device);
324 let pre_ln = LayerNormConfig::new(cfg.hidden_size)
325 .with_epsilon(cfg.layernorm_epsilon as f64)
326 .init(device);
327 let mut blocks = Vec::new();
328 for i in 0..cfg.num_layers {
329 debug!(
330 "Initializing CLIP transformer block {}/{}",
331 i + 1,
332 cfg.num_layers
333 );
334 blocks.push(ClipTransformerBlock::new(cfg, device));
335 }
336
337 info!("CLIP encoder created successfully");
338 Self {
339 emb,
340 pre_ln,
341 blocks,
342 }
343 }
344
345 /// Forward pass through CLIP encoder.
346 ///
347 /// Processes input images through patch embedding, position encoding,
348 /// and transformer layers to produce semantic visual features.
349 ///
350 /// # Arguments
351 ///
352 /// * `x` - Input image tensor of shape `[B, 3, H, W]`
353 /// * `_patch_embeds` - Optional pre-computed patch embeddings (currently unused,
354 /// reserved for future integration with SAM features)
355 ///
356 /// # Returns
357 ///
358 /// Feature tensor of shape `[B, num_patches+1, hidden_size]` where the first
359 /// token (index 0) is the class token and remaining tokens are patch features.
360 ///
361 /// # Note
362 ///
363 /// The `_patch_embeds` parameter is currently ignored. Full integration with
364 /// SAM features would require a projection layer to match embedding dimensions.
365 pub fn forward(&self, x: Tensor<B, 4>, _patch_embeds: Option<Tensor<B, 4>>) -> Tensor<B, 3> {
366 info!("CLIP encoder forward pass: input shape={:?}", x.dims());
367
368 // Note: We ignore patch_embeds for now as dimension matching is complex
369 // In full implementation, would need projection layer to match dimensions
370 let mut h = self.pre_ln.forward(self.emb.forward(x, None));
371
372 debug!("After embeddings and pre-LN: shape={:?}", h.dims());
373
374 for (i, blk) in self.blocks.iter().enumerate() {
375 trace!(
376 "Processing CLIP transformer block {}/{}",
377 i + 1,
378 self.blocks.len()
379 );
380 h = blk.forward(h);
381 }
382
383 info!("CLIP encoder output shape: {:?}", h.dims());
384 h
385 }
386
387 /// Forward pass excluding class token.
388 ///
389 /// Convenience method that runs the encoder and strips the class token
390 /// from the output sequence, returning only patch feature embeddings.
391 ///
392 /// # Arguments
393 ///
394 /// * `x` - Input image tensor of shape `[B, 3, H, W]`
395 /// * `patch_embeds` - Optional pre-computed patch embeddings
396 ///
397 /// # Returns
398 ///
399 /// Patch features of shape `[B, num_patches, hidden_size]` with the
400 /// class token removed (excludes first token in sequence).
401 pub fn forward_features(
402 &self,
403 x: Tensor<B, 4>,
404 patch_embeds: Option<Tensor<B, 4>>,
405 ) -> Tensor<B, 3> {
406 debug!("CLIP encoder forward_features (excluding class token)");
407 let h = self.forward(x, patch_embeds);
408 let d = h.dims();
409 h.slice([0..d[0], 1..d[1], 0..d[2]])
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use burn_ndarray::NdArray;
417 type TB = NdArray<f32>;
418
419 #[test]
420 fn embedding_shapes() {
421 let dev = Default::default();
422 let mut cfg = ClipConfig::base();
423 cfg.image_size = 32;
424 cfg.patch_size = 8;
425 cfg.hidden_size = 128;
426
427 let emb = ClipVisionEmbeddings::<TB>::new(&cfg, &dev);
428 let out = emb.forward(Tensor::<TB, 4>::zeros([1, 3, 32, 32], &dev), None);
429
430 // 32/8 = 4 patches per side, 4*4 = 16 patches + 1 cls = 17
431 assert_eq!(out.dims(), [1, 17, 128]);
432 }
433
434 #[test]
435 fn shapes_tiny() {
436 let dev = Default::default();
437 let mut cfg = ClipConfig::base();
438 cfg.num_layers = 2;
439 cfg.hidden_size = 128;
440 cfg.num_attention_heads = 4;
441 cfg.ffn_hidden_size = 512;
442 cfg.image_size = 32;
443 cfg.patch_size = 8;
444
445 let enc = ClipEncoder::<TB>::new(&cfg, &dev);
446 let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 32, 32], &dev), None);
447 assert_eq!(out.dims(), [1, 17, 128]);
448 }
449
450 #[test]
451 #[ignore]
452 fn shapes_full() {
453 let dev = Default::default();
454 let cfg = ClipConfig::default();
455 let enc = ClipEncoder::<TB>::new(&cfg, &dev);
456 let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 224, 224], &dev), None);
457 assert_eq!(out.dims(), [1, 257, 1024]);
458 }
459}