Skip to main content

oxicuda_vision/vit/
vit_encoder.rs

1//! ViT encoder: a stack of `depth` transformer blocks followed by a
2//! final layer normalisation.
3//!
4//! The encoder operates on flat `[n_tokens, embed_dim]` tensors and applies
5//! each `ViTBlock` sequentially with skip connections already handled inside
6//! each block.
7
8use crate::{
9    error::{VisionError, VisionResult},
10    handle::LcgRng,
11    vit::vit_block::{ViTBlock, ViTBlockConfig, layer_norm},
12};
13
14// ─── Config ──────────────────────────────────────────────────────────────────
15
16/// Configuration for the ViT encoder stack.
17#[derive(Debug, Clone, PartialEq)]
18pub struct ViTEncoderConfig {
19    /// Shared configuration for every transformer block.
20    pub block_cfg: ViTBlockConfig,
21    /// Number of transformer blocks (encoder depth).
22    pub depth: usize,
23}
24
25impl ViTEncoderConfig {
26    /// Create and validate a `ViTEncoderConfig`.
27    ///
28    /// # Errors
29    /// Propagates errors from `ViTBlockConfig::new` (embed/head validation).
30    /// Also returns `Internal` if `depth == 0`.
31    pub fn new(
32        embed_dim: usize,
33        n_heads: usize,
34        mlp_ratio: usize,
35        depth: usize,
36    ) -> VisionResult<Self> {
37        if depth == 0 {
38            return Err(VisionError::Internal("encoder depth must be > 0".into()));
39        }
40        let block_cfg = ViTBlockConfig::new(embed_dim, n_heads, mlp_ratio)?;
41        Ok(Self { block_cfg, depth })
42    }
43}
44
45// ─── ViTEncoder ──────────────────────────────────────────────────────────────
46
47/// Encoder stack: `depth` ViT blocks + a final layer norm.
48pub struct ViTEncoder {
49    /// The individual transformer blocks.
50    pub blocks: Vec<ViTBlock>,
51    /// Final LayerNorm scale: `[embed_dim]`.
52    pub final_ln_weight: Vec<f32>,
53    /// Final LayerNorm bias: `[embed_dim]`.
54    pub final_ln_bias: Vec<f32>,
55}
56
57impl ViTEncoder {
58    /// Construct the encoder: `depth` blocks with independent weight
59    /// initialisations from the shared `rng`.
60    ///
61    /// The final layer-norm is initialised with weight=1, bias=0.
62    pub fn new(cfg: ViTEncoderConfig, rng: &mut LcgRng) -> VisionResult<Self> {
63        let e = cfg.block_cfg.embed_dim;
64        let mut blocks = Vec::with_capacity(cfg.depth);
65        for _ in 0..cfg.depth {
66            blocks.push(ViTBlock::new(cfg.block_cfg.clone(), rng));
67        }
68        let final_ln_weight = vec![1.0f32; e];
69        let final_ln_bias = vec![0.0f32; e];
70        Ok(Self {
71            blocks,
72            final_ln_weight,
73            final_ln_bias,
74        })
75    }
76
77    /// Forward pass through all blocks then the final layer norm.
78    ///
79    /// `tokens`: flat `[n_tokens, embed_dim]`.
80    /// Returns `[n_tokens, embed_dim]`.
81    pub fn forward(&self, tokens: &[f32], n_tokens: usize) -> VisionResult<Vec<f32>> {
82        let e = self
83            .blocks
84            .first()
85            .map(|b| b.config.embed_dim)
86            .ok_or_else(|| VisionError::Internal("encoder has no blocks".into()))?;
87
88        if tokens.len() != n_tokens * e {
89            return Err(VisionError::DimensionMismatch {
90                expected: n_tokens * e,
91                got: tokens.len(),
92            });
93        }
94        if n_tokens == 0 {
95            return Err(VisionError::EmptyInput("tokens"));
96        }
97
98        // Apply each block sequentially
99        let mut h: Vec<f32> = tokens.to_vec();
100        for block in &self.blocks {
101            h = block.forward(&h, n_tokens)?;
102        }
103
104        // Final layer norm
105        let out = layer_norm(
106            &h,
107            &self.final_ln_weight,
108            &self.final_ln_bias,
109            n_tokens,
110            e,
111            1e-5,
112        );
113        Ok(out)
114    }
115
116    /// Embedding dimension (read from the first block config).
117    #[must_use]
118    pub fn embed_dim(&self) -> usize {
119        self.blocks.first().map_or(0, |b| b.config.embed_dim)
120    }
121}
122
123// ─── Tests ───────────────────────────────────────────────────────────────────
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    fn make_enc(depth: usize) -> ViTEncoder {
130        let cfg = ViTEncoderConfig::new(64, 4, 4, depth).expect("valid encoder config");
131        let mut rng = LcgRng::new(42);
132        ViTEncoder::new(cfg, &mut rng).expect("encoder created")
133    }
134
135    // ── Config ────────────────────────────────────────────────────────────────
136
137    #[test]
138    fn config_valid() {
139        let cfg = ViTEncoderConfig::new(64, 4, 4, 3).expect("valid");
140        assert_eq!(cfg.depth, 3);
141        assert_eq!(cfg.block_cfg.embed_dim, 64);
142    }
143
144    #[test]
145    fn config_depth_zero_errors() {
146        let r = ViTEncoderConfig::new(64, 4, 4, 0);
147        assert!(matches!(r, Err(VisionError::Internal(_))));
148    }
149
150    #[test]
151    fn config_propagates_block_error() {
152        // embed_dim not divisible by n_heads
153        let r = ViTEncoderConfig::new(65, 4, 4, 2);
154        assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
155    }
156
157    // ── Forward shape ─────────────────────────────────────────────────────────
158
159    #[test]
160    fn depth1_output_shape() {
161        let enc = make_enc(1);
162        let e = enc.embed_dim();
163        let n_tokens = 17;
164        let tokens = vec![0.1f32; n_tokens * e];
165        let out = enc.forward(&tokens, n_tokens).expect("forward ok");
166        assert_eq!(out.len(), n_tokens * e);
167    }
168
169    #[test]
170    fn depth2_output_shape() {
171        let enc = make_enc(2);
172        let e = enc.embed_dim();
173        let n_tokens = 17;
174        let tokens = vec![0.1f32; n_tokens * e];
175        let out = enc.forward(&tokens, n_tokens).expect("forward ok");
176        assert_eq!(out.len(), n_tokens * e);
177    }
178
179    #[test]
180    fn depth4_output_shape() {
181        let enc = make_enc(4);
182        let e = enc.embed_dim();
183        let n_tokens = 9;
184        let mut rng = LcgRng::new(11);
185        let mut tokens = vec![0.0f32; n_tokens * e];
186        rng.fill_normal(&mut tokens);
187        let out = enc.forward(&tokens, n_tokens).expect("forward ok");
188        assert_eq!(out.len(), n_tokens * e);
189    }
190
191    // ── Finite outputs ────────────────────────────────────────────────────────
192
193    #[test]
194    fn output_finite_random_input() {
195        let enc = make_enc(2);
196        let e = enc.embed_dim();
197        let n_tokens = 17;
198        let mut rng = LcgRng::new(7);
199        let mut tokens = vec![0.0f32; n_tokens * e];
200        rng.fill_normal(&mut tokens);
201        let out = enc.forward(&tokens, n_tokens).expect("forward ok");
202        assert!(
203            out.iter().all(|v| v.is_finite()),
204            "non-finite encoder output"
205        );
206    }
207
208    // ── Final LN runs ─────────────────────────────────────────────────────────
209
210    #[test]
211    fn final_ln_weight_bias_correct_size() {
212        let enc = make_enc(1);
213        assert_eq!(enc.final_ln_weight.len(), enc.embed_dim());
214        assert_eq!(enc.final_ln_bias.len(), enc.embed_dim());
215    }
216
217    #[test]
218    fn final_ln_weight_initialised_one() {
219        let enc = make_enc(1);
220        assert!(enc.final_ln_weight.iter().all(|&v| (v - 1.0).abs() < 1e-9));
221    }
222
223    // ── Error cases ───────────────────────────────────────────────────────────
224
225    #[test]
226    fn dimension_mismatch_errors() {
227        let enc = make_enc(1);
228        let e = enc.embed_dim();
229        // Wrong token count (n_tokens=5 but slice says 3 tokens)
230        let tokens = vec![0.0f32; 3 * e];
231        let r = enc.forward(&tokens, 5);
232        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
233    }
234
235    #[test]
236    fn empty_tokens_errors() {
237        let enc = make_enc(1);
238        let tokens: Vec<f32> = vec![];
239        let r = enc.forward(&tokens, 0);
240        assert!(matches!(r, Err(VisionError::EmptyInput(_))));
241    }
242
243    // ── Block count ───────────────────────────────────────────────────────────
244
245    #[test]
246    fn correct_number_of_blocks() {
247        for d in [1, 2, 4, 6, 12] {
248            let enc = make_enc(d);
249            assert_eq!(enc.blocks.len(), d, "wrong block count for depth={d}");
250        }
251    }
252}