Skip to main content

oxicuda_vision/vit/
mae.rs

1//! MAE — Masked Autoencoder (He et al. 2022 CVPR
2//! "Masked Autoencoders Are Scalable Vision Learners").
3//!
4//! Implements the canonical self-supervised vision pre-training recipe:
5//!
6//! 1. Split the image into a flat sequence of `n_patches` patches of length
7//!    `patch_pixels = patch_size² · in_channels`.
8//! 2. Apply a learnable linear `patch_embed` projection to every patch and
9//!    add an `encoder_pos_embed`.
10//! 3. Randomly select a subset of `n_visible = n_patches − n_masked` indices
11//!    (via partial Fisher–Yates) and keep only those tokens.
12//! 4. Run the kept tokens through a deep `ViTBlock` stack (the encoder).
13//! 5. Project encoded visible tokens to `decoder_dim`, scatter them back to
14//!    their original positions, fill masked positions with a learnable
15//!    `mask_token`, add a separate `decoder_pos_embed`, and run a (typically
16//!    shallower) ViTBlock stack (the decoder).
17//! 6. Project the decoder output to per-patch pixel space via
18//!    `decoder_pred`.
19//! 7. Loss = MEAN squared error **only over masked positions** (the canonical
20//!    MAE objective — visible reconstructions do not contribute).
21//!
22//! ## RNG safety
23//! The crate's `LcgRng::next_f32()` is biased — its output spans only
24//! `[0, ~0.5)` because `next_u32()` returns the high 31 bits. We therefore
25//! never call `next_f32() < mask_ratio` for masking. The random mask is built
26//! with a partial Fisher–Yates shuffle driven by `next_usize(n)`, which gives
27//! an exact (deterministic, count-correct) selection of `round(mask_ratio · n)`
28//! masked indices. Weight initialisation uses
29//! `(next_u32() as f32) / 4_294_967_296.0 − 0.5` (the genuine `[-0.5, 0.5)`
30//! recipe).
31
32use crate::{
33    error::{VisionError, VisionResult},
34    handle::LcgRng,
35    vit::{
36        vit_block::{ViTBlock, ViTBlockConfig, layer_norm, linear},
37        vit_encoder::ViTEncoderConfig,
38    },
39};
40
41// ─── Config ──────────────────────────────────────────────────────────────────
42
43/// Configuration for an MAE (Masked Autoencoder) model.
44#[derive(Debug, Clone, PartialEq)]
45pub struct MaeConfig {
46    /// Square spatial resolution of the input image (H = W).
47    pub img_size: usize,
48    /// Patch size (must divide `img_size`).
49    pub patch_size: usize,
50    /// Number of input channels (e.g. 3 for RGB).
51    pub in_channels: usize,
52    /// Encoder embedding dimension.
53    pub encoder_dim: usize,
54    /// Number of encoder transformer blocks.
55    pub encoder_depth: usize,
56    /// Number of encoder attention heads.
57    pub encoder_heads: usize,
58    /// Decoder embedding dimension (usually smaller than `encoder_dim`).
59    pub decoder_dim: usize,
60    /// Number of decoder transformer blocks.
61    pub decoder_depth: usize,
62    /// Number of decoder attention heads.
63    pub decoder_heads: usize,
64    /// MLP hidden-dim multiplier (shared by encoder + decoder blocks).
65    pub mlp_ratio: usize,
66    /// Fraction of patches to mask, in `[0, 1]` (e.g. 0.75).
67    pub mask_ratio: f32,
68}
69
70impl MaeConfig {
71    /// Build and validate a new `MaeConfig`.
72    ///
73    /// # Errors
74    /// - `img_size % patch_size != 0` → `InvalidPatchSize`
75    /// - any zero-sized dimension → `InvalidEmbedDim` / `EmptyInput`
76    /// - `mask_ratio` outside `[0, 1]` → `Internal`
77    /// - `encoder_dim % encoder_heads != 0` or
78    ///   `decoder_dim % decoder_heads != 0` → `HeadDimMismatch`
79    #[allow(clippy::too_many_arguments)]
80    pub fn new(
81        img_size: usize,
82        patch_size: usize,
83        in_channels: usize,
84        encoder_dim: usize,
85        encoder_depth: usize,
86        encoder_heads: usize,
87        decoder_dim: usize,
88        decoder_depth: usize,
89        decoder_heads: usize,
90        mlp_ratio: usize,
91        mask_ratio: f32,
92    ) -> VisionResult<Self> {
93        if patch_size == 0 || img_size == 0 || img_size % patch_size != 0 {
94            return Err(VisionError::InvalidPatchSize {
95                patch_size,
96                img_size,
97            });
98        }
99        if in_channels == 0 {
100            return Err(VisionError::EmptyInput("in_channels"));
101        }
102        if encoder_dim == 0 {
103            return Err(VisionError::InvalidEmbedDim(encoder_dim));
104        }
105        if decoder_dim == 0 {
106            return Err(VisionError::InvalidEmbedDim(decoder_dim));
107        }
108        if encoder_depth == 0 {
109            return Err(VisionError::Internal("encoder_depth must be > 0".into()));
110        }
111        if decoder_depth == 0 {
112            return Err(VisionError::Internal("decoder_depth must be > 0".into()));
113        }
114        if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
115            return Err(VisionError::Internal(format!(
116                "mask_ratio {mask_ratio} not in [0, 1]"
117            )));
118        }
119        // Block configs validate head divisibility:
120        let _ = ViTBlockConfig::new(encoder_dim, encoder_heads, mlp_ratio)?;
121        let _ = ViTBlockConfig::new(decoder_dim, decoder_heads, mlp_ratio)?;
122        Ok(Self {
123            img_size,
124            patch_size,
125            in_channels,
126            encoder_dim,
127            encoder_depth,
128            encoder_heads,
129            decoder_dim,
130            decoder_depth,
131            decoder_heads,
132            mlp_ratio,
133            mask_ratio,
134        })
135    }
136
137    /// Number of non-overlapping image patches.
138    #[must_use]
139    pub fn n_patches(&self) -> usize {
140        let grid = self.img_size / self.patch_size;
141        grid * grid
142    }
143
144    /// Per-patch flat pixel count (`patch_size² · in_channels`).
145    #[must_use]
146    pub fn patch_pixels(&self) -> usize {
147        self.patch_size * self.patch_size * self.in_channels
148    }
149}
150
151// ─── MaskMeta ────────────────────────────────────────────────────────────────
152
153/// Bookkeeping for a random mask over `n_patches` positions.
154///
155/// Both vectors are sorted ascending for deterministic downstream use, and
156/// `visible_ids ∪ masked_ids = 0..n_patches` (disjoint).
157#[derive(Debug, Clone, PartialEq, Eq)]
158pub struct MaskMeta {
159    /// Sorted indices of patches kept (encoder sees these).
160    pub visible_ids: Vec<usize>,
161    /// Sorted indices of patches replaced by the mask token.
162    pub masked_ids: Vec<usize>,
163}
164
165// ─── generate_random_mask ────────────────────────────────────────────────────
166
167/// Build a deterministic random mask over `n_patches` positions.
168///
169/// Uses a **partial Fisher–Yates shuffle** — the only RNG-safe technique with
170/// this crate's biased `LcgRng::next_f32` (see module-level doc-comment).
171///
172/// The shuffle picks exactly `n_masked = round(mask_ratio · n_patches)` masked
173/// indices regardless of `mask_ratio`'s decimal value; the returned counts are
174/// not stochastic.
175///
176/// # Errors
177/// - `n_patches == 0` → `EmptyInput`
178/// - `mask_ratio ∉ [0, 1]` or NaN/Inf → `Internal`
179pub fn generate_random_mask(
180    n_patches: usize,
181    mask_ratio: f32,
182    rng: &mut LcgRng,
183) -> VisionResult<MaskMeta> {
184    if n_patches == 0 {
185        return Err(VisionError::EmptyInput("n_patches"));
186    }
187    if !(0.0..=1.0).contains(&mask_ratio) || !mask_ratio.is_finite() {
188        return Err(VisionError::Internal(format!(
189            "mask_ratio {mask_ratio} not in [0, 1]"
190        )));
191    }
192    let mut ids: Vec<usize> = (0..n_patches).collect();
193    let n_masked = (mask_ratio * (n_patches as f32)).round() as usize;
194    let n_masked = n_masked.min(n_patches);
195
196    // Partial Fisher–Yates: for i in 0..n_masked, pick j uniformly in [i, n_patches)
197    // and swap. The first n_masked positions become a uniform random sample without
198    // replacement.
199    for i in 0..n_masked {
200        let remaining = n_patches - i;
201        // next_usize(remaining) returns a value in [0, remaining); add i for [i, n_patches)
202        let j = i + rng.next_usize(remaining);
203        ids.swap(i, j);
204    }
205
206    let mut masked_ids: Vec<usize> = ids[..n_masked].to_vec();
207    let mut visible_ids: Vec<usize> = ids[n_masked..].to_vec();
208    masked_ids.sort_unstable();
209    visible_ids.sort_unstable();
210    Ok(MaskMeta {
211        visible_ids,
212        masked_ids,
213    })
214}
215
216// ─── Mae ─────────────────────────────────────────────────────────────────────
217
218/// Masked Autoencoder model (He et al. 2022).
219///
220/// Weights are stored as flat row-major `Vec<f32>` consistent with the rest of
221/// the crate's transformer modules.
222pub struct Mae {
223    /// Top-level configuration.
224    pub config: MaeConfig,
225    /// Patch-embedding kernel: `[encoder_dim, patch_pixels]` row-major.
226    pub patch_embed_weights: Vec<f32>,
227    /// Patch-embedding bias: `[encoder_dim]`.
228    pub patch_embed_bias: Vec<f32>,
229    /// Encoder positional embeddings: `[n_patches, encoder_dim]`.
230    pub encoder_pos_embed: Vec<f32>,
231    /// Encoder ViT block stack.
232    pub encoder_blocks: Vec<ViTBlock>,
233    /// Final encoder LayerNorm scale (encoder_dim).
234    pub encoder_norm_gamma: Vec<f32>,
235    /// Final encoder LayerNorm bias  (encoder_dim).
236    pub encoder_norm_beta: Vec<f32>,
237    /// Decoder embed projection kernel: `[decoder_dim, encoder_dim]`.
238    pub decoder_embed_weights: Vec<f32>,
239    /// Decoder embed bias: `[decoder_dim]`.
240    pub decoder_embed_bias: Vec<f32>,
241    /// Learnable mask token shared by every masked position (decoder_dim).
242    pub mask_token: Vec<f32>,
243    /// Decoder positional embeddings: `[n_patches, decoder_dim]`.
244    pub decoder_pos_embed: Vec<f32>,
245    /// Decoder ViT block stack.
246    pub decoder_blocks: Vec<ViTBlock>,
247    /// Final decoder LayerNorm scale (decoder_dim).
248    pub decoder_norm_gamma: Vec<f32>,
249    /// Final decoder LayerNorm bias  (decoder_dim).
250    pub decoder_norm_beta: Vec<f32>,
251    /// Per-patch pixel projection kernel: `[patch_pixels, decoder_dim]`.
252    pub decoder_pred_weights: Vec<f32>,
253    /// Per-patch pixel projection bias: `[patch_pixels]`.
254    pub decoder_pred_bias: Vec<f32>,
255}
256
257/// Hazard-safe `[-0.5, 0.5)` uniform sample using the high 31 bits of LcgRng.
258///
259/// `next_u32()` already returns values in `[0, 2³¹)`, so dividing by `2³¹`
260/// gives a true `[0, 1)` sample; subtracting 0.5 centres it.
261#[inline]
262fn safe_centered_uniform(rng: &mut LcgRng) -> f32 {
263    (rng.next_u32() as f32) / 4_294_967_296.0 - 0.5
264}
265
266/// Fill `buf` with i.i.d. `[-scale, scale)` samples using the hazard-safe
267/// recipe (NOT `next_f32`, which is biased to `[0, ~0.5)`).
268fn fill_centered_uniform(buf: &mut [f32], scale: f32, rng: &mut LcgRng) {
269    for v in buf.iter_mut() {
270        *v = safe_centered_uniform(rng) * 2.0 * scale;
271    }
272}
273
274impl Mae {
275    /// Build and initialise a fresh MAE model.
276    ///
277    /// Weight initialisation:
278    /// - Linear / projection kernels: uniform `[-scale, scale)` with
279    ///   `scale = 1 / sqrt(fan_in)` (Xavier-like).
280    /// - Biases: zeros.
281    /// - LayerNorm gammas: ones, betas: zeros.
282    /// - Positional embeddings: small-magnitude uniform.
283    /// - `mask_token`: small-magnitude uniform.
284    ///
285    /// # Errors
286    /// Propagates any block-config validation error.
287    pub fn new(cfg: MaeConfig, rng: &mut LcgRng) -> VisionResult<Self> {
288        let n_patches = cfg.n_patches();
289        let pp = cfg.patch_pixels();
290        let edim = cfg.encoder_dim;
291        let ddim = cfg.decoder_dim;
292
293        let enc_scale = 1.0 / (pp as f32).sqrt();
294        let mut patch_embed_weights = vec![0.0f32; edim * pp];
295        fill_centered_uniform(&mut patch_embed_weights, enc_scale, rng);
296        let patch_embed_bias = vec![0.0f32; edim];
297
298        let pos_scale = 0.02f32; // canonical ViT pos-embed init magnitude
299        let mut encoder_pos_embed = vec![0.0f32; n_patches * edim];
300        fill_centered_uniform(&mut encoder_pos_embed, pos_scale, rng);
301
302        // Encoder block stack — reuse ViTEncoderConfig to validate and create blocks.
303        let enc_block_cfg =
304            ViTEncoderConfig::new(edim, cfg.encoder_heads, cfg.mlp_ratio, cfg.encoder_depth)?;
305        let mut encoder_blocks = Vec::with_capacity(cfg.encoder_depth);
306        for _ in 0..cfg.encoder_depth {
307            encoder_blocks.push(ViTBlock::new(enc_block_cfg.block_cfg.clone(), rng));
308        }
309        let encoder_norm_gamma = vec![1.0f32; edim];
310        let encoder_norm_beta = vec![0.0f32; edim];
311
312        let dec_in_scale = 1.0 / (edim as f32).sqrt();
313        let mut decoder_embed_weights = vec![0.0f32; ddim * edim];
314        fill_centered_uniform(&mut decoder_embed_weights, dec_in_scale, rng);
315        let decoder_embed_bias = vec![0.0f32; ddim];
316
317        let mut mask_token = vec![0.0f32; ddim];
318        fill_centered_uniform(&mut mask_token, pos_scale, rng);
319
320        let mut decoder_pos_embed = vec![0.0f32; n_patches * ddim];
321        fill_centered_uniform(&mut decoder_pos_embed, pos_scale, rng);
322
323        let dec_block_cfg =
324            ViTEncoderConfig::new(ddim, cfg.decoder_heads, cfg.mlp_ratio, cfg.decoder_depth)?;
325        let mut decoder_blocks = Vec::with_capacity(cfg.decoder_depth);
326        for _ in 0..cfg.decoder_depth {
327            decoder_blocks.push(ViTBlock::new(dec_block_cfg.block_cfg.clone(), rng));
328        }
329        let decoder_norm_gamma = vec![1.0f32; ddim];
330        let decoder_norm_beta = vec![0.0f32; ddim];
331
332        let pred_scale = 1.0 / (ddim as f32).sqrt();
333        let mut decoder_pred_weights = vec![0.0f32; pp * ddim];
334        fill_centered_uniform(&mut decoder_pred_weights, pred_scale, rng);
335        let decoder_pred_bias = vec![0.0f32; pp];
336
337        Ok(Self {
338            config: cfg,
339            patch_embed_weights,
340            patch_embed_bias,
341            encoder_pos_embed,
342            encoder_blocks,
343            encoder_norm_gamma,
344            encoder_norm_beta,
345            decoder_embed_weights,
346            decoder_embed_bias,
347            mask_token,
348            decoder_pos_embed,
349            decoder_blocks,
350            decoder_norm_gamma,
351            decoder_norm_beta,
352            decoder_pred_weights,
353            decoder_pred_bias,
354        })
355    }
356
357    /// Encode an already-patchified image.
358    ///
359    /// `image_patches` is `[n_patches, patch_pixels]` row-major. Returns the
360    /// `[n_visible, encoder_dim]` features of visible tokens plus the
361    /// generated `MaskMeta`.
362    ///
363    /// # Errors
364    /// - `image_patches.len() != n_patches · patch_pixels` → `DimensionMismatch`
365    /// - mask-generation errors propagated
366    pub fn encode(
367        &self,
368        image_patches: &[f32],
369        rng: &mut LcgRng,
370    ) -> VisionResult<(Vec<f32>, MaskMeta)> {
371        let n_patches = self.config.n_patches();
372        let pp = self.config.patch_pixels();
373        let edim = self.config.encoder_dim;
374
375        if n_patches == 0 {
376            return Err(VisionError::EmptyInput("n_patches"));
377        }
378        let expected = n_patches * pp;
379        if image_patches.len() != expected {
380            return Err(VisionError::DimensionMismatch {
381                expected,
382                got: image_patches.len(),
383            });
384        }
385
386        // Patch embed: y = X · W^T + b — output [n_patches, encoder_dim]
387        let mut embedded = linear(
388            image_patches,
389            &self.patch_embed_weights,
390            &self.patch_embed_bias,
391            pp,
392            edim,
393        );
394
395        // Add encoder positional embeddings.
396        for (i, v) in embedded.iter_mut().enumerate() {
397            *v += self
398                .encoder_pos_embed
399                .get(i)
400                .copied()
401                .ok_or(VisionError::Internal(
402                    "encoder_pos_embed shorter than embedded".into(),
403                ))?;
404        }
405
406        // Generate mask and gather visible tokens.
407        let mask_meta = generate_random_mask(n_patches, self.config.mask_ratio, rng)?;
408        let n_visible = mask_meta.visible_ids.len();
409
410        let mut visible_tokens = vec![0.0f32; n_visible * edim];
411        for (out_i, &src_i) in mask_meta.visible_ids.iter().enumerate() {
412            let src = embedded
413                .get(src_i * edim..(src_i + 1) * edim)
414                .ok_or(VisionError::Internal("visible idx out of range".into()))?;
415            let dst = visible_tokens
416                .get_mut(out_i * edim..(out_i + 1) * edim)
417                .ok_or(VisionError::Internal(
418                    "visible_tokens slice out of range".into(),
419                ))?;
420            dst.copy_from_slice(src);
421        }
422
423        // Encoder block stack (only on visible tokens — the MAE speed-up).
424        // ViTBlock::forward rejects n_tokens == 0, so for mask_ratio == 1 we
425        // produce an empty Vec without touching the blocks.
426        let encoded = if n_visible == 0 {
427            Vec::new()
428        } else {
429            let mut h = visible_tokens;
430            for block in &self.encoder_blocks {
431                h = block.forward(&h, n_visible)?;
432            }
433            layer_norm(
434                &h,
435                &self.encoder_norm_gamma,
436                &self.encoder_norm_beta,
437                n_visible,
438                edim,
439                1e-5,
440            )
441        };
442
443        Ok((encoded, mask_meta))
444    }
445
446    /// Decode visible features and reconstruct per-patch pixels.
447    ///
448    /// Returns `[n_patches, patch_pixels]` (full sequence including
449    /// reconstructions of the originally-visible positions).
450    ///
451    /// # Errors
452    /// - `encoded_visible.len() != |visible_ids| · encoder_dim` →
453    ///   `DimensionMismatch`
454    pub fn decode(&self, encoded_visible: &[f32], mask_meta: &MaskMeta) -> VisionResult<Vec<f32>> {
455        let n_patches = self.config.n_patches();
456        let edim = self.config.encoder_dim;
457        let ddim = self.config.decoder_dim;
458        let pp = self.config.patch_pixels();
459        let n_visible = mask_meta.visible_ids.len();
460        let n_masked = mask_meta.masked_ids.len();
461
462        if n_visible + n_masked != n_patches {
463            return Err(VisionError::Internal(
464                "MaskMeta visible + masked sizes do not sum to n_patches".into(),
465            ));
466        }
467        if encoded_visible.len() != n_visible * edim {
468            return Err(VisionError::DimensionMismatch {
469                expected: n_visible * edim,
470                got: encoded_visible.len(),
471            });
472        }
473
474        // Project visible features encoder_dim → decoder_dim (if any).
475        let visible_dec = if n_visible == 0 {
476            Vec::new()
477        } else {
478            linear(
479                encoded_visible,
480                &self.decoder_embed_weights,
481                &self.decoder_embed_bias,
482                edim,
483                ddim,
484            )
485        };
486
487        // Scatter into the full-length [n_patches, ddim] sequence: visible at
488        // their original ids, mask_token at masked ids.
489        let mut full = vec![0.0f32; n_patches * ddim];
490        for (vis_i, &dst_i) in mask_meta.visible_ids.iter().enumerate() {
491            let src = visible_dec
492                .get(vis_i * ddim..(vis_i + 1) * ddim)
493                .ok_or(VisionError::Internal("visible_dec slice".into()))?;
494            let dst = full
495                .get_mut(dst_i * ddim..(dst_i + 1) * ddim)
496                .ok_or(VisionError::Internal("full slice (visible)".into()))?;
497            dst.copy_from_slice(src);
498        }
499        for &dst_i in &mask_meta.masked_ids {
500            let dst = full
501                .get_mut(dst_i * ddim..(dst_i + 1) * ddim)
502                .ok_or(VisionError::Internal("full slice (masked)".into()))?;
503            dst.copy_from_slice(&self.mask_token);
504        }
505
506        // Add decoder positional embeddings.
507        for (i, v) in full.iter_mut().enumerate() {
508            *v += self
509                .decoder_pos_embed
510                .get(i)
511                .copied()
512                .ok_or(VisionError::Internal("decoder_pos_embed".into()))?;
513        }
514
515        // Decoder block stack on the FULL sequence.
516        let mut h = full;
517        for block in &self.decoder_blocks {
518            h = block.forward(&h, n_patches)?;
519        }
520        let post_norm = layer_norm(
521            &h,
522            &self.decoder_norm_gamma,
523            &self.decoder_norm_beta,
524            n_patches,
525            ddim,
526            1e-5,
527        );
528
529        // Per-patch pixel projection [n_patches, ddim] → [n_patches, patch_pixels].
530        let reconstructed = linear(
531            &post_norm,
532            &self.decoder_pred_weights,
533            &self.decoder_pred_bias,
534            ddim,
535            pp,
536        );
537        Ok(reconstructed)
538    }
539}
540
541// ─── mae_loss ────────────────────────────────────────────────────────────────
542
543/// Canonical MAE reconstruction loss.
544///
545/// Mean squared error averaged over **masked positions only** (visible
546/// positions are intentionally ignored — this is the core MAE insight).
547///
548/// # Errors
549/// - `reconstructed.len() != ground_truth_patches.len()` → `DimensionMismatch`
550/// - lengths not divisible by `n_patches` → `DimensionMismatch`
551/// - any masked index out of range → `Internal`
552/// - empty `masked_ids` → returns `Ok(0.0)`  (no error: with mask_ratio=0
553///   there is nothing to score; downstream code should treat 0 as "no signal")
554pub fn mae_loss(
555    reconstructed: &[f32],
556    ground_truth_patches: &[f32],
557    mask_meta: &MaskMeta,
558) -> VisionResult<f32> {
559    if reconstructed.len() != ground_truth_patches.len() {
560        return Err(VisionError::DimensionMismatch {
561            expected: reconstructed.len(),
562            got: ground_truth_patches.len(),
563        });
564    }
565    let n_patches = mask_meta.visible_ids.len() + mask_meta.masked_ids.len();
566    if n_patches == 0 {
567        return Err(VisionError::EmptyInput("mask_meta n_patches"));
568    }
569    if reconstructed.len() % n_patches != 0 {
570        return Err(VisionError::DimensionMismatch {
571            expected: n_patches,
572            got: reconstructed.len(),
573        });
574    }
575    let pp = reconstructed.len() / n_patches;
576    if mask_meta.masked_ids.is_empty() {
577        return Ok(0.0);
578    }
579    let mut sum_sq = 0.0f64;
580    let mut count: u64 = 0;
581    for &mi in &mask_meta.masked_ids {
582        let r = reconstructed
583            .get(mi * pp..(mi + 1) * pp)
584            .ok_or(VisionError::Internal("loss: masked idx".into()))?;
585        let g = ground_truth_patches
586            .get(mi * pp..(mi + 1) * pp)
587            .ok_or(VisionError::Internal("loss: masked idx (gt)".into()))?;
588        for (rv, gv) in r.iter().zip(g.iter()) {
589            let d = (*rv - *gv) as f64;
590            sum_sq += d * d;
591            count += 1;
592        }
593    }
594    let mean = if count == 0 {
595        0.0
596    } else {
597        sum_sq / (count as f64)
598    };
599    Ok(mean as f32)
600}
601
602// ─── Tests ───────────────────────────────────────────────────────────────────
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use std::collections::HashSet;
608
609    fn make_tiny_cfg() -> MaeConfig {
610        // img 8 / patch 4 = 2x2 = 4 patches, in_chans 3, patch_pixels = 48
611        MaeConfig::new(8, 4, 3, 16, 2, 4, 8, 1, 4, 2, 0.5).expect("valid tiny cfg")
612    }
613
614    fn make_medium_cfg() -> MaeConfig {
615        // img 16 / patch 4 = 4x4 = 16 patches, in_chans 3, patch_pixels = 48
616        MaeConfig::new(16, 4, 3, 32, 2, 4, 16, 1, 4, 2, 0.75).expect("valid med cfg")
617    }
618
619    // ── generate_random_mask ──────────────────────────────────────────────────
620
621    #[test]
622    fn mask_union_and_disjoint() {
623        let mut rng = LcgRng::new(1);
624        let n = 16;
625        let m = generate_random_mask(n, 0.5, &mut rng).expect("ok");
626        let v: HashSet<usize> = m.visible_ids.iter().copied().collect();
627        let k: HashSet<usize> = m.masked_ids.iter().copied().collect();
628        assert!(v.is_disjoint(&k));
629        let union: HashSet<usize> = v.union(&k).copied().collect();
630        let expected: HashSet<usize> = (0..n).collect();
631        assert_eq!(union, expected);
632        assert_eq!(v.len() + k.len(), n);
633    }
634
635    #[test]
636    fn mask_count_matches_round() {
637        let mut rng = LcgRng::new(2);
638        // 0.75 * 100 = 75 exactly
639        let m = generate_random_mask(100, 0.75, &mut rng).expect("ok");
640        assert_eq!(m.masked_ids.len(), 75);
641        assert_eq!(m.visible_ids.len(), 25);
642    }
643
644    #[test]
645    fn mask_count_rounds_correctly() {
646        let mut rng = LcgRng::new(3);
647        // 0.7 * 10 = 7 exactly
648        let m = generate_random_mask(10, 0.7, &mut rng).expect("ok");
649        assert_eq!(m.masked_ids.len(), 7);
650    }
651
652    #[test]
653    fn mask_ratio_zero_all_visible() {
654        let mut rng = LcgRng::new(4);
655        let m = generate_random_mask(8, 0.0, &mut rng).expect("ok");
656        assert_eq!(m.masked_ids.len(), 0);
657        assert_eq!(m.visible_ids.len(), 8);
658        assert_eq!(m.visible_ids, (0..8).collect::<Vec<_>>());
659    }
660
661    #[test]
662    fn mask_ratio_one_all_masked() {
663        let mut rng = LcgRng::new(5);
664        let m = generate_random_mask(8, 1.0, &mut rng).expect("ok");
665        assert_eq!(m.masked_ids.len(), 8);
666        assert_eq!(m.visible_ids.len(), 0);
667        assert_eq!(m.masked_ids, (0..8).collect::<Vec<_>>());
668    }
669
670    #[test]
671    fn mask_deterministic_same_seed() {
672        let mut a = LcgRng::new(42);
673        let mut b = LcgRng::new(42);
674        let ma = generate_random_mask(64, 0.75, &mut a).expect("ok");
675        let mb = generate_random_mask(64, 0.75, &mut b).expect("ok");
676        assert_eq!(ma, mb);
677    }
678
679    #[test]
680    fn mask_sorted_ascending() {
681        let mut rng = LcgRng::new(6);
682        let m = generate_random_mask(50, 0.6, &mut rng).expect("ok");
683        for w in m.visible_ids.windows(2) {
684            assert!(w[0] < w[1]);
685        }
686        for w in m.masked_ids.windows(2) {
687            assert!(w[0] < w[1]);
688        }
689    }
690
691    #[test]
692    fn mask_invalid_ratio_errors() {
693        let mut rng = LcgRng::new(7);
694        assert!(generate_random_mask(8, -0.1, &mut rng).is_err());
695        assert!(generate_random_mask(8, 1.5, &mut rng).is_err());
696        assert!(generate_random_mask(8, f32::NAN, &mut rng).is_err());
697    }
698
699    #[test]
700    fn mask_n_patches_zero_errors() {
701        let mut rng = LcgRng::new(8);
702        let r = generate_random_mask(0, 0.5, &mut rng);
703        assert!(matches!(r, Err(VisionError::EmptyInput(_))));
704    }
705
706    // ── Mae construction / config ─────────────────────────────────────────────
707
708    #[test]
709    fn cfg_patch_not_divisible_errors() {
710        let r = MaeConfig::new(7, 4, 3, 16, 1, 4, 8, 1, 4, 2, 0.5);
711        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
712    }
713
714    #[test]
715    fn cfg_zero_channels_errors() {
716        let r = MaeConfig::new(8, 4, 0, 16, 1, 4, 8, 1, 4, 2, 0.5);
717        assert!(r.is_err());
718    }
719
720    #[test]
721    fn cfg_zero_encoder_dim_errors() {
722        let r = MaeConfig::new(8, 4, 3, 0, 1, 4, 8, 1, 4, 2, 0.5);
723        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
724    }
725
726    #[test]
727    fn cfg_zero_decoder_dim_errors() {
728        let r = MaeConfig::new(8, 4, 3, 16, 1, 4, 0, 1, 4, 2, 0.5);
729        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
730    }
731
732    #[test]
733    fn cfg_zero_depth_errors() {
734        let r1 = MaeConfig::new(8, 4, 3, 16, 0, 4, 8, 1, 4, 2, 0.5);
735        let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 0, 4, 2, 0.5);
736        assert!(r1.is_err());
737        assert!(r2.is_err());
738    }
739
740    #[test]
741    fn cfg_mask_ratio_out_of_range_errors() {
742        let r1 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, -0.1);
743        let r2 = MaeConfig::new(8, 4, 3, 16, 1, 4, 8, 1, 4, 2, 1.5);
744        assert!(r1.is_err());
745        assert!(r2.is_err());
746    }
747
748    #[test]
749    fn cfg_n_patches_and_pixels() {
750        let cfg = make_tiny_cfg();
751        assert_eq!(cfg.n_patches(), 4);
752        assert_eq!(cfg.patch_pixels(), 4 * 4 * 3);
753    }
754
755    // ── encode / decode shapes ────────────────────────────────────────────────
756
757    #[test]
758    fn encode_shape() {
759        let cfg = make_medium_cfg();
760        let mut rng = LcgRng::new(11);
761        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
762        let n_patches = cfg.n_patches();
763        let pp = cfg.patch_pixels();
764        let edim = cfg.encoder_dim;
765        let patches = vec![0.1f32; n_patches * pp];
766        let mut rng2 = LcgRng::new(99);
767        let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
768        assert_eq!(enc.len(), mask.visible_ids.len() * edim);
769    }
770
771    #[test]
772    fn decode_shape_matches_patches() {
773        let cfg = make_medium_cfg();
774        let mut rng = LcgRng::new(13);
775        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
776        let n_patches = cfg.n_patches();
777        let pp = cfg.patch_pixels();
778        let patches = vec![0.1f32; n_patches * pp];
779        let mut rng2 = LcgRng::new(101);
780        let (enc, mask) = mae.encode(&patches, &mut rng2).expect("ok");
781        let recon = mae.decode(&enc, &mask).expect("ok");
782        assert_eq!(recon.len(), n_patches * pp);
783    }
784
785    #[test]
786    fn full_pipeline_deterministic_same_seed() {
787        let cfg = make_medium_cfg();
788        let mut rng_a = LcgRng::new(33);
789        let mae_a = Mae::new(cfg.clone(), &mut rng_a).expect("ok");
790        let mut rng_b = LcgRng::new(33);
791        let mae_b = Mae::new(cfg.clone(), &mut rng_b).expect("ok");
792
793        let n_patches = cfg.n_patches();
794        let pp = cfg.patch_pixels();
795        let mut patches = vec![0.0f32; n_patches * pp];
796        let mut rin = LcgRng::new(5);
797        for v in patches.iter_mut() {
798            *v = (rin.next_u32() as f32) / 4_294_967_296.0;
799        }
800
801        let mut r_a = LcgRng::new(77);
802        let mut r_b = LcgRng::new(77);
803        let (ea, ma) = mae_a.encode(&patches, &mut r_a).expect("ok");
804        let (eb, mb) = mae_b.encode(&patches, &mut r_b).expect("ok");
805        assert_eq!(ma, mb);
806        for (a, b) in ea.iter().zip(eb.iter()) {
807            assert!((a - b).abs() < 1e-6, "encode differs: {a} vs {b}");
808        }
809        let recon_a = mae_a.decode(&ea, &ma).expect("ok");
810        let recon_b = mae_b.decode(&eb, &mb).expect("ok");
811        for (a, b) in recon_a.iter().zip(recon_b.iter()) {
812            assert!((a - b).abs() < 1e-6, "decode differs: {a} vs {b}");
813        }
814    }
815
816    // ── encode / decode error paths ───────────────────────────────────────────
817
818    #[test]
819    fn encode_dimension_mismatch_errors() {
820        let cfg = make_tiny_cfg();
821        let mut rng = LcgRng::new(15);
822        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
823        // wrong size: 1 patch short
824        let pp = cfg.patch_pixels();
825        let patches = vec![0.0f32; (cfg.n_patches() - 1) * pp];
826        let mut rng2 = LcgRng::new(16);
827        let r = mae.encode(&patches, &mut rng2);
828        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
829    }
830
831    #[test]
832    fn decode_wrong_visible_length_errors() {
833        let cfg = make_tiny_cfg();
834        let mut rng = LcgRng::new(17);
835        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
836        // Build a valid mask, then truncate features.
837        let mut rng_m = LcgRng::new(18);
838        let mask = generate_random_mask(cfg.n_patches(), 0.5, &mut rng_m).expect("ok");
839        let wrong = vec![0.0f32; mask.visible_ids.len() * cfg.encoder_dim - 1];
840        let r = mae.decode(&wrong, &mask);
841        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
842    }
843
844    // ── mae_loss ──────────────────────────────────────────────────────────────
845
846    #[test]
847    fn loss_zero_when_match_on_masked_positions() {
848        // Ground truth and reconstructed only need to agree on masked positions.
849        let mask = MaskMeta {
850            visible_ids: vec![0, 2],
851            masked_ids: vec![1, 3],
852        };
853        let pp = 5;
854        let n_patches = 4;
855        let mut gt = vec![0.0f32; n_patches * pp];
856        let mut recon = vec![0.0f32; n_patches * pp];
857        for (i, g) in gt.iter_mut().enumerate() {
858            *g = i as f32;
859        }
860        // Match at masked positions (1, 3):
861        for &mi in &mask.masked_ids {
862            for k in 0..pp {
863                recon[mi * pp + k] = gt[mi * pp + k];
864            }
865        }
866        // Differ on visible positions (0, 2):
867        for k in 0..pp {
868            recon[k] = 999.0;
869            recon[2 * pp + k] = -777.0;
870        }
871        let loss = mae_loss(&recon, &gt, &mask).expect("ok");
872        assert!(
873            loss.abs() < 1e-6,
874            "loss should be 0 when masked match: {loss}"
875        );
876    }
877
878    #[test]
879    fn loss_independent_of_visible_positions() {
880        let mask = MaskMeta {
881            visible_ids: vec![0, 2],
882            masked_ids: vec![1, 3],
883        };
884        let pp = 3;
885        let n_patches = 4;
886        let mut gt = vec![0.0f32; n_patches * pp];
887        let mut recon_a = vec![0.0f32; n_patches * pp];
888        let mut recon_b = vec![0.0f32; n_patches * pp];
889        for i in 0..n_patches * pp {
890            gt[i] = (i as f32) * 0.1;
891            recon_a[i] = gt[i] + 0.5; // off everywhere
892            recon_b[i] = gt[i] + 0.5;
893        }
894        // Now alter ONLY visible positions in recon_b drastically:
895        for &vi in &mask.visible_ids {
896            for k in 0..pp {
897                recon_b[vi * pp + k] = 1234.0;
898            }
899        }
900        let la = mae_loss(&recon_a, &gt, &mask).expect("ok");
901        let lb = mae_loss(&recon_b, &gt, &mask).expect("ok");
902        assert!(
903            (la - lb).abs() < 1e-6,
904            "loss depends on visible: {la} vs {lb}"
905        );
906    }
907
908    #[test]
909    fn loss_dimension_mismatch_errors() {
910        let mask = MaskMeta {
911            visible_ids: vec![0],
912            masked_ids: vec![1],
913        };
914        let r = mae_loss(&[0.0; 4], &[0.0; 5], &mask);
915        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
916    }
917
918    #[test]
919    fn loss_mask_ratio_zero_returns_zero() {
920        let mut rng = LcgRng::new(21);
921        let m = generate_random_mask(6, 0.0, &mut rng).expect("ok");
922        let r = vec![1.0f32; 6 * 4];
923        let g = vec![2.0f32; 6 * 4];
924        let l = mae_loss(&r, &g, &m).expect("ok");
925        assert!(l.abs() < 1e-6, "no masked → loss = 0; got {l}");
926    }
927
928    #[test]
929    fn loss_positive_when_recon_off() {
930        let mask = MaskMeta {
931            visible_ids: vec![0],
932            masked_ids: vec![1],
933        };
934        let pp = 4;
935        let gt = vec![0.0f32; 2 * pp];
936        let mut recon = vec![0.0f32; 2 * pp];
937        for k in 0..pp {
938            recon[pp + k] = 1.0; // squared error of 1 per element
939        }
940        let l = mae_loss(&recon, &gt, &mask).expect("ok");
941        assert!((l - 1.0).abs() < 1e-6, "expected MSE=1, got {l}");
942    }
943
944    // ── encode produces finite outputs ────────────────────────────────────────
945
946    #[test]
947    fn encode_decode_finite() {
948        let cfg = make_medium_cfg();
949        let mut rng = LcgRng::new(45);
950        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
951        let n_patches = cfg.n_patches();
952        let pp = cfg.patch_pixels();
953        let mut patches = vec![0.0f32; n_patches * pp];
954        let mut rin = LcgRng::new(55);
955        rin.fill_normal(&mut patches);
956        let mut r2 = LcgRng::new(66);
957        let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
958        assert!(enc.iter().all(|v| v.is_finite()));
959        let recon = mae.decode(&enc, &mask).expect("ok");
960        assert!(recon.iter().all(|v| v.is_finite()));
961    }
962
963    // ── identity-ish weight round trip ────────────────────────────────────────
964
965    #[test]
966    fn identity_decoder_reconstructs_mask_token_at_masked() {
967        // Hand-craft the decoder so that, for any visible token, the value of
968        // the reconstructed patch at MASKED positions depends ONLY on
969        // mask_token and decoder_pos_embed, not on visible content. We do this
970        // by:
971        //   1. Setting decoder Q/K/V projections to zero → MHSA output is 0.
972        //   2. Setting MLP projections to zero → MLP residual contributes 0.
973        //   3. Pre-norm LN1/LN2 weights to 1 / biases to 0 (default).
974        //   4. final layer norm gamma=1, beta=0.
975        //   5. decoder_pred_weights = identity over decoder_dim → patch_pixels
976        //      identity when patch_pixels == decoder_dim, biases zero.
977        //   6. decoder_pos_embed = 0 to eliminate its contribution.
978        //
979        // Then at any masked position i: input to decoder is mask_token; after
980        // all the (post-norm) zero residuals it remains mask_token; LN
981        // normalises it; the linear projection passes it through. We check
982        // shape and that the masked-position outputs are independent of the
983        // visible token content by perturbing the visible input.
984        //
985        // To make the projection "identity-ish" cleanly, pick patch_pixels =
986        // decoder_dim = 4. So img_size = 2, patch_size = 2, in_channels = 1.
987        let mut cfg = MaeConfig::new(2, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.5).expect("ok");
988        cfg.mask_ratio = 0.5;
989        let mut rng = LcgRng::new(123);
990        let mut mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
991
992        // Zero out all decoder block weights (Q/K/V, output, MLP1, MLP2)
993        for block in mae.decoder_blocks.iter_mut() {
994            for v in block.weights.qkv_weight.iter_mut() {
995                *v = 0.0;
996            }
997            for v in block.weights.qkv_bias.iter_mut() {
998                *v = 0.0;
999            }
1000            for v in block.weights.out_weight.iter_mut() {
1001                *v = 0.0;
1002            }
1003            for v in block.weights.out_bias.iter_mut() {
1004                *v = 0.0;
1005            }
1006            for v in block.weights.mlp1_weight.iter_mut() {
1007                *v = 0.0;
1008            }
1009            for v in block.weights.mlp1_bias.iter_mut() {
1010                *v = 0.0;
1011            }
1012            for v in block.weights.mlp2_weight.iter_mut() {
1013                *v = 0.0;
1014            }
1015            for v in block.weights.mlp2_bias.iter_mut() {
1016                *v = 0.0;
1017            }
1018            // LayerNorms: gamma=1, beta=0 already.
1019        }
1020        // Final norm: gamma=1, beta=0 already.
1021        // decoder_pos_embed → 0 (no positional contribution at masked sites)
1022        for v in mae.decoder_pos_embed.iter_mut() {
1023            *v = 0.0;
1024        }
1025        // decoder_pred = identity (patch_pixels == decoder_dim == 4)
1026        for v in mae.decoder_pred_weights.iter_mut() {
1027            *v = 0.0;
1028        }
1029        for i in 0..4 {
1030            mae.decoder_pred_weights[i * 4 + i] = 1.0;
1031        }
1032        for v in mae.decoder_pred_bias.iter_mut() {
1033            *v = 0.0;
1034        }
1035        // mask_token set to a known value
1036        mae.mask_token = vec![0.1, -0.2, 0.3, -0.4];
1037
1038        // Run encode with TWO different visible-token contents but the SAME
1039        // mask (achieved by reusing the same RNG seed for encode).
1040        let n_patches = cfg.n_patches();
1041        let pp = cfg.patch_pixels();
1042
1043        let patches_a = vec![1.0f32; n_patches * pp];
1044        let mut patches_b = vec![1.0f32; n_patches * pp];
1045        // Different content at every pixel:
1046        for v in patches_b.iter_mut() {
1047            *v = 7.7;
1048        }
1049
1050        let mut r_a = LcgRng::new(2024);
1051        let mut r_b = LcgRng::new(2024);
1052        let (enc_a, ma) = mae.encode(&patches_a, &mut r_a).expect("ok");
1053        let (enc_b, mb) = mae.encode(&patches_b, &mut r_b).expect("ok");
1054        assert_eq!(ma, mb, "same RNG seed must produce same mask");
1055
1056        let recon_a = mae.decode(&enc_a, &ma).expect("ok");
1057        let recon_b = mae.decode(&enc_b, &mb).expect("ok");
1058        // At masked positions, reconstruction must be identical (zero
1059        // attention prevents visible info from reaching mask sites, and zero
1060        // decoder_pos_embed eliminates positional bias).
1061        //
1062        // Compute LN(mask_token) for the expected value: with 4 elements
1063        // (0.1, -0.2, 0.3, -0.4): mean=-0.05, centred=(0.15,-0.15,0.35,-0.35).
1064        // var = (.0225 + .0225 + .1225 + .1225)/4 = .0725
1065        // inv_std = 1/sqrt(.0725 + 1e-5) ≈ 3.71...
1066        let mean = (0.1f32 + (-0.2) + 0.3 + (-0.4)) / 4.0;
1067        let centred = [0.1f32 - mean, -0.2 - mean, 0.3 - mean, -0.4 - mean];
1068        let var = centred.iter().map(|c| c * c).sum::<f32>() / 4.0;
1069        let inv_std = 1.0 / (var + 1e-5).sqrt();
1070        let expected_at_mask: Vec<f32> = centred.iter().map(|c| c * inv_std).collect();
1071
1072        for &mi in &ma.masked_ids {
1073            for k in 0..pp {
1074                let a = recon_a[mi * pp + k];
1075                let b = recon_b[mi * pp + k];
1076                assert!(
1077                    (a - b).abs() < 1e-5,
1078                    "masked pos {mi} k={k}: a={a} b={b} (depends on visible!)"
1079                );
1080                let exp = expected_at_mask[k];
1081                assert!(
1082                    (a - exp).abs() < 1e-4,
1083                    "masked pos {mi} k={k}: got {a} expected {exp}"
1084                );
1085            }
1086        }
1087    }
1088
1089    // ── extra coverage ────────────────────────────────────────────────────────
1090
1091    #[test]
1092    fn mask_full_ratio_encoder_skipped() {
1093        // mask_ratio == 1: no visible tokens — encoder returns empty Vec
1094        // without panicking and decode reconstructs all-mask_token paths.
1095        let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 1.0).expect("ok");
1096        let mut rng = LcgRng::new(31);
1097        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
1098        let pp = cfg.patch_pixels();
1099        let n = cfg.n_patches();
1100        let patches = vec![0.0f32; n * pp];
1101        let mut r2 = LcgRng::new(32);
1102        let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
1103        assert_eq!(enc.len(), 0);
1104        assert_eq!(mask.masked_ids.len(), n);
1105        let recon = mae.decode(&enc, &mask).expect("ok");
1106        assert_eq!(recon.len(), n * pp);
1107        assert!(recon.iter().all(|v| v.is_finite()));
1108    }
1109
1110    #[test]
1111    fn mask_zero_ratio_full_encoder() {
1112        let cfg = MaeConfig::new(4, 2, 1, 4, 1, 1, 4, 1, 1, 1, 0.0).expect("ok");
1113        let mut rng = LcgRng::new(41);
1114        let mae = Mae::new(cfg.clone(), &mut rng).expect("ok");
1115        let pp = cfg.patch_pixels();
1116        let n = cfg.n_patches();
1117        let patches = vec![0.1f32; n * pp];
1118        let mut r2 = LcgRng::new(42);
1119        let (enc, mask) = mae.encode(&patches, &mut r2).expect("ok");
1120        assert_eq!(mask.masked_ids.len(), 0);
1121        assert_eq!(mask.visible_ids.len(), n);
1122        assert_eq!(enc.len(), n * cfg.encoder_dim);
1123    }
1124}