oxicuda_vision/vit/
vit_encoder.rs1use crate::{
9 error::{VisionError, VisionResult},
10 handle::LcgRng,
11 vit::vit_block::{ViTBlock, ViTBlockConfig, layer_norm},
12};
13
14#[derive(Debug, Clone, PartialEq)]
18pub struct ViTEncoderConfig {
19 pub block_cfg: ViTBlockConfig,
21 pub depth: usize,
23}
24
25impl ViTEncoderConfig {
26 pub fn new(
32 embed_dim: usize,
33 n_heads: usize,
34 mlp_ratio: usize,
35 depth: usize,
36 ) -> VisionResult<Self> {
37 if depth == 0 {
38 return Err(VisionError::Internal("encoder depth must be > 0".into()));
39 }
40 let block_cfg = ViTBlockConfig::new(embed_dim, n_heads, mlp_ratio)?;
41 Ok(Self { block_cfg, depth })
42 }
43}
44
45pub struct ViTEncoder {
49 pub blocks: Vec<ViTBlock>,
51 pub final_ln_weight: Vec<f32>,
53 pub final_ln_bias: Vec<f32>,
55}
56
57impl ViTEncoder {
58 pub fn new(cfg: ViTEncoderConfig, rng: &mut LcgRng) -> VisionResult<Self> {
63 let e = cfg.block_cfg.embed_dim;
64 let mut blocks = Vec::with_capacity(cfg.depth);
65 for _ in 0..cfg.depth {
66 blocks.push(ViTBlock::new(cfg.block_cfg.clone(), rng));
67 }
68 let final_ln_weight = vec![1.0f32; e];
69 let final_ln_bias = vec![0.0f32; e];
70 Ok(Self {
71 blocks,
72 final_ln_weight,
73 final_ln_bias,
74 })
75 }
76
77 pub fn forward(&self, tokens: &[f32], n_tokens: usize) -> VisionResult<Vec<f32>> {
82 let e = self
83 .blocks
84 .first()
85 .map(|b| b.config.embed_dim)
86 .ok_or_else(|| VisionError::Internal("encoder has no blocks".into()))?;
87
88 if tokens.len() != n_tokens * e {
89 return Err(VisionError::DimensionMismatch {
90 expected: n_tokens * e,
91 got: tokens.len(),
92 });
93 }
94 if n_tokens == 0 {
95 return Err(VisionError::EmptyInput("tokens"));
96 }
97
98 let mut h: Vec<f32> = tokens.to_vec();
100 for block in &self.blocks {
101 h = block.forward(&h, n_tokens)?;
102 }
103
104 let out = layer_norm(
106 &h,
107 &self.final_ln_weight,
108 &self.final_ln_bias,
109 n_tokens,
110 e,
111 1e-5,
112 );
113 Ok(out)
114 }
115
116 #[must_use]
118 pub fn embed_dim(&self) -> usize {
119 self.blocks.first().map_or(0, |b| b.config.embed_dim)
120 }
121}
122
123#[cfg(test)]
126mod tests {
127 use super::*;
128
129 fn make_enc(depth: usize) -> ViTEncoder {
130 let cfg = ViTEncoderConfig::new(64, 4, 4, depth).expect("valid encoder config");
131 let mut rng = LcgRng::new(42);
132 ViTEncoder::new(cfg, &mut rng).expect("encoder created")
133 }
134
135 #[test]
138 fn config_valid() {
139 let cfg = ViTEncoderConfig::new(64, 4, 4, 3).expect("valid");
140 assert_eq!(cfg.depth, 3);
141 assert_eq!(cfg.block_cfg.embed_dim, 64);
142 }
143
144 #[test]
145 fn config_depth_zero_errors() {
146 let r = ViTEncoderConfig::new(64, 4, 4, 0);
147 assert!(matches!(r, Err(VisionError::Internal(_))));
148 }
149
150 #[test]
151 fn config_propagates_block_error() {
152 let r = ViTEncoderConfig::new(65, 4, 4, 2);
154 assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
155 }
156
157 #[test]
160 fn depth1_output_shape() {
161 let enc = make_enc(1);
162 let e = enc.embed_dim();
163 let n_tokens = 17;
164 let tokens = vec![0.1f32; n_tokens * e];
165 let out = enc.forward(&tokens, n_tokens).expect("forward ok");
166 assert_eq!(out.len(), n_tokens * e);
167 }
168
169 #[test]
170 fn depth2_output_shape() {
171 let enc = make_enc(2);
172 let e = enc.embed_dim();
173 let n_tokens = 17;
174 let tokens = vec![0.1f32; n_tokens * e];
175 let out = enc.forward(&tokens, n_tokens).expect("forward ok");
176 assert_eq!(out.len(), n_tokens * e);
177 }
178
179 #[test]
180 fn depth4_output_shape() {
181 let enc = make_enc(4);
182 let e = enc.embed_dim();
183 let n_tokens = 9;
184 let mut rng = LcgRng::new(11);
185 let mut tokens = vec![0.0f32; n_tokens * e];
186 rng.fill_normal(&mut tokens);
187 let out = enc.forward(&tokens, n_tokens).expect("forward ok");
188 assert_eq!(out.len(), n_tokens * e);
189 }
190
191 #[test]
194 fn output_finite_random_input() {
195 let enc = make_enc(2);
196 let e = enc.embed_dim();
197 let n_tokens = 17;
198 let mut rng = LcgRng::new(7);
199 let mut tokens = vec![0.0f32; n_tokens * e];
200 rng.fill_normal(&mut tokens);
201 let out = enc.forward(&tokens, n_tokens).expect("forward ok");
202 assert!(
203 out.iter().all(|v| v.is_finite()),
204 "non-finite encoder output"
205 );
206 }
207
208 #[test]
211 fn final_ln_weight_bias_correct_size() {
212 let enc = make_enc(1);
213 assert_eq!(enc.final_ln_weight.len(), enc.embed_dim());
214 assert_eq!(enc.final_ln_bias.len(), enc.embed_dim());
215 }
216
217 #[test]
218 fn final_ln_weight_initialised_one() {
219 let enc = make_enc(1);
220 assert!(enc.final_ln_weight.iter().all(|&v| (v - 1.0).abs() < 1e-9));
221 }
222
223 #[test]
226 fn dimension_mismatch_errors() {
227 let enc = make_enc(1);
228 let e = enc.embed_dim();
229 let tokens = vec![0.0f32; 3 * e];
231 let r = enc.forward(&tokens, 5);
232 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
233 }
234
235 #[test]
236 fn empty_tokens_errors() {
237 let enc = make_enc(1);
238 let tokens: Vec<f32> = vec![];
239 let r = enc.forward(&tokens, 0);
240 assert!(matches!(r, Err(VisionError::EmptyInput(_))));
241 }
242
243 #[test]
246 fn correct_number_of_blocks() {
247 for d in [1, 2, 4, 6, 12] {
248 let enc = make_enc(d);
249 assert_eq!(enc.blocks.len(), d, "wrong block count for depth={d}");
250 }
251 }
252}