1use crate::{
14 error::{VisionError, VisionResult},
15 handle::LcgRng,
16};
17
18#[derive(Debug, Clone, PartialEq)]
22pub struct ViTBlockConfig {
23 pub embed_dim: usize,
25 pub n_heads: usize,
27 pub mlp_ratio: usize,
29}
30
31impl ViTBlockConfig {
32 pub fn new(embed_dim: usize, n_heads: usize, mlp_ratio: usize) -> VisionResult<Self> {
39 if embed_dim == 0 {
40 return Err(VisionError::InvalidEmbedDim(embed_dim));
41 }
42 if n_heads == 0 {
43 return Err(VisionError::InvalidNumHeads(n_heads));
44 }
45 if embed_dim % n_heads != 0 {
46 return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
47 }
48 Ok(Self {
49 embed_dim,
50 n_heads,
51 mlp_ratio,
52 })
53 }
54
55 #[must_use]
57 #[inline]
58 pub fn head_dim(&self) -> usize {
59 self.embed_dim / self.n_heads
60 }
61
62 #[must_use]
64 #[inline]
65 pub fn mlp_dim(&self) -> usize {
66 self.mlp_ratio * self.embed_dim
67 }
68}
69
70pub struct ViTBlockWeights {
76 pub qkv_weight: Vec<f32>,
79 pub qkv_bias: Vec<f32>,
81
82 pub out_weight: Vec<f32>,
84 pub out_bias: Vec<f32>,
86
87 pub mlp1_weight: Vec<f32>,
89 pub mlp1_bias: Vec<f32>,
91
92 pub mlp2_weight: Vec<f32>,
94 pub mlp2_bias: Vec<f32>,
96
97 pub ln1_weight: Vec<f32>,
99 pub ln1_bias: Vec<f32>,
101
102 pub ln2_weight: Vec<f32>,
104 pub ln2_bias: Vec<f32>,
106}
107
108impl ViTBlockWeights {
109 pub fn default_init(cfg: &ViTBlockConfig, rng: &mut LcgRng) -> Self {
115 let e = cfg.embed_dim;
116 let mlp = cfg.mlp_dim();
117 let scale = 1.0 / (e as f32).sqrt();
118
119 let fill_scaled = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
120 let mut v = vec![0.0f32; n];
121 rng.fill_normal(&mut v);
122 for x in &mut v {
123 *x *= sc;
124 }
125 v
126 };
127
128 let qkv_weight = fill_scaled(rng, 3 * e * e, scale);
129 let qkv_bias = vec![0.0f32; 3 * e];
130 let out_weight = fill_scaled(rng, e * e, scale);
131 let out_bias = vec![0.0f32; e];
132 let mlp1_weight = fill_scaled(rng, mlp * e, scale);
133 let mlp1_bias = vec![0.0f32; mlp];
134 let mlp2_weight = fill_scaled(rng, e * mlp, scale);
135 let mlp2_bias = vec![0.0f32; e];
136 let ln1_weight = vec![1.0f32; e];
137 let ln1_bias = vec![0.0f32; e];
138 let ln2_weight = vec![1.0f32; e];
139 let ln2_bias = vec![0.0f32; e];
140
141 Self {
142 qkv_weight,
143 qkv_bias,
144 out_weight,
145 out_bias,
146 mlp1_weight,
147 mlp1_bias,
148 mlp2_weight,
149 mlp2_bias,
150 ln1_weight,
151 ln1_bias,
152 ln2_weight,
153 ln2_bias,
154 }
155 }
156}
157
158pub struct ViTBlock {
162 pub config: ViTBlockConfig,
163 pub weights: ViTBlockWeights,
164}
165
166impl ViTBlock {
167 pub fn new(cfg: ViTBlockConfig, rng: &mut LcgRng) -> Self {
169 let weights = ViTBlockWeights::default_init(&cfg, rng);
170 Self {
171 config: cfg,
172 weights,
173 }
174 }
175
176 pub fn forward(&self, tokens: &[f32], n_tokens: usize) -> VisionResult<Vec<f32>> {
189 let e = self.config.embed_dim;
190 if tokens.len() != n_tokens * e {
191 return Err(VisionError::DimensionMismatch {
192 expected: n_tokens * e,
193 got: tokens.len(),
194 });
195 }
196 if n_tokens == 0 {
197 return Err(VisionError::EmptyInput("tokens"));
198 }
199
200 let w = &self.weights;
201 let cfg = &self.config;
202
203 let h = layer_norm(tokens, &w.ln1_weight, &w.ln1_bias, n_tokens, e, 1e-5);
205
206 let attn_out = mhsa(
208 &h,
209 n_tokens,
210 e,
211 cfg.n_heads,
212 cfg.head_dim(),
213 &w.qkv_weight,
214 &w.qkv_bias,
215 &w.out_weight,
216 &w.out_bias,
217 )?;
218
219 let mut h: Vec<f32> = attn_out
221 .iter()
222 .zip(tokens.iter())
223 .map(|(a, b)| a + b)
224 .collect();
225
226 let h2 = layer_norm(&h, &w.ln2_weight, &w.ln2_bias, n_tokens, e, 1e-5);
228
229 let mlp_dim = cfg.mlp_dim();
231 let mid = linear(&h2, &w.mlp1_weight, &w.mlp1_bias, e, mlp_dim);
232 let mid: Vec<f32> = mid.into_iter().map(gelu_exact).collect();
233 let mlp_out = linear(&mid, &w.mlp2_weight, &w.mlp2_bias, mlp_dim, e);
234
235 for (o, m) in h.iter_mut().zip(mlp_out.iter()) {
237 *o += m;
238 }
239
240 Ok(h)
241 }
242}
243
244pub(crate) fn layer_norm(
255 x: &[f32],
256 weight: &[f32],
257 bias: &[f32],
258 n: usize,
259 d: usize,
260 eps: f32,
261) -> Vec<f32> {
262 let mut out = vec![0.0f32; n * d];
263 for i in 0..n {
264 let row = &x[i * d..(i + 1) * d];
265 let mean: f32 = row.iter().sum::<f32>() / d as f32;
266 let var: f32 = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / d as f32;
267 let inv_std = 1.0 / (var + eps).sqrt();
268 let o = &mut out[i * d..(i + 1) * d];
269 for j in 0..d {
270 o[j] = (row[j] - mean) * inv_std * weight[j] + bias[j];
271 }
272 }
273 out
274}
275
276pub(crate) fn linear(x: &[f32], w: &[f32], b: &[f32], n_in: usize, n_out: usize) -> Vec<f32> {
283 let batch = x.len() / n_in;
284 let mut out = vec![0.0f32; batch * n_out];
285 for bi in 0..batch {
286 let xrow = &x[bi * n_in..(bi + 1) * n_in];
287 let orow = &mut out[bi * n_out..(bi + 1) * n_out];
288 for oi in 0..n_out {
289 let wrow = &w[oi * n_in..(oi + 1) * n_in];
290 let mut acc = b[oi];
291 for k in 0..n_in {
292 acc += xrow[k] * wrow[k];
293 }
294 orow[oi] = acc;
295 }
296 }
297 out
298}
299
300#[inline]
306pub(crate) fn gelu_exact(x: f32) -> f32 {
307 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
308 const COEFF: f32 = 0.044_715;
309 let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
310 x * 0.5 * (1.0 + inner.tanh())
311}
312
313#[allow(clippy::too_many_arguments)]
324pub(crate) fn mhsa(
325 tokens: &[f32],
326 n_tokens: usize,
327 embed_dim: usize,
328 n_heads: usize,
329 head_dim: usize,
330 qkv_weight: &[f32],
331 qkv_bias: &[f32],
332 out_weight: &[f32],
333 out_bias: &[f32],
334) -> VisionResult<Vec<f32>> {
335 let qkv = linear(tokens, qkv_weight, qkv_bias, embed_dim, 3 * embed_dim);
337
338 let mut q = vec![0.0f32; n_tokens * embed_dim];
340 let mut k = vec![0.0f32; n_tokens * embed_dim];
341 let mut v = vec![0.0f32; n_tokens * embed_dim];
342 for t in 0..n_tokens {
343 let src = &qkv[t * 3 * embed_dim..(t + 1) * 3 * embed_dim];
344 let qd = &mut q[t * embed_dim..(t + 1) * embed_dim];
345 let kd = &mut k[t * embed_dim..(t + 1) * embed_dim];
346 let vd = &mut v[t * embed_dim..(t + 1) * embed_dim];
347 qd.copy_from_slice(&src[..embed_dim]);
348 kd.copy_from_slice(&src[embed_dim..2 * embed_dim]);
349 vd.copy_from_slice(&src[2 * embed_dim..]);
350 }
351
352 let scale = 1.0 / (head_dim as f32).sqrt();
355 let mut concat = vec![0.0f32; n_tokens * embed_dim];
357
358 let mut scores = vec![0.0f32; n_tokens * n_tokens];
360
361 for h in 0..n_heads {
362 let hd_off = h * head_dim; for i in 0..n_tokens {
366 for j in 0..n_tokens {
367 let mut dot = 0.0f32;
368 for d in 0..head_dim {
369 dot += q[i * embed_dim + hd_off + d] * k[j * embed_dim + hd_off + d];
370 }
371 scores[i * n_tokens + j] = dot * scale;
372 }
373 }
374
375 softmax_rows(&mut scores, n_tokens, n_tokens);
377
378 for i in 0..n_tokens {
380 for d in 0..head_dim {
381 let mut acc = 0.0f32;
382 for j in 0..n_tokens {
383 acc += scores[i * n_tokens + j] * v[j * embed_dim + hd_off + d];
384 }
385 concat[i * embed_dim + hd_off + d] = acc;
386 }
387 }
388 }
389
390 let out = linear(&concat, out_weight, out_bias, embed_dim, embed_dim);
392
393 if out.iter().any(|v| !v.is_finite()) {
395 return Err(VisionError::NonFinite("mhsa output"));
396 }
397
398 Ok(out)
399}
400
401pub(crate) fn softmax_rows(logits: &mut [f32], n_rows: usize, n_cols: usize) {
403 for i in 0..n_rows {
404 let row = &mut logits[i * n_cols..(i + 1) * n_cols];
405 let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
407 let mut sum = 0.0f32;
409 for v in row.iter_mut() {
410 *v = (*v - mx).exp();
411 sum += *v;
412 }
413 let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
415 for v in row.iter_mut() {
416 *v *= inv;
417 }
418 }
419}
420
421#[cfg(test)]
424mod tests {
425 use super::*;
426
427 fn make_cfg() -> ViTBlockConfig {
428 ViTBlockConfig::new(64, 4, 4).expect("valid config")
429 }
430
431 #[test]
434 fn config_valid() {
435 let cfg = make_cfg();
436 assert_eq!(cfg.head_dim(), 16);
437 assert_eq!(cfg.mlp_dim(), 256);
438 }
439
440 #[test]
441 fn config_invalid_embed_zero() {
442 let r = ViTBlockConfig::new(0, 4, 4);
443 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
444 }
445
446 #[test]
447 fn config_invalid_heads_zero() {
448 let r = ViTBlockConfig::new(64, 0, 4);
449 assert!(matches!(r, Err(VisionError::InvalidNumHeads(0))));
450 }
451
452 #[test]
453 fn config_head_dim_mismatch() {
454 let r = ViTBlockConfig::new(64, 3, 4); assert!(matches!(
456 r,
457 Err(VisionError::HeadDimMismatch {
458 n_heads: 3,
459 embed_dim: 64
460 })
461 ));
462 }
463
464 #[test]
467 fn layer_norm_zero_input_with_identity_affine() {
468 let x = vec![0.0f32; 8];
471 let w = vec![1.0f32; 8];
472 let b = vec![0.0f32; 8];
473 let out = layer_norm(&x, &w, &b, 1, 8, 1e-5);
474 assert!(
475 out.iter().all(|&v| v.abs() < 1e-4),
476 "expected near-zero: {out:?}"
477 );
478 }
479
480 #[test]
481 fn layer_norm_constant_row_normalises_to_zero() {
482 let x = vec![5.0f32; 16];
484 let w = vec![1.0f32; 16];
485 let b = vec![0.0f32; 16];
486 let out = layer_norm(&x, &w, &b, 1, 16, 1e-5);
487 assert!(out.iter().all(|&v| v.abs() < 1e-4));
488 }
489
490 #[test]
491 fn layer_norm_output_shape() {
492 let x = vec![1.0f32; 4 * 64];
493 let w = vec![1.0f32; 64];
494 let b = vec![0.0f32; 64];
495 let out = layer_norm(&x, &w, &b, 4, 64, 1e-5);
496 assert_eq!(out.len(), 4 * 64);
497 }
498
499 #[test]
500 fn layer_norm_standard_normal_output() {
501 let mut rng = LcgRng::new(77);
503 let mut x = vec![0.0f32; 128];
504 rng.fill_normal(&mut x);
505 let w = vec![1.0f32; 128];
506 let b = vec![0.0f32; 128];
507 let out = layer_norm(&x, &w, &b, 1, 128, 1e-5);
508 let mean: f32 = out.iter().sum::<f32>() / 128.0;
509 let var: f32 = out.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / 128.0;
510 assert!(mean.abs() < 1e-4, "mean too large: {mean}");
511 assert!((var - 1.0).abs() < 1e-3, "var not ~1: {var}");
512 }
513
514 #[test]
517 fn mhsa_output_shape() {
518 let cfg = make_cfg();
519 let e = cfg.embed_dim;
520 let n_tokens = 17;
521 let mut rng = LcgRng::new(1);
522 let w = ViTBlockWeights::default_init(&cfg, &mut rng);
523 let tokens = vec![0.1f32; n_tokens * e];
524 let out = mhsa(
525 &tokens,
526 n_tokens,
527 e,
528 cfg.n_heads,
529 cfg.head_dim(),
530 &w.qkv_weight,
531 &w.qkv_bias,
532 &w.out_weight,
533 &w.out_bias,
534 )
535 .expect("mhsa ok");
536 assert_eq!(out.len(), n_tokens * e);
537 }
538
539 #[test]
540 fn mhsa_output_finite() {
541 let cfg = make_cfg();
542 let e = cfg.embed_dim;
543 let n_tokens = 10;
544 let mut rng = LcgRng::new(2);
545 let w = ViTBlockWeights::default_init(&cfg, &mut rng);
546 let mut tokens = vec![0.0f32; n_tokens * e];
547 rng.fill_normal(&mut tokens);
548 let out = mhsa(
549 &tokens,
550 n_tokens,
551 e,
552 cfg.n_heads,
553 cfg.head_dim(),
554 &w.qkv_weight,
555 &w.qkv_bias,
556 &w.out_weight,
557 &w.out_bias,
558 )
559 .expect("mhsa ok");
560 assert!(
561 out.iter().all(|v| v.is_finite()),
562 "non-finite in mhsa output"
563 );
564 }
565
566 #[test]
569 fn forward_output_shape() {
570 let cfg = make_cfg();
571 let e = cfg.embed_dim;
572 let n_tokens = 17; let mut rng = LcgRng::new(3);
574 let block = ViTBlock::new(cfg, &mut rng);
575 let tokens = vec![0.0f32; n_tokens * e];
576 let out = block.forward(&tokens, n_tokens).expect("forward ok");
577 assert_eq!(out.len(), n_tokens * e);
578 }
579
580 #[test]
581 fn forward_output_finite() {
582 let cfg = make_cfg();
583 let e = cfg.embed_dim;
584 let n_tokens = 17;
585 let mut rng = LcgRng::new(4);
586 let block = ViTBlock::new(cfg, &mut rng);
587 let mut tokens = vec![0.0f32; n_tokens * e];
588 rng.fill_normal(&mut tokens);
589 let out = block.forward(&tokens, n_tokens).expect("forward ok");
590 assert!(
591 out.iter().all(|v| v.is_finite()),
592 "non-finite in block output"
593 );
594 }
595
596 #[test]
597 fn forward_dimension_mismatch_errors() {
598 let cfg = make_cfg();
599 let n_tokens = 5;
600 let mut rng = LcgRng::new(5);
601 let block = ViTBlock::new(cfg, &mut rng);
602 let tokens = vec![0.0f32; n_tokens * 32]; let r = block.forward(&tokens, n_tokens);
605 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
606 }
607
608 #[test]
609 fn forward_residual_not_trivially_zero() {
610 let cfg = make_cfg();
614 let e = cfg.embed_dim;
615 let n_tokens = 8;
616 let mut rng = LcgRng::new(6);
617 let block = ViTBlock::new(cfg, &mut rng);
618 let mut tokens = vec![0.0f32; n_tokens * e];
619 rng.fill_normal(&mut tokens);
620 let out = block.forward(&tokens, n_tokens).expect("forward ok");
621 let diff: f32 = out
623 .iter()
624 .zip(tokens.iter())
625 .map(|(a, b)| (a - b).abs())
626 .sum();
627 assert!(diff > 1e-6, "block did not change tokens (diff={diff})");
628 }
629
630 #[test]
633 fn gelu_zero() {
634 assert!((gelu_exact(0.0) - 0.0).abs() < 1e-6);
636 }
637
638 #[test]
639 fn gelu_large_positive_approx_identity() {
640 let x = 10.0f32;
642 assert!(
643 (gelu_exact(x) - x).abs() < 1e-3,
644 "GELU({x}) = {}",
645 gelu_exact(x)
646 );
647 }
648
649 #[test]
650 fn gelu_large_negative_approx_zero() {
651 let x = -10.0f32;
653 assert!(gelu_exact(x).abs() < 1e-3, "GELU({x}) = {}", gelu_exact(x));
654 }
655}