use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
};
#[derive(Debug, Clone, PartialEq)]
pub struct PatchEmbedConfig {
pub img_size: usize,
pub patch_size: usize,
pub in_chans: usize,
pub embed_dim: usize,
}
impl PatchEmbedConfig {
pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> VisionResult<Self> {
if patch_size == 0 || img_size % patch_size != 0 {
return Err(VisionError::InvalidPatchSize {
patch_size,
img_size,
});
}
if embed_dim == 0 {
return Err(VisionError::InvalidEmbedDim(embed_dim));
}
if img_size == 0 || in_chans == 0 {
return Err(VisionError::InvalidImageSize {
height: img_size,
width: img_size,
channels: in_chans,
});
}
Ok(Self {
img_size,
patch_size,
in_chans,
embed_dim,
})
}
#[must_use]
pub fn grid_size(&self) -> usize {
self.img_size / self.patch_size
}
#[must_use]
pub fn n_patches(&self) -> usize {
self.grid_size() * self.grid_size()
}
#[must_use]
pub fn kernel_vol(&self) -> usize {
self.in_chans * self.patch_size * self.patch_size
}
}
pub struct PatchEmbedWeights {
pub kernel: Vec<f32>,
pub bias: Vec<f32>,
pub cls_token: Vec<f32>,
}
impl PatchEmbedWeights {
pub fn default_init(cfg: &PatchEmbedConfig, rng: &mut LcgRng) -> Self {
let kv = cfg.kernel_vol();
let scale = 1.0 / (kv as f32).sqrt();
let n_kernel = cfg.embed_dim * kv;
let mut kernel = vec![0.0f32; n_kernel];
rng.fill_normal(&mut kernel);
for v in &mut kernel {
*v *= scale;
}
let mut bias = vec![0.0f32; cfg.embed_dim];
rng.fill_normal(&mut bias);
for v in &mut bias {
*v *= 0.01;
}
let mut cls_token = vec![0.0f32; cfg.embed_dim];
rng.fill_normal(&mut cls_token);
for v in &mut cls_token {
*v *= 0.02;
}
Self {
kernel,
bias,
cls_token,
}
}
}
pub struct PatchEmbed {
pub config: PatchEmbedConfig,
pub weights: PatchEmbedWeights,
}
impl PatchEmbed {
pub fn new(cfg: PatchEmbedConfig, rng: &mut LcgRng) -> Self {
let weights = PatchEmbedWeights::default_init(&cfg, rng);
Self {
config: cfg,
weights,
}
}
pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
let cfg = &self.config;
let expected = cfg.in_chans * cfg.img_size * cfg.img_size;
if image.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: image.len(),
});
}
let n_patches = cfg.n_patches();
let grid = cfg.grid_size();
let p = cfg.patch_size;
let c = cfg.in_chans;
let e = cfg.embed_dim;
let kv = cfg.kernel_vol();
let mut out = vec![0.0f32; n_patches * e];
for ph in 0..grid {
for pw in 0..grid {
let patch_idx = ph * grid + pw;
for ed in 0..e {
let mut acc = self.weights.bias[ed];
let k_off = ed * kv;
for ci in 0..c {
for pi in 0..p {
for pj in 0..p {
let k_idx = k_off + ci * p * p + pi * p + pj;
let img_row = ph * p + pi;
let img_col = pw * p + pj;
let img_idx = ci * cfg.img_size * cfg.img_size
+ img_row * cfg.img_size
+ img_col;
acc += self.weights.kernel[k_idx] * image[img_idx];
}
}
}
out[patch_idx * e + ed] = acc;
}
}
}
Ok(out)
}
}
pub fn prepend_cls(tokens: &[f32], cls: &[f32], embed_dim: usize) -> VisionResult<Vec<f32>> {
let n_tok = tokens.len() / embed_dim;
if tokens.len() != n_tok * embed_dim {
return Err(VisionError::DimensionMismatch {
expected: n_tok * embed_dim,
got: tokens.len(),
});
}
if cls.len() != embed_dim {
return Err(VisionError::DimensionMismatch {
expected: embed_dim,
got: cls.len(),
});
}
let mut out = Vec::with_capacity((n_tok + 1) * embed_dim);
out.extend_from_slice(cls);
out.extend_from_slice(tokens);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_cfg() -> PatchEmbedConfig {
PatchEmbedConfig::new(16, 4, 3, 8).expect("valid config")
}
#[test]
fn config_valid() {
let cfg = make_cfg();
assert_eq!(cfg.n_patches(), 16); assert_eq!(cfg.grid_size(), 4);
assert_eq!(cfg.kernel_vol(), 3 * 4 * 4); }
#[test]
fn config_invalid_patch_size_not_dividing() {
let r = PatchEmbedConfig::new(16, 5, 3, 8);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn config_invalid_patch_size_zero() {
let r = PatchEmbedConfig::new(16, 0, 3, 8);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn config_invalid_embed_dim_zero() {
let r = PatchEmbedConfig::new(16, 4, 3, 0);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn forward_output_shape() {
let cfg = make_cfg(); let mut rng = LcgRng::new(1);
let pe = PatchEmbed::new(cfg.clone(), &mut rng);
let image = vec![0.5f32; 3 * 16 * 16];
let out = pe.forward(&image).expect("forward ok");
assert_eq!(out.len(), cfg.n_patches() * cfg.embed_dim);
}
#[test]
fn forward_wrong_image_size_errors() {
let cfg = make_cfg();
let mut rng = LcgRng::new(2);
let pe = PatchEmbed::new(cfg, &mut rng);
let image = vec![0.5f32; 10]; let r = pe.forward(&image);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn forward_zero_image_is_bias() {
let cfg = make_cfg();
let mut rng = LcgRng::new(3);
let pe = PatchEmbed::new(cfg.clone(), &mut rng);
let image = vec![0.0f32; 3 * 16 * 16];
let out = pe.forward(&image).expect("forward ok");
let diff = (out[0] - pe.weights.bias[0]).abs();
assert!(
diff < 1e-6,
"expected bias={}, got {}",
pe.weights.bias[0],
out[0]
);
}
#[test]
fn forward_finite_random_input() {
let cfg = PatchEmbedConfig::new(32, 4, 3, 64).expect("valid");
let mut rng = LcgRng::new(7);
let pe = PatchEmbed::new(cfg.clone(), &mut rng);
let mut image = vec![0.0f32; 3 * 32 * 32];
rng.fill_normal(&mut image);
let out = pe.forward(&image).expect("forward ok");
assert!(
out.iter().all(|v| v.is_finite()),
"output contains non-finite"
);
}
#[test]
fn prepend_cls_shape() {
let tokens = vec![1.0f32; 16 * 8]; let cls = vec![0.0f32; 8];
let out = prepend_cls(&tokens, &cls, 8).expect("ok");
assert_eq!(out.len(), 17 * 8);
assert!(out[..8].iter().all(|&v| v == 0.0));
assert_eq!(out[8..16], tokens[..8]);
}
#[test]
fn prepend_cls_wrong_cls_dim_errors() {
let tokens = vec![1.0f32; 16 * 8];
let cls = vec![0.0f32; 4]; let r = prepend_cls(&tokens, &cls, 8);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn weights_default_init_correct_size() {
let cfg = make_cfg();
let mut rng = LcgRng::new(42);
let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
assert_eq!(w.kernel.len(), cfg.embed_dim * cfg.kernel_vol());
assert_eq!(w.bias.len(), cfg.embed_dim);
assert_eq!(w.cls_token.len(), cfg.embed_dim);
}
#[test]
fn weights_default_init_finite() {
let cfg = make_cfg();
let mut rng = LcgRng::new(99);
let w = PatchEmbedWeights::default_init(&cfg, &mut rng);
assert!(w.kernel.iter().all(|v| v.is_finite()));
assert!(w.bias.iter().all(|v| v.is_finite()));
assert!(w.cls_token.iter().all(|v| v.is_finite()));
}
#[test]
fn patch_embed_different_seeds_differ() {
let cfg = make_cfg();
let image = vec![0.5f32; 3 * 16 * 16];
let mut rng1 = LcgRng::new(1);
let mut rng2 = LcgRng::new(2);
let pe1 = PatchEmbed::new(cfg.clone(), &mut rng1);
let pe2 = PatchEmbed::new(cfg, &mut rng2);
let out1 = pe1.forward(&image).expect("ok");
let out2 = pe2.forward(&image).expect("ok");
assert!(
out1.iter()
.zip(out2.iter())
.any(|(a, b)| (a - b).abs() > 1e-6)
);
}
}