1use crate::{
33 error::{VisionError, VisionResult},
34 handle::LcgRng,
35 vit::{
36 vit_block::{ViTBlock, ViTBlockConfig, layer_norm, linear},
37 vit_encoder::ViTEncoderConfig,
38 },
39};
40
41#[derive(Debug, Clone, PartialEq)]
45pub struct MaeConfig {
46 pub img_size: usize,
48 pub patch_size: usize,
50 pub in_channels: usize,
52 pub encoder_dim: usize,
54 pub encoder_depth: usize,
56 pub encoder_heads: usize,
58 pub decoder_dim: usize,
60 pub decoder_depth: usize,
62 pub decoder_heads: usize,
64 pub mlp_ratio: usize,
66 pub mask_ratio: f32,
68}
69
70impl MaeConfig {
71 #[allow(clippy::too_many_arguments)]
80 pub fn new(
81 img_size: usize,
82 patch_size: usize,
83 in_channels: usize,
84 encoder_dim: usize,
85 encoder_depth: usize,
86 encoder_heads: usize,
87 decoder_dim: usize,
88 decoder_depth: usize,
89 decoder_heads: usize,
90 mlp_ratio: usize,
91 mask_ratio: f32,
92 ) -> VisionResult<Self> {
93 if patch_size == 0 || img_size == 0 || img_size % patch_size != 0 {
94 return Err(VisionError::InvalidPatchSize {
95 patch_size,
96 img_size,
97 });
98 }
99 if in_channels == 0 {
100 return Err(VisionError::EmptyInput("in_channels"));
101 }
102 if encoder_dim == 0 {
103 return Err(VisionError::InvalidEmbedDim(encoder_dim));
104 }
105 if decoder_dim == 0 {
106 return Err(VisionError::InvalidEmbedDim(decoder_dim));
107 }
108 if encoder_depth == 0 {
109 return Err(VisionError::Internal("encoder_depth must be > 0".into()));
110 }
111 if decoder_depth == 0 {
112 return Err(VisionError::Internal("decoder_depth must be > 0".into()));
113 }
114 if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
115 return Err(VisionError::Internal(format!(
116 "mask_ratio {mask_ratio} not in [0, 1]"
117 )));
118 }
119 let _ = ViTBlockConfig::new(encoder_dim, encoder_heads, mlp_ratio)?;
121 let _ = ViTBlockConfig::new(decoder_dim, decoder_heads, mlp_ratio)?;
122 Ok(Self {
123 img_size,
124 patch_size,
125 in_channels,
126 encoder_dim,
127 encoder_depth,
128 encoder_heads,
129 decoder_dim,
130 decoder_depth,
131 decoder_heads,
132 mlp_ratio,
133 mask_ratio,
134 })
135 }
136
137 #[must_use]
139 pub fn n_patches(&self) -> usize {
140 let grid = self.img_size / self.patch_size;
141 grid * grid
142 }
143
144 #[must_use]
146 pub fn patch_pixels(&self) -> usize {
147 self.patch_size * self.patch_size * self.in_channels
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq)]
158pub struct MaskMeta {
159 pub visible_ids: Vec<usize>,
161 pub masked_ids: Vec<usize>,
163}
164
165pub fn generate_random_mask(
180 n_patches: usize,
181 mask_ratio: f32,
182 rng: &mut LcgRng,
183) -> VisionResult<MaskMeta> {
184 if n_patches == 0 {
185 return Err(VisionError::EmptyInput("n_patches"));
186 }
187 if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
188 return Err(VisionError::Internal(format!(
189 "mask_ratio {mask_ratio} not in [0, 1]"
190 )));
191 }
192 let mut ids: Vec<usize> = (0..n_patches).collect();
193 let n_masked = (mask_ratio * (n_patches as f32)).round() as usize;
194 let n_masked = n_masked.min(n_patches);
195
196 for i in 0..n_masked {
200 let remaining = n_patches - i;
201 let j = i + rng.next_usize(remaining);
203 ids.swap(i, j);
204 }
205
206 let mut masked_ids: Vec<usize> = ids[..n_masked].to_vec();
207 let mut visible_ids: Vec<usize> = ids[n_masked..].to_vec();
208 masked_ids.sort_unstable();
209 visible_ids.sort_unstable();
210 Ok(MaskMeta {
211 visible_ids,
212 masked_ids,
213 })
214}
215
216pub struct Mae {
223 pub config: MaeConfig,
225 pub patch_embed_weights: Vec<f32>,
227 pub patch_embed_bias: Vec<f32>,
229 pub encoder_pos_embed: Vec<f32>,
231 pub encoder_blocks: Vec<ViTBlock>,
233 pub encoder_norm_gamma: Vec<f32>,
235 pub encoder_norm_beta: Vec<f32>,
237 pub decoder_embed_weights: Vec<f32>,
239 pub decoder_embed_bias: Vec<f32>,
241 pub mask_token: Vec<f32>,
243 pub decoder_pos_embed: Vec<f32>,
245 pub decoder_blocks: Vec<ViTBlock>,
247 pub decoder_norm_gamma: Vec<f32>,
249 pub decoder_norm_beta: Vec<f32>,
251 pub decoder_pred_weights: Vec<f32>,
253 pub decoder_pred_bias: Vec<f32>,
255}
256
257#[inline]
262fn safe_centered_uniform(rng: &mut LcgRng) -> f32 {
263 (rng.next_u32() as f32) / 4_294_967_296.0 - 0.5
264}
265
266fn fill_centered_uniform(buf: &mut [f32], scale: f32, rng: &mut LcgRng) {
269 for v in buf.iter_mut() {
270 *v = safe_centered_uniform(rng) * 2.0 * scale;
271 }
272}
273
274impl Mae {
275 pub fn new(cfg: MaeConfig, rng: &mut LcgRng) -> VisionResult<Self> {
288 let n_patches = cfg.n_patches();
289 let pp = cfg.patch_pixels();
290 let edim = cfg.encoder_dim;
291 let ddim = cfg.decoder_dim;
292
293 let enc_scale = 1.0 / (pp as f32).sqrt();
294 let mut patch_embed_weights = vec![0.0f32; edim * pp];
295 fill_centered_uniform(&mut patch_embed_weights, enc_scale, rng);
296 let patch_embed_bias = vec![0.0f32; edim];
297
298 let pos_scale = 0.02f32; let mut encoder_pos_embed = vec![0.0f32; n_patches * edim];
300 fill_centered_uniform(&mut encoder_pos_embed, pos_scale, rng);
301
302 let enc_block_cfg =
304 ViTEncoderConfig::new(edim, cfg.encoder_heads, cfg.mlp_ratio, cfg.encoder_depth)?;
305 let mut encoder_blocks = Vec::with_capacity(cfg.encoder_depth);
306 for _ in 0..cfg.encoder_depth {
307 encoder_blocks.push(ViTBlock::new(enc_block_cfg.block_cfg.clone(), rng));
308 }
309 let encoder_norm_gamma = vec![1.0f32; edim];
310 let encoder_norm_beta = vec![0.0f32; edim];
311
312 let dec_in_scale = 1.0 / (edim as f32).sqrt();
313 let mut decoder_embed_weights = vec![0.0f32; ddim * edim];
314 fill_centered_uniform(&mut decoder_embed_weights, dec_in_scale, rng);
315 let decoder_embed_bias = vec![0.0f32; ddim];
316
317 let mut mask_token = vec![0.0f32; ddim];
318 fill_centered_uniform(&mut mask_token, pos_scale, rng);
319
320 let mut decoder_pos_embed = vec![0.0f32; n_patches * ddim];
321 fill_centered_uniform(&mut decoder_pos_embed, pos_scale, rng);
322
323 let dec_block_cfg =
324 ViTEncoderConfig::new(ddim, cfg.decoder_heads, cfg.mlp_ratio, cfg.decoder_depth)?;
325 let mut decoder_blocks = Vec::with_capacity(cfg.decoder_depth);
326 for _ in 0..cfg.decoder_depth {
327 decoder_blocks.push(ViTBlock::new(dec_block_cfg.block_cfg.clone(), rng));
328 }
329 let decoder_norm_gamma = vec![1.0f32; ddim];
330 let decoder_norm_beta = vec![0.0f32; ddim];
331
332 let pred_scale = 1.0 / (ddim as f32).sqrt();
333 let mut decoder_pred_weights = vec![0.0f32; pp * ddim];
334 fill_centered_uniform(&mut decoder_pred_weights, pred_scale, rng);
335 let decoder_pred_bias = vec![0.0f32; pp];
336
337 Ok(Self {
338 config: cfg,
339 patch_embed_weights,
340 patch_embed_bias,
341 encoder_pos_embed,
342 encoder_blocks,
343 encoder_norm_gamma,
344 encoder_norm_beta,
345 decoder_embed_weights,
346 decoder_embed_bias,
347 mask_token,
348 decoder_pos_embed,
349 decoder_blocks,
350 decoder_norm_gamma,
351 decoder_norm_beta,
352 decoder_pred_weights,
353 decoder_pred_bias,
354 })
355 }
356
357 pub fn encode(
367 &self,
368 image_patches: &[f32],
369 rng: &mut LcgRng,
370 ) -> VisionResult<(Vec<f32>, MaskMeta)> {
371 let n_patches = self.config.n_patches();
372 let pp = self.config.patch_pixels();
373 let edim = self.config.encoder_dim;
374
375 if n_patches == 0 {
376 return Err(VisionError::EmptyInput("n_patches"));
377 }
378 let expected = n_patches * pp;
379 if image_patches.len() != expected {
380 return Err(VisionError::DimensionMismatch {
381 expected,
382 got: image_patches.len(),
383 });
384 }
385
386 let mut embedded = linear(
388 image_patches,
389 &self.patch_embed_weights,
390 &self.patch_embed_bias,
391 pp,
392 edim,
393 );
394
395 for (i, v) in embedded.iter_mut().enumerate() {
397 *v += self
398 .encoder_pos_embed
399 .get(i)
400 .copied()
401 .ok_or(VisionError::Internal(
402 "encoder_pos_embed shorter than embedded".into(),
403 ))?;
404 }
405
406 let mask_meta = generate_random_mask(n_patches, self.config.mask_ratio, rng)?;
408 let n_visible = mask_meta.visible_ids.len();
409
410 let mut visible_tokens = vec![0.0f32; n_visible * edim];
411 for (out_i, &src_i) in mask_meta.visible_ids.iter().enumerate() {
412 let src = embedded
413 .get(src_i * edim..(src_i + 1) * edim)
414 .ok_or(VisionError::Internal("visible idx out of range".into()))?;
415 let dst = visible_tokens
416 .get_mut(out_i * edim..(out_i + 1) * edim)
417 .ok_or(VisionError::Internal(
418 "visible_tokens slice out of range".into(),
419 ))?;
420 dst.copy_from_slice(src);
421 }
422
423 let encoded = if n_visible == 0 {
427 Vec::new()
428 } else {
429 let mut h = visible_tokens;
430 for block in &self.encoder_blocks {
431 h = block.forward(&h, n_visible)?;
432 }
433 layer_norm(
434 &h,
435 &self.encoder_norm_gamma,
436 &self.encoder_norm_beta,
437 n_visible,
438 edim,
439 1e-5,
440 )
441 };
442
443 Ok((encoded, mask_meta))
444 }
445
446 pub fn decode(&self, encoded_visible: &[f32], mask_meta: &MaskMeta) -> VisionResult<Vec<f32>> {
455 let n_patches = self.config.n_patches();
456 let edim = self.config.encoder_dim;
457 let ddim = self.config.decoder_dim;
458 let pp = self.config.patch_pixels();
459 let n_visible = mask_meta.visible_ids.len();
460 let n_masked = mask_meta.masked_ids.len();
461
462 if n_visible + n_masked != n_patches {
463 return Err(VisionError::Internal(
464 "MaskMeta visible + masked sizes do not sum to n_patches".into(),
465 ));
466 }
467 if encoded_visible.len() != n_visible * edim {
468 return Err(VisionError::DimensionMismatch {
469 expected: n_visible * edim,
470 got: encoded_visible.len(),
471 });
472 }
473
474 let visible_dec = if n_visible == 0 {
476 Vec::new()
477 } else {
478 linear(
479 encoded_visible,
480 &self.decoder_embed_weights,
481 &self.decoder_embed_bias,
482 edim,
483 ddim,
484 )
485 };
486
487 let mut full = vec![0.0f32; n_patches * ddim];
490 for (vis_i, &dst_i) in mask_meta.visible_ids.iter().enumerate() {
491 let src = visible_dec
492 .get(vis_i * ddim..(vis_i + 1) * ddim)
493 .ok_or(VisionError::Internal("visible_dec slice".into()))?;
494 let dst = full
495 .get_mut(dst_i * ddim..(dst_i + 1) * ddim)
496 .ok_or(VisionError::Internal("full slice (visible)".into()))?;
497 dst.copy_from_slice(src);
498 }
499 for &dst_i in &mask_meta.masked_ids {
500 let dst = full
501 .get_mut(dst_i * ddim..(dst_i + 1) * ddim)
502 .ok_or(VisionError::Internal("full slice (masked)".into()))?;
503 dst.copy_from_slice(&self.mask_token);
504 }
505
506 for (i, v) in full.iter_mut().enumerate() {
508 *v += self
509 .decoder_pos_embed
510 .get(i)
511 .copied()
512 .ok_or(VisionError::Internal("decoder_pos_embed".into()))?;
513 }
514
515 let mut h = full;
517 for block in &self.decoder_blocks {
518 h = block.forward(&h, n_patches)?;
519 }
520 let post_norm = layer_norm(
521 &h,
522 &self.decoder_norm_gamma,
523 &self.decoder_norm_beta,
524 n_patches,
525 ddim,
526 1e-5,
527 );
528
529 let reconstructed = linear(
531 &post_norm,
532 &self.decoder_pred_weights,
533 &self.decoder_pred_bias,
534 ddim,
535 pp,
536 );
537 Ok(reconstructed)
538 }
539}
540
541pub fn mae_loss(
555 reconstructed: &[f32],
556 ground_truth_patches: &[f32],
557 mask_meta: &MaskMeta,
558) -> VisionResult<f32> {
559 if reconstructed.len() != ground_truth_patches.len() {
560 return Err(VisionError::DimensionMismatch {
561 expected: reconstructed.len(),
562 got: ground_truth_patches.len(),
563 });
564 }
565 let n_patches = mask_meta.visible_ids.len() + mask_meta.masked_ids.len();
566 if n_patches == 0 {
567 return Err(VisionError::EmptyInput("mask_meta n_patches"));
568 }
569 if reconstructed.len() % n_patches != 0 {
570 return Err(VisionError::DimensionMismatch {
571 expected: n_patches,
572 got: reconstructed.len(),
573 });
574 }
575 let pp = reconstructed.len() / n_patches;
576 if mask_meta.masked_ids.is_empty() {
577 return Ok(0.0);
578 }
579 let mut sum_sq = 0.0f64;
580 let mut count: u64 = 0;
581 for &mi in &mask_meta.masked_ids {
582 let r = reconstructed
583 .get(mi * pp..(mi + 1) * pp)
584 .ok_or(VisionError::Internal("loss: masked idx".into()))?;
585 let g = ground_truth_patches
586 .get(mi * pp..(mi + 1) * pp)
587 .ok_or(VisionError::Internal("loss: masked idx (gt)".into()))?;
588 for (rv, gv) in r.iter().zip(g.iter()) {
589 let d = (*rv - *gv) as f64;
590 sum_sq += d * d;
591 count += 1;
592 }
593 }
594 let mean = if count == 0 {
595 0.0
596 } else {
597 sum_sq / (count as f64)
598 };
599 Ok(mean as f32)
600}
601
602#[cfg(test)]
605mod tests {
606 use super::*;
607 use std::collections::HashSet;
608
609 fn make_tiny_cfg() -> MaeConfig {
610 MaeConfig::new(8, 4, 3, 16, 2, 4, 8, 1, 4, 2, 0.5).expect("valid tiny cfg")
612 }
613
614 fn make_medium_cfg() -> MaeConfig {
615 MaeConfig::new(16, 4, 3, 32, 2, 4, 16, 1, 4, 2, 0.75).expect("valid med cfg")
617 }
618
619 #[test]
622 fn mask_union_and_disjoint() {
623 let mut rng = LcgRng::new(1);
624 let n = 16;
625 let m = generate_random_mask(n, 0.5, &mut rng).expect("ok");
626 let v: HashSet<usize> = m.visible_ids.iter().copied().collect();
627 let k: HashSet<usize> = m.masked_ids.iter().copied().collect();
628 assert!(v.is_disjoint(&k));
629 let union: HashSet<usize> = v.union(&k).copied().collect();
630 let expected: HashSet<usize> = (0..n).collect();
631 assert_eq!(union, expected);
632 assert_eq!(v.len() + k.len(), n);
633 }
634
635 #[test]
636 fn mask_count_matches_round() {
637 let mut rng = LcgRng::new(2);
638 let m = generate_random_mask(100, 0.75, &mut rng).expect("ok");
640 assert_eq!(m.masked_ids.len(), 75);
641 assert_eq!(m.visible_ids.len(), 25);
642 }
643
644 #[test]
645 fn mask_count_rounds_correctly() {
646 let mut rng = LcgRng::new(3);
647 let m = generate_random_mask(10, 0.7, &mut rng).expect("ok");
649 assert_eq!(m.masked_ids.len(), 7);
650 }
651
652 #[test]
653 fn mask_ratio_zero_all_visible() {
654 let mut rng = LcgRng::new(4);
655 let m = generate_random_mask(8, 0.0, &mut rng).expect("ok");
656 assert_eq!(m.masked_ids.len(), 0);
657 assert_eq!(m.visible_ids.len(), 8);
658 assert_eq!(m.visible_ids, (0..8).collect::<Vec<_>>());
659 }
660
661 #[test]
662 fn mask_ratio_one_all_masked() {
663 let mut rng = LcgRng::new(5);
664 let m = generate_random_mask(8, 1.0, &mut rng).expect("ok");
665 assert_eq!(m.masked_ids.len(), 8);
666 assert_eq!(m.visible_ids.len(), 0);
667 assert_eq!(m.masked_ids, (0..8).collect::<Vec<_>>());
668 }
669
670 #[test]
671 fn mask_deterministic_same_seed() {
672 let mut a = LcgRng::new(42);
673 let mut b = LcgRng::new(42);
674 let ma = generate_random_mask(64, 0.75, &mut a).expect("ok");
675 let mb = generate_random_mask(64, 0.75, &mut b).expect("ok");
676 assert_eq!(ma, mb);
677 }
678
679 #[test]
680 fn mask_sorted_ascending() {
681 let mut rng = LcgRng::new(6);
682 let m = generate_random_mask(50, 0.6, &mut rng).expect("ok");
683 for w in m.visible_ids.windows(2) {
684 assert!(w[0] < w[1]);
685 }
686 for w in m.masked_ids.windows(2) {
687 assert!(w[0] < w[1]);
688 }
689 }
690
691 #[test]
692 fn mask_invalid_ratio_errors() {
693 let mut rng = LcgRng::new(7);
694 assert!(generate_random_mask(8, -0.1, &mut rng).is_err());
695 assert!(generate_random_mask(8, 1.5, &mut rng).is_err());
696 assert!(generate_random_mask(8, f32::NAN, &mut rng).is_err());
697 }
698
699 #[test]
700 fn mask_n_patches_zero_errors() {
701 let mut rng = LcgRng::new(8);
702 let r = generate_random_mask(0, 0.5, &mut rng);
703 assert!(matches!(r, Err(VisionError::EmptyInput(_))));
704 }
705
706 #[test]
709 fn cfg_patch_not_divisible_errors() {
710 let r = MaeConfig::new(7, 4, 3, 16, 1, 4, 8, 1, 4, 2, 0.5);
711 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
712 }
713
714 #[test]
715 fn cfg_zero_channels_errors() {
716 let r = MaeConfig::new(8, 4, 0, 16, 1, 4, 8, 1, 4, 2, 0.5);
717 assert!(r.is_err());
718 }
719
720 #[test]
721 fn cfg_zero_encoder_dim_errors() {
722 let r = MaeConfig::new(8, 4, 3, 0, 1, 4, 8, 1, 4, 2, 0.5);
723 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
724 }
725
726 #[test]
727 fn cfg_zero_decoder_dim_errors() {
728 let r = MaeConfig::new(8, 4, 3, 16, 1, 4, 0, 1, 4, 2, 0.5);
729 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
730 }
731
732 #[test]
733 fn cfg_zero_depth_errors() {
734 let r1 = MaeConfig::new(8, 4, 3, 16, 0, 4, 8, 1, 4, 2, 0.5);
735 let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 0, 4, 2, 0.5);
736 assert!(r1.is_err());
737 assert!(r2.is_err());
738 }
739
740 #[test]
741 fn cfg_mask_ratio_out_of_range_errors() {
742 let r1 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, -0.1);
743 let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, 1.5);
744 assert!(r1.is_err());
745 assert!(r2.is_err());
746 }
747
748 #[test]
749 fn cfg_n_patches_and_pixels() {
750 let cfg = make_tiny_cfg();
751 assert_eq!(cfg.n_patches(), 4);
752 assert_eq!(cfg.patch_pixels(), 4 * 4 * 3);
753 }
754
755 #[test]
758 fn encode_shape() {
759 let cfg = make_medium_cfg();
760 let mut rng = LcgRng::new(11);
761 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
762 let n_patches = cfg.n_patches();
763 let pp = cfg.patch_pixels();
764 let edim = cfg.encoder_dim;
765 let patches = vec![0.1f32; n_patches * pp];
766 let mut rng2 = LcgRng::new(99);
767 let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
768 assert_eq!(enc.len(), mask.visible_ids.len() * edim);
769 }
770
771 #[test]
772 fn decode_shape_matches_patches() {
773 let cfg = make_medium_cfg();
774 let mut rng = LcgRng::new(13);
775 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
776 let n_patches = cfg.n_patches();
777 let pp = cfg.patch_pixels();
778 let patches = vec![0.1f32; n_patches * pp];
779 let mut rng2 = LcgRng::new(101);
780 let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
781 let recon = mae.decode(&enc, &mask).expect("ok");
782 assert_eq!(recon.len(), n_patches * pp);
783 }
784
785 #[test]
786 fn full_pipeline_deterministic_same_seed() {
787 let cfg = make_medium_cfg();
788 let mut rng_a = LcgRng::new(33);
789 let mae_a = Mae::new(cfg.clone(), &mut rng_a).expect("ok");
790 let mut rng_b = LcgRng::new(33);
791 let mae_b = Mae::new(cfg.clone(), &mut rng_b).expect("ok");
792
793 let n_patches = cfg.n_patches();
794 let pp = cfg.patch_pixels();
795 let mut patches = vec![0.0f32; n_patches * pp];
796 let mut rin = LcgRng::new(5);
797 for v in patches.iter_mut() {
798 *v = (rin.next_u32() as f32) / 4_294_967_296.0;
799 }
800
801 let mut r_a = LcgRng::new(77);
802 let mut r_b = LcgRng::new(77);
803 let (ea, ma) = mae_a.encode(&patches, &mut r_a).expect("ok");
804 let (eb, mb) = mae_b.encode(&patches, &mut r_b).expect("ok");
805 assert_eq!(ma, mb);
806 for (a, b) in ea.iter().zip(eb.iter()) {
807 assert!((a - b).abs() < 1e-6, "encode differs: {a} vs {b}");
808 }
809 let recon_a = mae_a.decode(&ea, &ma).expect("ok");
810 let recon_b = mae_b.decode(&eb, &mb).expect("ok");
811 for (a, b) in recon_a.iter().zip(recon_b.iter()) {
812 assert!((a - b).abs() < 1e-6, "decode differs: {a} vs {b}");
813 }
814 }
815
816 #[test]
819 fn encode_dimension_mismatch_errors() {
820 let cfg = make_tiny_cfg();
821 let mut rng = LcgRng::new(15);
822 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
823 let pp = cfg.patch_pixels();
825 let patches = vec![0.0f32; (cfg.n_patches() - 1) * pp];
826 let mut rng2 = LcgRng::new(16);
827 let r = mae.encode(&patches, &mut rng2);
828 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
829 }
830
831 #[test]
832 fn decode_wrong_visible_length_errors() {
833 let cfg = make_tiny_cfg();
834 let mut rng = LcgRng::new(17);
835 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
836 let mut rng_m = LcgRng::new(18);
838 let mask = generate_random_mask(cfg.n_patches(), 0.5, &mut rng_m).expect("ok");
839 let wrong = vec![0.0f32; mask.visible_ids.len() * cfg.encoder_dim - 1];
840 let r = mae.decode(&wrong, &mask);
841 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
842 }
843
844 #[test]
847 fn loss_zero_when_match_on_masked_positions() {
848 let mask = MaskMeta {
850 visible_ids: vec![0, 2],
851 masked_ids: vec![1, 3],
852 };
853 let pp = 5;
854 let n_patches = 4;
855 let mut gt = vec![0.0f32; n_patches * pp];
856 let mut recon = vec![0.0f32; n_patches * pp];
857 for (i, g) in gt.iter_mut().enumerate() {
858 *g = i as f32;
859 }
860 for &mi in &mask.masked_ids {
862 for k in 0..pp {
863 recon[mi * pp + k] = gt[mi * pp + k];
864 }
865 }
866 for k in 0..pp {
868 recon[k] = 999.0;
869 recon[2 * pp + k] = -777.0;
870 }
871 let loss = mae_loss(&recon, >, &mask).expect("ok");
872 assert!(
873 loss.abs() < 1e-6,
874 "loss should be 0 when masked match: {loss}"
875 );
876 }
877
878 #[test]
879 fn loss_independent_of_visible_positions() {
880 let mask = MaskMeta {
881 visible_ids: vec![0, 2],
882 masked_ids: vec![1, 3],
883 };
884 let pp = 3;
885 let n_patches = 4;
886 let mut gt = vec![0.0f32; n_patches * pp];
887 let mut recon_a = vec![0.0f32; n_patches * pp];
888 let mut recon_b = vec![0.0f32; n_patches * pp];
889 for i in 0..n_patches * pp {
890 gt[i] = (i as f32) * 0.1;
891 recon_a[i] = gt[i] + 0.5; recon_b[i] = gt[i] + 0.5;
893 }
894 for &vi in &mask.visible_ids {
896 for k in 0..pp {
897 recon_b[vi * pp + k] = 1234.0;
898 }
899 }
900 let la = mae_loss(&recon_a, >, &mask).expect("ok");
901 let lb = mae_loss(&recon_b, >, &mask).expect("ok");
902 assert!(
903 (la - lb).abs() < 1e-6,
904 "loss depends on visible: {la} vs {lb}"
905 );
906 }
907
908 #[test]
909 fn loss_dimension_mismatch_errors() {
910 let mask = MaskMeta {
911 visible_ids: vec![0],
912 masked_ids: vec![1],
913 };
914 let r = mae_loss(&[0.0; 4], &[0.0; 5], &mask);
915 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
916 }
917
918 #[test]
919 fn loss_mask_ratio_zero_returns_zero() {
920 let mut rng = LcgRng::new(21);
921 let m = generate_random_mask(6, 0.0, &mut rng).expect("ok");
922 let r = vec![1.0f32; 6 * 4];
923 let g = vec![2.0f32; 6 * 4];
924 let l = mae_loss(&r, &g, &m).expect("ok");
925 assert!(l.abs() < 1e-6, "no masked → loss = 0; got {l}");
926 }
927
928 #[test]
929 fn loss_positive_when_recon_off() {
930 let mask = MaskMeta {
931 visible_ids: vec![0],
932 masked_ids: vec![1],
933 };
934 let pp = 4;
935 let gt = vec![0.0f32; 2 * pp];
936 let mut recon = vec![0.0f32; 2 * pp];
937 for k in 0..pp {
938 recon[pp + k] = 1.0; }
940 let l = mae_loss(&recon, >, &mask).expect("ok");
941 assert!((l - 1.0).abs() < 1e-6, "expected MSE=1, got {l}");
942 }
943
944 #[test]
947 fn encode_decode_finite() {
948 let cfg = make_medium_cfg();
949 let mut rng = LcgRng::new(45);
950 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
951 let n_patches = cfg.n_patches();
952 let pp = cfg.patch_pixels();
953 let mut patches = vec![0.0f32; n_patches * pp];
954 let mut rin = LcgRng::new(55);
955 rin.fill_normal(&mut patches);
956 let mut r2 = LcgRng::new(66);
957 let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
958 assert!(enc.iter().all(|v| v.is_finite()));
959 let recon = mae.decode(&enc, &mask).expect("ok");
960 assert!(recon.iter().all(|v| v.is_finite()));
961 }
962
963 #[test]
966 fn identity_decoder_reconstructs_mask_token_at_masked() {
967 let mut cfg = MaeConfig::new(2, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.5).expect("ok");
988 cfg.mask_ratio = 0.5;
989 let mut rng = LcgRng::new(123);
990 let mut mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
991
992 for block in mae.decoder_blocks.iter_mut() {
994 for v in block.weights.qkv_weight.iter_mut() {
995 *v = 0.0;
996 }
997 for v in block.weights.qkv_bias.iter_mut() {
998 *v = 0.0;
999 }
1000 for v in block.weights.out_weight.iter_mut() {
1001 *v = 0.0;
1002 }
1003 for v in block.weights.out_bias.iter_mut() {
1004 *v = 0.0;
1005 }
1006 for v in block.weights.mlp1_weight.iter_mut() {
1007 *v = 0.0;
1008 }
1009 for v in block.weights.mlp1_bias.iter_mut() {
1010 *v = 0.0;
1011 }
1012 for v in block.weights.mlp2_weight.iter_mut() {
1013 *v = 0.0;
1014 }
1015 for v in block.weights.mlp2_bias.iter_mut() {
1016 *v = 0.0;
1017 }
1018 }
1020 for v in mae.decoder_pos_embed.iter_mut() {
1023 *v = 0.0;
1024 }
1025 for v in mae.decoder_pred_weights.iter_mut() {
1027 *v = 0.0;
1028 }
1029 for i in 0..4 {
1030 mae.decoder_pred_weights[i * 4 + i] = 1.0;
1031 }
1032 for v in mae.decoder_pred_bias.iter_mut() {
1033 *v = 0.0;
1034 }
1035 mae.mask_token = vec![0.1, -0.2, 0.3, -0.4];
1037
1038 let n_patches = cfg.n_patches();
1041 let pp = cfg.patch_pixels();
1042
1043 let patches_a = vec![1.0f32; n_patches * pp];
1044 let mut patches_b = vec![1.0f32; n_patches * pp];
1045 for v in patches_b.iter_mut() {
1047 *v = 7.7;
1048 }
1049
1050 let mut r_a = LcgRng::new(2024);
1051 let mut r_b = LcgRng::new(2024);
1052 let (enc_a, ma) = mae.encode(&patches_a, &mut r_a).expect("ok");
1053 let (enc_b, mb) = mae.encode(&patches_b, &mut r_b).expect("ok");
1054 assert_eq!(ma, mb, "same RNG seed must produce same mask");
1055
1056 let recon_a = mae.decode(&enc_a, &ma).expect("ok");
1057 let recon_b = mae.decode(&enc_b, &mb).expect("ok");
1058 let mean = (0.1f32 + (-0.2) + 0.3 + (-0.4)) / 4.0;
1067 let centred = [0.1f32 - mean, -0.2 - mean, 0.3 - mean, -0.4 - mean];
1068 let var = centred.iter().map(|c| c * c).sum::<f32>() / 4.0;
1069 let inv_std = 1.0 / (var + 1e-5).sqrt();
1070 let expected_at_mask: Vec<f32> = centred.iter().map(|c| c * inv_std).collect();
1071
1072 for &mi in &ma.masked_ids {
1073 for k in 0..pp {
1074 let a = recon_a[mi * pp + k];
1075 let b = recon_b[mi * pp + k];
1076 assert!(
1077 (a - b).abs() < 1e-5,
1078 "masked pos {mi} k={k}: a={a} b={b} (depends on visible!)"
1079 );
1080 let exp = expected_at_mask[k];
1081 assert!(
1082 (a - exp).abs() < 1e-4,
1083 "masked pos {mi} k={k}: got {a} expected {exp}"
1084 );
1085 }
1086 }
1087 }
1088
1089 #[test]
1092 fn mask_full_ratio_encoder_skipped() {
1093 let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 1.0).expect("ok");
1096 let mut rng = LcgRng::new(31);
1097 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
1098 let pp = cfg.patch_pixels();
1099 let n = cfg.n_patches();
1100 let patches = vec![0.0f32; n * pp];
1101 let mut r2 = LcgRng::new(32);
1102 let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
1103 assert_eq!(enc.len(), 0);
1104 assert_eq!(mask.masked_ids.len(), n);
1105 let recon = mae.decode(&enc, &mask).expect("ok");
1106 assert_eq!(recon.len(), n * pp);
1107 assert!(recon.iter().all(|v| v.is_finite()));
1108 }
1109
1110 #[test]
1111 fn mask_zero_ratio_full_encoder() {
1112 let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.0).expect("ok");
1113 let mut rng = LcgRng::new(41);
1114 let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
1115 let pp = cfg.patch_pixels();
1116 let n = cfg.n_patches();
1117 let patches = vec![0.1f32; n * pp];
1118 let mut r2 = LcgRng::new(42);
1119 let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
1120 assert_eq!(mask.masked_ids.len(), 0);
1121 assert_eq!(mask.visible_ids.len(), n);
1122 assert_eq!(enc.len(), n * cfg.encoder_dim);
1123 }
1124}