use crate::{
blocks::VisionRng,
error::{VisionError, VisionResult},
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VitPatchConfig {
pub image_size: usize,
pub patch_size: usize,
pub n_channels: usize,
pub d_model: usize,
}
impl VitPatchConfig {
pub fn validate(&self) -> VisionResult<()> {
if self.patch_size == 0 || self.image_size % self.patch_size != 0 {
return Err(VisionError::InvalidPatchSize {
patch_size: self.patch_size,
img_size: self.image_size,
});
}
if self.d_model == 0 {
return Err(VisionError::InvalidEmbedDim(self.d_model));
}
if self.image_size == 0 || self.n_channels == 0 {
return Err(VisionError::InvalidImageSize {
height: self.image_size,
width: self.image_size,
channels: self.n_channels,
});
}
Ok(())
}
#[must_use]
#[inline]
pub fn grid_size(&self) -> usize {
self.image_size / self.patch_size
}
#[must_use]
#[inline]
pub fn patch_dim(&self) -> usize {
self.patch_size * self.patch_size * self.n_channels
}
}
pub struct VitPatchEmbed {
proj_w: Vec<f32>,
proj_b: Vec<f32>,
cls_token: Vec<f32>,
pos_emb: Vec<f32>,
config: VitPatchConfig,
}
impl VitPatchEmbed {
pub fn new(config: VitPatchConfig, rng: &mut VisionRng) -> VisionResult<Self> {
config.validate()?;
let patch_dim = config.patch_dim();
let d_model = config.d_model;
let n_tokens = config.grid_size() * config.grid_size() + 1;
let scale = 1.0 / (patch_dim as f32).sqrt();
let mut proj_w = vec![0.0_f32; d_model * patch_dim];
rng.fill_normal(&mut proj_w);
for w in &mut proj_w {
*w *= scale;
}
let mut proj_b = vec![0.0_f32; d_model];
rng.fill_normal(&mut proj_b);
for b in &mut proj_b {
*b *= 0.01;
}
let mut cls_token = vec![0.0_f32; d_model];
rng.fill_normal(&mut cls_token);
for c in &mut cls_token {
*c *= 0.02;
}
let mut pos_emb = vec![0.0_f32; n_tokens * d_model];
rng.fill_normal(&mut pos_emb);
for p in &mut pos_emb {
*p *= 0.02;
}
Ok(Self {
proj_w,
proj_b,
cls_token,
pos_emb,
config,
})
}
#[must_use]
#[inline]
pub fn config(&self) -> &VitPatchConfig {
&self.config
}
#[must_use]
#[inline]
pub fn n_patches(&self) -> usize {
self.config.grid_size() * self.config.grid_size()
}
pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
let cfg = &self.config;
let expected = cfg.n_channels * cfg.image_size * cfg.image_size;
if image.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: image.len(),
});
}
let grid = cfg.grid_size();
let patch = cfg.patch_size;
let n_patches = grid * grid;
let patch_dim = cfg.patch_dim();
let d_model = cfg.d_model;
let img = cfg.image_size;
let plane = img * img;
let mut out = vec![0.0_f32; (n_patches + 1) * d_model];
out[..d_model].copy_from_slice(&self.cls_token);
let mut flat = vec![0.0_f32; patch_dim];
for gy in 0..grid {
for gx in 0..grid {
for c in 0..cfg.n_channels {
let chan_base = c * plane;
let dst_chan = c * patch * patch;
for ph in 0..patch {
let row = gy * patch + ph;
let src_row = chan_base + row * img + gx * patch;
let dst_row = dst_chan + ph * patch;
flat[dst_row..dst_row + patch]
.copy_from_slice(&image[src_row..src_row + patch]);
}
}
let patch_idx = gy * grid + gx;
let out_base = (patch_idx + 1) * d_model;
for o in 0..d_model {
let w_row = &self.proj_w[o * patch_dim..(o + 1) * patch_dim];
let mut acc = self.proj_b[o];
for (wv, fv) in w_row.iter().zip(flat.iter()) {
acc += wv * fv;
}
out[out_base + o] = acc;
}
}
}
for (o, p) in out.iter_mut().zip(self.pos_emb.iter()) {
*o += *p;
}
if out.iter().any(|v| !v.is_finite()) {
return Err(VisionError::NonFinite("ViT patch embedding output"));
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn cfg() -> VitPatchConfig {
VitPatchConfig {
image_size: 16,
patch_size: 4,
n_channels: 3,
d_model: 8,
}
}
#[test]
fn n_patches_correct() {
let mut rng = LcgRng::new(1);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
assert_eq!(pe.n_patches(), 16);
}
#[test]
fn forward_shape() {
let mut rng = LcgRng::new(2);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
let image = vec![0.5_f32; 3 * 16 * 16];
let out = pe.forward(&image).expect("forward ok");
assert_eq!(out.len(), 17 * 8);
}
#[test]
fn forward_finite() {
let mut rng = LcgRng::new(3);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
let mut image = vec![0.0_f32; 3 * 16 * 16];
rng.fill_normal(&mut image);
let out = pe.forward(&image).expect("ok");
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn cls_token_prepended() {
let mut rng = LcgRng::new(4);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
let image = vec![0.0_f32; 3 * 16 * 16];
let out = pe.forward(&image).expect("ok");
let d = pe.config().d_model;
for (o, &out_o) in out.iter().enumerate().take(d) {
let recovered = out_o - pe.pos_emb[o];
assert!(
(recovered - pe.cls_token[o]).abs() < 1e-5,
"CLS token not recovered at dim {o}"
);
}
}
#[test]
fn image_size_not_divisible_error() {
let bad = VitPatchConfig {
image_size: 15,
patch_size: 4,
n_channels: 3,
d_model: 8,
};
let mut rng = LcgRng::new(5);
let r = VitPatchEmbed::new(bad, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn patch_size_0_error() {
let bad = VitPatchConfig {
image_size: 16,
patch_size: 0,
n_channels: 3,
d_model: 8,
};
let mut rng = LcgRng::new(6);
let r = VitPatchEmbed::new(bad, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn different_images_different_embeds() {
let mut rng = LcgRng::new(7);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
let img_a = vec![0.2_f32; 3 * 16 * 16];
let mut img_b = vec![0.2_f32; 3 * 16 * 16];
img_b[0] = 5.0; let out_a = pe.forward(&img_a).expect("ok");
let out_b = pe.forward(&img_b).expect("ok");
let diff: f32 = out_a
.iter()
.zip(out_b.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff > 1e-6, "embeddings should differ for different images");
}
#[test]
fn pos_emb_added() {
let mut rng = LcgRng::new(8);
let mut pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
for w in &mut pe.proj_w {
*w = 0.0;
}
for b in &mut pe.proj_b {
*b = 0.0;
}
for c in &mut pe.cls_token {
*c = 0.0;
}
let image = vec![3.0_f32; 3 * 16 * 16];
let out = pe.forward(&image).expect("ok");
for (o, p) in out.iter().zip(pe.pos_emb.iter()) {
assert!((o - p).abs() < 1e-6, "output must equal pos_emb");
}
}
#[test]
fn d_model_0_error() {
let bad = VitPatchConfig {
image_size: 16,
patch_size: 4,
n_channels: 3,
d_model: 0,
};
let mut rng = LcgRng::new(9);
let r = VitPatchEmbed::new(bad, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn single_patch() {
let single = VitPatchConfig {
image_size: 8,
patch_size: 8,
n_channels: 3,
d_model: 4,
};
let mut rng = LcgRng::new(10);
let pe = VitPatchEmbed::new(single, &mut rng).expect("ok");
assert_eq!(pe.n_patches(), 1);
let image = vec![0.5_f32; 3 * 8 * 8];
let out = pe.forward(&image).expect("ok");
assert_eq!(out.len(), 2 * 4); assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn n_channels_0_error() {
let bad = VitPatchConfig {
image_size: 16,
patch_size: 4,
n_channels: 0,
d_model: 8,
};
let mut rng = LcgRng::new(11);
let r = VitPatchEmbed::new(bad, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
}
#[test]
fn forward_wrong_image_len_error() {
let mut rng = LcgRng::new(12);
let pe = VitPatchEmbed::new(cfg(), &mut rng).expect("ok");
let image = vec![0.5_f32; 10]; let r = pe.forward(&image);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn patch_flatten_matches_manual() {
let c = VitPatchConfig {
image_size: 4,
patch_size: 2,
n_channels: 1,
d_model: 1,
};
let mut rng = LcgRng::new(13);
let mut pe = VitPatchEmbed::new(c, &mut rng).expect("ok");
pe.proj_w = vec![1.0, 0.0, 0.0, 0.0];
pe.proj_b = vec![0.0];
pe.cls_token = vec![0.0];
for p in &mut pe.pos_emb {
*p = 0.0;
}
let mut image = vec![0.0_f32; 16];
image[0] = 7.0; let out = pe.forward(&image).expect("ok");
assert!((out[1] - 7.0).abs() < 1e-6, "got {}", out[1]);
}
}