use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
vit::{
vit_block::linear,
vit_encoder::{ViTEncoder, ViTEncoderConfig},
},
};
#[derive(Debug, Clone, PartialEq)]
pub struct ViTConfig {
pub img_size: usize,
pub patch_size: usize,
pub in_chans: usize,
pub embed_dim: usize,
pub depth: usize,
pub n_heads: usize,
pub mlp_ratio: usize,
pub n_classes: usize,
}
impl ViTConfig {
#[must_use]
pub fn tiny() -> Self {
Self {
img_size: 32,
patch_size: 4,
in_chans: 3,
embed_dim: 64,
depth: 2,
n_heads: 4,
mlp_ratio: 4,
n_classes: 10,
}
}
pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
n_heads: usize,
mlp_ratio: usize,
n_classes: usize,
) -> VisionResult<Self> {
if n_classes == 0 {
return Err(VisionError::InvalidNumClasses(n_classes));
}
if depth == 0 {
return Err(VisionError::Internal("depth must be > 0".into()));
}
PatchEmbedConfig::new(img_size, patch_size, in_chans, embed_dim)?;
ViTEncoderConfig::new(embed_dim, n_heads, mlp_ratio, depth)?;
Ok(Self {
img_size,
patch_size,
in_chans,
embed_dim,
depth,
n_heads,
mlp_ratio,
n_classes,
})
}
#[must_use]
pub fn n_patches(&self) -> usize {
let grid = self.img_size / self.patch_size;
grid * grid
}
#[must_use]
pub fn seq_len(&self) -> usize {
self.n_patches() + 1
}
}
pub struct ViTModelWeights {
pub head_weight: Vec<f32>,
pub head_bias: Vec<f32>,
}
impl ViTModelWeights {
fn default_init(cfg: &ViTConfig, rng: &mut LcgRng) -> Self {
let scale = 1.0 / (cfg.embed_dim as f32).sqrt();
let mut head_weight = vec![0.0f32; cfg.n_classes * cfg.embed_dim];
rng.fill_normal(&mut head_weight);
for v in &mut head_weight {
*v *= scale;
}
let head_bias = vec![0.0f32; cfg.n_classes];
Self {
head_weight,
head_bias,
}
}
}
pub struct ViTModel {
pub config: ViTConfig,
pub patch_embed: PatchEmbed,
pub pos_embed: LearnablePosEmbed,
pub encoder: ViTEncoder,
pub weights: ViTModelWeights,
}
impl ViTModel {
pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let patch_cfg =
PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)?;
let patch_embed = PatchEmbed::new(patch_cfg, rng);
let seq_len = cfg.seq_len();
let pos_embed = LearnablePosEmbed::new(seq_len, cfg.embed_dim, rng)?;
let enc_cfg = ViTEncoderConfig::new(cfg.embed_dim, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
let encoder = ViTEncoder::new(enc_cfg, rng)?;
let weights = ViTModelWeights::default_init(&cfg, rng);
Ok(Self {
config: cfg,
patch_embed,
pos_embed,
encoder,
weights,
})
}
pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
let cfg = &self.config;
let expected_img = cfg.in_chans * cfg.img_size * cfg.img_size;
if image.len() != expected_img {
return Err(VisionError::DimensionMismatch {
expected: expected_img,
got: image.len(),
});
}
let patch_tokens = self.patch_embed.forward(image)?;
let cls_token = &self.patch_embed.weights.cls_token;
let mut tokens = prepend_cls(&patch_tokens, cls_token, cfg.embed_dim)?;
add_pos_embed(&mut tokens, &self.pos_embed.table, cfg.embed_dim)?;
let seq_len = cfg.seq_len();
let encoded = self.encoder.forward(&tokens, seq_len)?;
let cls_repr = &encoded[..cfg.embed_dim];
let logits = linear(
cls_repr,
&self.weights.head_weight,
&self.weights.head_bias,
cfg.embed_dim,
cfg.n_classes,
);
Ok(logits)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tiny_model() -> ViTModel {
let cfg = ViTConfig::tiny();
let mut rng = LcgRng::new(42);
ViTModel::new(cfg, &mut rng).expect("tiny model created")
}
#[test]
fn tiny_config_values() {
let cfg = ViTConfig::tiny();
assert_eq!(cfg.img_size, 32);
assert_eq!(cfg.patch_size, 4);
assert_eq!(cfg.in_chans, 3);
assert_eq!(cfg.embed_dim, 64);
assert_eq!(cfg.depth, 2);
assert_eq!(cfg.n_heads, 4);
assert_eq!(cfg.mlp_ratio, 4);
assert_eq!(cfg.n_classes, 10);
}
#[test]
fn tiny_config_n_patches() {
let cfg = ViTConfig::tiny();
assert_eq!(cfg.n_patches(), 64);
assert_eq!(cfg.seq_len(), 65);
}
#[test]
fn config_zero_classes_errors() {
let r = ViTConfig::new(32, 4, 3, 64, 2, 4, 4, 0);
assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
}
#[test]
fn config_invalid_patch_size_errors() {
let r = ViTConfig::new(32, 5, 3, 64, 2, 4, 4, 10); assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn config_head_dim_mismatch_errors() {
let r = ViTConfig::new(32, 4, 3, 63, 2, 4, 4, 10); assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
}
#[test]
fn forward_returns_ten_logits() {
let model = make_tiny_model();
let image = vec![0.0f32; 3 * 32 * 32];
let logits = model.forward(&image).expect("forward ok");
assert_eq!(logits.len(), 10, "expected 10 logits, got {}", logits.len());
}
#[test]
fn forward_logits_finite() {
let model = make_tiny_model();
let mut rng = LcgRng::new(7);
let mut image = vec![0.0f32; 3 * 32 * 32];
rng.fill_normal(&mut image);
let logits = model.forward(&image).expect("forward ok");
assert!(
logits.iter().all(|v| v.is_finite()),
"non-finite logits: {logits:?}"
);
}
#[test]
fn forward_random_input_not_constant_logits() {
let model = make_tiny_model();
let mut rng = LcgRng::new(13);
let mut img1 = vec![0.0f32; 3 * 32 * 32];
let mut img2 = vec![0.0f32; 3 * 32 * 32];
rng.fill_normal(&mut img1);
rng.fill_normal(&mut img2);
let l1 = model.forward(&img1).expect("ok");
let l2 = model.forward(&img2).expect("ok");
let diff: f32 = l1.iter().zip(l2.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(
diff > 1e-6,
"logits did not change between different images (diff={diff})"
);
}
#[test]
fn forward_wrong_image_size_errors() {
let model = make_tiny_model();
let image = vec![0.0f32; 3 * 32 * 31]; let r = model.forward(&image);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn forward_correct_image_size_passes() {
let model = make_tiny_model();
let image = vec![0.5f32; 3 * 32 * 32];
let logits = model
.forward(&image)
.expect("forward ok with constant image");
assert_eq!(logits.len(), 10);
}
#[test]
fn pos_embed_has_correct_positions() {
let model = make_tiny_model();
assert_eq!(model.pos_embed.n_positions, 65);
assert_eq!(model.pos_embed.embed_dim, 64);
}
#[test]
fn encoder_has_correct_depth() {
let model = make_tiny_model();
assert_eq!(model.encoder.blocks.len(), 2);
}
#[test]
fn head_weights_correct_size() {
let model = make_tiny_model();
assert_eq!(model.weights.head_weight.len(), 10 * 64);
assert_eq!(model.weights.head_bias.len(), 10);
}
#[test]
fn different_seeds_produce_different_outputs() {
let cfg = ViTConfig::tiny();
let mut rng_a = LcgRng::new(1);
let mut rng_b = LcgRng::new(2);
let model_a = ViTModel::new(cfg.clone(), &mut rng_a).expect("ok");
let model_b = ViTModel::new(cfg, &mut rng_b).expect("ok");
let image = vec![0.5f32; 3 * 32 * 32];
let la = model_a.forward(&image).expect("ok");
let lb = model_b.forward(&image).expect("ok");
let diff: f32 = la.iter().zip(lb.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(
diff > 1e-6,
"different seeds should yield different logits (diff={diff})"
);
}
}