1use crate::{
29 blocks::VisionRng,
30 error::{VisionError, VisionResult},
31};
32
33#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VitPatchConfig {
42 pub image_size: usize,
44 pub patch_size: usize,
46 pub n_channels: usize,
48 pub d_model: usize,
50}
51
52impl VitPatchConfig {
53 pub fn validate(&self) -> VisionResult<()> {
62 if self.patch_size == 0 || self.image_size % self.patch_size != 0 {
63 return Err(VisionError::InvalidPatchSize {
64 patch_size: self.patch_size,
65 img_size: self.image_size,
66 });
67 }
68 if self.d_model == 0 {
69 return Err(VisionError::InvalidEmbedDim(self.d_model));
70 }
71 if self.image_size == 0 || self.n_channels == 0 {
72 return Err(VisionError::InvalidImageSize {
73 height: self.image_size,
74 width: self.image_size,
75 channels: self.n_channels,
76 });
77 }
78 Ok(())
79 }
80
81 #[must_use]
83 #[inline]
84 pub fn grid_size(&self) -> usize {
85 self.image_size / self.patch_size
86 }
87
88 #[must_use]
90 #[inline]
91 pub fn patch_dim(&self) -> usize {
92 self.patch_size * self.patch_size * self.n_channels
93 }
94}
95
96pub struct VitPatchEmbed {
100 proj_w: Vec<f32>,
103 proj_b: Vec<f32>,
105 cls_token: Vec<f32>,
107 pos_emb: Vec<f32>,
109 config: VitPatchConfig,
111}
112
113impl VitPatchEmbed {
114 pub fn new(config: VitPatchConfig, rng: &mut VisionRng) -> VisionResult<Self> {
124 config.validate()?;
125
126 let patch_dim = config.patch_dim();
127 let d_model = config.d_model;
128 let n_tokens = config.grid_size() * config.grid_size() + 1;
129
130 let scale = 1.0 / (patch_dim as f32).sqrt();
131 let mut proj_w = vec![0.0_f32; d_model * patch_dim];
132 rng.fill_normal(&mut proj_w);
133 for w in &mut proj_w {
134 *w *= scale;
135 }
136
137 let mut proj_b = vec![0.0_f32; d_model];
138 rng.fill_normal(&mut proj_b);
139 for b in &mut proj_b {
140 *b *= 0.01;
141 }
142
143 let mut cls_token = vec![0.0_f32; d_model];
144 rng.fill_normal(&mut cls_token);
145 for c in &mut cls_token {
146 *c *= 0.02;
147 }
148
149 let mut pos_emb = vec![0.0_f32; n_tokens * d_model];
150 rng.fill_normal(&mut pos_emb);
151 for p in &mut pos_emb {
152 *p *= 0.02;
153 }
154
155 Ok(Self {
156 proj_w,
157 proj_b,
158 cls_token,
159 pos_emb,
160 config,
161 })
162 }
163
164 #[must_use]
166 #[inline]
167 pub fn config(&self) -> &VitPatchConfig {
168 &self.config
169 }
170
171 #[must_use]
173 #[inline]
174 pub fn n_patches(&self) -> usize {
175 self.config.grid_size() * self.config.grid_size()
176 }
177
178 pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
190 let cfg = &self.config;
191 let expected = cfg.n_channels * cfg.image_size * cfg.image_size;
192 if image.len() != expected {
193 return Err(VisionError::DimensionMismatch {
194 expected,
195 got: image.len(),
196 });
197 }
198
199 let grid = cfg.grid_size();
200 let patch = cfg.patch_size;
201 let n_patches = grid * grid;
202 let patch_dim = cfg.patch_dim();
203 let d_model = cfg.d_model;
204 let img = cfg.image_size;
205 let plane = img * img;
206
207 let mut out = vec![0.0_f32; (n_patches + 1) * d_model];
209
210 out[..d_model].copy_from_slice(&self.cls_token);
212
213 let mut flat = vec![0.0_f32; patch_dim];
217
218 for gy in 0..grid {
219 for gx in 0..grid {
220 for c in 0..cfg.n_channels {
222 let chan_base = c * plane;
223 let dst_chan = c * patch * patch;
224 for ph in 0..patch {
225 let row = gy * patch + ph;
226 let src_row = chan_base + row * img + gx * patch;
227 let dst_row = dst_chan + ph * patch;
228 flat[dst_row..dst_row + patch]
229 .copy_from_slice(&image[src_row..src_row + patch]);
230 }
231 }
232
233 let patch_idx = gy * grid + gx;
235 let out_base = (patch_idx + 1) * d_model;
236 for o in 0..d_model {
237 let w_row = &self.proj_w[o * patch_dim..(o + 1) * patch_dim];
238 let mut acc = self.proj_b[o];
239 for (wv, fv) in w_row.iter().zip(flat.iter()) {
240 acc += wv * fv;
241 }
242 out[out_base + o] = acc;
243 }
244 }
245 }
246
247 for (o, p) in out.iter_mut().zip(self.pos_emb.iter()) {
249 *o += *p;
250 }
251
252 if out.iter().any(|v| !v.is_finite()) {
253 return Err(VisionError::NonFinite("ViT patch embedding output"));
254 }
255 Ok(out)
256 }
257}
258
259#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::handle::LcgRng;
265
266 fn cfg() -> VitPatchConfig {
267 VitPatchConfig {
268 image_size: 16,
269 patch_size: 4,
270 n_channels: 3,
271 d_model: 8,
272 }
273 }
274
275 #[test]
276 fn n_patches_correct() {
277 let mut rng = LcgRng::new(1);
278 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
279 assert_eq!(pe.n_patches(), 16);
281 }
282
283 #[test]
284 fn forward_shape() {
285 let mut rng = LcgRng::new(2);
286 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
287 let image = vec![0.5_f32; 3 * 16 * 16];
288 let out = pe.forward(&image).expect("forward ok");
289 assert_eq!(out.len(), 17 * 8);
291 }
292
293 #[test]
294 fn forward_finite() {
295 let mut rng = LcgRng::new(3);
296 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
297 let mut image = vec![0.0_f32; 3 * 16 * 16];
298 rng.fill_normal(&mut image);
299 let out = pe.forward(&image).expect("ok");
300 assert!(out.iter().all(|v| v.is_finite()));
301 }
302
303 #[test]
304 fn cls_token_prepended() {
305 let mut rng = LcgRng::new(4);
311 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
312 let image = vec![0.0_f32; 3 * 16 * 16];
313 let out = pe.forward(&image).expect("ok");
314 let d = pe.config().d_model;
315 for (o, &out_o) in out.iter().enumerate().take(d) {
316 let recovered = out_o - pe.pos_emb[o];
317 assert!(
318 (recovered - pe.cls_token[o]).abs() < 1e-5,
319 "CLS token not recovered at dim {o}"
320 );
321 }
322 }
323
324 #[test]
325 fn image_size_not_divisible_error() {
326 let bad = VitPatchConfig {
327 image_size: 15,
328 patch_size: 4,
329 n_channels: 3,
330 d_model: 8,
331 };
332 let mut rng = LcgRng::new(5);
333 let r = VitPatchEmbed::new(bad, &mut rng);
334 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
335 }
336
337 #[test]
338 fn patch_size_0_error() {
339 let bad = VitPatchConfig {
340 image_size: 16,
341 patch_size: 0,
342 n_channels: 3,
343 d_model: 8,
344 };
345 let mut rng = LcgRng::new(6);
346 let r = VitPatchEmbed::new(bad, &mut rng);
347 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
348 }
349
350 #[test]
351 fn different_images_different_embeds() {
352 let mut rng = LcgRng::new(7);
353 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
354 let img_a = vec![0.2_f32; 3 * 16 * 16];
355 let mut img_b = vec![0.2_f32; 3 * 16 * 16];
356 img_b[0] = 5.0; let out_a = pe.forward(&img_a).expect("ok");
358 let out_b = pe.forward(&img_b).expect("ok");
359 let diff: f32 = out_a
360 .iter()
361 .zip(out_b.iter())
362 .map(|(a, b)| (a - b).abs())
363 .sum();
364 assert!(diff > 1e-6, "embeddings should differ for different images");
365 }
366
367 #[test]
368 fn pos_emb_added() {
369 let mut rng = LcgRng::new(8);
372 let mut pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
373 for w in &mut pe.proj_w {
374 *w = 0.0;
375 }
376 for b in &mut pe.proj_b {
377 *b = 0.0;
378 }
379 for c in &mut pe.cls_token {
380 *c = 0.0;
381 }
382 let image = vec![3.0_f32; 3 * 16 * 16];
383 let out = pe.forward(&image).expect("ok");
384 for (o, p) in out.iter().zip(pe.pos_emb.iter()) {
385 assert!((o - p).abs() < 1e-6, "output must equal pos_emb");
386 }
387 }
388
389 #[test]
390 fn d_model_0_error() {
391 let bad = VitPatchConfig {
392 image_size: 16,
393 patch_size: 4,
394 n_channels: 3,
395 d_model: 0,
396 };
397 let mut rng = LcgRng::new(9);
398 let r = VitPatchEmbed::new(bad, &mut rng);
399 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
400 }
401
402 #[test]
403 fn single_patch() {
404 let single = VitPatchConfig {
406 image_size: 8,
407 patch_size: 8,
408 n_channels: 3,
409 d_model: 4,
410 };
411 let mut rng = LcgRng::new(10);
412 let pe = VitPatchEmbed::new(single, &mut rng).expect("ok");
413 assert_eq!(pe.n_patches(), 1);
414 let image = vec![0.5_f32; 3 * 8 * 8];
415 let out = pe.forward(&image).expect("ok");
416 assert_eq!(out.len(), 2 * 4); assert!(out.iter().all(|v| v.is_finite()));
418 }
419
420 #[test]
421 fn n_channels_0_error() {
422 let bad = VitPatchConfig {
423 image_size: 16,
424 patch_size: 4,
425 n_channels: 0,
426 d_model: 8,
427 };
428 let mut rng = LcgRng::new(11);
429 let r = VitPatchEmbed::new(bad, &mut rng);
430 assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
431 }
432
433 #[test]
434 fn forward_wrong_image_len_error() {
435 let mut rng = LcgRng::new(12);
436 let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
437 let image = vec![0.5_f32; 10]; let r = pe.forward(&image);
439 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
440 }
441
442 #[test]
443 fn patch_flatten_matches_manual() {
444 let c = VitPatchConfig {
448 image_size: 4,
449 patch_size: 2,
450 n_channels: 1,
451 d_model: 1,
452 };
453 let mut rng = LcgRng::new(13);
454 let mut pe = VitPatchEmbed::new(c, &mut rng).expect("ok");
455 pe.proj_w = vec![1.0, 0.0, 0.0, 0.0];
457 pe.proj_b = vec![0.0];
458 pe.cls_token = vec![0.0];
459 for p in &mut pe.pos_emb {
460 *p = 0.0;
461 }
462 let mut image = vec![0.0_f32; 16];
463 image[0] = 7.0; let out = pe.forward(&image).expect("ok");
465 assert!((out[1] - 7.0).abs() < 1e-6, "got {}", out[1]);
467 }
468}