use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
};
#[derive(Debug, Clone)]
pub struct ClipVisionConfig {
pub vit_config: ViTConfig,
}
impl ClipVisionConfig {
#[must_use]
pub fn new(vit_config: ViTConfig) -> Self {
Self { vit_config }
}
#[must_use]
pub fn tiny() -> Self {
Self::new(ViTConfig::tiny())
}
}
pub struct ClipVisionEncoder {
pub config: ClipVisionConfig,
pub patch_embed: PatchEmbed,
pub pos_embed: LearnablePosEmbed,
pub encoder: ViTEncoder,
pub cls_token: Vec<f32>,
}
impl ClipVisionEncoder {
pub fn new(cfg: ClipVisionConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let vc = &cfg.vit_config;
let pe_cfg = PatchEmbedConfig::new(vc.img_size, vc.patch_size, vc.in_chans, vc.embed_dim)?;
let patch_embed = PatchEmbed::new(pe_cfg.clone(), rng);
let n_patches = pe_cfg.n_patches();
let n_positions = n_patches + 1;
let pos_embed = LearnablePosEmbed::new(n_positions, vc.embed_dim, rng)?;
let enc_cfg = ViTEncoderConfig::new(vc.embed_dim, vc.n_heads, vc.mlp_ratio, vc.depth)?;
let encoder = ViTEncoder::new(enc_cfg, rng)?;
let mut cls_token = vec![0.0f32; vc.embed_dim];
rng.fill_normal(&mut cls_token);
for v in &mut cls_token {
*v *= 0.02;
}
Ok(Self {
config: cfg,
patch_embed,
pos_embed,
encoder,
cls_token,
})
}
pub fn forward_single(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
let embed_dim = self.config.vit_config.embed_dim;
let patch_tokens = self.patch_embed.forward(image)?;
let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, embed_dim)?;
add_pos_embed(&mut tokens, &self.pos_embed.table, embed_dim)?;
let n_tokens = tokens.len() / embed_dim;
let encoded = self.encoder.forward(&tokens, n_tokens)?;
let cls_out = encoded[..embed_dim].to_vec();
Ok(cls_out)
}
pub fn forward_batch(&self, images: &[f32], batch_size: usize) -> VisionResult<Vec<Vec<f32>>> {
let vc = &self.config.vit_config;
let single_len = vc.in_chans * vc.img_size * vc.img_size;
if batch_size == 0 {
return Ok(Vec::new());
}
let expected = batch_size * single_len;
if images.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: images.len(),
});
}
let mut results = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let slice = &images[b * single_len..(b + 1) * single_len];
results.push(self.forward_single(slice)?);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_tiny_encoder(seed: u64) -> (ClipVisionEncoder, usize) {
let mut rng = LcgRng::new(seed);
let cfg = ClipVisionConfig::tiny();
let embed_dim = cfg.vit_config.embed_dim;
let encoder = ClipVisionEncoder::new(cfg, &mut rng).expect("tiny encoder ok");
(encoder, embed_dim)
}
fn make_image(in_chans: usize, img_size: usize) -> Vec<f32> {
let len = in_chans * img_size * img_size;
(0..len).map(|i| i as f32 / len as f32).collect()
}
#[test]
fn tiny_encoder_constructs() {
let (enc, _) = make_tiny_encoder(1);
let vc = &enc.config.vit_config;
assert_eq!(enc.cls_token.len(), vc.embed_dim);
let n_patches = (vc.img_size / vc.patch_size).pow(2);
assert_eq!(enc.pos_embed.n_positions, n_patches + 1);
}
#[test]
fn config_new_wraps_vit_config() {
let vit_cfg = ViTConfig::tiny();
let clip_cfg = ClipVisionConfig::new(vit_cfg.clone());
assert_eq!(clip_cfg.vit_config.embed_dim, vit_cfg.embed_dim);
}
#[test]
fn forward_single_output_shape() {
let (enc, embed_dim) = make_tiny_encoder(2);
let vc = &enc.config.vit_config;
let img = make_image(vc.in_chans, vc.img_size);
let z = enc.forward_single(&img).expect("forward_single ok");
assert_eq!(
z.len(),
embed_dim,
"forward_single output should be embed_dim"
);
}
#[test]
fn forward_single_output_finite() {
let (enc, _) = make_tiny_encoder(3);
let vc = &enc.config.vit_config;
let img = make_image(vc.in_chans, vc.img_size);
let z = enc.forward_single(&img).expect("ok");
assert!(
z.iter().all(|v| v.is_finite()),
"forward_single output must be finite"
);
}
#[test]
fn forward_single_error_wrong_image_size() {
let (enc, _) = make_tiny_encoder(4);
let wrong_img = vec![0.0f32; 10]; let r = enc.forward_single(&wrong_img);
assert!(
matches!(r, Err(VisionError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
r
);
}
#[test]
fn forward_single_deterministic() {
let (enc, _) = make_tiny_encoder(5);
let vc = &enc.config.vit_config;
let img = make_image(vc.in_chans, vc.img_size);
let z1 = enc.forward_single(&img).expect("ok");
let z2 = enc.forward_single(&img).expect("ok");
assert_eq!(z1, z2, "forward_single should be deterministic");
}
#[test]
fn forward_batch_output_count() {
let (enc, _) = make_tiny_encoder(6);
let vc = &enc.config.vit_config;
let single_len = vc.in_chans * vc.img_size * vc.img_size;
let batch_size = 3_usize;
let images = make_image(vc.in_chans * batch_size, vc.img_size);
let mut flat = images.clone();
flat.resize(batch_size * single_len, 0.0);
let results = enc
.forward_batch(&flat, batch_size)
.expect("forward_batch ok");
assert_eq!(results.len(), batch_size, "batch result count mismatch");
}
#[test]
fn forward_batch_each_embedding_has_embed_dim() {
let (enc, embed_dim) = make_tiny_encoder(7);
let vc = &enc.config.vit_config;
let single_len = vc.in_chans * vc.img_size * vc.img_size;
let batch_size = 4_usize;
let flat = vec![0.5f32; batch_size * single_len];
let results = enc.forward_batch(&flat, batch_size).expect("ok");
for (i, z) in results.iter().enumerate() {
assert_eq!(z.len(), embed_dim, "embedding {i} has wrong size");
}
}
#[test]
fn forward_batch_zero_batch_returns_empty() {
let (enc, _) = make_tiny_encoder(8);
let results = enc.forward_batch(&[], 0).expect("zero batch ok");
assert!(results.is_empty(), "zero batch should return empty Vec");
}
#[test]
fn forward_batch_error_wrong_total_length() {
let (enc, _) = make_tiny_encoder(9);
let vc = &enc.config.vit_config;
let single_len = vc.in_chans * vc.img_size * vc.img_size;
let flat = vec![0.0f32; 2 * single_len - 1];
let r = enc.forward_batch(&flat, 2);
assert!(
matches!(r, Err(VisionError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
r
);
}
#[test]
fn forward_batch_matches_individual() {
let (enc, embed_dim) = make_tiny_encoder(10);
let vc = &enc.config.vit_config;
let single_len = vc.in_chans * vc.img_size * vc.img_size;
let batch_size = 2_usize;
let flat: Vec<f32> = (0..batch_size * single_len)
.map(|i| i as f32 / (batch_size * single_len) as f32)
.collect();
let batch_results = enc.forward_batch(&flat, batch_size).expect("batch ok");
for b in 0..batch_size {
let single = enc
.forward_single(&flat[b * single_len..(b + 1) * single_len])
.expect("single ok");
for d in 0..embed_dim {
assert!(
(batch_results[b][d] - single[d]).abs() < 1e-6,
"batch[{b}][{d}] = {} ≠ single[{d}] = {}",
batch_results[b][d],
single[d]
);
}
}
}
}