1use crate::{
15 error::{VisionError, VisionResult},
16 handle::LcgRng,
17 patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
18 vit::{
19 vit_block::linear,
20 vit_encoder::{ViTEncoder, ViTEncoderConfig},
21 },
22};
23
24#[derive(Debug, Clone, PartialEq)]
28pub struct ViTConfig {
29 pub img_size: usize,
31 pub patch_size: usize,
33 pub in_chans: usize,
35 pub embed_dim: usize,
37 pub depth: usize,
39 pub n_heads: usize,
41 pub mlp_ratio: usize,
43 pub n_classes: usize,
45}
46
47impl ViTConfig {
48 #[must_use]
53 pub fn tiny() -> Self {
54 Self {
55 img_size: 32,
56 patch_size: 4,
57 in_chans: 3,
58 embed_dim: 64,
59 depth: 2,
60 n_heads: 4,
61 mlp_ratio: 4,
62 n_classes: 10,
63 }
64 }
65
66 pub fn new(
73 img_size: usize,
74 patch_size: usize,
75 in_chans: usize,
76 embed_dim: usize,
77 depth: usize,
78 n_heads: usize,
79 mlp_ratio: usize,
80 n_classes: usize,
81 ) -> VisionResult<Self> {
82 if n_classes == 0 {
83 return Err(VisionError::InvalidNumClasses(n_classes));
84 }
85 if depth == 0 {
86 return Err(VisionError::Internal("depth must be > 0".into()));
87 }
88 PatchEmbedConfig::new(img_size, patch_size, in_chans, embed_dim)?;
90 ViTEncoderConfig::new(embed_dim, n_heads, mlp_ratio, depth)?;
91
92 Ok(Self {
93 img_size,
94 patch_size,
95 in_chans,
96 embed_dim,
97 depth,
98 n_heads,
99 mlp_ratio,
100 n_classes,
101 })
102 }
103
104 #[must_use]
106 pub fn n_patches(&self) -> usize {
107 let grid = self.img_size / self.patch_size;
108 grid * grid
109 }
110
111 #[must_use]
113 pub fn seq_len(&self) -> usize {
114 self.n_patches() + 1
115 }
116}
117
118pub struct ViTModelWeights {
124 pub head_weight: Vec<f32>,
126 pub head_bias: Vec<f32>,
128}
129
130impl ViTModelWeights {
131 fn default_init(cfg: &ViTConfig, rng: &mut LcgRng) -> Self {
132 let scale = 1.0 / (cfg.embed_dim as f32).sqrt();
133 let mut head_weight = vec![0.0f32; cfg.n_classes * cfg.embed_dim];
134 rng.fill_normal(&mut head_weight);
135 for v in &mut head_weight {
136 *v *= scale;
137 }
138 let head_bias = vec![0.0f32; cfg.n_classes];
139 Self {
140 head_weight,
141 head_bias,
142 }
143 }
144}
145
146pub struct ViTModel {
150 pub config: ViTConfig,
152 pub patch_embed: PatchEmbed,
154 pub pos_embed: LearnablePosEmbed,
156 pub encoder: ViTEncoder,
158 pub weights: ViTModelWeights,
160}
161
162impl ViTModel {
163 pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
167 let patch_cfg =
168 PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)?;
169 let patch_embed = PatchEmbed::new(patch_cfg, rng);
170
171 let seq_len = cfg.seq_len();
173 let pos_embed = LearnablePosEmbed::new(seq_len, cfg.embed_dim, rng)?;
174
175 let enc_cfg = ViTEncoderConfig::new(cfg.embed_dim, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
176 let encoder = ViTEncoder::new(enc_cfg, rng)?;
177
178 let weights = ViTModelWeights::default_init(&cfg, rng);
179
180 Ok(Self {
181 config: cfg,
182 patch_embed,
183 pos_embed,
184 encoder,
185 weights,
186 })
187 }
188
189 pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
198 let cfg = &self.config;
199 let expected_img = cfg.in_chans * cfg.img_size * cfg.img_size;
200 if image.len() != expected_img {
201 return Err(VisionError::DimensionMismatch {
202 expected: expected_img,
203 got: image.len(),
204 });
205 }
206
207 let patch_tokens = self.patch_embed.forward(image)?;
209
210 let cls_token = &self.patch_embed.weights.cls_token;
212 let mut tokens = prepend_cls(&patch_tokens, cls_token, cfg.embed_dim)?;
213
214 add_pos_embed(&mut tokens, &self.pos_embed.table, cfg.embed_dim)?;
216
217 let seq_len = cfg.seq_len();
219 let encoded = self.encoder.forward(&tokens, seq_len)?;
220
221 let cls_repr = &encoded[..cfg.embed_dim];
223
224 let logits = linear(
226 cls_repr,
227 &self.weights.head_weight,
228 &self.weights.head_bias,
229 cfg.embed_dim,
230 cfg.n_classes,
231 );
232
233 Ok(logits)
234 }
235}
236
237#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn make_tiny_model() -> ViTModel {
244 let cfg = ViTConfig::tiny();
245 let mut rng = LcgRng::new(42);
246 ViTModel::new(cfg, &mut rng).expect("tiny model created")
247 }
248
249 #[test]
252 fn tiny_config_values() {
253 let cfg = ViTConfig::tiny();
254 assert_eq!(cfg.img_size, 32);
255 assert_eq!(cfg.patch_size, 4);
256 assert_eq!(cfg.in_chans, 3);
257 assert_eq!(cfg.embed_dim, 64);
258 assert_eq!(cfg.depth, 2);
259 assert_eq!(cfg.n_heads, 4);
260 assert_eq!(cfg.mlp_ratio, 4);
261 assert_eq!(cfg.n_classes, 10);
262 }
263
264 #[test]
265 fn tiny_config_n_patches() {
266 let cfg = ViTConfig::tiny();
267 assert_eq!(cfg.n_patches(), 64);
269 assert_eq!(cfg.seq_len(), 65);
270 }
271
272 #[test]
273 fn config_zero_classes_errors() {
274 let r = ViTConfig::new(32, 4, 3, 64, 2, 4, 4, 0);
275 assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
276 }
277
278 #[test]
279 fn config_invalid_patch_size_errors() {
280 let r = ViTConfig::new(32, 5, 3, 64, 2, 4, 4, 10); assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
282 }
283
284 #[test]
285 fn config_head_dim_mismatch_errors() {
286 let r = ViTConfig::new(32, 4, 3, 63, 2, 4, 4, 10); assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
288 }
289
290 #[test]
293 fn forward_returns_ten_logits() {
294 let model = make_tiny_model();
295 let image = vec![0.0f32; 3 * 32 * 32];
296 let logits = model.forward(&image).expect("forward ok");
297 assert_eq!(logits.len(), 10, "expected 10 logits, got {}", logits.len());
298 }
299
300 #[test]
301 fn forward_logits_finite() {
302 let model = make_tiny_model();
303 let mut rng = LcgRng::new(7);
304 let mut image = vec![0.0f32; 3 * 32 * 32];
305 rng.fill_normal(&mut image);
306 let logits = model.forward(&image).expect("forward ok");
307 assert!(
308 logits.iter().all(|v| v.is_finite()),
309 "non-finite logits: {logits:?}"
310 );
311 }
312
313 #[test]
314 fn forward_random_input_not_constant_logits() {
315 let model = make_tiny_model();
317 let mut rng = LcgRng::new(13);
318 let mut img1 = vec![0.0f32; 3 * 32 * 32];
319 let mut img2 = vec![0.0f32; 3 * 32 * 32];
320 rng.fill_normal(&mut img1);
321 rng.fill_normal(&mut img2);
322 let l1 = model.forward(&img1).expect("ok");
323 let l2 = model.forward(&img2).expect("ok");
324 let diff: f32 = l1.iter().zip(l2.iter()).map(|(a, b)| (a - b).abs()).sum();
325 assert!(
326 diff > 1e-6,
327 "logits did not change between different images (diff={diff})"
328 );
329 }
330
331 #[test]
332 fn forward_wrong_image_size_errors() {
333 let model = make_tiny_model();
334 let image = vec![0.0f32; 3 * 32 * 31]; let r = model.forward(&image);
337 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
338 }
339
340 #[test]
341 fn forward_correct_image_size_passes() {
342 let model = make_tiny_model();
343 let image = vec![0.5f32; 3 * 32 * 32];
344 let logits = model
345 .forward(&image)
346 .expect("forward ok with constant image");
347 assert_eq!(logits.len(), 10);
348 }
349
350 #[test]
353 fn pos_embed_has_correct_positions() {
354 let model = make_tiny_model();
355 assert_eq!(model.pos_embed.n_positions, 65);
357 assert_eq!(model.pos_embed.embed_dim, 64);
358 }
359
360 #[test]
361 fn encoder_has_correct_depth() {
362 let model = make_tiny_model();
363 assert_eq!(model.encoder.blocks.len(), 2);
364 }
365
366 #[test]
367 fn head_weights_correct_size() {
368 let model = make_tiny_model();
369 assert_eq!(model.weights.head_weight.len(), 10 * 64);
370 assert_eq!(model.weights.head_bias.len(), 10);
371 }
372
373 #[test]
374 fn different_seeds_produce_different_outputs() {
375 let cfg = ViTConfig::tiny();
376 let mut rng_a = LcgRng::new(1);
377 let mut rng_b = LcgRng::new(2);
378 let model_a = ViTModel::new(cfg.clone(), &mut rng_a).expect("ok");
379 let model_b = ViTModel::new(cfg, &mut rng_b).expect("ok");
380 let image = vec![0.5f32; 3 * 32 * 32];
381 let la = model_a.forward(&image).expect("ok");
382 let lb = model_b.forward(&image).expect("ok");
383 let diff: f32 = la.iter().zip(lb.iter()).map(|(a, b)| (a - b).abs()).sum();
384 assert!(
385 diff > 1e-6,
386 "different seeds should yield different logits (diff={diff})"
387 );
388 }
389}