use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
vit::vit_block::{ViTBlock, ViTBlockConfig, layer_norm},
};
#[derive(Debug, Clone, PartialEq)]
pub struct ViTEncoderConfig {
pub block_cfg: ViTBlockConfig,
pub depth: usize,
}
impl ViTEncoderConfig {
pub fn new(
embed_dim: usize,
n_heads: usize,
mlp_ratio: usize,
depth: usize,
) -> VisionResult<Self> {
if depth == 0 {
return Err(VisionError::Internal("encoder depth must be > 0".into()));
}
let block_cfg = ViTBlockConfig::new(embed_dim, n_heads, mlp_ratio)?;
Ok(Self { block_cfg, depth })
}
}
pub struct ViTEncoder {
pub blocks: Vec<ViTBlock>,
pub final_ln_weight: Vec<f32>,
pub final_ln_bias: Vec<f32>,
}
impl ViTEncoder {
pub fn new(cfg: ViTEncoderConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let e = cfg.block_cfg.embed_dim;
let mut blocks = Vec::with_capacity(cfg.depth);
for _ in 0..cfg.depth {
blocks.push(ViTBlock::new(cfg.block_cfg.clone(), rng));
}
let final_ln_weight = vec![1.0f32; e];
let final_ln_bias = vec![0.0f32; e];
Ok(Self {
blocks,
final_ln_weight,
final_ln_bias,
})
}
pub fn forward(&self, tokens: &[f32], n_tokens: usize) -> VisionResult<Vec<f32>> {
let e = self
.blocks
.first()
.map(|b| b.config.embed_dim)
.ok_or_else(|| VisionError::Internal("encoder has no blocks".into()))?;
if tokens.len() != n_tokens * e {
return Err(VisionError::DimensionMismatch {
expected: n_tokens * e,
got: tokens.len(),
});
}
if n_tokens == 0 {
return Err(VisionError::EmptyInput("tokens"));
}
let mut h: Vec<f32> = tokens.to_vec();
for block in &self.blocks {
h = block.forward(&h, n_tokens)?;
}
let out = layer_norm(
&h,
&self.final_ln_weight,
&self.final_ln_bias,
n_tokens,
e,
1e-5,
);
Ok(out)
}
#[must_use]
pub fn embed_dim(&self) -> usize {
self.blocks.first().map_or(0, |b| b.config.embed_dim)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_enc(depth: usize) -> ViTEncoder {
let cfg = ViTEncoderConfig::new(64, 4, 4, depth).expect("valid encoder config");
let mut rng = LcgRng::new(42);
ViTEncoder::new(cfg, &mut rng).expect("encoder created")
}
#[test]
fn config_valid() {
let cfg = ViTEncoderConfig::new(64, 4, 4, 3).expect("valid");
assert_eq!(cfg.depth, 3);
assert_eq!(cfg.block_cfg.embed_dim, 64);
}
#[test]
fn config_depth_zero_errors() {
let r = ViTEncoderConfig::new(64, 4, 4, 0);
assert!(matches!(r, Err(VisionError::Internal(_))));
}
#[test]
fn config_propagates_block_error() {
let r = ViTEncoderConfig::new(65, 4, 4, 2);
assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
}
#[test]
fn depth1_output_shape() {
let enc = make_enc(1);
let e = enc.embed_dim();
let n_tokens = 17;
let tokens = vec![0.1f32; n_tokens * e];
let out = enc.forward(&tokens, n_tokens).expect("forward ok");
assert_eq!(out.len(), n_tokens * e);
}
#[test]
fn depth2_output_shape() {
let enc = make_enc(2);
let e = enc.embed_dim();
let n_tokens = 17;
let tokens = vec![0.1f32; n_tokens * e];
let out = enc.forward(&tokens, n_tokens).expect("forward ok");
assert_eq!(out.len(), n_tokens * e);
}
#[test]
fn depth4_output_shape() {
let enc = make_enc(4);
let e = enc.embed_dim();
let n_tokens = 9;
let mut rng = LcgRng::new(11);
let mut tokens = vec![0.0f32; n_tokens * e];
rng.fill_normal(&mut tokens);
let out = enc.forward(&tokens, n_tokens).expect("forward ok");
assert_eq!(out.len(), n_tokens * e);
}
#[test]
fn output_finite_random_input() {
let enc = make_enc(2);
let e = enc.embed_dim();
let n_tokens = 17;
let mut rng = LcgRng::new(7);
let mut tokens = vec![0.0f32; n_tokens * e];
rng.fill_normal(&mut tokens);
let out = enc.forward(&tokens, n_tokens).expect("forward ok");
assert!(
out.iter().all(|v| v.is_finite()),
"non-finite encoder output"
);
}
#[test]
fn final_ln_weight_bias_correct_size() {
let enc = make_enc(1);
assert_eq!(enc.final_ln_weight.len(), enc.embed_dim());
assert_eq!(enc.final_ln_bias.len(), enc.embed_dim());
}
#[test]
fn final_ln_weight_initialised_one() {
let enc = make_enc(1);
assert!(enc.final_ln_weight.iter().all(|&v| (v - 1.0).abs() < 1e-9));
}
#[test]
fn dimension_mismatch_errors() {
let enc = make_enc(1);
let e = enc.embed_dim();
let tokens = vec![0.0f32; 3 * e];
let r = enc.forward(&tokens, 5);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn empty_tokens_errors() {
let enc = make_enc(1);
let tokens: Vec<f32> = vec![];
let r = enc.forward(&tokens, 0);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
#[test]
fn correct_number_of_blocks() {
for d in [1, 2, 4, 6, 12] {
let enc = make_enc(d);
assert_eq!(enc.blocks.len(), d, "wrong block count for depth={d}");
}
}
}