Skip to main content

oxicuda_vision/clip/
vision_encoder.rs

1//! CLIP vision encoder.
2//!
3//! Wraps a ViT encoder with CLIP-specific construction conveniences.
4//! A single CLS token is prepended, positional embeddings are added, and
5//! the encoder output at the CLS position is returned as the image embedding.
6
7use crate::{
8    error::{VisionError, VisionResult},
9    handle::LcgRng,
10    patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
11    vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
12};
13
14// ─── ClipVisionConfig ────────────────────────────────────────────────────────
15
16/// Configuration for the CLIP vision encoder.
17///
18/// Wraps a [`ViTConfig`] so that CLIP can share the same architectural
19/// hyper-parameter vocabulary as a standalone ViT model.
20#[derive(Debug, Clone)]
21pub struct ClipVisionConfig {
22    /// Underlying ViT hyper-parameters (image size, patch size, depth, …).
23    pub vit_config: ViTConfig,
24}
25
26impl ClipVisionConfig {
27    /// Create a [`ClipVisionConfig`] from an existing [`ViTConfig`].
28    #[must_use]
29    pub fn new(vit_config: ViTConfig) -> Self {
30        Self { vit_config }
31    }
32
33    /// Convenience constructor for a tiny CLIP encoder suitable for tests.
34    ///
35    /// Delegates to [`ViTConfig::tiny`].
36    #[must_use]
37    pub fn tiny() -> Self {
38        Self::new(ViTConfig::tiny())
39    }
40}
41
42// ─── ClipVisionEncoder ───────────────────────────────────────────────────────
43
44/// CLIP vision encoder: ViT-backbone that produces a single `embed_dim`
45/// CLS-token embedding per image.
46///
47/// Pipeline:
48/// ```text
49/// image [C × H × W]
50///   → patch_embed    → [n_patches, embed_dim]
51///   → prepend_cls    → [n_patches + 1, embed_dim]
52///   → add_pos_embed  → [n_patches + 1, embed_dim]
53///   → encoder        → [n_patches + 1, embed_dim]
54///   → tokens[0]      → [embed_dim]   (CLS token output)
55/// ```
56pub struct ClipVisionEncoder {
57    /// Full configuration.
58    pub config: ClipVisionConfig,
59    /// Strided Conv2D patch embedder.
60    pub patch_embed: PatchEmbed,
61    /// Learnable positional embeddings: `n_patches + 1` positions (incl. CLS).
62    pub pos_embed: LearnablePosEmbed,
63    /// Stack of ViT transformer blocks with final layer-norm.
64    pub encoder: ViTEncoder,
65    /// CLS token: flat `[embed_dim]`, Gaussian-initialised with scale 0.02.
66    pub cls_token: Vec<f32>,
67}
68
69impl ClipVisionEncoder {
70    /// Construct a new CLIP vision encoder.
71    ///
72    /// Initialises:
73    /// - Patch embedder (Conv2D kernel, bias).
74    /// - Learnable positional embedding table with `n_patches + 1` rows.
75    /// - ViT encoder stack.
76    /// - CLS token vector (N(0, 0.02²)).
77    ///
78    /// # Errors
79    /// Propagates any errors from the sub-component constructors.
80    pub fn new(cfg: ClipVisionConfig, rng: &mut LcgRng) -> VisionResult<Self> {
81        let vc = &cfg.vit_config;
82
83        // ── Patch embedder ────────────────────────────────────────────────────
84        let pe_cfg = PatchEmbedConfig::new(vc.img_size, vc.patch_size, vc.in_chans, vc.embed_dim)?;
85        let patch_embed = PatchEmbed::new(pe_cfg.clone(), rng);
86
87        // ── Positional embeddings: n_patches + 1 positions (CLS slot at 0) ──
88        let n_patches = pe_cfg.n_patches();
89        let n_positions = n_patches + 1;
90        let pos_embed = LearnablePosEmbed::new(n_positions, vc.embed_dim, rng)?;
91
92        // ── Encoder stack ─────────────────────────────────────────────────────
93        let enc_cfg = ViTEncoderConfig::new(vc.embed_dim, vc.n_heads, vc.mlp_ratio, vc.depth)?;
94        let encoder = ViTEncoder::new(enc_cfg, rng)?;
95
96        // ── CLS token: N(0, 0.02²) ───────────────────────────────────────────
97        let mut cls_token = vec![0.0f32; vc.embed_dim];
98        rng.fill_normal(&mut cls_token);
99        for v in &mut cls_token {
100            *v *= 0.02;
101        }
102
103        Ok(Self {
104            config: cfg,
105            patch_embed,
106            pos_embed,
107            encoder,
108            cls_token,
109        })
110    }
111
112    /// Run the encoder on a single image and return the CLS embedding.
113    ///
114    /// # Parameters
115    /// - `image`: flat `[in_chans × img_size × img_size]` CHW buffer.
116    ///
117    /// # Returns
118    /// `[embed_dim]` CLS-token embedding.
119    ///
120    /// # Errors
121    /// Returns [`VisionError::DimensionMismatch`] if the image size does not
122    /// match the configured dimensions.
123    pub fn forward_single(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
124        let embed_dim = self.config.vit_config.embed_dim;
125
126        // 1. Patch embedding → [n_patches, embed_dim]
127        let patch_tokens = self.patch_embed.forward(image)?;
128
129        // 2. Prepend CLS token → [n_patches + 1, embed_dim]
130        let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, embed_dim)?;
131
132        // 3. Add positional embeddings in-place.
133        add_pos_embed(&mut tokens, &self.pos_embed.table, embed_dim)?;
134
135        // 4. ViT encoder → [n_patches + 1, embed_dim]
136        let n_tokens = tokens.len() / embed_dim;
137        let encoded = self.encoder.forward(&tokens, n_tokens)?;
138
139        // 5. Extract CLS token (first row).
140        let cls_out = encoded[..embed_dim].to_vec();
141
142        Ok(cls_out)
143    }
144
145    /// Run the encoder on a batch of images.
146    ///
147    /// # Parameters
148    /// - `images`: flat `[batch × in_chans × img_size × img_size]` buffer.
149    /// - `batch_size`: number of images.
150    ///
151    /// # Returns
152    /// `Vec<Vec<f32>>` of length `batch_size`, each element is `[embed_dim]`.
153    ///
154    /// # Errors
155    /// Returns [`VisionError::DimensionMismatch`] if the flat buffer length
156    /// does not match `batch_size × in_chans × img_size × img_size`, or if
157    /// any individual forward pass fails.
158    pub fn forward_batch(&self, images: &[f32], batch_size: usize) -> VisionResult<Vec<Vec<f32>>> {
159        let vc = &self.config.vit_config;
160        let single_len = vc.in_chans * vc.img_size * vc.img_size;
161
162        if batch_size == 0 {
163            return Ok(Vec::new());
164        }
165
166        let expected = batch_size * single_len;
167        if images.len() != expected {
168            return Err(VisionError::DimensionMismatch {
169                expected,
170                got: images.len(),
171            });
172        }
173
174        let mut results = Vec::with_capacity(batch_size);
175        for b in 0..batch_size {
176            let slice = &images[b * single_len..(b + 1) * single_len];
177            results.push(self.forward_single(slice)?);
178        }
179
180        Ok(results)
181    }
182}
183
184// ─── Tests ───────────────────────────────────────────────────────────────────
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::handle::LcgRng;
190
191    /// Build a tiny encoder for tests.
192    fn make_tiny_encoder(seed: u64) -> (ClipVisionEncoder, usize) {
193        let mut rng = LcgRng::new(seed);
194        let cfg = ClipVisionConfig::tiny();
195        let embed_dim = cfg.vit_config.embed_dim;
196        let encoder = ClipVisionEncoder::new(cfg, &mut rng).expect("tiny encoder ok");
197        (encoder, embed_dim)
198    }
199
200    /// Fill an image buffer with ramp values (deterministic, finite).
201    fn make_image(in_chans: usize, img_size: usize) -> Vec<f32> {
202        let len = in_chans * img_size * img_size;
203        (0..len).map(|i| i as f32 / len as f32).collect()
204    }
205
206    // ── Construction ─────────────────────────────────────────────────────────
207
208    #[test]
209    fn tiny_encoder_constructs() {
210        let (enc, _) = make_tiny_encoder(1);
211        // CLS token has the right dimension.
212        let vc = &enc.config.vit_config;
213        assert_eq!(enc.cls_token.len(), vc.embed_dim);
214        // Positional embed has n_patches + 1 positions.
215        let n_patches = (vc.img_size / vc.patch_size).pow(2);
216        assert_eq!(enc.pos_embed.n_positions, n_patches + 1);
217    }
218
219    #[test]
220    fn config_new_wraps_vit_config() {
221        let vit_cfg = ViTConfig::tiny();
222        let clip_cfg = ClipVisionConfig::new(vit_cfg.clone());
223        assert_eq!(clip_cfg.vit_config.embed_dim, vit_cfg.embed_dim);
224    }
225
226    // ── forward_single ───────────────────────────────────────────────────────
227
228    #[test]
229    fn forward_single_output_shape() {
230        let (enc, embed_dim) = make_tiny_encoder(2);
231        let vc = &enc.config.vit_config;
232        let img = make_image(vc.in_chans, vc.img_size);
233        let z = enc.forward_single(&img).expect("forward_single ok");
234        assert_eq!(
235            z.len(),
236            embed_dim,
237            "forward_single output should be embed_dim"
238        );
239    }
240
241    #[test]
242    fn forward_single_output_finite() {
243        let (enc, _) = make_tiny_encoder(3);
244        let vc = &enc.config.vit_config;
245        let img = make_image(vc.in_chans, vc.img_size);
246        let z = enc.forward_single(&img).expect("ok");
247        assert!(
248            z.iter().all(|v| v.is_finite()),
249            "forward_single output must be finite"
250        );
251    }
252
253    #[test]
254    fn forward_single_error_wrong_image_size() {
255        let (enc, _) = make_tiny_encoder(4);
256        let wrong_img = vec![0.0f32; 10]; // definitely wrong
257        let r = enc.forward_single(&wrong_img);
258        assert!(
259            matches!(r, Err(VisionError::DimensionMismatch { .. })),
260            "expected DimensionMismatch, got {:?}",
261            r
262        );
263    }
264
265    #[test]
266    fn forward_single_deterministic() {
267        // Same encoder + same image → same output.
268        let (enc, _) = make_tiny_encoder(5);
269        let vc = &enc.config.vit_config;
270        let img = make_image(vc.in_chans, vc.img_size);
271        let z1 = enc.forward_single(&img).expect("ok");
272        let z2 = enc.forward_single(&img).expect("ok");
273        assert_eq!(z1, z2, "forward_single should be deterministic");
274    }
275
276    // ── forward_batch ────────────────────────────────────────────────────────
277
278    #[test]
279    fn forward_batch_output_count() {
280        let (enc, _) = make_tiny_encoder(6);
281        let vc = &enc.config.vit_config;
282        let single_len = vc.in_chans * vc.img_size * vc.img_size;
283        let batch_size = 3_usize;
284        let images = make_image(vc.in_chans * batch_size, vc.img_size);
285        // Manually pad to exact batch length.
286        let mut flat = images.clone();
287        flat.resize(batch_size * single_len, 0.0);
288        let results = enc
289            .forward_batch(&flat, batch_size)
290            .expect("forward_batch ok");
291        assert_eq!(results.len(), batch_size, "batch result count mismatch");
292    }
293
294    #[test]
295    fn forward_batch_each_embedding_has_embed_dim() {
296        let (enc, embed_dim) = make_tiny_encoder(7);
297        let vc = &enc.config.vit_config;
298        let single_len = vc.in_chans * vc.img_size * vc.img_size;
299        let batch_size = 4_usize;
300        let flat = vec![0.5f32; batch_size * single_len];
301        let results = enc.forward_batch(&flat, batch_size).expect("ok");
302        for (i, z) in results.iter().enumerate() {
303            assert_eq!(z.len(), embed_dim, "embedding {i} has wrong size");
304        }
305    }
306
307    #[test]
308    fn forward_batch_zero_batch_returns_empty() {
309        let (enc, _) = make_tiny_encoder(8);
310        let results = enc.forward_batch(&[], 0).expect("zero batch ok");
311        assert!(results.is_empty(), "zero batch should return empty Vec");
312    }
313
314    #[test]
315    fn forward_batch_error_wrong_total_length() {
316        let (enc, _) = make_tiny_encoder(9);
317        let vc = &enc.config.vit_config;
318        let single_len = vc.in_chans * vc.img_size * vc.img_size;
319        // One pixel too few for batch_size=2.
320        let flat = vec![0.0f32; 2 * single_len - 1];
321        let r = enc.forward_batch(&flat, 2);
322        assert!(
323            matches!(r, Err(VisionError::DimensionMismatch { .. })),
324            "expected DimensionMismatch, got {:?}",
325            r
326        );
327    }
328
329    #[test]
330    fn forward_batch_matches_individual() {
331        // batch forward should equal individual forward calls.
332        let (enc, embed_dim) = make_tiny_encoder(10);
333        let vc = &enc.config.vit_config;
334        let single_len = vc.in_chans * vc.img_size * vc.img_size;
335        let batch_size = 2_usize;
336        let flat: Vec<f32> = (0..batch_size * single_len)
337            .map(|i| i as f32 / (batch_size * single_len) as f32)
338            .collect();
339
340        let batch_results = enc.forward_batch(&flat, batch_size).expect("batch ok");
341
342        for b in 0..batch_size {
343            let single = enc
344                .forward_single(&flat[b * single_len..(b + 1) * single_len])
345                .expect("single ok");
346            for d in 0..embed_dim {
347                assert!(
348                    (batch_results[b][d] - single[d]).abs() < 1e-6,
349                    "batch[{b}][{d}] = {} ≠ single[{d}] = {}",
350                    batch_results[b][d],
351                    single[d]
352                );
353            }
354        }
355    }
356}