oxicuda_vision/patch_embed/
conv2d_patch.rs1use crate::{
4 error::{VisionError, VisionResult},
5 handle::LcgRng,
6};
7
8#[derive(Debug, Clone, PartialEq)]
17pub struct PatchEmbedConfig {
18 pub img_size: usize,
20 pub patch_size: usize,
22 pub in_chans: usize,
24 pub embed_dim: usize,
26}
27
28impl PatchEmbedConfig {
29 pub fn new(
31 img_size: usize,
32 patch_size: usize,
33 in_chans: usize,
34 embed_dim: usize,
35 ) -> VisionResult<Self> {
36 if patch_size == 0 || img_size % patch_size != 0 {
37 return Err(VisionError::InvalidPatchSize {
38 patch_size,
39 img_size,
40 });
41 }
42 if embed_dim == 0 {
43 return Err(VisionError::InvalidEmbedDim(embed_dim));
44 }
45 if img_size == 0 || in_chans == 0 {
46 return Err(VisionError::InvalidImageSize {
47 height: img_size,
48 width: img_size,
49 channels: in_chans,
50 });
51 }
52 Ok(Self {
53 img_size,
54 patch_size,
55 in_chans,
56 embed_dim,
57 })
58 }
59
60 #[must_use]
62 pub fn grid_size(&self) -> usize {
63 self.img_size / self.patch_size
64 }
65
66 #[must_use]
68 pub fn n_patches(&self) -> usize {
69 self.grid_size() * self.grid_size()
70 }
71
72 #[must_use]
74 pub fn kernel_vol(&self) -> usize {
75 self.in_chans * self.patch_size * self.patch_size
76 }
77}
78
79pub struct PatchEmbedWeights {
87 pub kernel: Vec<f32>,
89 pub bias: Vec<f32>,
91 pub cls_token: Vec<f32>,
93}
94
95impl PatchEmbedWeights {
96 pub fn default_init(cfg: &PatchEmbedConfig, rng: &mut LcgRng) -> Self {
98 let kv = cfg.kernel_vol();
99 let scale = 1.0 / (kv as f32).sqrt();
100 let n_kernel = cfg.embed_dim * kv;
101
102 let mut kernel = vec![0.0f32; n_kernel];
103 rng.fill_normal(&mut kernel);
104 for v in &mut kernel {
105 *v *= scale;
106 }
107
108 let mut bias = vec![0.0f32; cfg.embed_dim];
109 rng.fill_normal(&mut bias);
110 for v in &mut bias {
111 *v *= 0.01;
112 }
113
114 let mut cls_token = vec![0.0f32; cfg.embed_dim];
115 rng.fill_normal(&mut cls_token);
116 for v in &mut cls_token {
117 *v *= 0.02;
118 }
119
120 Self {
121 kernel,
122 bias,
123 cls_token,
124 }
125 }
126}
127
128pub struct PatchEmbed {
133 pub config: PatchEmbedConfig,
134 pub weights: PatchEmbedWeights,
135}
136
137impl PatchEmbed {
138 pub fn new(cfg: PatchEmbedConfig, rng: &mut LcgRng) -> Self {
140 let weights = PatchEmbedWeights::default_init(&cfg, rng);
141 Self {
142 config: cfg,
143 weights,
144 }
145 }
146
147 pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
151 let cfg = &self.config;
152 let expected = cfg.in_chans * cfg.img_size * cfg.img_size;
153 if image.len() != expected {
154 return Err(VisionError::DimensionMismatch {
155 expected,
156 got: image.len(),
157 });
158 }
159
160 let n_patches = cfg.n_patches();
161 let grid = cfg.grid_size();
162 let p = cfg.patch_size;
163 let c = cfg.in_chans;
164 let e = cfg.embed_dim;
165 let kv = cfg.kernel_vol(); let mut out = vec![0.0f32; n_patches * e];
168
169 for ph in 0..grid {
172 for pw in 0..grid {
173 let patch_idx = ph * grid + pw;
174 for ed in 0..e {
175 let mut acc = self.weights.bias[ed];
176 let k_off = ed * kv;
178 for ci in 0..c {
179 for pi in 0..p {
180 for pj in 0..p {
181 let k_idx = k_off + ci * p * p + pi * p + pj;
182 let img_row = ph * p + pi;
183 let img_col = pw * p + pj;
184 let img_idx = ci * cfg.img_size * cfg.img_size
185 + img_row * cfg.img_size
186 + img_col;
187 acc += self.weights.kernel[k_idx] * image[img_idx];
188 }
189 }
190 }
191 out[patch_idx * e + ed] = acc;
192 }
193 }
194 }
195
196 Ok(out)
197 }
198}
199
200pub fn prepend_cls(tokens: &[f32], cls: &[f32], embed_dim: usize) -> VisionResult<Vec<f32>> {
206 let n_tok = tokens.len() / embed_dim;
207 if tokens.len() != n_tok * embed_dim {
208 return Err(VisionError::DimensionMismatch {
209 expected: n_tok * embed_dim,
210 got: tokens.len(),
211 });
212 }
213 if cls.len() != embed_dim {
214 return Err(VisionError::DimensionMismatch {
215 expected: embed_dim,
216 got: cls.len(),
217 });
218 }
219 let mut out = Vec::with_capacity((n_tok + 1) * embed_dim);
220 out.extend_from_slice(cls);
221 out.extend_from_slice(tokens);
222 Ok(out)
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::handle::LcgRng;
231
232 fn make_cfg() -> PatchEmbedConfig {
233 PatchEmbedConfig::new(16, 4, 3, 8).expect("valid config")
234 }
235
236 #[test]
237 fn config_valid() {
238 let cfg = make_cfg();
239 assert_eq!(cfg.n_patches(), 16); assert_eq!(cfg.grid_size(), 4);
241 assert_eq!(cfg.kernel_vol(), 3 * 4 * 4); }
243
244 #[test]
245 fn config_invalid_patch_size_not_dividing() {
246 let r = PatchEmbedConfig::new(16, 5, 3, 8);
247 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
248 }
249
250 #[test]
251 fn config_invalid_patch_size_zero() {
252 let r = PatchEmbedConfig::new(16, 0, 3, 8);
253 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
254 }
255
256 #[test]
257 fn config_invalid_embed_dim_zero() {
258 let r = PatchEmbedConfig::new(16, 4, 3, 0);
259 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
260 }
261
262 #[test]
263 fn forward_output_shape() {
264 let cfg = make_cfg(); let mut rng = LcgRng::new(1);
266 let pe = PatchEmbed::new(cfg.clone(), &mut rng);
267 let image = vec![0.5f32; 3 * 16 * 16];
268 let out = pe.forward(&image).expect("forward ok");
269 assert_eq!(out.len(), cfg.n_patches() * cfg.embed_dim);
270 }
271
272 #[test]
273 fn forward_wrong_image_size_errors() {
274 let cfg = make_cfg();
275 let mut rng = LcgRng::new(2);
276 let pe = PatchEmbed::new(cfg, &mut rng);
277 let image = vec![0.5f32; 10]; let r = pe.forward(&image);
279 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
280 }
281
282 #[test]
283 fn forward_zero_image_is_bias() {
284 let cfg = make_cfg();
285 let mut rng = LcgRng::new(3);
286 let pe = PatchEmbed::new(cfg.clone(), &mut rng);
287 let image = vec![0.0f32; 3 * 16 * 16];
288 let out = pe.forward(&image).expect("forward ok");
289 let diff = (out[0] - pe.weights.bias[0]).abs();
291 assert!(
292 diff < 1e-6,
293 "expected bias={}, got {}",
294 pe.weights.bias[0],
295 out[0]
296 );
297 }
298
299 #[test]
300 fn forward_finite_random_input() {
301 let cfg = PatchEmbedConfig::new(32, 4, 3, 64).expect("valid");
302 let mut rng = LcgRng::new(7);
303 let pe = PatchEmbed::new(cfg.clone(), &mut rng);
304 let mut image = vec![0.0f32; 3 * 32 * 32];
305 rng.fill_normal(&mut image);
306 let out = pe.forward(&image).expect("forward ok");
307 assert!(
308 out.iter().all(|v| v.is_finite()),
309 "output contains non-finite"
310 );
311 }
312
313 #[test]
314 fn prepend_cls_shape() {
315 let tokens = vec![1.0f32; 16 * 8]; let cls = vec![0.0f32; 8];
317 let out = prepend_cls(&tokens, &cls, 8).expect("ok");
318 assert_eq!(out.len(), 17 * 8);
319 assert!(out[..8].iter().all(|&v| v == 0.0));
321 assert_eq!(out[8..16], tokens[..8]);
323 }
324
325 #[test]
326 fn prepend_cls_wrong_cls_dim_errors() {
327 let tokens = vec![1.0f32; 16 * 8];
328 let cls = vec![0.0f32; 4]; let r = prepend_cls(&tokens, &cls, 8);
330 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
331 }
332
333 #[test]
334 fn weights_default_init_correct_size() {
335 let cfg = make_cfg();
336 let mut rng = LcgRng::new(42);
337 let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
338 assert_eq!(w.kernel.len(), cfg.embed_dim * cfg.kernel_vol());
339 assert_eq!(w.bias.len(), cfg.embed_dim);
340 assert_eq!(w.cls_token.len(), cfg.embed_dim);
341 }
342
343 #[test]
344 fn weights_default_init_finite() {
345 let cfg = make_cfg();
346 let mut rng = LcgRng::new(99);
347 let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
348 assert!(w.kernel.iter().all(|v| v.is_finite()));
349 assert!(w.bias.iter().all(|v| v.is_finite()));
350 assert!(w.cls_token.iter().all(|v| v.is_finite()));
351 }
352
353 #[test]
354 fn patch_embed_different_seeds_differ() {
355 let cfg = make_cfg();
356 let image = vec![0.5f32; 3 * 16 * 16];
357 let mut rng1 = LcgRng::new(1);
358 let mut rng2 = LcgRng::new(2);
359 let pe1 = PatchEmbed::new(cfg.clone(), &mut rng1);
360 let pe2 = PatchEmbed::new(cfg, &mut rng2);
361 let out1 = pe1.forward(&image).expect("ok");
362 let out2 = pe2.forward(&image).expect("ok");
363 assert!(
365 out1.iter()
366 .zip(out2.iter())
367 .any(|(a, b)| (a - b).abs() > 1e-6)
368 );
369 }
370}