Skip to main content

oxicuda_vision/patch_embed/
conv2d_patch.rs

1//! Patch embedder: strided Conv2D producing `[N_patches, embed_dim]`.
2
3use crate::{
4    error::{VisionError, VisionResult},
5    handle::LcgRng,
6};
7
8// ─── Config ──────────────────────────────────────────────────────────────────
9
10/// Configuration for the patch embedder.
11///
12/// The image is assumed to have shape `[in_chans, img_size, img_size]`.
13/// It is split into `(img_size / patch_size)²` non-overlapping patches,
14/// each of size `in_chans × patch_size × patch_size`, which are projected
15/// linearly to `embed_dim`.
16#[derive(Debug, Clone, PartialEq)]
17pub struct PatchEmbedConfig {
18    /// Square image spatial dimension (H = W).
19    pub img_size: usize,
20    /// Patch spatial dimension (P × P windows, stride = P).
21    pub patch_size: usize,
22    /// Number of input channels (e.g., 3 for RGB).
23    pub in_chans: usize,
24    /// Output token embedding dimension.
25    pub embed_dim: usize,
26}
27
28impl PatchEmbedConfig {
29    /// Create and validate a `PatchEmbedConfig`.
30    pub fn new(
31        img_size: usize,
32        patch_size: usize,
33        in_chans: usize,
34        embed_dim: usize,
35    ) -> VisionResult<Self> {
36        if patch_size == 0 || img_size % patch_size != 0 {
37            return Err(VisionError::InvalidPatchSize {
38                patch_size,
39                img_size,
40            });
41        }
42        if embed_dim == 0 {
43            return Err(VisionError::InvalidEmbedDim(embed_dim));
44        }
45        if img_size == 0 || in_chans == 0 {
46            return Err(VisionError::InvalidImageSize {
47                height: img_size,
48                width: img_size,
49                channels: in_chans,
50            });
51        }
52        Ok(Self {
53            img_size,
54            patch_size,
55            in_chans,
56            embed_dim,
57        })
58    }
59
60    /// Number of patches along one spatial dimension.
61    #[must_use]
62    pub fn grid_size(&self) -> usize {
63        self.img_size / self.patch_size
64    }
65
66    /// Total number of patches (CLS token not counted here).
67    #[must_use]
68    pub fn n_patches(&self) -> usize {
69        self.grid_size() * self.grid_size()
70    }
71
72    /// Kernel volume for one filter: `in_chans × patch_size²`.
73    #[must_use]
74    pub fn kernel_vol(&self) -> usize {
75        self.in_chans * self.patch_size * self.patch_size
76    }
77}
78
79// ─── Weights ─────────────────────────────────────────────────────────────────
80
81/// Learnable weights for the patch embedder.
82///
83/// `kernel` has layout `[embed_dim, in_chans, patch_size, patch_size]`
84/// (row-major, C-contiguous): filter `e` occupies
85/// `kernel[e * kernel_vol .. (e+1) * kernel_vol]`.
86pub struct PatchEmbedWeights {
87    /// Conv2D kernel: flat `[embed_dim × in_chans × P × P]`.
88    pub kernel: Vec<f32>,
89    /// Bias: flat `[embed_dim]`.
90    pub bias: Vec<f32>,
91    /// CLS token: flat `[embed_dim]`.
92    pub cls_token: Vec<f32>,
93}
94
95impl PatchEmbedWeights {
96    /// Xavier/He-style default init: N(0, 1/√(kernel_vol)).
97    pub fn default_init(cfg: &PatchEmbedConfig, rng: &mut LcgRng) -> Self {
98        let kv = cfg.kernel_vol();
99        let scale = 1.0 / (kv as f32).sqrt();
100        let n_kernel = cfg.embed_dim * kv;
101
102        let mut kernel = vec![0.0f32; n_kernel];
103        rng.fill_normal(&mut kernel);
104        for v in &mut kernel {
105            *v *= scale;
106        }
107
108        let mut bias = vec![0.0f32; cfg.embed_dim];
109        rng.fill_normal(&mut bias);
110        for v in &mut bias {
111            *v *= 0.01;
112        }
113
114        let mut cls_token = vec![0.0f32; cfg.embed_dim];
115        rng.fill_normal(&mut cls_token);
116        for v in &mut cls_token {
117            *v *= 0.02;
118        }
119
120        Self {
121            kernel,
122            bias,
123            cls_token,
124        }
125    }
126}
127
128// ─── PatchEmbed ──────────────────────────────────────────────────────────────
129
130/// Patch embedder: converts a CHW image to a `[N_patches, embed_dim]` token
131/// sequence via a strided Conv2D with `stride = kernel_size = patch_size`.
132pub struct PatchEmbed {
133    pub config: PatchEmbedConfig,
134    pub weights: PatchEmbedWeights,
135}
136
137impl PatchEmbed {
138    /// Create a new `PatchEmbed` with Xavier-initialised weights.
139    pub fn new(cfg: PatchEmbedConfig, rng: &mut LcgRng) -> Self {
140        let weights = PatchEmbedWeights::default_init(&cfg, rng);
141        Self {
142            config: cfg,
143            weights,
144        }
145    }
146
147    /// Forward pass: `image` is flat `[in_chans, img_size, img_size]` CHW.
148    ///
149    /// Returns `[n_patches, embed_dim]` flat row-major.
150    pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
151        let cfg = &self.config;
152        let expected = cfg.in_chans * cfg.img_size * cfg.img_size;
153        if image.len() != expected {
154            return Err(VisionError::DimensionMismatch {
155                expected,
156                got: image.len(),
157            });
158        }
159
160        let n_patches = cfg.n_patches();
161        let grid = cfg.grid_size();
162        let p = cfg.patch_size;
163        let c = cfg.in_chans;
164        let e = cfg.embed_dim;
165        let kv = cfg.kernel_vol(); // c * p * p
166
167        let mut out = vec![0.0f32; n_patches * e];
168
169        // For each patch (ph, pw) and each output channel ed:
170        //   out[ph*grid + pw, ed] = bias[ed] + Σ_{ci,pi,pj} kernel[ed, ci, pi, pj] * image[ci, ph*p+pi, pw*p+pj]
171        for ph in 0..grid {
172            for pw in 0..grid {
173                let patch_idx = ph * grid + pw;
174                for ed in 0..e {
175                    let mut acc = self.weights.bias[ed];
176                    // Kernel for output channel `ed`: slice [ed*kv .. (ed+1)*kv]
177                    let k_off = ed * kv;
178                    for ci in 0..c {
179                        for pi in 0..p {
180                            for pj in 0..p {
181                                let k_idx = k_off + ci * p * p + pi * p + pj;
182                                let img_row = ph * p + pi;
183                                let img_col = pw * p + pj;
184                                let img_idx = ci * cfg.img_size * cfg.img_size
185                                    + img_row * cfg.img_size
186                                    + img_col;
187                                acc += self.weights.kernel[k_idx] * image[img_idx];
188                            }
189                        }
190                    }
191                    out[patch_idx * e + ed] = acc;
192                }
193            }
194        }
195
196        Ok(out)
197    }
198}
199
200// ─── CLS prepend ─────────────────────────────────────────────────────────────
201
202/// Prepend the CLS token to a `[n_patches, embed_dim]` token sequence.
203///
204/// Returns `[(n_patches+1) * embed_dim]` flat, with the CLS token at index 0.
205pub fn prepend_cls(tokens: &[f32], cls: &[f32], embed_dim: usize) -> VisionResult<Vec<f32>> {
206    let n_tok = tokens.len() / embed_dim;
207    if tokens.len() != n_tok * embed_dim {
208        return Err(VisionError::DimensionMismatch {
209            expected: n_tok * embed_dim,
210            got: tokens.len(),
211        });
212    }
213    if cls.len() != embed_dim {
214        return Err(VisionError::DimensionMismatch {
215            expected: embed_dim,
216            got: cls.len(),
217        });
218    }
219    let mut out = Vec::with_capacity((n_tok + 1) * embed_dim);
220    out.extend_from_slice(cls);
221    out.extend_from_slice(tokens);
222    Ok(out)
223}
224
225// ─── Tests ───────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::handle::LcgRng;
231
232    fn make_cfg() -> PatchEmbedConfig {
233        PatchEmbedConfig::new(16, 4, 3, 8).expect("valid config")
234    }
235
236    #[test]
237    fn config_valid() {
238        let cfg = make_cfg();
239        assert_eq!(cfg.n_patches(), 16); // (16/4)^2
240        assert_eq!(cfg.grid_size(), 4);
241        assert_eq!(cfg.kernel_vol(), 3 * 4 * 4); // c*p*p = 48
242    }
243
244    #[test]
245    fn config_invalid_patch_size_not_dividing() {
246        let r = PatchEmbedConfig::new(16, 5, 3, 8);
247        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
248    }
249
250    #[test]
251    fn config_invalid_patch_size_zero() {
252        let r = PatchEmbedConfig::new(16, 0, 3, 8);
253        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
254    }
255
256    #[test]
257    fn config_invalid_embed_dim_zero() {
258        let r = PatchEmbedConfig::new(16, 4, 3, 0);
259        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
260    }
261
262    #[test]
263    fn forward_output_shape() {
264        let cfg = make_cfg(); // 16×16×3, p=4 → 16 patches, embed=8
265        let mut rng = LcgRng::new(1);
266        let pe = PatchEmbed::new(cfg.clone(), &mut rng);
267        let image = vec![0.5f32; 3 * 16 * 16];
268        let out = pe.forward(&image).expect("forward ok");
269        assert_eq!(out.len(), cfg.n_patches() * cfg.embed_dim);
270    }
271
272    #[test]
273    fn forward_wrong_image_size_errors() {
274        let cfg = make_cfg();
275        let mut rng = LcgRng::new(2);
276        let pe = PatchEmbed::new(cfg, &mut rng);
277        let image = vec![0.5f32; 10]; // wrong
278        let r = pe.forward(&image);
279        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
280    }
281
282    #[test]
283    fn forward_zero_image_is_bias() {
284        let cfg = make_cfg();
285        let mut rng = LcgRng::new(3);
286        let pe = PatchEmbed::new(cfg.clone(), &mut rng);
287        let image = vec![0.0f32; 3 * 16 * 16];
288        let out = pe.forward(&image).expect("forward ok");
289        // With zero input, output = bias, so patch 0 channel 0 = bias[0]
290        let diff = (out[0] - pe.weights.bias[0]).abs();
291        assert!(
292            diff < 1e-6,
293            "expected bias={}, got {}",
294            pe.weights.bias[0],
295            out[0]
296        );
297    }
298
299    #[test]
300    fn forward_finite_random_input() {
301        let cfg = PatchEmbedConfig::new(32, 4, 3, 64).expect("valid");
302        let mut rng = LcgRng::new(7);
303        let pe = PatchEmbed::new(cfg.clone(), &mut rng);
304        let mut image = vec![0.0f32; 3 * 32 * 32];
305        rng.fill_normal(&mut image);
306        let out = pe.forward(&image).expect("forward ok");
307        assert!(
308            out.iter().all(|v| v.is_finite()),
309            "output contains non-finite"
310        );
311    }
312
313    #[test]
314    fn prepend_cls_shape() {
315        let tokens = vec![1.0f32; 16 * 8]; // 16 patches, embed=8
316        let cls = vec![0.0f32; 8];
317        let out = prepend_cls(&tokens, &cls, 8).expect("ok");
318        assert_eq!(out.len(), 17 * 8);
319        // First row is the CLS token
320        assert!(out[..8].iter().all(|&v| v == 0.0));
321        // Next row is the first patch token
322        assert_eq!(out[8..16], tokens[..8]);
323    }
324
325    #[test]
326    fn prepend_cls_wrong_cls_dim_errors() {
327        let tokens = vec![1.0f32; 16 * 8];
328        let cls = vec![0.0f32; 4]; // wrong
329        let r = prepend_cls(&tokens, &cls, 8);
330        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
331    }
332
333    #[test]
334    fn weights_default_init_correct_size() {
335        let cfg = make_cfg();
336        let mut rng = LcgRng::new(42);
337        let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
338        assert_eq!(w.kernel.len(), cfg.embed_dim * cfg.kernel_vol());
339        assert_eq!(w.bias.len(), cfg.embed_dim);
340        assert_eq!(w.cls_token.len(), cfg.embed_dim);
341    }
342
343    #[test]
344    fn weights_default_init_finite() {
345        let cfg = make_cfg();
346        let mut rng = LcgRng::new(99);
347        let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
348        assert!(w.kernel.iter().all(|v| v.is_finite()));
349        assert!(w.bias.iter().all(|v| v.is_finite()));
350        assert!(w.cls_token.iter().all(|v| v.is_finite()));
351    }
352
353    #[test]
354    fn patch_embed_different_seeds_differ() {
355        let cfg = make_cfg();
356        let image = vec![0.5f32; 3 * 16 * 16];
357        let mut rng1 = LcgRng::new(1);
358        let mut rng2 = LcgRng::new(2);
359        let pe1 = PatchEmbed::new(cfg.clone(), &mut rng1);
360        let pe2 = PatchEmbed::new(cfg, &mut rng2);
361        let out1 = pe1.forward(&image).expect("ok");
362        let out2 = pe2.forward(&image).expect("ok");
363        // Different kernels should yield different outputs
364        assert!(
365            out1.iter()
366                .zip(out2.iter())
367                .any(|(a, b)| (a - b).abs() > 1e-6)
368        );
369    }
370}