Skip to main content

oxicuda_vision/ssl/
dinov2.rs

1//! DINOv2 — a faithful CPU reference of the self-supervised distillation
2//! recipe from Oquab et al. 2023, *"DINOv2: Learning Robust Visual Features
3//! without Supervision"*, which combines the image-level **DINO** objective
4//! (Caron et al. 2021, *"Emerging Properties in Self-Supervised Vision
5//! Transformers"*) with the patch-level **iBOT** masked-image-modelling
6//! objective (Zhou et al. 2022, *"iBOT: Image BERT Pre-Training with Online
7//! Tokenizer"*).
8//!
9//! ## The pieces (all real computation, no stubs)
10//!
11//! 1. **ViT backbone** — patch-embed → prepend CLS → positional embed →
12//!    transformer encoder, returning the `[CLS]` embedding *and* the patch
13//!    tokens. We reuse [`crate::patch_embed::PatchEmbed`] and
14//!    [`crate::vit::ViTEncoder`].
15//! 2. **DINO projection head** — a 3-layer MLP (GELU) → L2-normalise →
16//!    **weight-normalised prototype layer** producing `K` prototype logits.
17//!    The prototype layer's rows are L2-normalised (weight norm with unit
18//!    magnitude), so a logit is the cosine similarity between the projected
19//!    feature and a learned prototype.
20//! 3. **DINO loss** — cross-entropy between a *sharpened, centred* teacher
21//!    distribution and a *softer* student distribution:
22//!    `H = −Σ_k p_t(k) · log p_s(k)`, with
23//!    `p_t = softmax((g_t − c) / τ_t)`, `p_s = softmax(g_s / τ_s)`, and the
24//!    teacher temperature `τ_t < τ_s` (sharper teacher).
25//! 4. **EMA teacher** — `θ_t ← m·θ_t + (1−m)·θ_s`.
26//! 5. **Centering** — running buffer `c ← λ·c + (1−λ)·mean_batch(g_t)`,
27//!    subtracted from teacher logits before the softmax to prevent collapse.
28//! 6. **iBOT masked-patch term** — at student patch positions that are masked,
29//!    predict the *teacher's* (unmasked) patch-prototype distribution via the
30//!    same cross-entropy, giving a dense per-patch self-distillation signal.
31//!
32//! All parameters are flat row-major `Vec<f32>`; no `unsafe`, no external RNG.
33
34use crate::{
35    error::{VisionError, VisionResult},
36    handle::LcgRng,
37    patch_embed::{PatchEmbed, PatchEmbedConfig, prepend_cls},
38    vit::vit_block::{gelu_exact, linear},
39    vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
40};
41
42// ─── Backbone output ────────────────────────────────────────────────────────────
43
44/// The two outputs of a DINOv2 ViT backbone forward pass.
45#[derive(Debug, Clone)]
46pub struct BackboneOutput {
47    /// The `[CLS]` embedding: flat `[embed_dim]`.
48    pub cls: Vec<f32>,
49    /// The patch tokens: flat `[n_patches · embed_dim]`.
50    pub patches: Vec<f32>,
51    /// Number of patch tokens.
52    pub n_patches: usize,
53}
54
55// ─── ViT backbone ───────────────────────────────────────────────────────────────
56
57/// A ViT backbone that returns both the `[CLS]` embedding and the patch tokens.
58pub struct DinoBackbone {
59    /// ViT hyper-parameters.
60    pub config: ViTConfig,
61    patch_embed: PatchEmbed,
62    cls_token: Vec<f32>,
63    pos_embed: Vec<f32>, // [(n_patches+1) · embed_dim]
64    encoder: ViTEncoder,
65}
66
67impl DinoBackbone {
68    /// Construct a backbone with Gaussian-initialised weights.
69    ///
70    /// # Errors
71    /// Propagates patch / encoder validation errors.
72    pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
73        let e = cfg.embed_dim;
74        let pe_cfg = PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, e)?;
75        let patch_embed = PatchEmbed::new(pe_cfg, rng);
76
77        let mut cls_token = vec![0.0f32; e];
78        rng.fill_normal(&mut cls_token);
79        for v in &mut cls_token {
80            *v *= 0.02;
81        }
82
83        let seq_len = cfg.n_patches() + 1;
84        let mut pos_embed = vec![0.0f32; seq_len * e];
85        rng.fill_normal(&mut pos_embed);
86        for v in &mut pos_embed {
87            *v *= 0.02;
88        }
89
90        let enc_cfg = ViTEncoderConfig::new(e, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
91        let encoder = ViTEncoder::new(enc_cfg, rng)?;
92
93        Ok(Self {
94            config: cfg,
95            patch_embed,
96            cls_token,
97            pos_embed,
98            encoder,
99        })
100    }
101
102    /// Forward pass returning the `[CLS]` embedding and the patch tokens.
103    ///
104    /// # Errors
105    /// Propagates dimension / backbone errors.
106    pub fn forward(&self, image: &[f32]) -> VisionResult<BackboneOutput> {
107        let e = self.config.embed_dim;
108        let n_patches = self.config.n_patches();
109
110        let patch_tokens = self.patch_embed.forward(image)?;
111        let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, e)?;
112        // Add positional embedding over CLS + patches.
113        for (t, p) in tokens.iter_mut().zip(self.pos_embed.iter()) {
114            *t += p;
115        }
116        let seq_len = n_patches + 1;
117        let encoded = self.encoder.forward(&tokens, seq_len)?;
118
119        let cls = encoded[..e].to_vec();
120        let patches = encoded[e..].to_vec();
121        Ok(BackboneOutput {
122            cls,
123            patches,
124            n_patches,
125        })
126    }
127}
128
129// ─── DINO projection head ───────────────────────────────────────────────────────
130
131/// The DINO projection head: 3-layer MLP (GELU) → L2-normalise →
132/// weight-normalised prototype layer producing `n_prototypes` logits.
133///
134/// The prototype layer is *weight-normalised*: its rows are L2-normalised so
135/// that each logit is the cosine similarity of the bottleneck feature with a
136/// learned unit prototype, scaled by a learned per-layer gain `g`.
137#[derive(Clone)]
138pub struct DinoHead {
139    in_dim: usize,
140    hidden_dim: usize,
141    bottleneck_dim: usize,
142    n_prototypes: usize,
143    // MLP: in → hidden → hidden → bottleneck.
144    w1: Vec<f32>,
145    b1: Vec<f32>,
146    w2: Vec<f32>,
147    b2: Vec<f32>,
148    w3: Vec<f32>,
149    b3: Vec<f32>,
150    /// Prototype directions `[n_prototypes · bottleneck_dim]` (normalised at use).
151    prototypes: Vec<f32>,
152    /// Weight-norm gain (scalar magnitude `g`), as in `weight_norm`.
153    gain: f32,
154}
155
156impl DinoHead {
157    /// Construct a head with Gaussian-initialised weights.
158    ///
159    /// # Errors
160    /// - [`VisionError::InvalidEmbedDim`] if `in_dim`, `hidden_dim`, or
161    ///   `bottleneck_dim` is 0.
162    /// - [`VisionError::InvalidProjDim`] if `n_prototypes == 0`.
163    pub fn new(
164        in_dim: usize,
165        hidden_dim: usize,
166        bottleneck_dim: usize,
167        n_prototypes: usize,
168        rng: &mut LcgRng,
169    ) -> VisionResult<Self> {
170        if in_dim == 0 {
171            return Err(VisionError::InvalidEmbedDim(in_dim));
172        }
173        if hidden_dim == 0 {
174            return Err(VisionError::InvalidEmbedDim(hidden_dim));
175        }
176        if bottleneck_dim == 0 {
177            return Err(VisionError::InvalidEmbedDim(bottleneck_dim));
178        }
179        if n_prototypes == 0 {
180            return Err(VisionError::InvalidProjDim(n_prototypes));
181        }
182
183        let fill = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
184            let mut v = vec![0.0f32; n];
185            rng.fill_normal(&mut v);
186            for x in &mut v {
187                *x *= sc;
188            }
189            v
190        };
191
192        let w1 = fill(rng, hidden_dim * in_dim, 1.0 / (in_dim as f32).sqrt());
193        let b1 = vec![0.0f32; hidden_dim];
194        let w2 = fill(
195            rng,
196            hidden_dim * hidden_dim,
197            1.0 / (hidden_dim as f32).sqrt(),
198        );
199        let b2 = vec![0.0f32; hidden_dim];
200        let w3 = fill(
201            rng,
202            bottleneck_dim * hidden_dim,
203            1.0 / (hidden_dim as f32).sqrt(),
204        );
205        let b3 = vec![0.0f32; bottleneck_dim];
206        // Prototype directions — random, normalised on the fly.
207        let prototypes = fill(
208            rng,
209            n_prototypes * bottleneck_dim,
210            1.0 / (bottleneck_dim as f32).sqrt(),
211        );
212
213        Ok(Self {
214            in_dim,
215            hidden_dim,
216            bottleneck_dim,
217            n_prototypes,
218            w1,
219            b1,
220            w2,
221            b2,
222            w3,
223            b3,
224            prototypes,
225            gain: 1.0,
226        })
227    }
228
229    /// Number of prototype logits produced by this head.
230    #[must_use]
231    pub fn n_prototypes(&self) -> usize {
232        self.n_prototypes
233    }
234
235    /// Apply the head to a single feature vector `[in_dim]`, returning the
236    /// `[n_prototypes]` prototype logits.
237    ///
238    /// Pipeline: `MLP → L2-normalise (bottleneck) → cosine vs each prototype × gain`.
239    ///
240    /// # Errors
241    /// - [`VisionError::DimensionMismatch`] if `x.len() != in_dim`.
242    pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
243        if x.len() != self.in_dim {
244            return Err(VisionError::DimensionMismatch {
245                expected: self.in_dim,
246                got: x.len(),
247            });
248        }
249
250        // 3-layer MLP with GELU between layers (no activation after the last).
251        let h1 = linear(x, &self.w1, &self.b1, self.in_dim, self.hidden_dim);
252        let h1: Vec<f32> = h1.into_iter().map(gelu_exact).collect();
253        let h2 = linear(&h1, &self.w2, &self.b2, self.hidden_dim, self.hidden_dim);
254        let h2: Vec<f32> = h2.into_iter().map(gelu_exact).collect();
255        let mut z = linear(
256            &h2,
257            &self.w3,
258            &self.b3,
259            self.hidden_dim,
260            self.bottleneck_dim,
261        );
262
263        // L2-normalise the bottleneck feature.
264        let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
265        let inv = 1.0 / norm.max(1e-12);
266        for v in &mut z {
267            *v *= inv;
268        }
269
270        // Weight-normalised prototype layer: logit_k = gain · ⟨z, p̂_k⟩.
271        let bd = self.bottleneck_dim;
272        let mut logits = vec![0.0f32; self.n_prototypes];
273        for (k, lk) in logits.iter_mut().enumerate() {
274            let proto = &self.prototypes[k * bd..(k + 1) * bd];
275            let pnorm: f32 = proto.iter().map(|&v| v * v).sum::<f32>().sqrt();
276            let pinv = 1.0 / pnorm.max(1e-12);
277            let dot: f32 = z.iter().zip(proto.iter()).map(|(&a, &b)| a * b).sum();
278            *lk = self.gain * dot * pinv;
279        }
280        Ok(logits)
281    }
282
283    /// Apply the head to a batch of features `[batch · in_dim]`, returning
284    /// `[batch · n_prototypes]`.
285    ///
286    /// # Errors
287    /// - [`VisionError::DimensionMismatch`] on a length not divisible by `in_dim`.
288    pub fn forward_batch(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
289        if x.is_empty() || x.len() % self.in_dim != 0 {
290            return Err(VisionError::DimensionMismatch {
291                expected: self.in_dim,
292                got: x.len() % self.in_dim,
293            });
294        }
295        let batch = x.len() / self.in_dim;
296        let mut out = vec![0.0f32; batch * self.n_prototypes];
297        for b in 0..batch {
298            let row = self.forward(&x[b * self.in_dim..(b + 1) * self.in_dim])?;
299            out[b * self.n_prototypes..(b + 1) * self.n_prototypes].copy_from_slice(&row);
300        }
301        Ok(out)
302    }
303
304    /// Total number of learnable scalars (used by the EMA update).
305    fn num_params(&self) -> usize {
306        self.w1.len()
307            + self.b1.len()
308            + self.w2.len()
309            + self.b2.len()
310            + self.w3.len()
311            + self.b3.len()
312            + self.prototypes.len()
313            + 1 // gain
314    }
315
316    /// Flatten all parameters into a single vector (for distance computations).
317    #[cfg(test)]
318    fn flatten(&self) -> Vec<f32> {
319        let mut v = Vec::with_capacity(self.num_params());
320        v.extend_from_slice(&self.w1);
321        v.extend_from_slice(&self.b1);
322        v.extend_from_slice(&self.w2);
323        v.extend_from_slice(&self.b2);
324        v.extend_from_slice(&self.w3);
325        v.extend_from_slice(&self.b3);
326        v.extend_from_slice(&self.prototypes);
327        v.push(self.gain);
328        v
329    }
330
331    /// EMA update **of this (teacher) head toward** the `student` head:
332    /// `θ_t ← m·θ_t + (1−m)·θ_s` applied parameter-wise.
333    ///
334    /// # Errors
335    /// - [`VisionError::Internal`] if the heads have mismatched parameter shapes.
336    pub fn ema_update(&mut self, student: &DinoHead, momentum: f32) -> VisionResult<()> {
337        if self.num_params() != student.num_params()
338            || self.w1.len() != student.w1.len()
339            || self.prototypes.len() != student.prototypes.len()
340        {
341            return Err(VisionError::Internal(
342                "ema_update: teacher/student head shape mismatch".into(),
343            ));
344        }
345        let m = momentum;
346        let lerp = |dst: &mut [f32], src: &[f32]| {
347            for (d, &s) in dst.iter_mut().zip(src.iter()) {
348                *d = m * *d + (1.0 - m) * s;
349            }
350        };
351        lerp(&mut self.w1, &student.w1);
352        lerp(&mut self.b1, &student.b1);
353        lerp(&mut self.w2, &student.w2);
354        lerp(&mut self.b2, &student.b2);
355        lerp(&mut self.w3, &student.w3);
356        lerp(&mut self.b3, &student.b3);
357        lerp(&mut self.prototypes, &student.prototypes);
358        self.gain = m * self.gain + (1.0 - m) * student.gain;
359        Ok(())
360    }
361}
362
363// ─── Distributions, centering, and the DINO loss ────────────────────────────────
364
365/// Numerically-stable softmax of `logits / temperature` (after subtracting an
366/// optional per-element `center`).
367///
368/// `center` may be empty (no centering) or the same length as `logits`.
369fn softmax_temp(logits: &[f32], center: &[f32], temperature: f32) -> Vec<f32> {
370    let n = logits.len();
371    let mut scaled = vec![0.0f32; n];
372    if center.is_empty() {
373        for i in 0..n {
374            scaled[i] = logits[i] / temperature;
375        }
376    } else {
377        for i in 0..n {
378            scaled[i] = (logits[i] - center[i]) / temperature;
379        }
380    }
381    let mx = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
382    let mut sum = 0.0f32;
383    for s in &mut scaled {
384        *s = (*s - mx).exp();
385        sum += *s;
386    }
387    let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
388    for s in &mut scaled {
389        *s *= inv;
390    }
391    scaled
392}
393
394/// Softmax of `logits/τ_s` (student branch — never centred).
395///
396/// # Errors
397/// - [`VisionError::NonPositiveTemperature`] if `tau <= 0`.
398pub fn student_softmax(logits: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
399    if tau <= 0.0 {
400        return Err(VisionError::NonPositiveTemperature(tau));
401    }
402    Ok(softmax_temp(logits, &[], tau))
403}
404
405/// Sharpened, centred teacher distribution `softmax((g_t − c)/τ_t)`.
406///
407/// `center` must be empty or have the same length as `logits`.
408///
409/// # Errors
410/// - [`VisionError::NonPositiveTemperature`] if `tau <= 0`.
411/// - [`VisionError::DimensionMismatch`] if `center` is non-empty and mismatched.
412pub fn teacher_softmax(logits: &[f32], center: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
413    if tau <= 0.0 {
414        return Err(VisionError::NonPositiveTemperature(tau));
415    }
416    if !center.is_empty() && center.len() != logits.len() {
417        return Err(VisionError::DimensionMismatch {
418            expected: logits.len(),
419            got: center.len(),
420        });
421    }
422    Ok(softmax_temp(logits, center, tau))
423}
424
425/// Cross-entropy `H(p_t, p_s) = −Σ_k p_t(k) · log p_s(k)`.
426///
427/// The teacher distribution `p_t` is the *target*; the student distribution
428/// `p_s` is the *prediction*. Returns a value `≥ 0`, equal to the entropy of
429/// `p_t` when `p_s == p_t` (its minimum over `p_s`), and `0` exactly when both
430/// are the same one-hot distribution.
431///
432/// # Errors
433/// - [`VisionError::DimensionMismatch`] if the two distributions differ in length.
434pub fn cross_entropy(p_teacher: &[f32], p_student: &[f32]) -> VisionResult<f32> {
435    if p_teacher.len() != p_student.len() {
436        return Err(VisionError::DimensionMismatch {
437            expected: p_teacher.len(),
438            got: p_student.len(),
439        });
440    }
441    let mut h = 0.0f32;
442    for (&pt, &ps) in p_teacher.iter().zip(p_student.iter()) {
443        if pt > 0.0 {
444            // Guard log(0); ps ∈ (0, 1] for a softmax, but clamp defensively.
445            h -= pt * ps.max(1e-12).ln();
446        }
447    }
448    Ok(h)
449}
450
451/// The full DINO loss between teacher logits and student logits.
452///
453/// Computes `H(softmax((g_t − c)/τ_t), softmax(g_s/τ_s))`. The teacher branch
454/// is centred (with running buffer `c`) and sharpened (`τ_t < τ_s`).
455///
456/// # Errors
457/// - Non-positive temperatures, or mismatched lengths.
458pub fn dino_loss(
459    student_logits: &[f32],
460    teacher_logits: &[f32],
461    center: &[f32],
462    tau_student: f32,
463    tau_teacher: f32,
464) -> VisionResult<f32> {
465    if student_logits.len() != teacher_logits.len() {
466        return Err(VisionError::DimensionMismatch {
467            expected: teacher_logits.len(),
468            got: student_logits.len(),
469        });
470    }
471    let p_t = teacher_softmax(teacher_logits, center, tau_teacher)?;
472    let p_s = student_softmax(student_logits, tau_student)?;
473    cross_entropy(&p_t, &p_s)
474}
475
476// ─── Centering buffer ───────────────────────────────────────────────────────────
477
478/// Running centre buffer for the teacher outputs.
479///
480/// Updated by `c ← λ·c + (1−λ)·mean_batch(g_t)` and subtracted from teacher
481/// logits before the softmax. This prevents the trivial collapse where the
482/// teacher always predicts the same prototype.
483#[derive(Debug, Clone)]
484pub struct CenteringBuffer {
485    /// The centre vector `[n_prototypes]`.
486    pub center: Vec<f32>,
487    /// EMA decay `λ ∈ [0, 1)`.
488    pub momentum: f32,
489}
490
491impl CenteringBuffer {
492    /// New zero-initialised buffer of dimension `dim`.
493    #[must_use]
494    pub fn new(dim: usize, momentum: f32) -> Self {
495        Self {
496            center: vec![0.0f32; dim],
497            momentum,
498        }
499    }
500
501    /// Update the centre from a batch of teacher logits `[batch · dim]`.
502    ///
503    /// `c ← λ·c + (1−λ)·mean_batch(g_t)`.
504    ///
505    /// # Errors
506    /// - [`VisionError::DimensionMismatch`] if `batch_logits` is not a multiple
507    ///   of `dim` or is empty.
508    pub fn update(&mut self, batch_logits: &[f32]) -> VisionResult<()> {
509        let dim = self.center.len();
510        if dim == 0 || batch_logits.is_empty() || batch_logits.len() % dim != 0 {
511            return Err(VisionError::DimensionMismatch {
512                expected: dim,
513                got: batch_logits.len(),
514            });
515        }
516        let batch = batch_logits.len() / dim;
517        let mut mean = vec![0.0f32; dim];
518        for b in 0..batch {
519            for k in 0..dim {
520                mean[k] += batch_logits[b * dim + k];
521            }
522        }
523        let inv_b = 1.0 / batch as f32;
524        let lam = self.momentum;
525        for (c, m) in self.center.iter_mut().zip(mean.iter()) {
526            let batch_mean = m * inv_b;
527            *c = lam * *c + (1.0 - lam) * batch_mean;
528        }
529        Ok(())
530    }
531}
532
533// ─── iBOT masked-patch term ─────────────────────────────────────────────────────
534
535/// The iBOT masked-image-modelling loss term.
536///
537/// For each *masked* student patch position, the student must predict the
538/// teacher's (unmasked) patch-prototype distribution. The loss is the mean
539/// cross-entropy over the masked positions; positions that are not masked are
540/// ignored.
541///
542/// - `student_patch_logits` / `teacher_patch_logits`: `[n_patches · n_proto]`.
543/// - `mask`: `[n_patches]` booleans; `true` ⇒ that patch is masked for the
544///   student and contributes to the loss.
545/// - `patch_center`: optional `[n_proto]` centre for the teacher patch head.
546///
547/// Returns `0.0` if no patch is masked.
548///
549/// # Errors
550/// - Mismatched shapes or non-positive temperatures.
551pub fn ibot_loss(
552    student_patch_logits: &[f32],
553    teacher_patch_logits: &[f32],
554    mask: &[bool],
555    patch_center: &[f32],
556    n_proto: usize,
557    tau_student: f32,
558    tau_teacher: f32,
559) -> VisionResult<f32> {
560    if n_proto == 0 {
561        return Err(VisionError::InvalidProjDim(n_proto));
562    }
563    let n_patches = mask.len();
564    if student_patch_logits.len() != n_patches * n_proto
565        || teacher_patch_logits.len() != n_patches * n_proto
566    {
567        return Err(VisionError::DimensionMismatch {
568            expected: n_patches * n_proto,
569            got: student_patch_logits.len(),
570        });
571    }
572
573    let mut total = 0.0f32;
574    let mut count = 0usize;
575    for p in 0..n_patches {
576        if !mask[p] {
577            continue;
578        }
579        let s = &student_patch_logits[p * n_proto..(p + 1) * n_proto];
580        let t = &teacher_patch_logits[p * n_proto..(p + 1) * n_proto];
581        let l = dino_loss(s, t, patch_center, tau_student, tau_teacher)?;
582        total += l;
583        count += 1;
584    }
585    if count == 0 {
586        return Ok(0.0);
587    }
588    Ok(total / count as f32)
589}
590
591// ─── Tests ──────────────────────────────────────────────────────────────────────
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    fn l2(a: &[f32], b: &[f32]) -> f32 {
598        a.iter()
599            .zip(b.iter())
600            .map(|(&x, &y)| (x - y) * (x - y))
601            .sum::<f32>()
602            .sqrt()
603    }
604
605    fn entropy(p: &[f32]) -> f32 {
606        let mut h = 0.0f32;
607        for &v in p {
608            if v > 0.0 {
609                h -= v * v.ln();
610            }
611        }
612        h
613    }
614
615    fn make_head(seed: u64, k: usize) -> DinoHead {
616        let mut rng = LcgRng::new(seed);
617        DinoHead::new(32, 64, 16, k, &mut rng).expect("head ok")
618    }
619
620    // ── Backbone ──────────────────────────────────────────────────────────────────
621
622    #[test]
623    fn backbone_returns_cls_and_patches() {
624        let mut rng = LcgRng::new(1);
625        let cfg = ViTConfig::tiny();
626        let e = cfg.embed_dim;
627        let n_patches = cfg.n_patches();
628        let bb = DinoBackbone::new(cfg, &mut rng).expect("backbone ok");
629        let img = vec![0.3f32; 3 * 32 * 32];
630        let out = bb.forward(&img).expect("forward ok");
631        assert_eq!(out.cls.len(), e, "CLS must be [embed_dim]");
632        assert_eq!(
633            out.patches.len(),
634            n_patches * e,
635            "patches must be [n_patches, e]"
636        );
637        assert_eq!(out.n_patches, n_patches);
638        assert!(out.cls.iter().all(|v| v.is_finite()));
639        assert!(out.patches.iter().all(|v| v.is_finite()));
640    }
641
642    // ── Head: (f) prototype logits shape + softmax sums to 1 ──────────────────────
643
644    #[test]
645    fn head_prototype_logits_shape_and_softmax() {
646        let head = make_head(2, 128);
647        let mut rng = LcgRng::new(3);
648        let mut x = vec![0.0f32; 32];
649        rng.fill_normal(&mut x);
650        let logits = head.forward(&x).expect("ok");
651        assert_eq!(logits.len(), 128, "prototype logits must be [n_prototypes]");
652        let p = student_softmax(&logits, 0.1).expect("ok");
653        let sum: f32 = p.iter().sum();
654        assert!((sum - 1.0).abs() < 1e-5, "softmax must sum to 1; got {sum}");
655        // Logits are cosine·gain, so each is within [-gain, gain] = [-1, 1].
656        for &l in &logits {
657            assert!(
658                (-1.0 - 1e-4..=1.0 + 1e-4).contains(&l),
659                "logit out of cosine range: {l}"
660            );
661        }
662    }
663
664    // ── (a) EMA update moves teacher strictly toward student ──────────────────────
665
666    #[test]
667    fn ema_update_moves_teacher_toward_student() {
668        let mut teacher = make_head(10, 64);
669        let student = make_head(20, 64); // different seed ⇒ different params
670        let before = l2(&teacher.flatten(), &student.flatten());
671        assert!(before > 0.0, "teacher and student must start apart");
672        teacher.ema_update(&student, 0.9).expect("ema ok");
673        let after = l2(&teacher.flatten(), &student.flatten());
674        assert!(
675            after < before,
676            "EMA must reduce ‖θ_t − θ_s‖: before={before}, after={after}"
677        );
678        // For momentum m, the distance scales by exactly m.
679        assert!(
680            (after - 0.9 * before).abs() < 1e-3 * before.max(1.0),
681            "EMA distance should scale by m=0.9: after={after}, 0.9·before={}",
682            0.9 * before
683        );
684    }
685
686    #[test]
687    fn ema_update_shape_mismatch_errors() {
688        let mut teacher = make_head(10, 64);
689        let other = make_head(11, 32); // different n_prototypes
690        let r = teacher.ema_update(&other, 0.9);
691        assert!(matches!(r, Err(VisionError::Internal(_))));
692    }
693
694    // ── (b) DINO loss ≥ 0, = 0 only when distributions match ──────────────────────
695
696    #[test]
697    fn dino_loss_nonnegative() {
698        let mut rng = LcgRng::new(30);
699        for _ in 0..20 {
700            let mut sl = vec![0.0f32; 16];
701            let mut tl = vec![0.0f32; 16];
702            rng.fill_normal(&mut sl);
703            rng.fill_normal(&mut tl);
704            let l = dino_loss(&sl, &tl, &[], 0.1, 0.04).expect("ok");
705            assert!(l >= -1e-6, "DINO loss must be ≥ 0; got {l}");
706        }
707    }
708
709    #[test]
710    fn dino_loss_minimised_when_student_matches_teacher() {
711        // When the student distribution equals the teacher distribution, the
712        // cross-entropy equals the teacher entropy (its minimum over p_s).
713        // A one-hot teacher (entropy 0) ⇒ loss → 0 as the student concentrates.
714        let teacher_logits = vec![20.0f32, -20.0, -20.0, -20.0]; // ~one-hot at 0
715        let student_logits = vec![20.0f32, -20.0, -20.0, -20.0];
716        // Same temperature so distributions coincide.
717        let p_t = teacher_softmax(&teacher_logits, &[], 0.1).expect("ok");
718        let p_s = student_softmax(&student_logits, 0.1).expect("ok");
719        let h_self = cross_entropy(&p_t, &p_s).expect("ok");
720        assert!(
721            h_self < 1e-3,
722            "matched ~one-hot dists give ≈0 loss; got {h_self}"
723        );
724
725        // A mismatched student must give a strictly larger loss.
726        let student_bad = vec![-20.0f32, 20.0, -20.0, -20.0]; // peaks elsewhere
727        let p_bad = student_softmax(&student_bad, 0.1).expect("ok");
728        let h_bad = cross_entropy(&p_t, &p_bad).expect("ok");
729        assert!(
730            h_bad > h_self + 1.0,
731            "mismatched student must raise the loss: self={h_self}, bad={h_bad}"
732        );
733    }
734
735    #[test]
736    fn cross_entropy_equals_entropy_at_self() {
737        // H(p, p) == entropy(p) for any distribution p.
738        let logits = vec![1.0f32, 0.3, -0.5, 2.0, -1.0];
739        let p = student_softmax(&logits, 1.0).expect("ok");
740        let ce = cross_entropy(&p, &p).expect("ok");
741        let ent = entropy(&p);
742        assert!(
743            (ce - ent).abs() < 1e-5,
744            "H(p,p) must equal entropy(p): {ce} vs {ent}"
745        );
746    }
747
748    // ── (c) Centering keeps running teacher-output mean near 0 ────────────────────
749
750    #[test]
751    fn centering_drives_mean_near_zero() {
752        // Repeatedly feed the same biased batch; the centred logits' mean must
753        // shrink toward 0 as the centre converges to the batch mean.
754        let dim = 8;
755        let mut buf = CenteringBuffer::new(dim, 0.9);
756        // Biased teacher logits: every sample equals the same vector with a
757        // strong offset on dim 0.
758        let base: Vec<f32> = (0..dim).map(|k| if k == 0 { 5.0 } else { 0.1 }).collect();
759        let batch = 4;
760        let mut flat = Vec::new();
761        for _ in 0..batch {
762            flat.extend_from_slice(&base);
763        }
764
765        // Many updates ⇒ centre → batch mean (== base here).
766        for _ in 0..400 {
767            buf.update(&flat).expect("ok");
768        }
769        // Centred logits: base − centre ≈ 0.
770        let centred_mean: f32 = base
771            .iter()
772            .zip(buf.center.iter())
773            .map(|(&g, &c)| (g - c).abs())
774            .sum::<f32>()
775            / dim as f32;
776        assert!(
777            centred_mean < 1e-2,
778            "centering should drive (g − c) mean ≈ 0; got {centred_mean}"
779        );
780    }
781
782    #[test]
783    fn centering_update_bad_shape_errors() {
784        let mut buf = CenteringBuffer::new(8, 0.9);
785        let r = buf.update(&[0.0f32; 7]); // 7 not a multiple of 8
786        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
787    }
788
789    // ── (d) Lower teacher temperature sharpens (lowers entropy) ───────────────────
790
791    #[test]
792    fn lower_teacher_temperature_sharpens_distribution() {
793        let logits = vec![2.0f32, 1.0, 0.5, -0.5, -1.0, 0.2];
794        let p_hot = teacher_softmax(&logits, &[], 0.04).expect("ok"); // sharp
795        let p_soft = teacher_softmax(&logits, &[], 0.5).expect("ok"); // soft
796        let h_hot = entropy(&p_hot);
797        let h_soft = entropy(&p_soft);
798        assert!(
799            h_hot < h_soft,
800            "lower τ_t must lower entropy (sharper): H(0.04)={h_hot} vs H(0.5)={h_soft}"
801        );
802        // And the sharper distribution must have a larger peak probability.
803        let max_hot = p_hot.iter().cloned().fold(0.0f32, f32::max);
804        let max_soft = p_soft.iter().cloned().fold(0.0f32, f32::max);
805        assert!(max_hot > max_soft, "sharper dist must have a higher peak");
806    }
807
808    // ── (e) Nudging student toward teacher drives the loss DOWN ───────────────────
809
810    #[test]
811    fn nudging_student_toward_teacher_lowers_loss() {
812        // Simulate one gradient-free optimisation step: move student logits a
813        // fraction of the way toward the teacher logits and check the DINO loss
814        // (with matched temperatures so the target is well-defined) decreases.
815        let teacher_logits = vec![1.5f32, -0.5, 0.7, -1.2, 0.3, 0.9];
816        let student_before = vec![-1.0f32, 0.8, -0.3, 1.1, -0.6, 0.0];
817        let tau = 0.1;
818
819        let loss_before = dino_loss(&student_before, &teacher_logits, &[], tau, tau).expect("ok");
820
821        // Nudge 60% toward the teacher.
822        let alpha = 0.6f32;
823        let student_after: Vec<f32> = student_before
824            .iter()
825            .zip(teacher_logits.iter())
826            .map(|(&s, &t)| s + alpha * (t - s))
827            .collect();
828        let loss_after = dino_loss(&student_after, &teacher_logits, &[], tau, tau).expect("ok");
829
830        assert!(
831            loss_after < loss_before,
832            "moving the student toward the teacher must lower the loss: before={loss_before}, after={loss_after}"
833        );
834    }
835
836    #[test]
837    fn two_views_loss_decreases_when_student_aligns() {
838        // Two augmented "views": teacher sees view A, student sees view B. We
839        // emulate the head outputs as logits and verify that aligning the
840        // student logits toward the teacher target (one gradient-free step)
841        // reduces the cross-view DINO loss.
842        let head = make_head(40, 32);
843        let mut rng = LcgRng::new(41);
844        let mut view_a = vec![0.0f32; 32];
845        let mut view_b = vec![0.0f32; 32];
846        rng.fill_normal(&mut view_a);
847        rng.fill_normal(&mut view_b);
848
849        let teacher_logits = head.forward(&view_a).expect("ok");
850        let student_logits = head.forward(&view_b).expect("ok");
851        let tau = 0.1;
852        let loss_before = dino_loss(&student_logits, &teacher_logits, &[], tau, tau).expect("ok");
853
854        let nudged: Vec<f32> = student_logits
855            .iter()
856            .zip(teacher_logits.iter())
857            .map(|(&s, &t)| s + 0.5 * (t - s))
858            .collect();
859        let loss_after = dino_loss(&nudged, &teacher_logits, &[], tau, tau).expect("ok");
860        assert!(
861            loss_after < loss_before,
862            "aligning student to teacher across views must lower loss: {loss_before} → {loss_after}"
863        );
864    }
865
866    // ── Student softmax temperature guard ─────────────────────────────────────────
867
868    #[test]
869    fn nonpositive_temperature_errors() {
870        let r = student_softmax(&[1.0, 2.0], 0.0);
871        assert!(matches!(r, Err(VisionError::NonPositiveTemperature(_))));
872        let r2 = teacher_softmax(&[1.0, 2.0], &[], -0.1);
873        assert!(matches!(r2, Err(VisionError::NonPositiveTemperature(_))));
874    }
875
876    // ── iBOT masked-patch term ────────────────────────────────────────────────────
877
878    #[test]
879    fn ibot_loss_only_counts_masked_patches() {
880        let n_patches = 4;
881        let n_proto = 6;
882        let mut rng = LcgRng::new(50);
883        let mut s = vec![0.0f32; n_patches * n_proto];
884        let mut t = vec![0.0f32; n_patches * n_proto];
885        rng.fill_normal(&mut s);
886        rng.fill_normal(&mut t);
887
888        // No patch masked ⇒ loss is exactly 0.
889        let none = vec![false; n_patches];
890        let l0 = ibot_loss(&s, &t, &none, &[], n_proto, 0.1, 0.04).expect("ok");
891        assert_eq!(l0, 0.0, "no masked patches ⇒ zero iBOT loss");
892
893        // Mask patches 0 and 2 ⇒ loss equals the mean of their per-patch losses.
894        let mut mask = vec![false; n_patches];
895        mask[0] = true;
896        mask[2] = true;
897        let l = ibot_loss(&s, &t, &mask, &[], n_proto, 0.1, 0.04).expect("ok");
898        let l_p0 = dino_loss(&s[0..n_proto], &t[0..n_proto], &[], 0.1, 0.04).expect("ok");
899        let l_p2 = dino_loss(
900            &s[2 * n_proto..3 * n_proto],
901            &t[2 * n_proto..3 * n_proto],
902            &[],
903            0.1,
904            0.04,
905        )
906        .expect("ok");
907        let expected = 0.5 * (l_p0 + l_p2);
908        assert!(
909            (l - expected).abs() < 1e-5,
910            "iBOT loss must average masked-patch losses: {l} vs {expected}"
911        );
912        assert!(l >= 0.0, "iBOT loss must be ≥ 0");
913    }
914
915    #[test]
916    fn ibot_loss_nudging_masked_student_lowers_loss() {
917        // Aligning the masked student patch toward the teacher patch reduces it.
918        let n_proto = 5;
919        let teacher = vec![
920            // patch 0 (masked target)
921            1.2f32, -0.4, 0.6, -1.0, 0.2, // patch 1
922            0.1, 0.1, 0.1, 0.1, 0.1,
923        ];
924        let student = vec![
925            -1.0f32, 0.7, -0.2, 1.0, -0.5, // patch 1
926            0.0, 0.0, 0.0, 0.0, 0.0,
927        ];
928        let mask = vec![true, false];
929        let tau = 0.1;
930        let before = ibot_loss(&student, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
931
932        let mut nudged = student.clone();
933        for k in 0..n_proto {
934            nudged[k] += 0.6 * (teacher[k] - student[k]);
935        }
936        let after = ibot_loss(&nudged, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
937        assert!(
938            after < before,
939            "nudging masked student patch toward teacher must lower iBOT loss: {before} → {after}"
940        );
941    }
942
943    #[test]
944    fn ibot_loss_bad_shape_errors() {
945        let mask = vec![true, false];
946        let r = ibot_loss(&[0.0f32; 5], &[0.0f32; 10], &mask, &[], 5, 0.1, 0.04);
947        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
948    }
949
950    // ── Head batch ────────────────────────────────────────────────────────────────
951
952    #[test]
953    fn head_forward_batch_matches_single() {
954        let head = make_head(60, 32);
955        let mut rng = LcgRng::new(61);
956        let batch = 3;
957        let mut x = vec![0.0f32; batch * 32];
958        rng.fill_normal(&mut x);
959        let all = head.forward_batch(&x).expect("ok");
960        let k = head.n_prototypes();
961        for b in 0..batch {
962            let single = head.forward(&x[b * 32..(b + 1) * 32]).expect("ok");
963            for (j, &v) in single.iter().enumerate() {
964                assert!(
965                    (all[b * k + j] - v).abs() < 1e-6,
966                    "batch vs single mismatch at b={b}, j={j}"
967                );
968            }
969        }
970    }
971
972    #[test]
973    fn head_dimension_mismatch_errors() {
974        let head = make_head(70, 32);
975        let r = head.forward(&[0.0f32; 31]);
976        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
977    }
978
979    #[test]
980    fn head_zero_prototypes_errors() {
981        let mut rng = LcgRng::new(80);
982        let r = DinoHead::new(32, 64, 16, 0, &mut rng);
983        assert!(matches!(r, Err(VisionError::InvalidProjDim(0))));
984    }
985}