Skip to main content

oxicuda_vision/vit/
vit_model.rs

1//! Full Vision Transformer (ViT) model.
2//!
3//! ## Pipeline
4//! ```text
5//! image [C, H, W]
6//!   → PatchEmbed         → [n_patches, embed_dim]
7//!   → prepend_cls        → [n_patches + 1, embed_dim]
8//!   → add_pos_embed      → [n_patches + 1, embed_dim]
9//!   → ViTEncoder         → [n_patches + 1, embed_dim]
10//!   → CLS token [0]      → [embed_dim]
11//!   → Linear head        → [n_classes]   (logits)
12//! ```
13
14use crate::{
15    error::{VisionError, VisionResult},
16    handle::LcgRng,
17    patch_embed::{LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, prepend_cls},
18    vit::{
19        vit_block::linear,
20        vit_encoder::{ViTEncoder, ViTEncoderConfig},
21    },
22};
23
24// ─── Config ──────────────────────────────────────────────────────────────────
25
26/// Top-level ViT configuration.
27#[derive(Debug, Clone, PartialEq)]
28pub struct ViTConfig {
29    /// Square spatial resolution of the input image (H = W).
30    pub img_size: usize,
31    /// Patch size (stride = kernel size = `patch_size`).
32    pub patch_size: usize,
33    /// Number of input channels (e.g. 3 for RGB).
34    pub in_chans: usize,
35    /// Token embedding dimension.
36    pub embed_dim: usize,
37    /// Number of transformer blocks.
38    pub depth: usize,
39    /// Number of attention heads per block.
40    pub n_heads: usize,
41    /// MLP hidden-dim multiplier.
42    pub mlp_ratio: usize,
43    /// Number of output classes for the classification head.
44    pub n_classes: usize,
45}
46
47impl ViTConfig {
48    /// Tiny ViT: suitable for CIFAR-10 style 32×32 RGB images.
49    ///
50    /// `img_size=32`, `patch_size=4`, `in_chans=3`, `embed_dim=64`,
51    /// `depth=2`, `n_heads=4`, `mlp_ratio=4`, `n_classes=10`.
52    #[must_use]
53    pub fn tiny() -> Self {
54        Self {
55            img_size: 32,
56            patch_size: 4,
57            in_chans: 3,
58            embed_dim: 64,
59            depth: 2,
60            n_heads: 4,
61            mlp_ratio: 4,
62            n_classes: 10,
63        }
64    }
65
66    /// Validate and construct a `ViTConfig`.
67    ///
68    /// # Errors
69    /// - `n_classes == 0` → `InvalidNumClasses`
70    /// - patch or image size issues → propagated from `PatchEmbedConfig`
71    /// - `embed_dim % n_heads != 0` → `HeadDimMismatch`
72    pub fn new(
73        img_size: usize,
74        patch_size: usize,
75        in_chans: usize,
76        embed_dim: usize,
77        depth: usize,
78        n_heads: usize,
79        mlp_ratio: usize,
80        n_classes: usize,
81    ) -> VisionResult<Self> {
82        if n_classes == 0 {
83            return Err(VisionError::InvalidNumClasses(n_classes));
84        }
85        if depth == 0 {
86            return Err(VisionError::Internal("depth must be > 0".into()));
87        }
88        // Delegate patch / embed / head validation to their own constructors.
89        PatchEmbedConfig::new(img_size, patch_size, in_chans, embed_dim)?;
90        ViTEncoderConfig::new(embed_dim, n_heads, mlp_ratio, depth)?;
91
92        Ok(Self {
93            img_size,
94            patch_size,
95            in_chans,
96            embed_dim,
97            depth,
98            n_heads,
99            mlp_ratio,
100            n_classes,
101        })
102    }
103
104    /// Number of non-overlapping patches for this image / patch size.
105    #[must_use]
106    pub fn n_patches(&self) -> usize {
107        let grid = self.img_size / self.patch_size;
108        grid * grid
109    }
110
111    /// Total sequence length including the CLS token.
112    #[must_use]
113    pub fn seq_len(&self) -> usize {
114        self.n_patches() + 1
115    }
116}
117
118// ─── Weights ─────────────────────────────────────────────────────────────────
119
120/// Learnable weights for the ViT classification head.
121///
122/// The head is a single linear projection from `embed_dim` to `n_classes`.
123pub struct ViTModelWeights {
124    /// Head projection kernel: `[n_classes, embed_dim]` flat.
125    pub head_weight: Vec<f32>,
126    /// Head projection bias: `[n_classes]`.
127    pub head_bias: Vec<f32>,
128}
129
130impl ViTModelWeights {
131    fn default_init(cfg: &ViTConfig, rng: &mut LcgRng) -> Self {
132        let scale = 1.0 / (cfg.embed_dim as f32).sqrt();
133        let mut head_weight = vec![0.0f32; cfg.n_classes * cfg.embed_dim];
134        rng.fill_normal(&mut head_weight);
135        for v in &mut head_weight {
136            *v *= scale;
137        }
138        let head_bias = vec![0.0f32; cfg.n_classes];
139        Self {
140            head_weight,
141            head_bias,
142        }
143    }
144}
145
146// ─── ViTModel ─────────────────────────────────────────────────────────────────
147
148/// Full Vision Transformer model.
149pub struct ViTModel {
150    /// Top-level model configuration.
151    pub config: ViTConfig,
152    /// Strided conv2d patch embedder.
153    pub patch_embed: PatchEmbed,
154    /// Learnable positional embeddings for `seq_len` positions (CLS + patches).
155    pub pos_embed: LearnablePosEmbed,
156    /// Transformer encoder stack.
157    pub encoder: ViTEncoder,
158    /// Classification head weights.
159    pub weights: ViTModelWeights,
160}
161
162impl ViTModel {
163    /// Build and initialise a full ViT model.
164    ///
165    /// All sub-modules share the same `rng` stream for reproducibility.
166    pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
167        let patch_cfg =
168            PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)?;
169        let patch_embed = PatchEmbed::new(patch_cfg, rng);
170
171        // Positional embedding covers CLS token + all patch tokens.
172        let seq_len = cfg.seq_len();
173        let pos_embed = LearnablePosEmbed::new(seq_len, cfg.embed_dim, rng)?;
174
175        let enc_cfg = ViTEncoderConfig::new(cfg.embed_dim, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
176        let encoder = ViTEncoder::new(enc_cfg, rng)?;
177
178        let weights = ViTModelWeights::default_init(&cfg, rng);
179
180        Ok(Self {
181            config: cfg,
182            patch_embed,
183            pos_embed,
184            encoder,
185            weights,
186        })
187    }
188
189    /// Forward pass.
190    ///
191    /// `image` is flat CHW: `[in_chans, img_size, img_size]`.
192    /// Returns logits: `[n_classes]`.
193    ///
194    /// # Errors
195    /// Returns `DimensionMismatch` if `image.len()` does not match
196    /// `in_chans * img_size * img_size`.
197    pub fn forward(&self, image: &[f32]) -> VisionResult<Vec<f32>> {
198        let cfg = &self.config;
199        let expected_img = cfg.in_chans * cfg.img_size * cfg.img_size;
200        if image.len() != expected_img {
201            return Err(VisionError::DimensionMismatch {
202                expected: expected_img,
203                got: image.len(),
204            });
205        }
206
207        // Step 1: patch embedding → [n_patches, embed_dim]
208        let patch_tokens = self.patch_embed.forward(image)?;
209
210        // Step 2: prepend CLS token → [n_patches + 1, embed_dim]
211        let cls_token = &self.patch_embed.weights.cls_token;
212        let mut tokens = prepend_cls(&patch_tokens, cls_token, cfg.embed_dim)?;
213
214        // Step 3: add positional embeddings (all seq_len positions)
215        add_pos_embed(&mut tokens, &self.pos_embed.table, cfg.embed_dim)?;
216
217        // Step 4: transformer encoder → [seq_len, embed_dim]
218        let seq_len = cfg.seq_len();
219        let encoded = self.encoder.forward(&tokens, seq_len)?;
220
221        // Step 5: extract CLS token (first row)
222        let cls_repr = &encoded[..cfg.embed_dim];
223
224        // Step 6: classification head → [n_classes]
225        let logits = linear(
226            cls_repr,
227            &self.weights.head_weight,
228            &self.weights.head_bias,
229            cfg.embed_dim,
230            cfg.n_classes,
231        );
232
233        Ok(logits)
234    }
235}
236
237// ─── Tests ───────────────────────────────────────────────────────────────────
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn make_tiny_model() -> ViTModel {
244        let cfg = ViTConfig::tiny();
245        let mut rng = LcgRng::new(42);
246        ViTModel::new(cfg, &mut rng).expect("tiny model created")
247    }
248
249    // ── Config ────────────────────────────────────────────────────────────────
250
251    #[test]
252    fn tiny_config_values() {
253        let cfg = ViTConfig::tiny();
254        assert_eq!(cfg.img_size, 32);
255        assert_eq!(cfg.patch_size, 4);
256        assert_eq!(cfg.in_chans, 3);
257        assert_eq!(cfg.embed_dim, 64);
258        assert_eq!(cfg.depth, 2);
259        assert_eq!(cfg.n_heads, 4);
260        assert_eq!(cfg.mlp_ratio, 4);
261        assert_eq!(cfg.n_classes, 10);
262    }
263
264    #[test]
265    fn tiny_config_n_patches() {
266        let cfg = ViTConfig::tiny();
267        // (32/4)^2 = 8^2 = 64 patches
268        assert_eq!(cfg.n_patches(), 64);
269        assert_eq!(cfg.seq_len(), 65);
270    }
271
272    #[test]
273    fn config_zero_classes_errors() {
274        let r = ViTConfig::new(32, 4, 3, 64, 2, 4, 4, 0);
275        assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
276    }
277
278    #[test]
279    fn config_invalid_patch_size_errors() {
280        let r = ViTConfig::new(32, 5, 3, 64, 2, 4, 4, 10); // 32 % 5 != 0
281        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
282    }
283
284    #[test]
285    fn config_head_dim_mismatch_errors() {
286        let r = ViTConfig::new(32, 4, 3, 63, 2, 4, 4, 10); // 63 % 4 != 0
287        assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
288    }
289
290    // ── Forward ───────────────────────────────────────────────────────────────
291
292    #[test]
293    fn forward_returns_ten_logits() {
294        let model = make_tiny_model();
295        let image = vec![0.0f32; 3 * 32 * 32];
296        let logits = model.forward(&image).expect("forward ok");
297        assert_eq!(logits.len(), 10, "expected 10 logits, got {}", logits.len());
298    }
299
300    #[test]
301    fn forward_logits_finite() {
302        let model = make_tiny_model();
303        let mut rng = LcgRng::new(7);
304        let mut image = vec![0.0f32; 3 * 32 * 32];
305        rng.fill_normal(&mut image);
306        let logits = model.forward(&image).expect("forward ok");
307        assert!(
308            logits.iter().all(|v| v.is_finite()),
309            "non-finite logits: {logits:?}"
310        );
311    }
312
313    #[test]
314    fn forward_random_input_not_constant_logits() {
315        // Different images → different logits
316        let model = make_tiny_model();
317        let mut rng = LcgRng::new(13);
318        let mut img1 = vec![0.0f32; 3 * 32 * 32];
319        let mut img2 = vec![0.0f32; 3 * 32 * 32];
320        rng.fill_normal(&mut img1);
321        rng.fill_normal(&mut img2);
322        let l1 = model.forward(&img1).expect("ok");
323        let l2 = model.forward(&img2).expect("ok");
324        let diff: f32 = l1.iter().zip(l2.iter()).map(|(a, b)| (a - b).abs()).sum();
325        assert!(
326            diff > 1e-6,
327            "logits did not change between different images (diff={diff})"
328        );
329    }
330
331    #[test]
332    fn forward_wrong_image_size_errors() {
333        let model = make_tiny_model();
334        // Too small
335        let image = vec![0.0f32; 3 * 32 * 31]; // wrong
336        let r = model.forward(&image);
337        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
338    }
339
340    #[test]
341    fn forward_correct_image_size_passes() {
342        let model = make_tiny_model();
343        let image = vec![0.5f32; 3 * 32 * 32];
344        let logits = model
345            .forward(&image)
346            .expect("forward ok with constant image");
347        assert_eq!(logits.len(), 10);
348    }
349
350    // ── Structural checks ─────────────────────────────────────────────────────
351
352    #[test]
353    fn pos_embed_has_correct_positions() {
354        let model = make_tiny_model();
355        // seq_len = 65 (64 patches + 1 CLS)
356        assert_eq!(model.pos_embed.n_positions, 65);
357        assert_eq!(model.pos_embed.embed_dim, 64);
358    }
359
360    #[test]
361    fn encoder_has_correct_depth() {
362        let model = make_tiny_model();
363        assert_eq!(model.encoder.blocks.len(), 2);
364    }
365
366    #[test]
367    fn head_weights_correct_size() {
368        let model = make_tiny_model();
369        assert_eq!(model.weights.head_weight.len(), 10 * 64);
370        assert_eq!(model.weights.head_bias.len(), 10);
371    }
372
373    #[test]
374    fn different_seeds_produce_different_outputs() {
375        let cfg = ViTConfig::tiny();
376        let mut rng_a = LcgRng::new(1);
377        let mut rng_b = LcgRng::new(2);
378        let model_a = ViTModel::new(cfg.clone(), &mut rng_a).expect("ok");
379        let model_b = ViTModel::new(cfg, &mut rng_b).expect("ok");
380        let image = vec![0.5f32; 3 * 32 * 32];
381        let la = model_a.forward(&image).expect("ok");
382        let lb = model_b.forward(&image).expect("ok");
383        let diff: f32 = la.iter().zip(lb.iter()).map(|(a, b)| (a - b).abs()).sum();
384        assert!(
385            diff > 1e-6,
386            "different seeds should yield different logits (diff={diff})"
387        );
388    }
389}