1use crate::{
8 error::{VisionError, VisionResult},
9 handle::LcgRng,
10 patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
11 vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
12};
13
14#[derive(Debug, Clone)]
21pub struct ClipVisionConfig {
22 pub vit_config: ViTConfig,
24}
25
26impl ClipVisionConfig {
27 #[must_use]
29 pub fn new(vit_config: ViTConfig) -> Self {
30 Self { vit_config }
31 }
32
33 #[must_use]
37 pub fn tiny() -> Self {
38 Self::new(ViTConfig::tiny())
39 }
40}
41
42pub struct ClipVisionEncoder {
57 pub config: ClipVisionConfig,
59 pub patch_embed: PatchEmbed,
61 pub pos_embed: LearnablePosEmbed,
63 pub encoder: ViTEncoder,
65 pub cls_token: Vec<f32>,
67}
68
69impl ClipVisionEncoder {
70 pub fn new(cfg: ClipVisionConfig, rng: &mut LcgRng) -> VisionResult<Self> {
81 let vc = &cfg.vit_config;
82
83 let pe_cfg = PatchEmbedConfig::new(vc.img_size, vc.patch_size, vc.in_chans, vc.embed_dim)?;
85 let patch_embed = PatchEmbed::new(pe_cfg.clone(), rng);
86
87 let n_patches = pe_cfg.n_patches();
89 let n_positions = n_patches + 1;
90 let pos_embed = LearnablePosEmbed::new(n_positions, vc.embed_dim, rng)?;
91
92 let enc_cfg = ViTEncoderConfig::new(vc.embed_dim, vc.n_heads, vc.mlp_ratio, vc.depth)?;
94 let encoder = ViTEncoder::new(enc_cfg, rng)?;
95
96 let mut cls_token = vec![0.0f32; vc.embed_dim];
98 rng.fill_normal(&mut cls_token);
99 for v in &mut cls_token {
100 *v *= 0.02;
101 }
102
103 Ok(Self {
104 config: cfg,
105 patch_embed,
106 pos_embed,
107 encoder,
108 cls_token,
109 })
110 }
111
112 pub fn forward_single(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
124 let embed_dim = self.config.vit_config.embed_dim;
125
126 let patch_tokens = self.patch_embed.forward(image)?;
128
129 let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, embed_dim)?;
131
132 add_pos_embed(&mut tokens, &self.pos_embed.table, embed_dim)?;
134
135 let n_tokens = tokens.len() / embed_dim;
137 let encoded = self.encoder.forward(&tokens, n_tokens)?;
138
139 let cls_out = encoded[..embed_dim].to_vec();
141
142 Ok(cls_out)
143 }
144
145 pub fn forward_batch(&self, images: &[f32], batch_size: usize) -> VisionResult<Vec<Vec<f32>>> {
159 let vc = &self.config.vit_config;
160 let single_len = vc.in_chans * vc.img_size * vc.img_size;
161
162 if batch_size == 0 {
163 return Ok(Vec::new());
164 }
165
166 let expected = batch_size * single_len;
167 if images.len() != expected {
168 return Err(VisionError::DimensionMismatch {
169 expected,
170 got: images.len(),
171 });
172 }
173
174 let mut results = Vec::with_capacity(batch_size);
175 for b in 0..batch_size {
176 let slice = &images[b * single_len..(b + 1) * single_len];
177 results.push(self.forward_single(slice)?);
178 }
179
180 Ok(results)
181 }
182}
183
184#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::handle::LcgRng;
190
191 fn make_tiny_encoder(seed: u64) -> (ClipVisionEncoder, usize) {
193 let mut rng = LcgRng::new(seed);
194 let cfg = ClipVisionConfig::tiny();
195 let embed_dim = cfg.vit_config.embed_dim;
196 let encoder = ClipVisionEncoder::new(cfg, &mut rng).expect("tiny encoder ok");
197 (encoder, embed_dim)
198 }
199
200 fn make_image(in_chans: usize, img_size: usize) -> Vec<f32> {
202 let len = in_chans * img_size * img_size;
203 (0..len).map(|i| i as f32 / len as f32).collect()
204 }
205
206 #[test]
209 fn tiny_encoder_constructs() {
210 let (enc, _) = make_tiny_encoder(1);
211 let vc = &enc.config.vit_config;
213 assert_eq!(enc.cls_token.len(), vc.embed_dim);
214 let n_patches = (vc.img_size / vc.patch_size).pow(2);
216 assert_eq!(enc.pos_embed.n_positions, n_patches + 1);
217 }
218
219 #[test]
220 fn config_new_wraps_vit_config() {
221 let vit_cfg = ViTConfig::tiny();
222 let clip_cfg = ClipVisionConfig::new(vit_cfg.clone());
223 assert_eq!(clip_cfg.vit_config.embed_dim, vit_cfg.embed_dim);
224 }
225
226 #[test]
229 fn forward_single_output_shape() {
230 let (enc, embed_dim) = make_tiny_encoder(2);
231 let vc = &enc.config.vit_config;
232 let img = make_image(vc.in_chans, vc.img_size);
233 let z = enc.forward_single(&img).expect("forward_single ok");
234 assert_eq!(
235 z.len(),
236 embed_dim,
237 "forward_single output should be embed_dim"
238 );
239 }
240
241 #[test]
242 fn forward_single_output_finite() {
243 let (enc, _) = make_tiny_encoder(3);
244 let vc = &enc.config.vit_config;
245 let img = make_image(vc.in_chans, vc.img_size);
246 let z = enc.forward_single(&img).expect("ok");
247 assert!(
248 z.iter().all(|v| v.is_finite()),
249 "forward_single output must be finite"
250 );
251 }
252
253 #[test]
254 fn forward_single_error_wrong_image_size() {
255 let (enc, _) = make_tiny_encoder(4);
256 let wrong_img = vec![0.0f32; 10]; let r = enc.forward_single(&wrong_img);
258 assert!(
259 matches!(r, Err(VisionError::DimensionMismatch { .. })),
260 "expected DimensionMismatch, got {:?}",
261 r
262 );
263 }
264
265 #[test]
266 fn forward_single_deterministic() {
267 let (enc, _) = make_tiny_encoder(5);
269 let vc = &enc.config.vit_config;
270 let img = make_image(vc.in_chans, vc.img_size);
271 let z1 = enc.forward_single(&img).expect("ok");
272 let z2 = enc.forward_single(&img).expect("ok");
273 assert_eq!(z1, z2, "forward_single should be deterministic");
274 }
275
276 #[test]
279 fn forward_batch_output_count() {
280 let (enc, _) = make_tiny_encoder(6);
281 let vc = &enc.config.vit_config;
282 let single_len = vc.in_chans * vc.img_size * vc.img_size;
283 let batch_size = 3_usize;
284 let images = make_image(vc.in_chans * batch_size, vc.img_size);
285 let mut flat = images.clone();
287 flat.resize(batch_size * single_len, 0.0);
288 let results = enc
289 .forward_batch(&flat, batch_size)
290 .expect("forward_batch ok");
291 assert_eq!(results.len(), batch_size, "batch result count mismatch");
292 }
293
294 #[test]
295 fn forward_batch_each_embedding_has_embed_dim() {
296 let (enc, embed_dim) = make_tiny_encoder(7);
297 let vc = &enc.config.vit_config;
298 let single_len = vc.in_chans * vc.img_size * vc.img_size;
299 let batch_size = 4_usize;
300 let flat = vec![0.5f32; batch_size * single_len];
301 let results = enc.forward_batch(&flat, batch_size).expect("ok");
302 for (i, z) in results.iter().enumerate() {
303 assert_eq!(z.len(), embed_dim, "embedding {i} has wrong size");
304 }
305 }
306
307 #[test]
308 fn forward_batch_zero_batch_returns_empty() {
309 let (enc, _) = make_tiny_encoder(8);
310 let results = enc.forward_batch(&[], 0).expect("zero batch ok");
311 assert!(results.is_empty(), "zero batch should return empty Vec");
312 }
313
314 #[test]
315 fn forward_batch_error_wrong_total_length() {
316 let (enc, _) = make_tiny_encoder(9);
317 let vc = &enc.config.vit_config;
318 let single_len = vc.in_chans * vc.img_size * vc.img_size;
319 let flat = vec![0.0f32; 2 * single_len - 1];
321 let r = enc.forward_batch(&flat, 2);
322 assert!(
323 matches!(r, Err(VisionError::DimensionMismatch { .. })),
324 "expected DimensionMismatch, got {:?}",
325 r
326 );
327 }
328
329 #[test]
330 fn forward_batch_matches_individual() {
331 let (enc, embed_dim) = make_tiny_encoder(10);
333 let vc = &enc.config.vit_config;
334 let single_len = vc.in_chans * vc.img_size * vc.img_size;
335 let batch_size = 2_usize;
336 let flat: Vec<f32> = (0..batch_size * single_len)
337 .map(|i| i as f32 / (batch_size * single_len) as f32)
338 .collect();
339
340 let batch_results = enc.forward_batch(&flat, batch_size).expect("batch ok");
341
342 for b in 0..batch_size {
343 let single = enc
344 .forward_single(&flat[b * single_len..(b + 1) * single_len])
345 .expect("single ok");
346 for d in 0..embed_dim {
347 assert!(
348 (batch_results[b][d] - single[d]).abs() < 1e-6,
349 "batch[{b}][{d}] = {} ≠ single[{d}] = {}",
350 batch_results[b][d],
351 single[d]
352 );
353 }
354 }
355 }
356}