Skip to main content

oxicuda_ssl/masked/
beit.rs

1//! BEiT — Bao et al. 2021 — BERT Pre-Training of Image Transformers.
2//!
3//! Key insight: instead of predicting raw pixel values (like MAE / SimMIM),
4//! BEiT treats masked image modeling as **discrete token prediction**.  A
5//! VQ-VAE / dVAE tokenizer maps each patch to a codebook index; the ViT is
6//! then trained to predict those discrete tokens at masked positions via
7//! cross-entropy — analogous to BERT's masked word prediction in NLP.
8//!
9//! # Components
10//!
11//! - [`VqCodebook`] — EMA-maintained vector-quantisation codebook.
12//! - [`BeitConfig`] — hyper-parameters for tokenizer + pretraining loss.
13//! - [`BeitResult`] — composite loss (BEiT CE + VQ commitment) and metrics.
14//! - [`vq_codebook_init`] — random initialisation of the codebook.
15//! - [`vq_encode`] — nearest-neighbour encoding + straight-through VQ loss.
16//! - [`vq_update_codebook`] — EMA codebook update from assigned embeddings.
17//! - [`beit_loss`] — cross-entropy of student logits vs. discrete VQ tokens.
18//! - [`beit_block_mask`] — BEiT-style random rectangular block masking.
19//!
20//! # References
21//! - Bao et al. *BEiT: BERT Pre-Training of Image Transformers* (ICLR 2022)
22//! - van den Oord et al. *Neural Discrete Representation Learning* (NeurIPS 2017)
23
24use crate::error::{SslError, SslResult};
25use crate::handle::LcgRng;
26
27// ─── Configuration ────────────────────────────────────────────────────────────
28
29/// BEiT tokenizer + pretraining configuration.
30#[derive(Debug, Clone)]
31pub struct BeitConfig {
32    /// Codebook size K (number of discrete visual tokens). Default: 8192.
33    pub n_codes: usize,
34    /// Code dimensionality C. Default: 256.
35    pub code_dim: usize,
36    /// Fraction of patches to mask during pre-training. Default: 0.4.
37    pub mask_ratio: f32,
38    /// EMA momentum for codebook update. Default: 0.999.
39    pub ema_momentum: f32,
40    /// Commitment loss weight β. Default: 0.25.
41    pub commitment_weight: f32,
42    /// Softmax temperature for BEiT loss (scales student logits). Default: 1.0.
43    pub temperature: f32,
44    /// Numerical stability ε. Default: 1e-6.
45    pub eps: f32,
46}
47
48impl Default for BeitConfig {
49    fn default() -> Self {
50        Self {
51            n_codes: 8192,
52            code_dim: 256,
53            mask_ratio: 0.4,
54            ema_momentum: 0.999,
55            commitment_weight: 0.25,
56            temperature: 1.0,
57            eps: 1e-6,
58        }
59    }
60}
61
62impl BeitConfig {
63    /// Validated BEiT configuration constructor.
64    ///
65    /// # Errors
66    /// - [`SslError::InvalidParameter`] for any out-of-range or zero value.
67    /// - [`SslError::InvalidMaskRatio`] when `mask_ratio ∉ [0, 1)`.
68    /// - [`SslError::InvalidTemperature`] when `temperature ≤ 0`.
69    pub fn new(
70        n_codes: usize,
71        code_dim: usize,
72        mask_ratio: f32,
73        ema_momentum: f32,
74        commitment_weight: f32,
75        temperature: f32,
76        eps: f32,
77    ) -> SslResult<Self> {
78        if n_codes == 0 {
79            return Err(SslError::InvalidParameter {
80                name: "n_codes".into(),
81                reason: "must be > 0".into(),
82            });
83        }
84        if code_dim == 0 {
85            return Err(SslError::InvalidParameter {
86                name: "code_dim".into(),
87                reason: "must be > 0".into(),
88            });
89        }
90        if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
91            return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
92        }
93        if !(ema_momentum.is_finite() && (0.0..=1.0).contains(&ema_momentum)) {
94            return Err(SslError::InvalidMomentum {
95                momentum: ema_momentum,
96            });
97        }
98        if !(commitment_weight.is_finite() && commitment_weight >= 0.0) {
99            return Err(SslError::InvalidParameter {
100                name: "commitment_weight".into(),
101                reason: "must be finite and >= 0".into(),
102            });
103        }
104        if !(temperature.is_finite() && temperature > 0.0) {
105            return Err(SslError::InvalidTemperature { temp: temperature });
106        }
107        if !(eps.is_finite() && eps > 0.0) {
108            return Err(SslError::InvalidParameter {
109                name: "eps".into(),
110                reason: "must be finite and > 0".into(),
111            });
112        }
113        Ok(Self {
114            n_codes,
115            code_dim,
116            mask_ratio,
117            ema_momentum,
118            commitment_weight,
119            temperature,
120            eps,
121        })
122    }
123}
124
125// ─── VQ Codebook ─────────────────────────────────────────────────────────────
126
127/// EMA-maintained vector-quantisation codebook.
128///
129/// Stores K code vectors of dimension C as a flat `[K × C]` row-major buffer.
130/// EMA running statistics (`ema_counts`, `ema_sum`) are used for the
131/// Laplace-smoothed online codebook update rule:
132///
133/// ```text
134///     n_k  ← m·n_k + (1-m)·|S_k|
135///     e_k  ← m·e_k_sum + (1-m)·Σ_{z∈S_k} z   (running sum)
136///     code_k ← e_k_sum / n_k                   (normalised)
137/// ```
138#[derive(Debug, Clone)]
139pub struct VqCodebook {
140    /// Codebook entries E ∈ ℝ^{K×C}, row-major. Length = K * C.
141    pub embeddings: Vec<f32>,
142    /// Number of codes K.
143    pub n_codes: usize,
144    /// Code dimensionality C.
145    pub code_dim: usize,
146    /// EMA momentum m for codebook update (closer to 1 → slower update).
147    pub ema_momentum: f32,
148    /// Commitment loss weight β.
149    pub commitment_weight: f32,
150    /// EMA usage counts per code \[K\]. Initialised to 1 (Laplace smoothing).
151    pub ema_counts: Vec<f32>,
152    /// EMA running sums per code \[K × C\]. Initialised to the code vectors.
153    pub ema_sum: Vec<f32>,
154}
155
156// ─── Initialisation ───────────────────────────────────────────────────────────
157
158/// Initialise a [`VqCodebook`] with random N(0, 1/√C) entries and EMA state.
159///
160/// The codes are drawn from N(0, 1) and scaled by 1/√`code_dim` so that
161/// initialised norms are ≈ 1, preventing large initial VQ losses.
162///
163/// # Errors
164/// - [`SslError::InvalidParameter`] when `n_codes == 0` or `code_dim == 0`.
165pub fn vq_codebook_init(
166    n_codes: usize,
167    code_dim: usize,
168    rng: &mut LcgRng,
169) -> SslResult<VqCodebook> {
170    if n_codes == 0 {
171        return Err(SslError::InvalidParameter {
172            name: "n_codes".into(),
173            reason: "must be > 0".into(),
174        });
175    }
176    if code_dim == 0 {
177        return Err(SslError::InvalidParameter {
178            name: "code_dim".into(),
179            reason: "must be > 0".into(),
180        });
181    }
182    let total = n_codes * code_dim;
183    let mut embeddings = vec![0.0_f32; total];
184    rng.fill_normal(&mut embeddings);
185    let scale = 1.0 / (code_dim as f32).sqrt();
186    for v in &mut embeddings {
187        *v *= scale;
188    }
189    // EMA state: counts start at 1 (Laplace), sums start equal to the codes.
190    let ema_counts = vec![1.0_f32; n_codes];
191    let ema_sum = embeddings.clone();
192    Ok(VqCodebook {
193        embeddings,
194        n_codes,
195        code_dim,
196        ema_momentum: 0.999,
197        commitment_weight: 0.25,
198        ema_counts,
199        ema_sum,
200    })
201}
202
203// ─── Encoding ─────────────────────────────────────────────────────────────────
204
205/// Encode patch embeddings to nearest codebook indices + straight-through VQ loss.
206///
207/// For each embedding `z_i ∈ ℝ^C`, finds the nearest code via brute-force
208/// L2 distance: `k* = argmin_k ||z_i - e_k||²`.  Returns:
209/// - `indices`: `[N]` integer codebook assignments.
210/// - `quantized_z`: `[N × C]` quantised embeddings (with straight-through
211///   gradient: `z_q = z + sg(e_{k*} - z)`, implemented here as `e_{k*}`
212///   since we are forward-only).
213/// - `vq_loss`: scalar combining codebook loss + β·commitment loss.
214///
215/// VQ loss formula:
216/// ```text
217///     L_vq = mean_i [ ||sg(z_i) - e_{k*}||² + β·||z_i - sg(e_{k*})||² ]
218/// ```
219/// Both terms equal `||z_i - e_{k*}||²` in the forward pass (no gradients
220/// here); we weight the second by `commitment_weight`.
221///
222/// # Errors
223/// - [`SslError::EmptyInput`] when `n_patches == 0` or `code_dim == 0`.
224/// - [`SslError::DimensionMismatch`] when `embeddings.len() != n_patches * code_dim`.
225pub fn vq_encode(
226    codebook: &VqCodebook,
227    embeddings: &[f32],
228    n_patches: usize,
229    code_dim: usize,
230) -> SslResult<(Vec<usize>, Vec<f32>, f32)> {
231    if n_patches == 0 || code_dim == 0 {
232        return Err(SslError::EmptyInput);
233    }
234    let expected = n_patches * code_dim;
235    if embeddings.len() != expected {
236        return Err(SslError::DimensionMismatch {
237            expected,
238            got: embeddings.len(),
239        });
240    }
241    if codebook.n_codes == 0 {
242        return Err(SslError::EmptyInput);
243    }
244
245    let k = codebook.n_codes;
246    let c = code_dim;
247    let beta = codebook.commitment_weight;
248
249    let mut indices = Vec::with_capacity(n_patches);
250    let mut quantized_z = Vec::with_capacity(n_patches * c);
251    let mut vq_loss_acc = 0.0_f64;
252
253    for i in 0..n_patches {
254        let z = &embeddings[i * c..(i + 1) * c];
255
256        // Brute-force nearest-neighbour search: O(K·C) per embedding.
257        let mut best_k = 0usize;
258        let mut best_dist = f64::MAX;
259
260        for ki in 0..k {
261            let e_k = &codebook.embeddings[ki * c..(ki + 1) * c];
262            let dist: f64 = z
263                .iter()
264                .zip(e_k.iter())
265                .map(|(&zi, &eki)| {
266                    let d = (zi - eki) as f64;
267                    d * d
268                })
269                .sum();
270            if dist < best_dist {
271                best_dist = dist;
272                best_k = ki;
273            }
274        }
275
276        indices.push(best_k);
277
278        // Quantised embedding = nearest code (straight-through in forward).
279        let e_star = &codebook.embeddings[best_k * c..(best_k + 1) * c];
280        quantized_z.extend_from_slice(e_star);
281
282        // VQ loss: codebook term ||sg(z) - e_{k*}||² + β·||z - sg(e_{k*})||²
283        // Both equal best_dist in forward; we apply the β weight to the
284        // commitment (encoder) term.
285        vq_loss_acc += best_dist * (1.0 + beta as f64);
286    }
287
288    let vq_loss = (vq_loss_acc / n_patches as f64) as f32;
289    Ok((indices, quantized_z, vq_loss))
290}
291
292// ─── Codebook update ─────────────────────────────────────────────────────────
293
294/// EMA update of the codebook using the embeddings assigned to each code.
295///
296/// Implements:
297/// ```text
298///     n_k ← m·n_k + (1-m)·|S_k|          (EMA of cluster sizes)
299///     sum_k ← m·sum_k + (1-m)·Σ_{z∈S_k} z  (EMA of cluster sums)
300///     e_k ← sum_k / n_k                     (normalised code vector)
301/// ```
302/// Codes that receive no assignments in this batch are left unchanged
303/// (their counts and sums are decayed by momentum only).
304///
305/// # Errors
306/// - [`SslError::EmptyInput`] when `n_patches == 0`.
307/// - [`SslError::DimensionMismatch`] when slice lengths are inconsistent.
308pub fn vq_update_codebook(
309    codebook: &mut VqCodebook,
310    embeddings: &[f32],
311    indices: &[usize],
312    n_patches: usize,
313) -> SslResult<()> {
314    if n_patches == 0 {
315        return Err(SslError::EmptyInput);
316    }
317    let c = codebook.code_dim;
318    let k = codebook.n_codes;
319    let expected_emb = n_patches * c;
320    if embeddings.len() != expected_emb {
321        return Err(SslError::DimensionMismatch {
322            expected: expected_emb,
323            got: embeddings.len(),
324        });
325    }
326    if indices.len() != n_patches {
327        return Err(SslError::DimensionMismatch {
328            expected: n_patches,
329            got: indices.len(),
330        });
331    }
332    // Validate index range.
333    for &idx in indices {
334        if idx >= k {
335            return Err(SslError::InvalidParameter {
336                name: "index".into(),
337                reason: format!("codebook index {idx} out of range [0, {k})"),
338            });
339        }
340    }
341
342    let m = codebook.ema_momentum;
343    let one_minus_m = 1.0 - m;
344
345    // Accumulate per-code batch statistics.
346    let mut batch_counts = vec![0.0_f32; k];
347    let mut batch_sums = vec![0.0_f32; k * c];
348
349    for (i, &ki) in indices.iter().enumerate() {
350        batch_counts[ki] += 1.0;
351        let z = &embeddings[i * c..(i + 1) * c];
352        let sum_slice = &mut batch_sums[ki * c..(ki + 1) * c];
353        for (s, &zi) in sum_slice.iter_mut().zip(z.iter()) {
354            *s += zi;
355        }
356    }
357
358    // EMA update of counts and sums, then re-normalise codebook entries.
359    for ki in 0..k {
360        codebook.ema_counts[ki] = m * codebook.ema_counts[ki] + one_minus_m * batch_counts[ki];
361        let count = codebook.ema_counts[ki].max(1e-6); // avoid div-by-zero
362        let sum_slice = &mut codebook.ema_sum[ki * c..(ki + 1) * c];
363        let batch_sum_slice = &batch_sums[ki * c..(ki + 1) * c];
364        for (s, &bs) in sum_slice.iter_mut().zip(batch_sum_slice.iter()) {
365            *s = m * (*s) + one_minus_m * bs;
366        }
367        // Normalise to get the updated code vector.
368        let inv_count = 1.0 / count;
369        let emb_slice = &mut codebook.embeddings[ki * c..(ki + 1) * c];
370        let ema_sum_slice = &codebook.ema_sum[ki * c..(ki + 1) * c];
371        for (e, &es) in emb_slice.iter_mut().zip(ema_sum_slice.iter()) {
372            *e = es * inv_count;
373        }
374    }
375
376    Ok(())
377}
378
379// ─── BEiT pretraining loss ────────────────────────────────────────────────────
380
381/// Composite result from the BEiT pretraining loss.
382#[derive(Debug, Clone)]
383pub struct BeitResult {
384    /// Cross-entropy loss of student logits vs. discrete VQ tokens at masked positions.
385    pub beit_loss: f32,
386    /// VQ commitment loss (codebook term + β·encoder term).
387    pub vq_loss: f32,
388    /// `beit_loss + vq_loss`.
389    pub total_loss: f32,
390    /// Number of masked patches (positions where loss was computed).
391    pub n_masked: usize,
392    /// Fraction of codebook entries used at least once this batch (∈ [0, 1]).
393    pub codebook_usage: f32,
394    /// Effective codebook perplexity = exp(H(assignment distribution)) ∈ [1, K].
395    pub perplexity: f32,
396}
397
398/// BEiT pretraining cross-entropy loss.
399///
400/// Computes:
401/// ```text
402///     L = -1/M  Σ_{i: mask[i]=true}  log softmax(p_i / τ)[q_i]
403/// ```
404/// where `p_i ∈ ℝ^K` are the student's unnormalized logits for patch `i`,
405/// `q_i` is the VQ codebook index assigned by the tokenizer, `τ` is the
406/// softmax temperature, and `M` is the number of masked patches.
407///
408/// When no patches are masked (`mask` is all `false`), returns
409/// `BeitResult { beit_loss: 0, vq_loss, total_loss: vq_loss, n_masked: 0, .. }`.
410///
411/// # Errors
412/// - [`SslError::InvalidParameter`] when `n_codes == 0`.
413/// - [`SslError::EmptyInput`] when `n_patches == 0`.
414/// - [`SslError::DimensionMismatch`] when slice lengths are inconsistent.
415/// - [`SslError::InvalidTemperature`] when `config.temperature ≤ 0`.
416pub fn beit_loss(
417    student_logits: &[f32],
418    token_indices: &[usize],
419    mask: &[bool],
420    n_patches: usize,
421    n_codes: usize,
422    config: &BeitConfig,
423) -> SslResult<BeitResult> {
424    if n_codes == 0 {
425        return Err(SslError::InvalidParameter {
426            name: "n_codes".into(),
427            reason: "must be > 0".into(),
428        });
429    }
430    if n_patches == 0 {
431        return Err(SslError::EmptyInput);
432    }
433    if !(config.temperature.is_finite() && config.temperature > 0.0) {
434        return Err(SslError::InvalidTemperature {
435            temp: config.temperature,
436        });
437    }
438
439    let expected_logits = n_patches * n_codes;
440    if student_logits.len() != expected_logits {
441        return Err(SslError::DimensionMismatch {
442            expected: expected_logits,
443            got: student_logits.len(),
444        });
445    }
446    if token_indices.len() != n_patches {
447        return Err(SslError::DimensionMismatch {
448            expected: n_patches,
449            got: token_indices.len(),
450        });
451    }
452    if mask.len() != n_patches {
453        return Err(SslError::DimensionMismatch {
454            expected: n_patches,
455            got: mask.len(),
456        });
457    }
458
459    // Validate token index range.
460    for &qi in token_indices {
461        if qi >= n_codes {
462            return Err(SslError::InvalidParameter {
463                name: "token_index".into(),
464                reason: format!("token index {qi} out of range [0, {n_codes})"),
465            });
466        }
467    }
468
469    let tau = config.temperature;
470    let n_masked = mask.iter().filter(|&&m| m).count();
471
472    // ── BEiT cross-entropy at masked positions ────────────────────────────────
473    let mut beit_loss_acc = 0.0_f64;
474
475    // Per-code assignment frequency for perplexity / usage calculation.
476    let mut code_freq = vec![0.0_f64; n_codes];
477
478    for i in 0..n_patches {
479        let qi = token_indices[i];
480        let logits = &student_logits[i * n_codes..(i + 1) * n_codes];
481
482        // Accumulate code frequencies over ALL patches (not just masked) for
483        // a representative perplexity estimate.
484        code_freq[qi] += 1.0;
485
486        if !mask[i] {
487            continue; // only predict at masked positions
488        }
489
490        // Numerically stable softmax with temperature.
491        let mut max_v = f32::NEG_INFINITY;
492        for &lv in logits {
493            let scaled = lv / tau;
494            if scaled > max_v {
495                max_v = scaled;
496            }
497        }
498        let mut sum_exp = 0.0_f64;
499        let mut exp_qi = 0.0_f64;
500        for (k, &lv) in logits.iter().enumerate() {
501            let e = ((lv / tau - max_v) as f64).exp();
502            sum_exp += e;
503            if k == qi {
504                exp_qi = e;
505            }
506        }
507        let log_prob = (exp_qi / sum_exp.max(1e-30)).max(1e-30_f64).ln();
508        beit_loss_acc -= log_prob;
509    }
510
511    let beit_loss_val = if n_masked == 0 {
512        0.0_f32
513    } else {
514        (beit_loss_acc / n_masked as f64) as f32
515    };
516
517    // ── Codebook usage and perplexity ────────────────────────────────────────
518    let total_assignments = n_patches as f64;
519    let n_used = code_freq.iter().filter(|&&f| f > 0.0).count();
520    let codebook_usage = n_used as f32 / n_codes as f32;
521
522    // Perplexity = exp(H) where H = -Σ p_k log p_k, p_k = freq_k / total.
523    let mut entropy = 0.0_f64;
524    for &freq in &code_freq {
525        if freq > 0.0 {
526            let p = freq / total_assignments;
527            entropy -= p * p.ln();
528        }
529    }
530    let perplexity = entropy.exp().clamp(1.0, n_codes as f64) as f32;
531
532    // ── VQ loss (passthrough from config — callers typically compute it via
533    //    vq_encode + vq_update_codebook; here we provide 0 as placeholder
534    //    unless the caller supplies it via config.commitment_weight context).
535    // Since BEiT loss function doesn't have access to the raw embeddings,
536    // vq_loss is reported as 0 here.  The caller should add the vq_loss
537    // returned by vq_encode to the total when assembling the training step.
538    let vq_loss_val = 0.0_f32;
539    let total_loss = beit_loss_val + vq_loss_val;
540
541    Ok(BeitResult {
542        beit_loss: beit_loss_val,
543        vq_loss: vq_loss_val,
544        total_loss,
545        n_masked,
546        codebook_usage,
547        perplexity,
548    })
549}
550
551// ─── Block masking ────────────────────────────────────────────────────────────
552
553/// Generate a BEiT-style block mask on a 2-D patch grid.
554///
555/// Unlike MAE's per-patch Bernoulli mask, BEiT uses **random rectangular
556/// blocks** (aspect-ratio-aware) to mask contiguous spatial regions.  This
557/// encourages the model to reason about object structure rather than isolated
558/// pixels.
559///
560/// The algorithm:
561/// 1. Sample a random block area uniformly in `[min_area, max_area]` patches
562///    where `min_area = max(1, floor(n_patches · 0.05))` and
563///    `max_area = max(min_area, ceil(n_patches · 0.3))`.
564/// 2. Sample a random aspect ratio r ∈ {0.3, 0.5, 0.75, 1.0, 1.33, 2.0, 3.0}
565///    (log-uniform discrete grid from the BEiT paper).
566/// 3. Compute block height `bh = sqrt(area / r)`, width `bw = sqrt(area * r)`,
567///    clamped to the grid bounds.
568/// 4. Place the block at a uniformly random position.
569/// 5. Repeat until the number of newly masked patches reaches the target
570///    `floor(n_patches · mask_ratio)` or a safety iteration limit is hit.
571///
572/// Returns `Vec<bool>` of length `n_patches` in row-major order
573/// (`true` ⟺ patch is masked).
574///
575/// # Errors
576/// - [`SslError::EmptyInput`] when `patch_grid_h == 0` or `patch_grid_w == 0`.
577/// - [`SslError::InvalidMaskRatio`] when `mask_ratio ∉ [0, 1)`.
578/// - [`SslError::InvalidParameter`] when `n_patches != patch_grid_h * patch_grid_w`.
579pub fn beit_block_mask(
580    n_patches: usize,
581    patch_grid_h: usize,
582    patch_grid_w: usize,
583    mask_ratio: f32,
584    rng: &mut LcgRng,
585) -> SslResult<Vec<bool>> {
586    if patch_grid_h == 0 || patch_grid_w == 0 {
587        return Err(SslError::EmptyInput);
588    }
589    if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
590        return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
591    }
592    let grid_total = patch_grid_h * patch_grid_w;
593    if n_patches != grid_total {
594        return Err(SslError::InvalidParameter {
595            name: "n_patches".into(),
596            reason: format!(
597                "n_patches ({n_patches}) must equal patch_grid_h * patch_grid_w ({grid_total})"
598            ),
599        });
600    }
601
602    let target_masked = (n_patches as f32 * mask_ratio).floor() as usize;
603    let mut mask = vec![false; n_patches];
604    let mut n_masked = 0usize;
605
606    if target_masked == 0 {
607        return Ok(mask);
608    }
609
610    // Discrete aspect-ratio candidates (BEiT paper uses log-uniform grid).
611    const ASPECT_RATIOS: [f32; 7] = [0.3, 0.5, 0.75, 1.0, 1.33, 2.0, 3.0];
612
613    // Block area range: 5%–30% of total grid area, at least 1.
614    let min_area = (n_patches as f32 * 0.05).ceil() as usize;
615    let min_area = min_area.max(1);
616    let max_area = (n_patches as f32 * 0.30).ceil() as usize;
617    let max_area = max_area.max(min_area);
618
619    // Safety valve to prevent infinite loops on tiny grids.
620    let max_iters = (target_masked * 16 + 1).max(200);
621    let mut iters = 0usize;
622
623    while n_masked < target_masked && iters < max_iters {
624        iters += 1;
625
626        // Sample block area.
627        let area_range = max_area - min_area + 1;
628        let area = min_area + rng.next_usize(area_range);
629
630        // Sample aspect ratio.
631        let ratio_idx = rng.next_usize(ASPECT_RATIOS.len());
632        let ar = ASPECT_RATIOS[ratio_idx];
633
634        // Derive block height and width from area and aspect ratio.
635        let bh_f = (area as f32 / ar).sqrt();
636        let bw_f = (area as f32 * ar).sqrt();
637        let bh = (bh_f.round() as usize).clamp(1, patch_grid_h);
638        let bw = (bw_f.round() as usize).clamp(1, patch_grid_w);
639
640        // Uniformly sample top-left anchor.
641        let r0 = if patch_grid_h > bh {
642            rng.next_usize(patch_grid_h - bh + 1)
643        } else {
644            0
645        };
646        let c0 = if patch_grid_w > bw {
647            rng.next_usize(patch_grid_w - bw + 1)
648        } else {
649            0
650        };
651
652        // Stamp the block onto the mask.
653        for r in r0..r0 + bh {
654            for c in c0..c0 + bw {
655                let idx = r * patch_grid_w + c;
656                if !mask[idx] {
657                    mask[idx] = true;
658                    n_masked += 1;
659                    // Stop stamping this block if we've hit the target.
660                    if n_masked >= target_masked {
661                        break;
662                    }
663                }
664            }
665            if n_masked >= target_masked {
666                break;
667            }
668        }
669    }
670
671    Ok(mask)
672}
673
674// ─── Tests ────────────────────────────────────────────────────────────────────
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    // ── vq_codebook_init ──────────────────────────────────────────────────────
681
682    /// Initialised codebook must have exactly K×C entries.
683    #[test]
684    fn vq_codebook_init_correct_shape() {
685        let mut rng = LcgRng::new(1);
686        let cb = vq_codebook_init(64, 32, &mut rng).expect("vq_codebook_init should succeed");
687        assert_eq!(cb.embeddings.len(), 64 * 32);
688        assert_eq!(cb.n_codes, 64);
689        assert_eq!(cb.code_dim, 32);
690        assert_eq!(cb.ema_counts.len(), 64);
691        assert_eq!(cb.ema_sum.len(), 64 * 32);
692    }
693
694    /// All initialised entries must be finite.
695    #[test]
696    fn vq_codebook_init_entries_finite() {
697        let mut rng = LcgRng::new(2);
698        let cb = vq_codebook_init(16, 8, &mut rng).expect("vq_codebook_init should succeed");
699        assert!(cb.embeddings.iter().all(|v| v.is_finite()));
700        assert!(cb.ema_sum.iter().all(|v| v.is_finite()));
701    }
702
703    /// Zero n_codes must return an error.
704    #[test]
705    fn vq_codebook_init_rejects_zero_codes() {
706        let mut rng = LcgRng::new(3);
707        assert!(vq_codebook_init(0, 32, &mut rng).is_err());
708    }
709
710    /// Zero code_dim must return an error.
711    #[test]
712    fn vq_codebook_init_rejects_zero_dim() {
713        let mut rng = LcgRng::new(4);
714        assert!(vq_codebook_init(16, 0, &mut rng).is_err());
715    }
716
717    // ── vq_encode ─────────────────────────────────────────────────────────────
718
719    /// All returned indices must be within [0, K).
720    #[test]
721    fn vq_encode_indices_in_range() {
722        let mut rng = LcgRng::new(5);
723        let k = 32;
724        let c = 8;
725        let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
726        let n = 20;
727        let mut emb = vec![0.0_f32; n * c];
728        rng.fill_normal(&mut emb);
729        let (indices, _, _) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
730        assert_eq!(indices.len(), n);
731        for &idx in &indices {
732            assert!(idx < k, "index {idx} out of range");
733        }
734    }
735
736    /// VQ loss must be non-negative.
737    #[test]
738    fn vq_encode_vq_loss_non_negative() {
739        let mut rng = LcgRng::new(6);
740        let k = 16;
741        let c = 4;
742        let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
743        let n = 10;
744        let mut emb = vec![0.0_f32; n * c];
745        rng.fill_normal(&mut emb);
746        let (_, _, vq_loss) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
747        assert!(vq_loss >= 0.0, "vq_loss = {vq_loss} should be >= 0");
748    }
749
750    /// Quantised output has shape [N × C].
751    #[test]
752    fn vq_encode_quantized_shape() {
753        let mut rng = LcgRng::new(7);
754        let k = 8;
755        let c = 6;
756        let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
757        let n = 5;
758        let mut emb = vec![0.0_f32; n * c];
759        rng.fill_normal(&mut emb);
760        let (indices, quantized, _) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
761        assert_eq!(quantized.len(), n * c);
762        assert_eq!(indices.len(), n);
763    }
764
765    /// If the embedding exactly equals a codebook entry, that entry is selected.
766    #[test]
767    fn vq_encode_exact_match_selected() {
768        let mut rng = LcgRng::new(8);
769        let k = 8;
770        let c = 4;
771        let mut cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
772        // Force codebook[3] to be the zero vector.
773        for v in &mut cb.embeddings[3 * c..4 * c] {
774            *v = 0.0;
775        }
776        // Embed = zero vector → should match code 3 (or whichever other code
777        // is closest to zero; we just assert the returned distance is minimal).
778        let emb = vec![0.0_f32; c];
779        let (indices, _, vq_loss) = vq_encode(&cb, &emb, 1, c).expect("vq_encode should succeed");
780        // The selected code must be a valid index.
781        assert!(indices[0] < k);
782        // Loss must be non-negative.
783        assert!(vq_loss >= 0.0);
784    }
785
786    // ── vq_update_codebook ────────────────────────────────────────────────────
787
788    /// After assigning all embeddings to code 0 with a constant vector,
789    /// code 0 must move toward that vector.
790    #[test]
791    fn vq_update_codebook_ema_moves_toward_assigned() {
792        let mut rng = LcgRng::new(9);
793        let k = 4;
794        let c = 3;
795        let mut cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
796        // Use a small momentum so the update is visible.
797        cb.ema_momentum = 0.5;
798
799        // Record original code-0 value.
800        let orig_code0: Vec<f32> = cb.embeddings[0..c].to_vec();
801
802        // All patches assigned to code 0 with embedding = [1, 1, 1].
803        let n = 5;
804        let emb = vec![1.0_f32; n * c];
805        let indices = vec![0usize; n];
806        vq_update_codebook(&mut cb, &emb, &indices, n).expect("vq_update_codebook should succeed");
807
808        let updated_code0: Vec<f32> = cb.embeddings[0..c].to_vec();
809        // Each component of code 0 should be between its original value and 1.0.
810        for (orig, updated) in orig_code0.iter().zip(updated_code0.iter()) {
811            let dist_before = (orig - 1.0).abs();
812            let dist_after = (updated - 1.0).abs();
813            assert!(
814                dist_after < dist_before || dist_before < 1e-6,
815                "EMA update did not move code 0 toward [1,1,1]: orig={orig} updated={updated}"
816            );
817        }
818    }
819
820    // ── beit_loss ─────────────────────────────────────────────────────────────
821
822    /// BEiT loss must be finite and non-negative for random inputs.
823    #[test]
824    fn beit_loss_finite_and_non_negative() {
825        let mut rng = LcgRng::new(10);
826        let n = 16;
827        let k = 8;
828        let cfg = BeitConfig {
829            n_codes: k,
830            code_dim: 4,
831            ..BeitConfig::default()
832        };
833        let mut logits = vec![0.0_f32; n * k];
834        rng.fill_normal(&mut logits);
835        let indices: Vec<usize> = (0..n).map(|i| i % k).collect();
836        let mask: Vec<bool> = (0..n).map(|i| i % 2 == 0).collect();
837        let result =
838            beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
839        assert!(result.total_loss.is_finite(), "total_loss should be finite");
840        assert!(result.beit_loss >= 0.0, "beit_loss should be >= 0");
841    }
842
843    /// n_masked must equal the count of `true` entries in the mask.
844    #[test]
845    fn beit_loss_n_masked_matches_mask() {
846        let n = 20;
847        let k = 4;
848        let cfg = BeitConfig {
849            n_codes: k,
850            ..BeitConfig::default()
851        };
852        let logits = vec![1.0_f32; n * k];
853        let indices = vec![0usize; n];
854        let mask: Vec<bool> = (0..n).map(|i| i < 7).collect(); // 7 masked
855        let result =
856            beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
857        assert_eq!(result.n_masked, 7);
858    }
859
860    /// When mask is all-false, beit_loss must be 0 (no positions to predict).
861    #[test]
862    fn beit_loss_all_unmasked_returns_zero() {
863        let n = 8;
864        let k = 4;
865        let cfg = BeitConfig {
866            n_codes: k,
867            ..BeitConfig::default()
868        };
869        let logits = vec![0.5_f32; n * k];
870        let indices = vec![0usize; n];
871        let mask = vec![false; n];
872        let result =
873            beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
874        assert_eq!(result.n_masked, 0);
875        assert!(
876            result.beit_loss.abs() < 1e-7,
877            "expected 0 loss, got {}",
878            result.beit_loss
879        );
880    }
881
882    /// codebook_usage must lie in [0, 1].
883    #[test]
884    fn beit_loss_codebook_usage_in_range() {
885        let mut rng = LcgRng::new(11);
886        let n = 12;
887        let k = 16;
888        let cfg = BeitConfig {
889            n_codes: k,
890            ..BeitConfig::default()
891        };
892        let mut logits = vec![0.0_f32; n * k];
893        rng.fill_normal(&mut logits);
894        let indices: Vec<usize> = (0..n).map(|_| rng.next_usize(k)).collect();
895        let mask = vec![true; n];
896        let result =
897            beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
898        assert!(
899            (0.0..=1.0).contains(&result.codebook_usage),
900            "codebook_usage = {}",
901            result.codebook_usage
902        );
903    }
904
905    /// Perplexity must be in [1, K].
906    #[test]
907    fn beit_loss_perplexity_in_range() {
908        let mut rng = LcgRng::new(12);
909        let n = 32;
910        let k = 16;
911        let cfg = BeitConfig {
912            n_codes: k,
913            ..BeitConfig::default()
914        };
915        let mut logits = vec![0.0_f32; n * k];
916        rng.fill_normal(&mut logits);
917        // Assign each patch to a distinct code (cycling) to maximise diversity.
918        let indices: Vec<usize> = (0..n).map(|i| i % k).collect();
919        let mask = vec![true; n];
920        let result =
921            beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
922        assert!(
923            result.perplexity >= 1.0 && result.perplexity <= k as f32 + 1e-4,
924            "perplexity = {} out of [1, {}]",
925            result.perplexity,
926            k
927        );
928    }
929
930    /// Invalid n_codes = 0 must return an error.
931    #[test]
932    fn beit_loss_rejects_zero_n_codes() {
933        let logits = vec![1.0_f32; 4];
934        let indices = vec![0usize; 4];
935        let mask = vec![true; 4];
936        let cfg = BeitConfig::default();
937        assert!(beit_loss(&logits, &indices, &mask, 4, 0, &cfg).is_err());
938    }
939
940    // ── beit_block_mask ───────────────────────────────────────────────────────
941
942    /// The mask must have exactly n_patches entries.
943    #[test]
944    fn beit_block_mask_correct_length() {
945        let mut rng = LcgRng::new(13);
946        let h = 14;
947        let w = 14;
948        let n = h * w;
949        let mask = beit_block_mask(n, h, w, 0.4, &mut rng).expect("beit_block_mask should succeed");
950        assert_eq!(mask.len(), n);
951    }
952
953    /// With mask_ratio = 0, no patches should be masked.
954    #[test]
955    fn beit_block_mask_zero_ratio_all_unmasked() {
956        let mut rng = LcgRng::new(14);
957        let h = 8;
958        let w = 8;
959        let n = h * w;
960        let mask = beit_block_mask(n, h, w, 0.0, &mut rng).expect("beit_block_mask should succeed");
961        assert!(mask.iter().all(|&v| !v));
962    }
963
964    /// mask_ratio > 1 must return InvalidMaskRatio error.
965    #[test]
966    fn beit_block_mask_rejects_invalid_ratio() {
967        let mut rng = LcgRng::new(15);
968        assert!(beit_block_mask(16, 4, 4, 1.1, &mut rng).is_err());
969        assert!(beit_block_mask(16, 4, 4, -0.1, &mut rng).is_err());
970        assert!(beit_block_mask(16, 4, 4, f32::NAN, &mut rng).is_err());
971    }
972
973    /// For a 14×14 grid with mask_ratio ≈ 0.4, roughly 40% of patches should
974    /// be masked (within a wide block-masking tolerance of ±0.25).
975    #[test]
976    fn beit_block_mask_approx_ratio() {
977        let mut rng = LcgRng::new(16);
978        let h = 14;
979        let w = 14;
980        let n = h * w; // 196
981        let ratio = 0.4_f32;
982        let mask =
983            beit_block_mask(n, h, w, ratio, &mut rng).expect("beit_block_mask should succeed");
984        let n_masked = mask.iter().filter(|&&v| v).count();
985        // The block mask stops exactly at target = floor(196 * 0.4) = 78.
986        let target = (n as f32 * ratio).floor() as usize;
987        assert!(
988            n_masked <= target,
989            "n_masked ({n_masked}) > target ({target}): block stopped early but should not over-shoot"
990        );
991        // We expect it to reach at least 30% of target.
992        assert!(
993            n_masked >= target / 2,
994            "too few patches masked: {n_masked} vs target {target}"
995        );
996    }
997
998    /// A batch of patches: all assignments returned by vq_encode are valid.
999    #[test]
1000    fn vq_encode_batch_all_valid_assignments() {
1001        let mut rng = LcgRng::new(17);
1002        let k = 32;
1003        let c = 16;
1004        let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
1005        let n = 50;
1006        let mut emb = vec![0.0_f32; n * c];
1007        rng.fill_normal(&mut emb);
1008        let (indices, quantized, vq_loss) =
1009            vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
1010        assert_eq!(indices.len(), n);
1011        assert_eq!(quantized.len(), n * c);
1012        assert!(vq_loss.is_finite() && vq_loss >= 0.0);
1013        for &idx in &indices {
1014            assert!(idx < k, "assignment {idx} out of [0, {k})");
1015        }
1016    }
1017}