Skip to main content

oxicuda_vision/vit/
vit_patch.rs

1//! Vision Transformer (ViT) patch embedding — Dosovitskiy et al. 2021.
2//!
3//! "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale"
4//! (Dosovitskiy et al., ICLR 2021) splits an image into a regular grid of
5//! non-overlapping square patches, flattens each patch into a vector, and
6//! projects it linearly into a `d_model`-dimensional embedding. A learnable
7//! `[CLS]` token is prepended, and a learnable positional embedding is added to
8//! the resulting `n_patches + 1` token sequence.
9//!
10//! ## Pipeline
11//! ```text
12//! image [C × H × W]
13//!   ──split──▶ patches [(H/P)·(W/P) × (C·P·P)]
14//!   ──proj───▶ tokens  [n_patches × d_model]   (linear: W·patch + b)
15//!   ──cls────▶        [(n_patches+1) × d_model] (prepend [CLS] token)
16//!   ──pos────▶        [(n_patches+1) × d_model] (+ learnable position embed)
17//! ```
18//!
19//! Unlike [`crate::patch_embed::PatchEmbed`] — which exposes the strided-conv
20//! view and matches the `patch_embed_ptx` kernel layout — this module follows
21//! the original ViT formulation literally: an explicit per-patch flatten
22//! (`C·P·P`) followed by a dense projection, the `[CLS]` prepend, and the
23//! positional embedding fused into a single [`VitPatchEmbed::forward`] call.
24//!
25//! All tensors use flat row-major `Vec<f32>` layouts and the forward pass runs
26//! on the CPU.
27
28use crate::{
29    blocks::VisionRng,
30    error::{VisionError, VisionResult},
31};
32
33// ─── Config ──────────────────────────────────────────────────────────────────
34
35/// Configuration for [`VitPatchEmbed`].
36///
37/// The image is `[n_channels × image_size × image_size]`. It is tiled into
38/// `(image_size / patch_size)²` non-overlapping patches, each flattened to
39/// `patch_size² · n_channels` values and projected to `d_model`.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VitPatchConfig {
42    /// Square image spatial size (height = width).
43    pub image_size: usize,
44    /// Square patch spatial size (stride = `patch_size`).
45    pub patch_size: usize,
46    /// Number of input channels (e.g. 3 for RGB).
47    pub n_channels: usize,
48    /// Output token embedding dimension.
49    pub d_model: usize,
50}
51
52impl VitPatchConfig {
53    /// Validate the configuration.
54    ///
55    /// # Errors
56    /// - [`VisionError::InvalidPatchSize`] if `patch_size == 0` or `image_size`
57    ///   is not an exact multiple of `patch_size`.
58    /// - [`VisionError::InvalidEmbedDim`] if `d_model == 0`.
59    /// - [`VisionError::InvalidImageSize`] if `image_size == 0` or
60    ///   `n_channels == 0`.
61    pub fn validate(&self) -> VisionResult<()> {
62        if self.patch_size == 0 || self.image_size % self.patch_size != 0 {
63            return Err(VisionError::InvalidPatchSize {
64                patch_size: self.patch_size,
65                img_size: self.image_size,
66            });
67        }
68        if self.d_model == 0 {
69            return Err(VisionError::InvalidEmbedDim(self.d_model));
70        }
71        if self.image_size == 0 || self.n_channels == 0 {
72            return Err(VisionError::InvalidImageSize {
73                height: self.image_size,
74                width: self.image_size,
75                channels: self.n_channels,
76            });
77        }
78        Ok(())
79    }
80
81    /// Patches along one spatial axis: `image_size / patch_size`.
82    #[must_use]
83    #[inline]
84    pub fn grid_size(&self) -> usize {
85        self.image_size / self.patch_size
86    }
87
88    /// Flattened patch length: `patch_size² · n_channels`.
89    #[must_use]
90    #[inline]
91    pub fn patch_dim(&self) -> usize {
92        self.patch_size * self.patch_size * self.n_channels
93    }
94}
95
96// ─── VitPatchEmbed ───────────────────────────────────────────────────────────
97
98/// ViT patch-embedding layer with `[CLS]` token and learnable positions.
99pub struct VitPatchEmbed {
100    /// Projection weight, flat `[d_model × (patch_size² · n_channels)]`,
101    /// row-major (output-major).
102    proj_w: Vec<f32>,
103    /// Projection bias, flat `[d_model]`.
104    proj_b: Vec<f32>,
105    /// Learnable class token, flat `[d_model]`.
106    cls_token: Vec<f32>,
107    /// Learnable positional embedding, flat `[(n_patches + 1) × d_model]`.
108    pos_emb: Vec<f32>,
109    /// Validated configuration.
110    config: VitPatchConfig,
111}
112
113impl VitPatchEmbed {
114    /// Construct a new patch embedder with random parameters.
115    ///
116    /// The projection weight is initialised `N(0, 1/√patch_dim)` (the fan-in
117    /// scaling used by the reference implementations); biases are near-zero,
118    /// the class token is `N(0, 0.02)` (ViT's truncated-normal scale), and the
119    /// positional embedding is `N(0, 0.02)`.
120    ///
121    /// # Errors
122    /// Propagates [`VitPatchConfig::validate`] failures.
123    pub fn new(config: VitPatchConfig, rng: &mut VisionRng) -> VisionResult<Self> {
124        config.validate()?;
125
126        let patch_dim = config.patch_dim();
127        let d_model = config.d_model;
128        let n_tokens = config.grid_size() * config.grid_size() + 1;
129
130        let scale = 1.0 / (patch_dim as f32).sqrt();
131        let mut proj_w = vec![0.0_f32; d_model * patch_dim];
132        rng.fill_normal(&mut proj_w);
133        for w in &mut proj_w {
134            *w *= scale;
135        }
136
137        let mut proj_b = vec![0.0_f32; d_model];
138        rng.fill_normal(&mut proj_b);
139        for b in &mut proj_b {
140            *b *= 0.01;
141        }
142
143        let mut cls_token = vec![0.0_f32; d_model];
144        rng.fill_normal(&mut cls_token);
145        for c in &mut cls_token {
146            *c *= 0.02;
147        }
148
149        let mut pos_emb = vec![0.0_f32; n_tokens * d_model];
150        rng.fill_normal(&mut pos_emb);
151        for p in &mut pos_emb {
152            *p *= 0.02;
153        }
154
155        Ok(Self {
156            proj_w,
157            proj_b,
158            cls_token,
159            pos_emb,
160            config,
161        })
162    }
163
164    /// Read-only access to the configuration.
165    #[must_use]
166    #[inline]
167    pub fn config(&self) -> &VitPatchConfig {
168        &self.config
169    }
170
171    /// Number of image patches: `(image_size / patch_size)²`.
172    #[must_use]
173    #[inline]
174    pub fn n_patches(&self) -> usize {
175        self.config.grid_size() * self.config.grid_size()
176    }
177
178    /// Embed an image into a token sequence.
179    ///
180    /// `image` must be `[n_channels × image_size × image_size]` row-major (CHW).
181    /// The returned sequence is `[(n_patches + 1) × d_model]`: a prepended
182    /// `[CLS]` token followed by the `n_patches` projected patch tokens, with
183    /// the learnable positional embedding added element-wise.
184    ///
185    /// # Errors
186    /// - [`VisionError::DimensionMismatch`] if `image.len()` does not equal
187    ///   `n_channels · image_size²`.
188    /// - [`VisionError::NonFinite`] if a non-finite value is produced.
189    pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
190        let cfg = &self.config;
191        let expected = cfg.n_channels * cfg.image_size * cfg.image_size;
192        if image.len() != expected {
193            return Err(VisionError::DimensionMismatch {
194                expected,
195                got: image.len(),
196            });
197        }
198
199        let grid = cfg.grid_size();
200        let patch = cfg.patch_size;
201        let n_patches = grid * grid;
202        let patch_dim = cfg.patch_dim();
203        let d_model = cfg.d_model;
204        let img = cfg.image_size;
205        let plane = img * img;
206
207        // Output: [(n_patches + 1) × d_model]; token 0 is CLS.
208        let mut out = vec![0.0_f32; (n_patches + 1) * d_model];
209
210        // ── CLS token (row 0) ────────────────────────────────────────────────
211        out[..d_model].copy_from_slice(&self.cls_token);
212
213        // ── Patch projection (rows 1..=n_patches) ────────────────────────────
214        // Scratch buffer for one flattened patch in (C, ph, pw) order so the
215        // projection weight rows line up with `[c · P² + ph · P + pw]`.
216        let mut flat = vec![0.0_f32; patch_dim];
217
218        for gy in 0..grid {
219            for gx in 0..grid {
220                // Flatten patch (gy, gx) → `flat[c·P² + ph·P + pw]`.
221                for c in 0..cfg.n_channels {
222                    let chan_base = c * plane;
223                    let dst_chan = c * patch * patch;
224                    for ph in 0..patch {
225                        let row = gy * patch + ph;
226                        let src_row = chan_base + row * img + gx * patch;
227                        let dst_row = dst_chan + ph * patch;
228                        flat[dst_row..dst_row + patch]
229                            .copy_from_slice(&image[src_row..src_row + patch]);
230                    }
231                }
232
233                // Linear projection: token = W · flat + b.
234                let patch_idx = gy * grid + gx;
235                let out_base = (patch_idx + 1) * d_model;
236                for o in 0..d_model {
237                    let w_row = &self.proj_w[o * patch_dim..(o + 1) * patch_dim];
238                    let mut acc = self.proj_b[o];
239                    for (wv, fv) in w_row.iter().zip(flat.iter()) {
240                        acc += wv * fv;
241                    }
242                    out[out_base + o] = acc;
243                }
244            }
245        }
246
247        // ── Positional embedding (element-wise add over all tokens) ──────────
248        for (o, p) in out.iter_mut().zip(self.pos_emb.iter()) {
249            *o += *p;
250        }
251
252        if out.iter().any(|v| !v.is_finite()) {
253            return Err(VisionError::NonFinite("ViT patch embedding output"));
254        }
255        Ok(out)
256    }
257}
258
259// ─── Tests ───────────────────────────────────────────────────────────────────
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::handle::LcgRng;
265
266    fn cfg() -> VitPatchConfig {
267        VitPatchConfig {
268            image_size: 16,
269            patch_size: 4,
270            n_channels: 3,
271            d_model: 8,
272        }
273    }
274
275    #[test]
276    fn n_patches_correct() {
277        let mut rng = LcgRng::new(1);
278        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
279        // (16/4)² = 16
280        assert_eq!(pe.n_patches(), 16);
281    }
282
283    #[test]
284    fn forward_shape() {
285        let mut rng = LcgRng::new(2);
286        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
287        let image = vec![0.5_f32; 3 * 16 * 16];
288        let out = pe.forward(&image).expect("forward ok");
289        // (16 patches + 1 CLS) × d_model(8)
290        assert_eq!(out.len(), 17 * 8);
291    }
292
293    #[test]
294    fn forward_finite() {
295        let mut rng = LcgRng::new(3);
296        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
297        let mut image = vec![0.0_f32; 3 * 16 * 16];
298        rng.fill_normal(&mut image);
299        let out = pe.forward(&image).expect("ok");
300        assert!(out.iter().all(|v| v.is_finite()));
301    }
302
303    #[test]
304    fn cls_token_prepended() {
305        // With a zero image, the projection of every patch is just proj_b, but
306        // the CLS row equals cls_token + pos_emb[0]; the patch rows equal
307        // proj_b + pos_emb[row]. Verify the CLS row differs from the patch rows
308        // (cls_token is not generally equal to proj_b) and that subtracting the
309        // positional embedding recovers cls_token at row 0.
310        let mut rng = LcgRng::new(4);
311        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
312        let image = vec![0.0_f32; 3 * 16 * 16];
313        let out = pe.forward(&image).expect("ok");
314        let d = pe.config().d_model;
315        for (o, &out_o) in out.iter().enumerate().take(d) {
316            let recovered = out_o - pe.pos_emb[o];
317            assert!(
318                (recovered - pe.cls_token[o]).abs() < 1e-5,
319                "CLS token not recovered at dim {o}"
320            );
321        }
322    }
323
324    #[test]
325    fn image_size_not_divisible_error() {
326        let bad = VitPatchConfig {
327            image_size: 15,
328            patch_size: 4,
329            n_channels: 3,
330            d_model: 8,
331        };
332        let mut rng = LcgRng::new(5);
333        let r = VitPatchEmbed::new(bad, &mut rng);
334        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
335    }
336
337    #[test]
338    fn patch_size_0_error() {
339        let bad = VitPatchConfig {
340            image_size: 16,
341            patch_size: 0,
342            n_channels: 3,
343            d_model: 8,
344        };
345        let mut rng = LcgRng::new(6);
346        let r = VitPatchEmbed::new(bad, &mut rng);
347        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
348    }
349
350    #[test]
351    fn different_images_different_embeds() {
352        let mut rng = LcgRng::new(7);
353        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
354        let img_a = vec![0.2_f32; 3 * 16 * 16];
355        let mut img_b = vec![0.2_f32; 3 * 16 * 16];
356        img_b[0] = 5.0; // perturb a single pixel
357        let out_a = pe.forward(&img_a).expect("ok");
358        let out_b = pe.forward(&img_b).expect("ok");
359        let diff: f32 = out_a
360            .iter()
361            .zip(out_b.iter())
362            .map(|(a, b)| (a - b).abs())
363            .sum();
364        assert!(diff > 1e-6, "embeddings should differ for different images");
365    }
366
367    #[test]
368    fn pos_emb_added() {
369        // Build a layer, then zero the projection + cls + bias and re-run on a
370        // zero image: the output must equal exactly the positional embedding.
371        let mut rng = LcgRng::new(8);
372        let mut pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
373        for w in &mut pe.proj_w {
374            *w = 0.0;
375        }
376        for b in &mut pe.proj_b {
377            *b = 0.0;
378        }
379        for c in &mut pe.cls_token {
380            *c = 0.0;
381        }
382        let image = vec![3.0_f32; 3 * 16 * 16];
383        let out = pe.forward(&image).expect("ok");
384        for (o, p) in out.iter().zip(pe.pos_emb.iter()) {
385            assert!((o - p).abs() < 1e-6, "output must equal pos_emb");
386        }
387    }
388
389    #[test]
390    fn d_model_0_error() {
391        let bad = VitPatchConfig {
392            image_size: 16,
393            patch_size: 4,
394            n_channels: 3,
395            d_model: 0,
396        };
397        let mut rng = LcgRng::new(9);
398        let r = VitPatchEmbed::new(bad, &mut rng);
399        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
400    }
401
402    #[test]
403    fn single_patch() {
404        // image_size == patch_size ⇒ exactly one patch.
405        let single = VitPatchConfig {
406            image_size: 8,
407            patch_size: 8,
408            n_channels: 3,
409            d_model: 4,
410        };
411        let mut rng = LcgRng::new(10);
412        let pe = VitPatchEmbed::new(single, &mut rng).expect("ok");
413        assert_eq!(pe.n_patches(), 1);
414        let image = vec![0.5_f32; 3 * 8 * 8];
415        let out = pe.forward(&image).expect("ok");
416        assert_eq!(out.len(), 2 * 4); // (1 patch + CLS) × 4
417        assert!(out.iter().all(|v| v.is_finite()));
418    }
419
420    #[test]
421    fn n_channels_0_error() {
422        let bad = VitPatchConfig {
423            image_size: 16,
424            patch_size: 4,
425            n_channels: 0,
426            d_model: 8,
427        };
428        let mut rng = LcgRng::new(11);
429        let r = VitPatchEmbed::new(bad, &mut rng);
430        assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
431    }
432
433    #[test]
434    fn forward_wrong_image_len_error() {
435        let mut rng = LcgRng::new(12);
436        let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
437        let image = vec![0.5_f32; 10]; // wrong length
438        let r = pe.forward(&image);
439        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
440    }
441
442    #[test]
443    fn patch_flatten_matches_manual() {
444        // d_model = 1, identity-like weight: set proj_w to select the first
445        // element of the flattened patch and zero bias/cls/pos. Then token 1's
446        // value must equal image[0] (top-left pixel of channel 0).
447        let c = VitPatchConfig {
448            image_size: 4,
449            patch_size: 2,
450            n_channels: 1,
451            d_model: 1,
452        };
453        let mut rng = LcgRng::new(13);
454        let mut pe = VitPatchEmbed::new(c, &mut rng).expect("ok");
455        // patch_dim = 2*2*1 = 4; weight row picks element 0.
456        pe.proj_w = vec![1.0, 0.0, 0.0, 0.0];
457        pe.proj_b = vec![0.0];
458        pe.cls_token = vec![0.0];
459        for p in &mut pe.pos_emb {
460            *p = 0.0;
461        }
462        let mut image = vec![0.0_f32; 16];
463        image[0] = 7.0; // top-left pixel → first element of patch (0,0)
464        let out = pe.forward(&image).expect("ok");
465        // out[0] = CLS = 0; out[1] = token of patch (0,0) = image[0] = 7.
466        assert!((out[1] - 7.0).abs() < 1e-6, "got {}", out[1]);
467    }
468}