use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
vit::{
vit_block::{ViTBlock, ViTBlockConfig, layer_norm, linear},
vit_encoder::ViTEncoderConfig,
},
};
#[derive(Debug, Clone, PartialEq)]
pub struct MaeConfig {
pub img_size: usize,
pub patch_size: usize,
pub in_channels: usize,
pub encoder_dim: usize,
pub encoder_depth: usize,
pub encoder_heads: usize,
pub decoder_dim: usize,
pub decoder_depth: usize,
pub decoder_heads: usize,
pub mlp_ratio: usize,
pub mask_ratio: f32,
}
impl MaeConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
img_size: usize,
patch_size: usize,
in_channels: usize,
encoder_dim: usize,
encoder_depth: usize,
encoder_heads: usize,
decoder_dim: usize,
decoder_depth: usize,
decoder_heads: usize,
mlp_ratio: usize,
mask_ratio: f32,
) -> VisionResult<Self> {
if patch_size == 0 || img_size == 0 || img_size % patch_size != 0 {
return Err(VisionError::InvalidPatchSize {
patch_size,
img_size,
});
}
if in_channels == 0 {
return Err(VisionError::EmptyInput("in_channels"));
}
if encoder_dim == 0 {
return Err(VisionError::InvalidEmbedDim(encoder_dim));
}
if decoder_dim == 0 {
return Err(VisionError::InvalidEmbedDim(decoder_dim));
}
if encoder_depth == 0 {
return Err(VisionError::Internal("encoder_depth must be > 0".into()));
}
if decoder_depth == 0 {
return Err(VisionError::Internal("decoder_depth must be > 0".into()));
}
if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
return Err(VisionError::Internal(format!(
"mask_ratio {mask_ratio} not in [0, 1]"
)));
}
let _ = ViTBlockConfig::new(encoder_dim, encoder_heads, mlp_ratio)?;
let _ = ViTBlockConfig::new(decoder_dim, decoder_heads, mlp_ratio)?;
Ok(Self {
img_size,
patch_size,
in_channels,
encoder_dim,
encoder_depth,
encoder_heads,
decoder_dim,
decoder_depth,
decoder_heads,
mlp_ratio,
mask_ratio,
})
}
#[must_use]
pub fn n_patches(&self) -> usize {
let grid = self.img_size / self.patch_size;
grid * grid
}
#[must_use]
pub fn patch_pixels(&self) -> usize {
self.patch_size * self.patch_size * self.in_channels
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MaskMeta {
pub visible_ids: Vec<usize>,
pub masked_ids: Vec<usize>,
}
pub fn generate_random_mask(
n_patches: usize,
mask_ratio: f32,
rng: &mut LcgRng,
) -> VisionResult<MaskMeta> {
if n_patches == 0 {
return Err(VisionError::EmptyInput("n_patches"));
}
if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
return Err(VisionError::Internal(format!(
"mask_ratio {mask_ratio} not in [0, 1]"
)));
}
let mut ids: Vec<usize> = (0..n_patches).collect();
let n_masked = (mask_ratio * (n_patches as f32)).round() as usize;
let n_masked = n_masked.min(n_patches);
for i in 0..n_masked {
let remaining = n_patches - i;
let j = i + rng.next_usize(remaining);
ids.swap(i, j);
}
let mut masked_ids: Vec<usize> = ids[..n_masked].to_vec();
let mut visible_ids: Vec<usize> = ids[n_masked..].to_vec();
masked_ids.sort_unstable();
visible_ids.sort_unstable();
Ok(MaskMeta {
visible_ids,
masked_ids,
})
}
pub struct Mae {
pub config: MaeConfig,
pub patch_embed_weights: Vec<f32>,
pub patch_embed_bias: Vec<f32>,
pub encoder_pos_embed: Vec<f32>,
pub encoder_blocks: Vec<ViTBlock>,
pub encoder_norm_gamma: Vec<f32>,
pub encoder_norm_beta: Vec<f32>,
pub decoder_embed_weights: Vec<f32>,
pub decoder_embed_bias: Vec<f32>,
pub mask_token: Vec<f32>,
pub decoder_pos_embed: Vec<f32>,
pub decoder_blocks: Vec<ViTBlock>,
pub decoder_norm_gamma: Vec<f32>,
pub decoder_norm_beta: Vec<f32>,
pub decoder_pred_weights: Vec<f32>,
pub decoder_pred_bias: Vec<f32>,
}
#[inline]
fn safe_centered_uniform(rng: &mut LcgRng) -> f32 {
(rng.next_u32() as f32) / 2_147_483_648.0 - 0.5
}
fn fill_centered_uniform(buf: &mut [f32], scale: f32, rng: &mut LcgRng) {
for v in buf.iter_mut() {
*v = safe_centered_uniform(rng) * 2.0 * scale;
}
}
impl Mae {
pub fn new(cfg: MaeConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let edim = cfg.encoder_dim;
let ddim = cfg.decoder_dim;
let enc_scale = 1.0 / (pp as f32).sqrt();
let mut patch_embed_weights = vec![0.0f32; edim * pp];
fill_centered_uniform(&mut patch_embed_weights, enc_scale, rng);
let patch_embed_bias = vec![0.0f32; edim];
let pos_scale = 0.02f32; let mut encoder_pos_embed = vec![0.0f32; n_patches * edim];
fill_centered_uniform(&mut encoder_pos_embed, pos_scale, rng);
let enc_block_cfg =
ViTEncoderConfig::new(edim, cfg.encoder_heads, cfg.mlp_ratio, cfg.encoder_depth)?;
let mut encoder_blocks = Vec::with_capacity(cfg.encoder_depth);
for _ in 0..cfg.encoder_depth {
encoder_blocks.push(ViTBlock::new(enc_block_cfg.block_cfg.clone(), rng));
}
let encoder_norm_gamma = vec![1.0f32; edim];
let encoder_norm_beta = vec![0.0f32; edim];
let dec_in_scale = 1.0 / (edim as f32).sqrt();
let mut decoder_embed_weights = vec![0.0f32; ddim * edim];
fill_centered_uniform(&mut decoder_embed_weights, dec_in_scale, rng);
let decoder_embed_bias = vec![0.0f32; ddim];
let mut mask_token = vec![0.0f32; ddim];
fill_centered_uniform(&mut mask_token, pos_scale, rng);
let mut decoder_pos_embed = vec![0.0f32; n_patches * ddim];
fill_centered_uniform(&mut decoder_pos_embed, pos_scale, rng);
let dec_block_cfg =
ViTEncoderConfig::new(ddim, cfg.decoder_heads, cfg.mlp_ratio, cfg.decoder_depth)?;
let mut decoder_blocks = Vec::with_capacity(cfg.decoder_depth);
for _ in 0..cfg.decoder_depth {
decoder_blocks.push(ViTBlock::new(dec_block_cfg.block_cfg.clone(), rng));
}
let decoder_norm_gamma = vec![1.0f32; ddim];
let decoder_norm_beta = vec![0.0f32; ddim];
let pred_scale = 1.0 / (ddim as f32).sqrt();
let mut decoder_pred_weights = vec![0.0f32; pp * ddim];
fill_centered_uniform(&mut decoder_pred_weights, pred_scale, rng);
let decoder_pred_bias = vec![0.0f32; pp];
Ok(Self {
config: cfg,
patch_embed_weights,
patch_embed_bias,
encoder_pos_embed,
encoder_blocks,
encoder_norm_gamma,
encoder_norm_beta,
decoder_embed_weights,
decoder_embed_bias,
mask_token,
decoder_pos_embed,
decoder_blocks,
decoder_norm_gamma,
decoder_norm_beta,
decoder_pred_weights,
decoder_pred_bias,
})
}
pub fn encode(
&self,
image_patches: &[f32],
rng: &mut LcgRng,
) -> VisionResult<(Vec<f32>, MaskMeta)> {
let n_patches = self.config.n_patches();
let pp = self.config.patch_pixels();
let edim = self.config.encoder_dim;
if n_patches == 0 {
return Err(VisionError::EmptyInput("n_patches"));
}
let expected = n_patches * pp;
if image_patches.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: image_patches.len(),
});
}
let mut embedded = linear(
image_patches,
&self.patch_embed_weights,
&self.patch_embed_bias,
pp,
edim,
);
for (i, v) in embedded.iter_mut().enumerate() {
*v += self
.encoder_pos_embed
.get(i)
.copied()
.ok_or(VisionError::Internal(
"encoder_pos_embed shorter than embedded".into(),
))?;
}
let mask_meta = generate_random_mask(n_patches, self.config.mask_ratio, rng)?;
let n_visible = mask_meta.visible_ids.len();
let mut visible_tokens = vec![0.0f32; n_visible * edim];
for (out_i, &src_i) in mask_meta.visible_ids.iter().enumerate() {
let src = embedded
.get(src_i * edim..(src_i + 1) * edim)
.ok_or(VisionError::Internal("visible idx out of range".into()))?;
let dst = visible_tokens
.get_mut(out_i * edim..(out_i + 1) * edim)
.ok_or(VisionError::Internal(
"visible_tokens slice out of range".into(),
))?;
dst.copy_from_slice(src);
}
let encoded = if n_visible == 0 {
Vec::new()
} else {
let mut h = visible_tokens;
for block in &self.encoder_blocks {
h = block.forward(&h, n_visible)?;
}
layer_norm(
&h,
&self.encoder_norm_gamma,
&self.encoder_norm_beta,
n_visible,
edim,
1e-5,
)
};
Ok((encoded, mask_meta))
}
pub fn decode(&self, encoded_visible: &[f32], mask_meta: &MaskMeta) -> VisionResult<Vec<f32>> {
let n_patches = self.config.n_patches();
let edim = self.config.encoder_dim;
let ddim = self.config.decoder_dim;
let pp = self.config.patch_pixels();
let n_visible = mask_meta.visible_ids.len();
let n_masked = mask_meta.masked_ids.len();
if n_visible + n_masked != n_patches {
return Err(VisionError::Internal(
"MaskMeta visible + masked sizes do not sum to n_patches".into(),
));
}
if encoded_visible.len() != n_visible * edim {
return Err(VisionError::DimensionMismatch {
expected: n_visible * edim,
got: encoded_visible.len(),
});
}
let visible_dec = if n_visible == 0 {
Vec::new()
} else {
linear(
encoded_visible,
&self.decoder_embed_weights,
&self.decoder_embed_bias,
edim,
ddim,
)
};
let mut full = vec![0.0f32; n_patches * ddim];
for (vis_i, &dst_i) in mask_meta.visible_ids.iter().enumerate() {
let src = visible_dec
.get(vis_i * ddim..(vis_i + 1) * ddim)
.ok_or(VisionError::Internal("visible_dec slice".into()))?;
let dst = full
.get_mut(dst_i * ddim..(dst_i + 1) * ddim)
.ok_or(VisionError::Internal("full slice (visible)".into()))?;
dst.copy_from_slice(src);
}
for &dst_i in &mask_meta.masked_ids {
let dst = full
.get_mut(dst_i * ddim..(dst_i + 1) * ddim)
.ok_or(VisionError::Internal("full slice (masked)".into()))?;
dst.copy_from_slice(&self.mask_token);
}
for (i, v) in full.iter_mut().enumerate() {
*v += self
.decoder_pos_embed
.get(i)
.copied()
.ok_or(VisionError::Internal("decoder_pos_embed".into()))?;
}
let mut h = full;
for block in &self.decoder_blocks {
h = block.forward(&h, n_patches)?;
}
let post_norm = layer_norm(
&h,
&self.decoder_norm_gamma,
&self.decoder_norm_beta,
n_patches,
ddim,
1e-5,
);
let reconstructed = linear(
&post_norm,
&self.decoder_pred_weights,
&self.decoder_pred_bias,
ddim,
pp,
);
Ok(reconstructed)
}
}
pub fn mae_loss(
reconstructed: &[f32],
ground_truth_patches: &[f32],
mask_meta: &MaskMeta,
) -> VisionResult<f32> {
if reconstructed.len() != ground_truth_patches.len() {
return Err(VisionError::DimensionMismatch {
expected: reconstructed.len(),
got: ground_truth_patches.len(),
});
}
let n_patches = mask_meta.visible_ids.len() + mask_meta.masked_ids.len();
if n_patches == 0 {
return Err(VisionError::EmptyInput("mask_meta n_patches"));
}
if reconstructed.len() % n_patches != 0 {
return Err(VisionError::DimensionMismatch {
expected: n_patches,
got: reconstructed.len(),
});
}
let pp = reconstructed.len() / n_patches;
if mask_meta.masked_ids.is_empty() {
return Ok(0.0);
}
let mut sum_sq = 0.0f64;
let mut count: u64 = 0;
for &mi in &mask_meta.masked_ids {
let r = reconstructed
.get(mi * pp..(mi + 1) * pp)
.ok_or(VisionError::Internal("loss: masked idx".into()))?;
let g = ground_truth_patches
.get(mi * pp..(mi + 1) * pp)
.ok_or(VisionError::Internal("loss: masked idx (gt)".into()))?;
for (rv, gv) in r.iter().zip(g.iter()) {
let d = (*rv - *gv) as f64;
sum_sq += d * d;
count += 1;
}
}
let mean = if count == 0 {
0.0
} else {
sum_sq / (count as f64)
};
Ok(mean as f32)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
fn make_tiny_cfg() -> MaeConfig {
MaeConfig::new(8, 4, 3, 16, 2, 4, 8, 1, 4, 2, 0.5).expect("valid tiny cfg")
}
fn make_medium_cfg() -> MaeConfig {
MaeConfig::new(16, 4, 3, 32, 2, 4, 16, 1, 4, 2, 0.75).expect("valid med cfg")
}
#[test]
fn mask_union_and_disjoint() {
let mut rng = LcgRng::new(1);
let n = 16;
let m = generate_random_mask(n, 0.5, &mut rng).expect("ok");
let v: HashSet<usize> = m.visible_ids.iter().copied().collect();
let k: HashSet<usize> = m.masked_ids.iter().copied().collect();
assert!(v.is_disjoint(&k));
let union: HashSet<usize> = v.union(&k).copied().collect();
let expected: HashSet<usize> = (0..n).collect();
assert_eq!(union, expected);
assert_eq!(v.len() + k.len(), n);
}
#[test]
fn mask_count_matches_round() {
let mut rng = LcgRng::new(2);
let m = generate_random_mask(100, 0.75, &mut rng).expect("ok");
assert_eq!(m.masked_ids.len(), 75);
assert_eq!(m.visible_ids.len(), 25);
}
#[test]
fn mask_count_rounds_correctly() {
let mut rng = LcgRng::new(3);
let m = generate_random_mask(10, 0.7, &mut rng).expect("ok");
assert_eq!(m.masked_ids.len(), 7);
}
#[test]
fn mask_ratio_zero_all_visible() {
let mut rng = LcgRng::new(4);
let m = generate_random_mask(8, 0.0, &mut rng).expect("ok");
assert_eq!(m.masked_ids.len(), 0);
assert_eq!(m.visible_ids.len(), 8);
assert_eq!(m.visible_ids, (0..8).collect::<Vec<_>>());
}
#[test]
fn mask_ratio_one_all_masked() {
let mut rng = LcgRng::new(5);
let m = generate_random_mask(8, 1.0, &mut rng).expect("ok");
assert_eq!(m.masked_ids.len(), 8);
assert_eq!(m.visible_ids.len(), 0);
assert_eq!(m.masked_ids, (0..8).collect::<Vec<_>>());
}
#[test]
fn mask_deterministic_same_seed() {
let mut a = LcgRng::new(42);
let mut b = LcgRng::new(42);
let ma = generate_random_mask(64, 0.75, &mut a).expect("ok");
let mb = generate_random_mask(64, 0.75, &mut b).expect("ok");
assert_eq!(ma, mb);
}
#[test]
fn mask_sorted_ascending() {
let mut rng = LcgRng::new(6);
let m = generate_random_mask(50, 0.6, &mut rng).expect("ok");
for w in m.visible_ids.windows(2) {
assert!(w[0] < w[1]);
}
for w in m.masked_ids.windows(2) {
assert!(w[0] < w[1]);
}
}
#[test]
fn mask_invalid_ratio_errors() {
let mut rng = LcgRng::new(7);
assert!(generate_random_mask(8, -0.1, &mut rng).is_err());
assert!(generate_random_mask(8, 1.5, &mut rng).is_err());
assert!(generate_random_mask(8, f32::NAN, &mut rng).is_err());
}
#[test]
fn mask_n_patches_zero_errors() {
let mut rng = LcgRng::new(8);
let r = generate_random_mask(0, 0.5, &mut rng);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
#[test]
fn cfg_patch_not_divisible_errors() {
let r = MaeConfig::new(7, 4, 3, 16, 1, 4, 8, 1, 4, 2, 0.5);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn cfg_zero_channels_errors() {
let r = MaeConfig::new(8, 4, 0, 16, 1, 4, 8, 1, 4, 2, 0.5);
assert!(r.is_err());
}
#[test]
fn cfg_zero_encoder_dim_errors() {
let r = MaeConfig::new(8, 4, 3, 0, 1, 4, 8, 1, 4, 2, 0.5);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn cfg_zero_decoder_dim_errors() {
let r = MaeConfig::new(8, 4, 3, 16, 1, 4, 0, 1, 4, 2, 0.5);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn cfg_zero_depth_errors() {
let r1 = MaeConfig::new(8, 4, 3, 16, 0, 4, 8, 1, 4, 2, 0.5);
let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 0, 4, 2, 0.5);
assert!(r1.is_err());
assert!(r2.is_err());
}
#[test]
fn cfg_mask_ratio_out_of_range_errors() {
let r1 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, -0.1);
let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, 1.5);
assert!(r1.is_err());
assert!(r2.is_err());
}
#[test]
fn cfg_n_patches_and_pixels() {
let cfg = make_tiny_cfg();
assert_eq!(cfg.n_patches(), 4);
assert_eq!(cfg.patch_pixels(), 4 * 4 * 3);
}
#[test]
fn encode_shape() {
let cfg = make_medium_cfg();
let mut rng = LcgRng::new(11);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let edim = cfg.encoder_dim;
let patches = vec![0.1f32; n_patches * pp];
let mut rng2 = LcgRng::new(99);
let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
assert_eq!(enc.len(), mask.visible_ids.len() * edim);
}
#[test]
fn decode_shape_matches_patches() {
let cfg = make_medium_cfg();
let mut rng = LcgRng::new(13);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let patches = vec![0.1f32; n_patches * pp];
let mut rng2 = LcgRng::new(101);
let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
let recon = mae.decode(&enc, &mask).expect("ok");
assert_eq!(recon.len(), n_patches * pp);
}
#[test]
fn full_pipeline_deterministic_same_seed() {
let cfg = make_medium_cfg();
let mut rng_a = LcgRng::new(33);
let mae_a = Mae::new(cfg.clone(), &mut rng_a).expect("ok");
let mut rng_b = LcgRng::new(33);
let mae_b = Mae::new(cfg.clone(), &mut rng_b).expect("ok");
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let mut patches = vec![0.0f32; n_patches * pp];
let mut rin = LcgRng::new(5);
for v in patches.iter_mut() {
*v = (rin.next_u32() as f32) / 2_147_483_648.0;
}
let mut r_a = LcgRng::new(77);
let mut r_b = LcgRng::new(77);
let (ea, ma) = mae_a.encode(&patches, &mut r_a).expect("ok");
let (eb, mb) = mae_b.encode(&patches, &mut r_b).expect("ok");
assert_eq!(ma, mb);
for (a, b) in ea.iter().zip(eb.iter()) {
assert!((a - b).abs() < 1e-6, "encode differs: {a} vs {b}");
}
let recon_a = mae_a.decode(&ea, &ma).expect("ok");
let recon_b = mae_b.decode(&eb, &mb).expect("ok");
for (a, b) in recon_a.iter().zip(recon_b.iter()) {
assert!((a - b).abs() < 1e-6, "decode differs: {a} vs {b}");
}
}
#[test]
fn encode_dimension_mismatch_errors() {
let cfg = make_tiny_cfg();
let mut rng = LcgRng::new(15);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let pp = cfg.patch_pixels();
let patches = vec![0.0f32; (cfg.n_patches() - 1) * pp];
let mut rng2 = LcgRng::new(16);
let r = mae.encode(&patches, &mut rng2);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn decode_wrong_visible_length_errors() {
let cfg = make_tiny_cfg();
let mut rng = LcgRng::new(17);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let mut rng_m = LcgRng::new(18);
let mask = generate_random_mask(cfg.n_patches(), 0.5, &mut rng_m).expect("ok");
let wrong = vec![0.0f32; mask.visible_ids.len() * cfg.encoder_dim - 1];
let r = mae.decode(&wrong, &mask);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn loss_zero_when_match_on_masked_positions() {
let mask = MaskMeta {
visible_ids: vec![0, 2],
masked_ids: vec![1, 3],
};
let pp = 5;
let n_patches = 4;
let mut gt = vec![0.0f32; n_patches * pp];
let mut recon = vec![0.0f32; n_patches * pp];
for (i, g) in gt.iter_mut().enumerate() {
*g = i as f32;
}
for &mi in &mask.masked_ids {
for k in 0..pp {
recon[mi * pp + k] = gt[mi * pp + k];
}
}
for k in 0..pp {
recon[k] = 999.0;
recon[2 * pp + k] = -777.0;
}
let loss = mae_loss(&recon, >, &mask).expect("ok");
assert!(
loss.abs() < 1e-6,
"loss should be 0 when masked match: {loss}"
);
}
#[test]
fn loss_independent_of_visible_positions() {
let mask = MaskMeta {
visible_ids: vec![0, 2],
masked_ids: vec![1, 3],
};
let pp = 3;
let n_patches = 4;
let mut gt = vec![0.0f32; n_patches * pp];
let mut recon_a = vec![0.0f32; n_patches * pp];
let mut recon_b = vec![0.0f32; n_patches * pp];
for i in 0..n_patches * pp {
gt[i] = (i as f32) * 0.1;
recon_a[i] = gt[i] + 0.5; recon_b[i] = gt[i] + 0.5;
}
for &vi in &mask.visible_ids {
for k in 0..pp {
recon_b[vi * pp + k] = 1234.0;
}
}
let la = mae_loss(&recon_a, >, &mask).expect("ok");
let lb = mae_loss(&recon_b, >, &mask).expect("ok");
assert!(
(la - lb).abs() < 1e-6,
"loss depends on visible: {la} vs {lb}"
);
}
#[test]
fn loss_dimension_mismatch_errors() {
let mask = MaskMeta {
visible_ids: vec![0],
masked_ids: vec![1],
};
let r = mae_loss(&[0.0; 4], &[0.0; 5], &mask);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn loss_mask_ratio_zero_returns_zero() {
let mut rng = LcgRng::new(21);
let m = generate_random_mask(6, 0.0, &mut rng).expect("ok");
let r = vec![1.0f32; 6 * 4];
let g = vec![2.0f32; 6 * 4];
let l = mae_loss(&r, &g, &m).expect("ok");
assert!(l.abs() < 1e-6, "no masked → loss = 0; got {l}");
}
#[test]
fn loss_positive_when_recon_off() {
let mask = MaskMeta {
visible_ids: vec![0],
masked_ids: vec![1],
};
let pp = 4;
let gt = vec![0.0f32; 2 * pp];
let mut recon = vec![0.0f32; 2 * pp];
for k in 0..pp {
recon[pp + k] = 1.0; }
let l = mae_loss(&recon, >, &mask).expect("ok");
assert!((l - 1.0).abs() < 1e-6, "expected MSE=1, got {l}");
}
#[test]
fn encode_decode_finite() {
let cfg = make_medium_cfg();
let mut rng = LcgRng::new(45);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let mut patches = vec![0.0f32; n_patches * pp];
let mut rin = LcgRng::new(55);
rin.fill_normal(&mut patches);
let mut r2 = LcgRng::new(66);
let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
assert!(enc.iter().all(|v| v.is_finite()));
let recon = mae.decode(&enc, &mask).expect("ok");
assert!(recon.iter().all(|v| v.is_finite()));
}
#[test]
fn identity_decoder_reconstructs_mask_token_at_masked() {
let mut cfg = MaeConfig::new(2, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.5).expect("ok");
cfg.mask_ratio = 0.5;
let mut rng = LcgRng::new(123);
let mut mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
for block in mae.decoder_blocks.iter_mut() {
for v in block.weights.qkv_weight.iter_mut() {
*v = 0.0;
}
for v in block.weights.qkv_bias.iter_mut() {
*v = 0.0;
}
for v in block.weights.out_weight.iter_mut() {
*v = 0.0;
}
for v in block.weights.out_bias.iter_mut() {
*v = 0.0;
}
for v in block.weights.mlp1_weight.iter_mut() {
*v = 0.0;
}
for v in block.weights.mlp1_bias.iter_mut() {
*v = 0.0;
}
for v in block.weights.mlp2_weight.iter_mut() {
*v = 0.0;
}
for v in block.weights.mlp2_bias.iter_mut() {
*v = 0.0;
}
}
for v in mae.decoder_pos_embed.iter_mut() {
*v = 0.0;
}
for v in mae.decoder_pred_weights.iter_mut() {
*v = 0.0;
}
for i in 0..4 {
mae.decoder_pred_weights[i * 4 + i] = 1.0;
}
for v in mae.decoder_pred_bias.iter_mut() {
*v = 0.0;
}
mae.mask_token = vec![0.1, -0.2, 0.3, -0.4];
let n_patches = cfg.n_patches();
let pp = cfg.patch_pixels();
let patches_a = vec![1.0f32; n_patches * pp];
let mut patches_b = vec![1.0f32; n_patches * pp];
for v in patches_b.iter_mut() {
*v = 7.7;
}
let mut r_a = LcgRng::new(2024);
let mut r_b = LcgRng::new(2024);
let (enc_a, ma) = mae.encode(&patches_a, &mut r_a).expect("ok");
let (enc_b, mb) = mae.encode(&patches_b, &mut r_b).expect("ok");
assert_eq!(ma, mb, "same RNG seed must produce same mask");
let recon_a = mae.decode(&enc_a, &ma).expect("ok");
let recon_b = mae.decode(&enc_b, &mb).expect("ok");
let mean = (0.1f32 + (-0.2) + 0.3 + (-0.4)) / 4.0;
let centred = [0.1f32 - mean, -0.2 - mean, 0.3 - mean, -0.4 - mean];
let var = centred.iter().map(|c| c * c).sum::<f32>() / 4.0;
let inv_std = 1.0 / (var + 1e-5).sqrt();
let expected_at_mask: Vec<f32> = centred.iter().map(|c| c * inv_std).collect();
for &mi in &ma.masked_ids {
for k in 0..pp {
let a = recon_a[mi * pp + k];
let b = recon_b[mi * pp + k];
assert!(
(a - b).abs() < 1e-5,
"masked pos {mi} k={k}: a={a} b={b} (depends on visible!)"
);
let exp = expected_at_mask[k];
assert!(
(a - exp).abs() < 1e-4,
"masked pos {mi} k={k}: got {a} expected {exp}"
);
}
}
}
#[test]
fn mask_full_ratio_encoder_skipped() {
let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 1.0).expect("ok");
let mut rng = LcgRng::new(31);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let pp = cfg.patch_pixels();
let n = cfg.n_patches();
let patches = vec![0.0f32; n * pp];
let mut r2 = LcgRng::new(32);
let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
assert_eq!(enc.len(), 0);
assert_eq!(mask.masked_ids.len(), n);
let recon = mae.decode(&enc, &mask).expect("ok");
assert_eq!(recon.len(), n * pp);
assert!(recon.iter().all(|v| v.is_finite()));
}
#[test]
fn mask_zero_ratio_full_encoder() {
let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.0).expect("ok");
let mut rng = LcgRng::new(41);
let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
let pp = cfg.patch_pixels();
let n = cfg.n_patches();
let patches = vec![0.1f32; n * pp];
let mut r2 = LcgRng::new(42);
let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
assert_eq!(mask.masked_ids.len(), 0);
assert_eq!(mask.visible_ids.len(), n);
assert_eq!(enc.len(), n * cfg.encoder_dim);
}
}