Skip to main content

oxicuda_ssl/masked/
data2vec.rs

1//! data2vec — Baevski et al. 2022, ICML.
2//!
3//! Unified self-supervised learning via **teacher-student masked prediction**:
4//!
5//! 1. A **teacher** network (EMA of the student) encodes the **full, unmasked**
6//!    input and produces target representations.
7//! 2. The **student** encoder receives the **masked** input and predicts the
8//!    teacher's representations at masked positions.
9//! 3. The loss is the smooth-L1 (Huber) divergence between L2-normalised student
10//!    predictions and L2-normalised teacher targets, summed only over masked tokens.
11//!
12//! ```text
13//!  θ_teacher ← m · θ_teacher + (1−m) · θ_student       [EMA update]
14//!  target_j  ← target_j / (‖target[:,j]‖₂ + ε)         [per-dim batch norm]
15//!  L          = mean huber(student_pred − target, β)     [masked positions only]
16//! ```
17//!
18//! Reference: "data2vec: A General Framework for Self-supervised Learning in
19//! Speech, Vision and Language", Baevski et al., ICML 2022.
20
21use crate::error::{SslError, SslResult};
22use crate::handle::LcgRng;
23
24// ─── Configuration ────────────────────────────────────────────────────────────
25
26/// Hyper-parameters for the data2vec training objective.
27#[derive(Debug, Clone)]
28pub struct Data2VecConfig {
29    /// Fraction of tokens to mask (default 0.65, paper canonical for vision).
30    pub mask_ratio: f32,
31    /// EMA coefficient for the teacher network (default 0.999).
32    pub momentum: f32,
33    /// Huber loss threshold β (default 2.0).
34    pub beta: f32,
35    /// Whether to L2-normalise teacher representations per feature dimension
36    /// across the batch/token axis before computing the loss (default `true`).
37    pub normalize_targets: bool,
38    /// Number of teacher layer outputs to average for target computation
39    /// (default 1; set > 1 to enable a rolling exponential history buffer).
40    pub top_k_average: usize,
41}
42
43impl Default for Data2VecConfig {
44    fn default() -> Self {
45        Self {
46            mask_ratio: 0.65,
47            momentum: 0.999,
48            beta: 2.0,
49            normalize_targets: true,
50            top_k_average: 1,
51        }
52    }
53}
54
55// ─── Result types ─────────────────────────────────────────────────────────────
56
57/// Output of a single data2vec loss computation.
58#[derive(Debug, Clone)]
59pub struct Data2VecResult {
60    /// Mean Huber loss over all masked positions and feature dimensions.
61    pub loss: f32,
62    /// Number of token positions that were masked (contributed to the loss).
63    pub n_masked: usize,
64    /// Not meaningful for regression; always 0.0.
65    pub accuracy_at_1: f32,
66}
67
68// ─── Teacher-state ────────────────────────────────────────────────────────────
69
70/// Mutable state that tracks the teacher EMA parameter vector and training step.
71///
72/// The teacher is initialised to a copy of the online (student) parameters and
73/// updated each training step via:
74/// ```text
75///   θ_teacher ← m · θ_teacher + (1−m) · θ_online
76/// ```
77#[derive(Debug, Clone)]
78pub struct Data2VecState {
79    /// EMA teacher parameter vector (flat, same length as the student).
80    pub teacher_params: Vec<f32>,
81    /// Training-step counter (incremented by [`Self::update_teacher`]).
82    pub step: usize,
83}
84
85impl Data2VecState {
86    /// Initialise state by cloning the current online (student) parameters.
87    ///
88    /// The teacher starts as an exact copy of the student so that the very first
89    /// EMA update moves from a well-defined position.
90    #[must_use]
91    pub fn new(online_params: &[f32]) -> Self {
92        Self {
93            teacher_params: online_params.to_vec(),
94            step: 0,
95        }
96    }
97
98    /// Apply the EMA update `θ_teacher ← m·θ_teacher + (1−m)·θ_online` and
99    /// increment the step counter.
100    ///
101    /// # Errors
102    /// - [`SslError::InvalidMomentum`] when `momentum` is outside `[0, 1]` or
103    ///   non-finite.
104    /// - [`SslError::DimensionMismatch`] when `online_params` has a different
105    ///   length than the stored teacher vector.
106    pub fn update_teacher(&mut self, online_params: &[f32], momentum: f32) -> SslResult<()> {
107        if !(momentum.is_finite() && (0.0..=1.0).contains(&momentum)) {
108            return Err(SslError::InvalidMomentum { momentum });
109        }
110        if self.teacher_params.len() != online_params.len() {
111            return Err(SslError::DimensionMismatch {
112                expected: self.teacher_params.len(),
113                got: online_params.len(),
114            });
115        }
116        let one_minus_m = 1.0 - momentum;
117        for (t, &o) in self.teacher_params.iter_mut().zip(online_params.iter()) {
118            *t = momentum * *t + one_minus_m * o;
119        }
120        self.step += 1;
121        Ok(())
122    }
123
124    /// Reference to the current teacher parameter vector.
125    #[must_use]
126    #[inline]
127    pub fn teacher(&self) -> &[f32] {
128        &self.teacher_params
129    }
130}
131
132// ─── Core primitives ──────────────────────────────────────────────────────────
133
134/// Per-element Huber (smooth-L1) loss, averaged over all elements.
135///
136/// For each element `x = prediction − target`:
137/// ```text
138///   huber(x, β) = 0.5·x²/β   if |x| < β
139///               = |x| − β/2  otherwise
140/// ```
141///
142/// # Panics
143/// Does not panic, but returns `f32::NAN` if either slice contains NaN/inf.
144#[must_use]
145pub fn huber_loss(predictions: &[f32], targets: &[f32], beta: f32) -> f32 {
146    if predictions.is_empty() || predictions.len() != targets.len() {
147        return 0.0;
148    }
149    let n = predictions.len() as f64;
150    let half_beta = (beta as f64) / 2.0;
151    let inv_beta = 1.0 / (beta as f64);
152    let total: f64 = predictions
153        .iter()
154        .zip(targets.iter())
155        .map(|(&p, &t)| {
156            let x = (p - t) as f64;
157            let ax = x.abs();
158            if ax < beta as f64 {
159                0.5 * x * x * inv_beta
160            } else {
161                ax - half_beta
162            }
163        })
164        .sum();
165    (total / n) as f32
166}
167
168/// Normalise teacher representations **along the batch dimension** in-place.
169///
170/// For each feature dimension `d ∈ [0, dim)` the normalisation is:
171/// ```text
172///   norm_d = sqrt( Σᵢ target[i·dim + d]² / n_tokens )
173///   target[i·dim + d] /= (norm_d + ε),   ε = 1e-8
174/// ```
175///
176/// This is applied **only over the provided slice**, which the caller
177/// restricts to masked tokens. The slice must have length `n_tokens × dim`.
178pub fn normalize_teacher_targets(targets: &mut [f32], n_tokens: usize, dim: usize) {
179    if n_tokens == 0 || dim == 0 || targets.len() != n_tokens * dim {
180        return;
181    }
182    const EPS: f32 = 1e-8;
183    let n = n_tokens as f32;
184    // For each feature dimension, compute the RMS across tokens and normalise.
185    for d in 0..dim {
186        let mut sum_sq = 0.0_f32;
187        for i in 0..n_tokens {
188            let v = targets[i * dim + d];
189            sum_sq += v * v;
190        }
191        let norm = (sum_sq / n).sqrt();
192        let scale = 1.0 / (norm + EPS);
193        for i in 0..n_tokens {
194            targets[i * dim + d] *= scale;
195        }
196    }
197}
198
199/// Generate a boolean mask of length `n_tokens` with exactly
200/// `floor(n_tokens × mask_ratio)` positions set to `true` (= masked).
201///
202/// The selection is performed via a Fisher-Yates partial shuffle over an index
203/// array, mirroring the approach in [`crate::masked::mae::random_patch_mask`],
204/// but produces a `Vec<bool>` directly keyed to token indices.
205///
206/// # Errors
207/// - [`SslError::EmptyInput`] when `n_tokens == 0`.
208/// - [`SslError::InvalidMaskRatio`] when `mask_ratio` is outside `[0, 1)` or
209///   non-finite.
210pub fn data2vec_mask(n_tokens: usize, mask_ratio: f32, rng: &mut LcgRng) -> SslResult<Vec<bool>> {
211    if n_tokens == 0 {
212        return Err(SslError::EmptyInput);
213    }
214    if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
215        return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
216    }
217    let n_mask = (n_tokens as f32 * mask_ratio) as usize;
218    // Build an index pool and shuffle the first n_mask positions.
219    let mut indices: Vec<usize> = (0..n_tokens).collect();
220    rng.shuffle(&mut indices);
221    let mut mask = vec![false; n_tokens];
222    for &idx in indices.iter().take(n_mask) {
223        mask[idx] = true;
224    }
225    Ok(mask)
226}
227
228// ─── Loss computation ─────────────────────────────────────────────────────────
229
230/// Compute the data2vec loss for a single sample.
231///
232/// Implements the full algorithm:
233/// 1. Validate shapes.
234/// 2. Optionally L2-normalise teacher representations across masked tokens per
235///    feature dimension.
236/// 3. Compute mean Huber loss between student predictions and normalised teacher
237///    targets at **masked positions only**.
238///
239/// # Arguments
240/// * `student_pred`   — `[n_tokens × dim]` student predictions (row-major).
241/// * `teacher_repr`   — `[n_tokens × dim]` teacher representations (row-major).
242/// * `mask`           — `[n_tokens]` boolean vector; `true` = masked position.
243/// * `n_tokens`, `dim` — spatial and channel dimensions.
244/// * `config`         — data2vec hyper-parameters.
245///
246/// # Errors
247/// - [`SslError::EmptyInput`] when `n_tokens == 0` or `dim == 0`.
248/// - [`SslError::DimensionMismatch`] when any buffer has the wrong length.
249/// - [`SslError::EmptyInput`] when no tokens are masked (graceful; returns 0.0
250///   loss inside `Data2VecResult` rather than erroring, since the caller may
251///   legitimately supply an all-visible batch during warm-up).
252pub fn data2vec_loss(
253    student_pred: &[f32],
254    teacher_repr: &[f32],
255    mask: &[bool],
256    n_tokens: usize,
257    dim: usize,
258    config: &Data2VecConfig,
259) -> SslResult<Data2VecResult> {
260    // ── 1. Shape validation ───────────────────────────────────────────────────
261    if n_tokens == 0 || dim == 0 {
262        return Err(SslError::EmptyInput);
263    }
264    let expected = n_tokens * dim;
265    if student_pred.len() != expected {
266        return Err(SslError::DimensionMismatch {
267            expected,
268            got: student_pred.len(),
269        });
270    }
271    if teacher_repr.len() != expected {
272        return Err(SslError::DimensionMismatch {
273            expected,
274            got: teacher_repr.len(),
275        });
276    }
277    if mask.len() != n_tokens {
278        return Err(SslError::DimensionMismatch {
279            expected: n_tokens,
280            got: mask.len(),
281        });
282    }
283
284    // ── 2. Collect masked-token indices ───────────────────────────────────────
285    let masked_indices: Vec<usize> = (0..n_tokens).filter(|&i| mask[i]).collect();
286    let n_masked = masked_indices.len();
287
288    if n_masked == 0 {
289        // Graceful: no masked tokens → loss is trivially 0.
290        return Ok(Data2VecResult {
291            loss: 0.0,
292            n_masked: 0,
293            accuracy_at_1: 0.0,
294        });
295    }
296
297    // ── 3. Build contiguous masked-token buffers ──────────────────────────────
298    // Collect masked teacher representations into a contiguous buffer so that
299    // normalize_teacher_targets can operate with simple row-major indexing.
300    let mut teacher_masked = Vec::with_capacity(n_masked * dim);
301    let mut student_masked = Vec::with_capacity(n_masked * dim);
302    for &i in &masked_indices {
303        let start = i * dim;
304        let end = start + dim;
305        teacher_masked.extend_from_slice(&teacher_repr[start..end]);
306        student_masked.extend_from_slice(&student_pred[start..end]);
307    }
308
309    // ── 4. Optional target normalisation ─────────────────────────────────────
310    if config.normalize_targets {
311        normalize_teacher_targets(&mut teacher_masked, n_masked, dim);
312    }
313
314    // ── 5. Huber loss over masked positions ───────────────────────────────────
315    let loss = huber_loss(&student_masked, &teacher_masked, config.beta);
316
317    Ok(Data2VecResult {
318        loss,
319        n_masked,
320        accuracy_at_1: 0.0,
321    })
322}
323
324/// Compute the mean data2vec loss over a batch of samples.
325///
326/// Buffers are laid out batch-first:
327/// - `student_preds` : `[batch_size × n_tokens × dim]`
328/// - `teacher_reprs` : `[batch_size × n_tokens × dim]`
329/// - `masks`         : `[batch_size × n_tokens]` boolean
330///
331/// Each sample's loss is computed independently with
332/// [`data2vec_loss`] and the results are averaged.
333///
334/// # Errors
335/// Propagates all errors from [`data2vec_loss`] together with additional shape
336/// checks for the batch dimension.
337pub fn data2vec_batch_loss(
338    student_preds: &[f32],
339    teacher_reprs: &[f32],
340    masks: &[bool],
341    batch_size: usize,
342    n_tokens: usize,
343    dim: usize,
344    config: &Data2VecConfig,
345) -> SslResult<f32> {
346    if batch_size == 0 {
347        return Err(SslError::EmptyInput);
348    }
349    let sample_len = n_tokens * dim;
350    let expected_feat = batch_size * sample_len;
351    let expected_mask = batch_size * n_tokens;
352
353    if student_preds.len() != expected_feat {
354        return Err(SslError::DimensionMismatch {
355            expected: expected_feat,
356            got: student_preds.len(),
357        });
358    }
359    if teacher_reprs.len() != expected_feat {
360        return Err(SslError::DimensionMismatch {
361            expected: expected_feat,
362            got: teacher_reprs.len(),
363        });
364    }
365    if masks.len() != expected_mask {
366        return Err(SslError::DimensionMismatch {
367            expected: expected_mask,
368            got: masks.len(),
369        });
370    }
371
372    let mut total_loss = 0.0_f64;
373    for b in 0..batch_size {
374        let feat_start = b * sample_len;
375        let feat_end = feat_start + sample_len;
376        let mask_start = b * n_tokens;
377        let mask_end = mask_start + n_tokens;
378
379        let result = data2vec_loss(
380            &student_preds[feat_start..feat_end],
381            &teacher_reprs[feat_start..feat_end],
382            &masks[mask_start..mask_end],
383            n_tokens,
384            dim,
385            config,
386        )?;
387        total_loss += result.loss as f64;
388    }
389    Ok((total_loss / batch_size as f64) as f32)
390}
391
392// ─── Tests ────────────────────────────────────────────────────────────────────
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::handle::LcgRng;
398
399    // ── 1. Config defaults ────────────────────────────────────────────────────
400
401    #[test]
402    fn config_defaults() {
403        let cfg = Data2VecConfig::default();
404        assert!((cfg.mask_ratio - 0.65).abs() < 1e-7);
405        assert!((cfg.momentum - 0.999).abs() < 1e-7);
406        assert!((cfg.beta - 2.0).abs() < 1e-7);
407        assert!(cfg.normalize_targets);
408        assert_eq!(cfg.top_k_average, 1);
409    }
410
411    // ── 2. Huber loss — small error (|x| < β) ────────────────────────────────
412    // huber(0.5, β=2) = 0.5 * 0.5² / 2.0 = 0.5 * 0.25 / 2.0 = 0.0625
413
414    #[test]
415    fn huber_loss_small_error() {
416        let pred = vec![0.5_f32];
417        let tgt = vec![0.0_f32];
418        let loss = huber_loss(&pred, &tgt, 2.0);
419        let expected = 0.5_f32 * 0.25_f32 / 2.0_f32; // 0.0625
420        assert!(
421            (loss - expected).abs() < 1e-6,
422            "loss={loss} expected={expected}"
423        );
424    }
425
426    // ── 3. Huber loss — large error (|x| >= β) ───────────────────────────────
427    // huber(3.0, β=2) = |3| − β/2 = 3 − 1 = 2.0
428
429    #[test]
430    fn huber_loss_large_error() {
431        let pred = vec![3.0_f32];
432        let tgt = vec![0.0_f32];
433        let loss = huber_loss(&pred, &tgt, 2.0);
434        assert!((loss - 2.0).abs() < 1e-6, "loss={loss}");
435    }
436
437    // ── 4. Huber loss — zero when pred == target ──────────────────────────────
438
439    #[test]
440    fn huber_loss_zero() {
441        let v = vec![1.5_f32, -0.7, 3.2, 0.0];
442        let loss = huber_loss(&v, &v, 2.0);
443        assert!(loss.abs() < 1e-7, "loss={loss}");
444    }
445
446    // ── 5. Mask has exact number of masked tokens ─────────────────────────────
447
448    #[test]
449    fn mask_exact_ratio() {
450        let mut rng = LcgRng::new(42);
451        let mask = data2vec_mask(100, 0.65, &mut rng).expect("data2vec_mask should succeed");
452        let n_masked = mask.iter().filter(|&&v| v).count();
453        assert_eq!(n_masked, 65, "expected 65 masked, got {n_masked}");
454    }
455
456    // ── 6. Mask length equals n_tokens ───────────────────────────────────────
457
458    #[test]
459    fn mask_length() {
460        let mut rng = LcgRng::new(7);
461        let mask = data2vec_mask(196, 0.75, &mut rng).expect("data2vec_mask should succeed");
462        assert_eq!(mask.len(), 196);
463    }
464
465    // ── 7. Loss is zero when student perfectly matches teacher at masked positions
466
467    #[test]
468    fn data2vec_loss_only_masked() {
469        let n_tokens = 10;
470        let dim = 4;
471        // Construct a mask where only token 3 and 7 are masked.
472        let mut mask = vec![false; n_tokens];
473        mask[3] = true;
474        mask[7] = true;
475
476        // Teacher = student everywhere.
477        let repr: Vec<f32> = (0..n_tokens * dim).map(|i| (i as f32) * 0.1).collect();
478        let cfg = Data2VecConfig {
479            normalize_targets: false,
480            ..Data2VecConfig::default()
481        };
482        let result = data2vec_loss(&repr, &repr, &mask, n_tokens, dim, &cfg)
483            .expect("data2vec_loss should succeed");
484        assert!(result.loss.abs() < 1e-6, "loss={}", result.loss);
485        assert_eq!(result.n_masked, 2);
486        assert!((result.accuracy_at_1 - 0.0).abs() < 1e-7);
487    }
488
489    // ── 8. All-false mask (no masked tokens) → graceful zero loss ─────────────
490
491    #[test]
492    fn data2vec_loss_no_masked_tokens() {
493        let n_tokens = 8;
494        let dim = 3;
495        let mask = vec![false; n_tokens];
496        let v = vec![0.0_f32; n_tokens * dim];
497        let cfg = Data2VecConfig::default();
498        let result = data2vec_loss(&v, &v, &mask, n_tokens, dim, &cfg)
499            .expect("data2vec_loss should succeed");
500        assert_eq!(result.n_masked, 0);
501        assert!(result.loss.abs() < 1e-7);
502    }
503
504    // ── 9. normalize_teacher_targets reduces large values ─────────────────────
505
506    #[test]
507    fn normalize_targets_reduces_large_values() {
508        let n_tokens = 4;
509        let dim = 2;
510        // All values are large.
511        let mut targets = vec![100.0_f32; n_tokens * dim];
512        normalize_teacher_targets(&mut targets, n_tokens, dim);
513        // After normalisation every value should be near 1.0 (all equal, so
514        // norm_d = sqrt(Σ 100² / 4) = sqrt(10000) = 100, scale = 1/100 → 1.0).
515        for &v in &targets {
516            assert!(v.abs() < 2.0, "value after norm={v}");
517        }
518    }
519
520    // ── 10. State init: teacher equals online params ──────────────────────────
521
522    #[test]
523    fn state_init_matches_online() {
524        let online = vec![0.1_f32, 0.5, -0.3, 1.2];
525        let state = Data2VecState::new(&online);
526        assert_eq!(state.teacher(), online.as_slice());
527        assert_eq!(state.step, 0);
528    }
529
530    // ── 11. EMA update with momentum=0 copies online exactly ─────────────────
531
532    #[test]
533    fn state_update_closer_to_online_m0() {
534        let teacher_init = vec![1.0_f32, 2.0, 3.0];
535        let online = vec![10.0_f32, 20.0, 30.0];
536        let mut state = Data2VecState::new(&teacher_init);
537        state
538            .update_teacher(&online, 0.0)
539            .expect("update_teacher should succeed");
540        // With m=0: teacher = 0*old + 1*online = online.
541        for (&t, &o) in state.teacher().iter().zip(online.iter()) {
542            assert!((t - o).abs() < 1e-6, "teacher={t} online={o}");
543        }
544        assert_eq!(state.step, 1);
545    }
546
547    // ── 12. EMA update with momentum=1.0 leaves teacher unchanged ────────────
548
549    #[test]
550    fn state_update_m1_unchanged() {
551        let teacher_init = vec![5.0_f32, -3.0, 0.7];
552        let online = vec![0.0_f32, 0.0, 0.0];
553        let mut state = Data2VecState::new(&teacher_init);
554        let expected = state.teacher().to_vec();
555        state
556            .update_teacher(&online, 1.0)
557            .expect("update_teacher should succeed");
558        // With m=1: teacher = 1*old + 0*online = old.
559        for (&t, &e) in state.teacher().iter().zip(expected.iter()) {
560            assert!((t - e).abs() < 1e-6, "teacher={t} expected={e}");
561        }
562    }
563
564    // ── 13. Batch loss with batch_size=1 matches single-sample loss ───────────
565
566    #[test]
567    fn batch_loss_matches_single() {
568        let n_tokens = 6;
569        let dim = 4;
570        let mut rng = LcgRng::new(99);
571
572        let mut student = vec![0.0_f32; n_tokens * dim];
573        let mut teacher = vec![0.0_f32; n_tokens * dim];
574        rng.fill_normal(&mut student);
575        rng.fill_normal(&mut teacher);
576
577        let mask = data2vec_mask(n_tokens, 0.5, &mut rng).expect("data2vec_mask should succeed");
578
579        let cfg = Data2VecConfig::default();
580
581        let single = data2vec_loss(&student, &teacher, &mask, n_tokens, dim, &cfg)
582            .expect("value should be present")
583            .loss;
584        let batch = data2vec_batch_loss(&student, &teacher, &mask, 1, n_tokens, dim, &cfg)
585            .expect("data2vec_batch_loss should succeed");
586
587        assert!(
588            (single - batch).abs() < 1e-5,
589            "single={single} batch={batch}"
590        );
591    }
592
593    // ── 14. Batch loss is finite for random inputs ────────────────────────────
594
595    #[test]
596    fn batch_loss_finite() {
597        let batch_size = 4;
598        let n_tokens = 16;
599        let dim = 8;
600        let mut rng = LcgRng::new(1337);
601
602        let total_feat = batch_size * n_tokens * dim;
603        let mut student = vec![0.0_f32; total_feat];
604        let mut teacher = vec![0.0_f32; total_feat];
605        rng.fill_normal(&mut student);
606        rng.fill_normal(&mut teacher);
607
608        let mut masks = Vec::with_capacity(batch_size * n_tokens);
609        for _ in 0..batch_size {
610            masks.extend(
611                data2vec_mask(n_tokens, 0.65, &mut rng).expect("data2vec_mask should succeed"),
612            );
613        }
614
615        let cfg = Data2VecConfig::default();
616        let loss = data2vec_batch_loss(&student, &teacher, &masks, batch_size, n_tokens, dim, &cfg)
617            .expect("value should be present");
618
619        assert!(loss.is_finite(), "loss={loss}");
620        assert!(loss >= 0.0, "loss={loss}");
621    }
622
623    // ── Extra: invalid mask_ratio errors ─────────────────────────────────────
624
625    #[test]
626    fn mask_invalid_ratio_errors() {
627        let mut rng = LcgRng::new(1);
628        assert!(data2vec_mask(10, 1.0, &mut rng).is_err()); // ratio == 1.0 forbidden
629        assert!(data2vec_mask(10, -0.1, &mut rng).is_err());
630        assert!(data2vec_mask(10, f32::NAN, &mut rng).is_err());
631    }
632
633    // ── Extra: invalid momentum rejected by state update ─────────────────────
634
635    #[test]
636    fn state_update_rejects_invalid_momentum() {
637        let mut state = Data2VecState::new(&[1.0_f32, 2.0]);
638        let online = vec![3.0_f32, 4.0];
639        assert!(state.update_teacher(&online, 1.5).is_err());
640        assert!(state.update_teacher(&online, -0.1).is_err());
641        assert!(state.update_teacher(&online, f32::NAN).is_err());
642    }
643
644    // ── Extra: normalize_teacher_targets is a no-op on empty input ───────────
645
646    #[test]
647    fn normalize_teacher_targets_empty_noop() {
648        let mut v: Vec<f32> = vec![];
649        normalize_teacher_targets(&mut v, 0, 4); // must not panic
650        let mut v2 = vec![1.0_f32; 8];
651        normalize_teacher_targets(&mut v2, 4, 0); // must not panic
652    }
653
654    // ── Extra: DimensionMismatch for shape errors ─────────────────────────────
655
656    #[test]
657    fn data2vec_loss_shape_errors() {
658        let n = 4;
659        let d = 3;
660        let cfg = Data2VecConfig::default();
661        let good = vec![0.0_f32; n * d];
662        let short = vec![0.0_f32; n * d - 1];
663        let mask = vec![true; n];
664        // Wrong student length.
665        assert!(data2vec_loss(&short, &good, &mask, n, d, &cfg).is_err());
666        // Wrong teacher length.
667        assert!(data2vec_loss(&good, &short, &mask, n, d, &cfg).is_err());
668        // Wrong mask length.
669        let bad_mask = vec![true; n - 1];
670        assert!(data2vec_loss(&good, &good, &bad_mask, n, d, &cfg).is_err());
671    }
672}