optical_embeddings/vision/
clip.rs

1//! CLIP vision encoder used by DeepSeek-OCR.
2//!
3//! This module implements the CLIP (Contrastive Language-Image Pre-training) vision encoder
4//! component of the DeepSeek-OCR pipeline. CLIP provides global attention processing for
5//! semantic understanding of visual features.
6//!
7//! # Architecture
8//!
9//! The CLIP encoder consists of:
10//! - **Vision embeddings**: Patch embedding with class token and position embeddings
11//! - **Pre-LayerNorm**: Applied before transformer blocks
12//! - **Transformer blocks**: Stack of self-attention layers with Quick GELU activation
13//!
14//! # Configuration
15//!
16//! Supports multiple model sizes:
17//! - **CLIP-Base**: 768-dim, 12 layers, 12 heads
18//! - **CLIP-Large**: 1024-dim, 24 layers, 16 heads (default)
19//!
20//! # Example
21//!
22//! ```ignore
23//! use optical_embeddings::vision::ClipEncoder;
24//! use optical_embeddings::config::ClipConfig;
25//!
26//! let config = ClipConfig::large();
27//! let encoder = ClipEncoder::new(&config, &device);
28//! let features = encoder.forward(image_tensor, None);
29//! ```
30
31use burn::nn::{
32    conv::{Conv2d, Conv2dConfig},
33    Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
34};
35use burn::prelude::*;
36use burn::tensor::{backend::Backend, Distribution, Tensor};
37use log::{debug, info, trace};
38
39use super::attention::Attention;
40use crate::config::ClipConfig;
41
42/// CLIP vision embeddings layer.
43///
44/// Converts input images to patch embeddings with class token and position encodings.
45/// Implements the embedding strategy used in Vision Transformers (ViT).
46///
47/// # Components
48///
49/// - `class_embedding`: Learnable class token prepended to patch sequence
50/// - `patch_embedding`: Convolutional layer that splits image into patches
51/// - `position_embedding`: Learnable positional encodings for each patch + class token
52#[derive(Module, Debug)]
53struct ClipVisionEmbeddings<B: Backend> {
54    class_embedding: Tensor<B, 1>,
55    patch_embedding: Conv2d<B>,
56    position_embedding: Embedding<B>,
57    embed_dim: usize,
58}
59
60impl<B: Backend> ClipVisionEmbeddings<B> {
61    /// Creates a new CLIP vision embeddings layer.
62    ///
63    /// # Arguments
64    ///
65    /// * `cfg` - CLIP configuration specifying embedding dimensions and patch size
66    /// * `device` - Device to initialize tensors on (CPU/GPU)
67    ///
68    /// # Returns
69    ///
70    /// Initialized embeddings layer ready for forward pass
71    fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
72        info!(
73            "Creating CLIP vision embeddings: hidden_size={}, image_size={}, patch_size={}",
74            cfg.hidden_size, cfg.image_size, cfg.patch_size
75        );
76
77        let embed_dim = cfg.hidden_size;
78        let cls = Tensor::random([embed_dim], Distribution::Normal(0.0, 0.02), device);
79        let pe = Conv2dConfig::new([3, embed_dim], [cfg.patch_size, cfg.patch_size])
80            .with_stride([cfg.patch_size, cfg.patch_size])
81            .with_bias(false)
82            .init(device);
83        let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
84
85        debug!(
86            "CLIP embeddings: num_patches={}, embed_dim={}",
87            num_patches, embed_dim
88        );
89
90        let pos = EmbeddingConfig::new(num_patches + 1, embed_dim).init(device);
91        Self {
92            class_embedding: cls,
93            patch_embedding: pe,
94            position_embedding: pos,
95            embed_dim,
96        }
97    }
98
99    /// Forward pass through embeddings layer.
100    ///
101    /// # Arguments
102    ///
103    /// * `pixels` - Input image tensor of shape `[B, 3, H, W]`
104    /// * `patch_embeds` - Optional pre-computed patch embeddings (currently unused)
105    ///
106    /// # Returns
107    ///
108    /// Embedded sequence of shape `[B, num_patches+1, embed_dim]` where the first
109    /// token is the class token and remaining tokens are patch embeddings with
110    /// position encodings added.
111    fn forward(&self, pixels: Tensor<B, 4>, patch_embeds: Option<Tensor<B, 4>>) -> Tensor<B, 3> {
112        trace!("CLIP embeddings forward: input shape={:?}", pixels.dims());
113
114        let _ = pixels.dims()[0];
115
116        // Get patch embeddings: either provided or compute from pixels
117        let patches = if let Some(_) = patch_embeds {
118            debug!("Using provided patch embeddings");
119            // If SAM features provided, they need to be resized/projected to match embed_dim
120            // For now, we'll compute fresh patch embeddings
121            self.patch_embedding.forward(pixels)
122        } else {
123            trace!("Computing fresh patch embeddings");
124            self.patch_embedding.forward(pixels)
125        };
126
127        // patches: [B, embed_dim, H_patches, W_patches]
128        let shp = patches.dims();
129        let batch = shp[0];
130        let embed_dim = shp[1];
131        let h_patches = shp[2];
132        let w_patches = shp[3];
133        let num_patches = h_patches * w_patches;
134
135        debug!("Patch embedding output: batch={}, embed_dim={}, h_patches={}, w_patches={}, total_patches={}", 
136               batch, embed_dim, h_patches, w_patches, num_patches);
137
138        // Reshape to [B, embed_dim, num_patches] then transpose to [B, num_patches, embed_dim]
139        let patches = patches
140            .reshape([batch, embed_dim, num_patches])
141            .swap_dims(1, 2);
142
143        // Prepare class token: [embed_dim] -> [1, 1, embed_dim] -> [B, 1, embed_dim]
144        let cls = self
145            .class_embedding
146            .clone()
147            .reshape([1, 1, self.embed_dim])
148            .repeat(&[batch]);
149
150        // Concatenate class token with patches: [B, num_patches+1, embed_dim]
151        let embs = Tensor::cat(vec![cls, patches], 1);
152        let seq_len = embs.dims()[1];
153
154        debug!(
155            "After concatenation: seq_len={} (1 cls + {} patches)",
156            seq_len, num_patches
157        );
158
159        // Create position IDs: [0, 1, 2, ..., seq_len-1]
160        let pos_ids = Tensor::arange(0..(seq_len as i64), &embs.device())
161            .reshape([1, seq_len])
162            .repeat(&[batch]);
163
164        // Add position embeddings
165        let result = embs + self.position_embedding.forward(pos_ids);
166        trace!("CLIP embeddings output shape: {:?}", result.dims());
167        result
168    }
169}
170
171/// CLIP transformer block implementing pre-LayerNorm architecture.
172///
173/// Each block consists of:
174/// 1. LayerNorm → Multi-head self-attention → Residual connection
175/// 2. LayerNorm → Feed-forward network (with Quick GELU) → Residual connection
176///
177/// # Quick GELU
178///
179/// Uses the approximation: `x * sigmoid(1.702 * x)` instead of standard GELU
180/// for computational efficiency as specified in CLIP.
181#[derive(Module, Debug)]
182struct ClipTransformerBlock<B: Backend> {
183    ln1: LayerNorm<B>,
184    attn: Attention<B>,
185    ln2: LayerNorm<B>,
186    ff1: burn::nn::Linear<B>,
187    ff2: burn::nn::Linear<B>,
188}
189
190impl<B: Backend> ClipTransformerBlock<B> {
191    /// Creates a new CLIP transformer block.
192    ///
193    /// # Arguments
194    ///
195    /// * `cfg` - CLIP configuration with layer dimensions and hyperparameters
196    /// * `device` - Device for tensor initialization
197    fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
198        trace!(
199            "Creating CLIP transformer block: hidden_size={}, num_heads={}, ffn_hidden={}",
200            cfg.hidden_size,
201            cfg.num_attention_heads,
202            cfg.ffn_hidden_size
203        );
204
205        let ln1 = LayerNormConfig::new(cfg.hidden_size)
206            .with_epsilon(cfg.layernorm_epsilon as f64)
207            .init(device);
208        let ln2 = LayerNormConfig::new(cfg.hidden_size)
209            .with_epsilon(cfg.layernorm_epsilon as f64)
210            .init(device);
211        let attn = Attention::new(cfg.hidden_size, cfg.num_attention_heads, device);
212        let ff1 = burn::nn::LinearConfig::new(cfg.hidden_size, cfg.ffn_hidden_size)
213            .with_bias(true)
214            .init(device);
215        let ff2 = burn::nn::LinearConfig::new(cfg.ffn_hidden_size, cfg.hidden_size)
216            .with_bias(true)
217            .init(device);
218        Self {
219            ln1,
220            attn,
221            ln2,
222            ff1,
223            ff2,
224        }
225    }
226
227    /// Quick GELU activation function.
228    ///
229    /// Approximation of GELU using: `x * sigmoid(1.702 * x)`
230    ///
231    /// # Arguments
232    ///
233    /// * `x` - Input tensor
234    ///
235    /// # Returns
236    ///
237    /// Activated tensor with same shape as input
238    fn quick_gelu(x: Tensor<B, 3>) -> Tensor<B, 3> {
239        let s = burn::tensor::activation::sigmoid(x.clone() * 1.702);
240        x * s
241    }
242
243    /// Forward pass through transformer block.
244    ///
245    /// Applies pre-LayerNorm architecture:
246    /// 1. x = x + Attention(LayerNorm(x))
247    /// 2. x = x + FFN(LayerNorm(x))
248    ///
249    /// # Arguments
250    ///
251    /// * `x` - Input tensor of shape `[B, seq_len, hidden_size]`
252    ///
253    /// # Returns
254    ///
255    /// Output tensor of same shape as input
256    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
257        trace!("CLIP transformer block forward: input shape={:?}", x.dims());
258        let h = x.clone() + self.attn.forward(self.ln1.forward(x.clone()));
259        let y = self.ff2.forward(Self::quick_gelu(
260            self.ff1.forward(self.ln2.forward(h.clone())),
261        ));
262        h + y
263    }
264}
265
266/// CLIP vision encoder for DeepSeek-OCR.
267///
268/// Implements the CLIP-Large architecture (or configurable variants) for global
269/// semantic understanding of visual features. Processes patch embeddings through
270/// a stack of transformer blocks with pre-LayerNorm and Quick GELU activation.
271///
272/// # Architecture Details
273///
274/// - **Input**: RGB images `[B, 3, H, W]`
275/// - **Patch embedding**: Splits image into non-overlapping patches
276/// - **Class token**: Learnable token prepended to sequence
277/// - **Position encoding**: Added to all tokens
278/// - **Transformer**: Stack of attention + FFN blocks
279/// - **Output**: Sequence of embeddings `[B, num_patches+1, hidden_size]`
280///
281/// # Default Configuration (CLIP-Large)
282///
283/// - 1024-dimensional embeddings
284/// - 24 transformer layers
285/// - 16 attention heads per layer
286/// - 4096-dimensional feed-forward hidden layer
287/// - 14×14 patch size on 224×224 images (256 patches)
288///
289/// # Example
290///
291/// ```ignore
292/// let config = ClipConfig::large();
293/// let encoder = ClipEncoder::new(&config, &device);
294///
295/// let image = Tensor::zeros(, &device);[1]
296/// let features = encoder.forward(image, None);
297/// // features shape:  (256 patches + 1 class token)[1]
298/// ```
299#[derive(Module, Debug)]
300pub struct ClipEncoder<B: Backend> {
301    emb: ClipVisionEmbeddings<B>,
302    pre_ln: LayerNorm<B>,
303    blocks: Vec<ClipTransformerBlock<B>>,
304}
305
306impl<B: Backend> ClipEncoder<B> {
307    /// Creates a new CLIP encoder.
308    ///
309    /// Initializes all layers including embeddings, pre-LayerNorm, and
310    /// transformer blocks according to the provided configuration.
311    ///
312    /// # Arguments
313    ///
314    /// * `cfg` - CLIP configuration specifying model architecture
315    /// * `device` - Device for tensor allocation (CPU/GPU)
316    ///
317    /// # Returns
318    ///
319    /// Fully initialized CLIP encoder ready for inference
320    pub fn new(cfg: &ClipConfig, device: &B::Device) -> Self {
321        info!("Creating CLIP encoder with {} layers", cfg.num_layers);
322
323        let emb = ClipVisionEmbeddings::new(cfg, device);
324        let pre_ln = LayerNormConfig::new(cfg.hidden_size)
325            .with_epsilon(cfg.layernorm_epsilon as f64)
326            .init(device);
327        let mut blocks = Vec::new();
328        for i in 0..cfg.num_layers {
329            debug!(
330                "Initializing CLIP transformer block {}/{}",
331                i + 1,
332                cfg.num_layers
333            );
334            blocks.push(ClipTransformerBlock::new(cfg, device));
335        }
336
337        info!("CLIP encoder created successfully");
338        Self {
339            emb,
340            pre_ln,
341            blocks,
342        }
343    }
344
345    /// Forward pass through CLIP encoder.
346    ///
347    /// Processes input images through patch embedding, position encoding,
348    /// and transformer layers to produce semantic visual features.
349    ///
350    /// # Arguments
351    ///
352    /// * `x` - Input image tensor of shape `[B, 3, H, W]`
353    /// * `_patch_embeds` - Optional pre-computed patch embeddings (currently unused,
354    ///   reserved for future integration with SAM features)
355    ///
356    /// # Returns
357    ///
358    /// Feature tensor of shape `[B, num_patches+1, hidden_size]` where the first
359    /// token (index 0) is the class token and remaining tokens are patch features.
360    ///
361    /// # Note
362    ///
363    /// The `_patch_embeds` parameter is currently ignored. Full integration with
364    /// SAM features would require a projection layer to match embedding dimensions.
365    pub fn forward(&self, x: Tensor<B, 4>, _patch_embeds: Option<Tensor<B, 4>>) -> Tensor<B, 3> {
366        info!("CLIP encoder forward pass: input shape={:?}", x.dims());
367
368        // Note: We ignore patch_embeds for now as dimension matching is complex
369        // In full implementation, would need projection layer to match dimensions
370        let mut h = self.pre_ln.forward(self.emb.forward(x, None));
371
372        debug!("After embeddings and pre-LN: shape={:?}", h.dims());
373
374        for (i, blk) in self.blocks.iter().enumerate() {
375            trace!(
376                "Processing CLIP transformer block {}/{}",
377                i + 1,
378                self.blocks.len()
379            );
380            h = blk.forward(h);
381        }
382
383        info!("CLIP encoder output shape: {:?}", h.dims());
384        h
385    }
386
387    /// Forward pass excluding class token.
388    ///
389    /// Convenience method that runs the encoder and strips the class token
390    /// from the output sequence, returning only patch feature embeddings.
391    ///
392    /// # Arguments
393    ///
394    /// * `x` - Input image tensor of shape `[B, 3, H, W]`
395    /// * `patch_embeds` - Optional pre-computed patch embeddings
396    ///
397    /// # Returns
398    ///
399    /// Patch features of shape `[B, num_patches, hidden_size]` with the
400    /// class token removed (excludes first token in sequence).
401    pub fn forward_features(
402        &self,
403        x: Tensor<B, 4>,
404        patch_embeds: Option<Tensor<B, 4>>,
405    ) -> Tensor<B, 3> {
406        debug!("CLIP encoder forward_features (excluding class token)");
407        let h = self.forward(x, patch_embeds);
408        let d = h.dims();
409        h.slice([0..d[0], 1..d[1], 0..d[2]])
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use burn_ndarray::NdArray;
417    type TB = NdArray<f32>;
418
419    #[test]
420    fn embedding_shapes() {
421        let dev = Default::default();
422        let mut cfg = ClipConfig::base();
423        cfg.image_size = 32;
424        cfg.patch_size = 8;
425        cfg.hidden_size = 128;
426
427        let emb = ClipVisionEmbeddings::<TB>::new(&cfg, &dev);
428        let out = emb.forward(Tensor::<TB, 4>::zeros([1, 3, 32, 32], &dev), None);
429
430        // 32/8 = 4 patches per side, 4*4 = 16 patches + 1 cls = 17
431        assert_eq!(out.dims(), [1, 17, 128]);
432    }
433
434    #[test]
435    fn shapes_tiny() {
436        let dev = Default::default();
437        let mut cfg = ClipConfig::base();
438        cfg.num_layers = 2;
439        cfg.hidden_size = 128;
440        cfg.num_attention_heads = 4;
441        cfg.ffn_hidden_size = 512;
442        cfg.image_size = 32;
443        cfg.patch_size = 8;
444
445        let enc = ClipEncoder::<TB>::new(&cfg, &dev);
446        let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 32, 32], &dev), None);
447        assert_eq!(out.dims(), [1, 17, 128]);
448    }
449
450    #[test]
451    #[ignore]
452    fn shapes_full() {
453        let dev = Default::default();
454        let cfg = ClipConfig::default();
455        let enc = ClipEncoder::<TB>::new(&cfg, &dev);
456        let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 224, 224], &dev), None);
457        assert_eq!(out.dims(), [1, 257, 1024]);
458    }
459}