1use crate::{
35 error::{VisionError, VisionResult},
36 handle::LcgRng,
37 patch_embed::{PatchEmbed, PatchEmbedConfig, prepend_cls},
38 vit::vit_block::{gelu_exact, linear},
39 vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
40};
41
42#[derive(Debug, Clone)]
46pub struct BackboneOutput {
47 pub cls: Vec<f32>,
49 pub patches: Vec<f32>,
51 pub n_patches: usize,
53}
54
55pub struct DinoBackbone {
59 pub config: ViTConfig,
61 patch_embed: PatchEmbed,
62 cls_token: Vec<f32>,
63 pos_embed: Vec<f32>, encoder: ViTEncoder,
65}
66
67impl DinoBackbone {
68 pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
73 let e = cfg.embed_dim;
74 let pe_cfg = PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, e)?;
75 let patch_embed = PatchEmbed::new(pe_cfg, rng);
76
77 let mut cls_token = vec![0.0f32; e];
78 rng.fill_normal(&mut cls_token);
79 for v in &mut cls_token {
80 *v *= 0.02;
81 }
82
83 let seq_len = cfg.n_patches() + 1;
84 let mut pos_embed = vec![0.0f32; seq_len * e];
85 rng.fill_normal(&mut pos_embed);
86 for v in &mut pos_embed {
87 *v *= 0.02;
88 }
89
90 let enc_cfg = ViTEncoderConfig::new(e, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
91 let encoder = ViTEncoder::new(enc_cfg, rng)?;
92
93 Ok(Self {
94 config: cfg,
95 patch_embed,
96 cls_token,
97 pos_embed,
98 encoder,
99 })
100 }
101
102 pub fn forward(&self, image: &[f32]) -> VisionResult<BackboneOutput> {
107 let e = self.config.embed_dim;
108 let n_patches = self.config.n_patches();
109
110 let patch_tokens = self.patch_embed.forward(image)?;
111 let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, e)?;
112 for (t, p) in tokens.iter_mut().zip(self.pos_embed.iter()) {
114 *t += p;
115 }
116 let seq_len = n_patches + 1;
117 let encoded = self.encoder.forward(&tokens, seq_len)?;
118
119 let cls = encoded[..e].to_vec();
120 let patches = encoded[e..].to_vec();
121 Ok(BackboneOutput {
122 cls,
123 patches,
124 n_patches,
125 })
126 }
127}
128
129#[derive(Clone)]
138pub struct DinoHead {
139 in_dim: usize,
140 hidden_dim: usize,
141 bottleneck_dim: usize,
142 n_prototypes: usize,
143 w1: Vec<f32>,
145 b1: Vec<f32>,
146 w2: Vec<f32>,
147 b2: Vec<f32>,
148 w3: Vec<f32>,
149 b3: Vec<f32>,
150 prototypes: Vec<f32>,
152 gain: f32,
154}
155
156impl DinoHead {
157 pub fn new(
164 in_dim: usize,
165 hidden_dim: usize,
166 bottleneck_dim: usize,
167 n_prototypes: usize,
168 rng: &mut LcgRng,
169 ) -> VisionResult<Self> {
170 if in_dim == 0 {
171 return Err(VisionError::InvalidEmbedDim(in_dim));
172 }
173 if hidden_dim == 0 {
174 return Err(VisionError::InvalidEmbedDim(hidden_dim));
175 }
176 if bottleneck_dim == 0 {
177 return Err(VisionError::InvalidEmbedDim(bottleneck_dim));
178 }
179 if n_prototypes == 0 {
180 return Err(VisionError::InvalidProjDim(n_prototypes));
181 }
182
183 let fill = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
184 let mut v = vec![0.0f32; n];
185 rng.fill_normal(&mut v);
186 for x in &mut v {
187 *x *= sc;
188 }
189 v
190 };
191
192 let w1 = fill(rng, hidden_dim * in_dim, 1.0 / (in_dim as f32).sqrt());
193 let b1 = vec![0.0f32; hidden_dim];
194 let w2 = fill(
195 rng,
196 hidden_dim * hidden_dim,
197 1.0 / (hidden_dim as f32).sqrt(),
198 );
199 let b2 = vec![0.0f32; hidden_dim];
200 let w3 = fill(
201 rng,
202 bottleneck_dim * hidden_dim,
203 1.0 / (hidden_dim as f32).sqrt(),
204 );
205 let b3 = vec![0.0f32; bottleneck_dim];
206 let prototypes = fill(
208 rng,
209 n_prototypes * bottleneck_dim,
210 1.0 / (bottleneck_dim as f32).sqrt(),
211 );
212
213 Ok(Self {
214 in_dim,
215 hidden_dim,
216 bottleneck_dim,
217 n_prototypes,
218 w1,
219 b1,
220 w2,
221 b2,
222 w3,
223 b3,
224 prototypes,
225 gain: 1.0,
226 })
227 }
228
229 #[must_use]
231 pub fn n_prototypes(&self) -> usize {
232 self.n_prototypes
233 }
234
235 pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
243 if x.len() != self.in_dim {
244 return Err(VisionError::DimensionMismatch {
245 expected: self.in_dim,
246 got: x.len(),
247 });
248 }
249
250 let h1 = linear(x, &self.w1, &self.b1, self.in_dim, self.hidden_dim);
252 let h1: Vec<f32> = h1.into_iter().map(gelu_exact).collect();
253 let h2 = linear(&h1, &self.w2, &self.b2, self.hidden_dim, self.hidden_dim);
254 let h2: Vec<f32> = h2.into_iter().map(gelu_exact).collect();
255 let mut z = linear(
256 &h2,
257 &self.w3,
258 &self.b3,
259 self.hidden_dim,
260 self.bottleneck_dim,
261 );
262
263 let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
265 let inv = 1.0 / norm.max(1e-12);
266 for v in &mut z {
267 *v *= inv;
268 }
269
270 let bd = self.bottleneck_dim;
272 let mut logits = vec![0.0f32; self.n_prototypes];
273 for (k, lk) in logits.iter_mut().enumerate() {
274 let proto = &self.prototypes[k * bd..(k + 1) * bd];
275 let pnorm: f32 = proto.iter().map(|&v| v * v).sum::<f32>().sqrt();
276 let pinv = 1.0 / pnorm.max(1e-12);
277 let dot: f32 = z.iter().zip(proto.iter()).map(|(&a, &b)| a * b).sum();
278 *lk = self.gain * dot * pinv;
279 }
280 Ok(logits)
281 }
282
283 pub fn forward_batch(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
289 if x.is_empty() || x.len() % self.in_dim != 0 {
290 return Err(VisionError::DimensionMismatch {
291 expected: self.in_dim,
292 got: x.len() % self.in_dim,
293 });
294 }
295 let batch = x.len() / self.in_dim;
296 let mut out = vec![0.0f32; batch * self.n_prototypes];
297 for b in 0..batch {
298 let row = self.forward(&x[b * self.in_dim..(b + 1) * self.in_dim])?;
299 out[b * self.n_prototypes..(b + 1) * self.n_prototypes].copy_from_slice(&row);
300 }
301 Ok(out)
302 }
303
304 fn num_params(&self) -> usize {
306 self.w1.len()
307 + self.b1.len()
308 + self.w2.len()
309 + self.b2.len()
310 + self.w3.len()
311 + self.b3.len()
312 + self.prototypes.len()
313 + 1 }
315
316 #[cfg(test)]
318 fn flatten(&self) -> Vec<f32> {
319 let mut v = Vec::with_capacity(self.num_params());
320 v.extend_from_slice(&self.w1);
321 v.extend_from_slice(&self.b1);
322 v.extend_from_slice(&self.w2);
323 v.extend_from_slice(&self.b2);
324 v.extend_from_slice(&self.w3);
325 v.extend_from_slice(&self.b3);
326 v.extend_from_slice(&self.prototypes);
327 v.push(self.gain);
328 v
329 }
330
331 pub fn ema_update(&mut self, student: &DinoHead, momentum: f32) -> VisionResult<()> {
337 if self.num_params() != student.num_params()
338 || self.w1.len() != student.w1.len()
339 || self.prototypes.len() != student.prototypes.len()
340 {
341 return Err(VisionError::Internal(
342 "ema_update: teacher/student head shape mismatch".into(),
343 ));
344 }
345 let m = momentum;
346 let lerp = |dst: &mut [f32], src: &[f32]| {
347 for (d, &s) in dst.iter_mut().zip(src.iter()) {
348 *d = m * *d + (1.0 - m) * s;
349 }
350 };
351 lerp(&mut self.w1, &student.w1);
352 lerp(&mut self.b1, &student.b1);
353 lerp(&mut self.w2, &student.w2);
354 lerp(&mut self.b2, &student.b2);
355 lerp(&mut self.w3, &student.w3);
356 lerp(&mut self.b3, &student.b3);
357 lerp(&mut self.prototypes, &student.prototypes);
358 self.gain = m * self.gain + (1.0 - m) * student.gain;
359 Ok(())
360 }
361}
362
363fn softmax_temp(logits: &[f32], center: &[f32], temperature: f32) -> Vec<f32> {
370 let n = logits.len();
371 let mut scaled = vec![0.0f32; n];
372 if center.is_empty() {
373 for i in 0..n {
374 scaled[i] = logits[i] / temperature;
375 }
376 } else {
377 for i in 0..n {
378 scaled[i] = (logits[i] - center[i]) / temperature;
379 }
380 }
381 let mx = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
382 let mut sum = 0.0f32;
383 for s in &mut scaled {
384 *s = (*s - mx).exp();
385 sum += *s;
386 }
387 let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
388 for s in &mut scaled {
389 *s *= inv;
390 }
391 scaled
392}
393
394pub fn student_softmax(logits: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
399 if tau <= 0.0 {
400 return Err(VisionError::NonPositiveTemperature(tau));
401 }
402 Ok(softmax_temp(logits, &[], tau))
403}
404
405pub fn teacher_softmax(logits: &[f32], center: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
413 if tau <= 0.0 {
414 return Err(VisionError::NonPositiveTemperature(tau));
415 }
416 if !center.is_empty() && center.len() != logits.len() {
417 return Err(VisionError::DimensionMismatch {
418 expected: logits.len(),
419 got: center.len(),
420 });
421 }
422 Ok(softmax_temp(logits, center, tau))
423}
424
425pub fn cross_entropy(p_teacher: &[f32], p_student: &[f32]) -> VisionResult<f32> {
435 if p_teacher.len() != p_student.len() {
436 return Err(VisionError::DimensionMismatch {
437 expected: p_teacher.len(),
438 got: p_student.len(),
439 });
440 }
441 let mut h = 0.0f32;
442 for (&pt, &ps) in p_teacher.iter().zip(p_student.iter()) {
443 if pt > 0.0 {
444 h -= pt * ps.max(1e-12).ln();
446 }
447 }
448 Ok(h)
449}
450
451pub fn dino_loss(
459 student_logits: &[f32],
460 teacher_logits: &[f32],
461 center: &[f32],
462 tau_student: f32,
463 tau_teacher: f32,
464) -> VisionResult<f32> {
465 if student_logits.len() != teacher_logits.len() {
466 return Err(VisionError::DimensionMismatch {
467 expected: teacher_logits.len(),
468 got: student_logits.len(),
469 });
470 }
471 let p_t = teacher_softmax(teacher_logits, center, tau_teacher)?;
472 let p_s = student_softmax(student_logits, tau_student)?;
473 cross_entropy(&p_t, &p_s)
474}
475
476#[derive(Debug, Clone)]
484pub struct CenteringBuffer {
485 pub center: Vec<f32>,
487 pub momentum: f32,
489}
490
491impl CenteringBuffer {
492 #[must_use]
494 pub fn new(dim: usize, momentum: f32) -> Self {
495 Self {
496 center: vec![0.0f32; dim],
497 momentum,
498 }
499 }
500
501 pub fn update(&mut self, batch_logits: &[f32]) -> VisionResult<()> {
509 let dim = self.center.len();
510 if dim == 0 || batch_logits.is_empty() || batch_logits.len() % dim != 0 {
511 return Err(VisionError::DimensionMismatch {
512 expected: dim,
513 got: batch_logits.len(),
514 });
515 }
516 let batch = batch_logits.len() / dim;
517 let mut mean = vec![0.0f32; dim];
518 for b in 0..batch {
519 for k in 0..dim {
520 mean[k] += batch_logits[b * dim + k];
521 }
522 }
523 let inv_b = 1.0 / batch as f32;
524 let lam = self.momentum;
525 for (c, m) in self.center.iter_mut().zip(mean.iter()) {
526 let batch_mean = m * inv_b;
527 *c = lam * *c + (1.0 - lam) * batch_mean;
528 }
529 Ok(())
530 }
531}
532
533pub fn ibot_loss(
552 student_patch_logits: &[f32],
553 teacher_patch_logits: &[f32],
554 mask: &[bool],
555 patch_center: &[f32],
556 n_proto: usize,
557 tau_student: f32,
558 tau_teacher: f32,
559) -> VisionResult<f32> {
560 if n_proto == 0 {
561 return Err(VisionError::InvalidProjDim(n_proto));
562 }
563 let n_patches = mask.len();
564 if student_patch_logits.len() != n_patches * n_proto
565 || teacher_patch_logits.len() != n_patches * n_proto
566 {
567 return Err(VisionError::DimensionMismatch {
568 expected: n_patches * n_proto,
569 got: student_patch_logits.len(),
570 });
571 }
572
573 let mut total = 0.0f32;
574 let mut count = 0usize;
575 for p in 0..n_patches {
576 if !mask[p] {
577 continue;
578 }
579 let s = &student_patch_logits[p * n_proto..(p + 1) * n_proto];
580 let t = &teacher_patch_logits[p * n_proto..(p + 1) * n_proto];
581 let l = dino_loss(s, t, patch_center, tau_student, tau_teacher)?;
582 total += l;
583 count += 1;
584 }
585 if count == 0 {
586 return Ok(0.0);
587 }
588 Ok(total / count as f32)
589}
590
591#[cfg(test)]
594mod tests {
595 use super::*;
596
597 fn l2(a: &[f32], b: &[f32]) -> f32 {
598 a.iter()
599 .zip(b.iter())
600 .map(|(&x, &y)| (x - y) * (x - y))
601 .sum::<f32>()
602 .sqrt()
603 }
604
605 fn entropy(p: &[f32]) -> f32 {
606 let mut h = 0.0f32;
607 for &v in p {
608 if v > 0.0 {
609 h -= v * v.ln();
610 }
611 }
612 h
613 }
614
615 fn make_head(seed: u64, k: usize) -> DinoHead {
616 let mut rng = LcgRng::new(seed);
617 DinoHead::new(32, 64, 16, k, &mut rng).expect("head ok")
618 }
619
620 #[test]
623 fn backbone_returns_cls_and_patches() {
624 let mut rng = LcgRng::new(1);
625 let cfg = ViTConfig::tiny();
626 let e = cfg.embed_dim;
627 let n_patches = cfg.n_patches();
628 let bb = DinoBackbone::new(cfg, &mut rng).expect("backbone ok");
629 let img = vec![0.3f32; 3 * 32 * 32];
630 let out = bb.forward(&img).expect("forward ok");
631 assert_eq!(out.cls.len(), e, "CLS must be [embed_dim]");
632 assert_eq!(
633 out.patches.len(),
634 n_patches * e,
635 "patches must be [n_patches, e]"
636 );
637 assert_eq!(out.n_patches, n_patches);
638 assert!(out.cls.iter().all(|v| v.is_finite()));
639 assert!(out.patches.iter().all(|v| v.is_finite()));
640 }
641
642 #[test]
645 fn head_prototype_logits_shape_and_softmax() {
646 let head = make_head(2, 128);
647 let mut rng = LcgRng::new(3);
648 let mut x = vec![0.0f32; 32];
649 rng.fill_normal(&mut x);
650 let logits = head.forward(&x).expect("ok");
651 assert_eq!(logits.len(), 128, "prototype logits must be [n_prototypes]");
652 let p = student_softmax(&logits, 0.1).expect("ok");
653 let sum: f32 = p.iter().sum();
654 assert!((sum - 1.0).abs() < 1e-5, "softmax must sum to 1; got {sum}");
655 for &l in &logits {
657 assert!(
658 (-1.0 - 1e-4..=1.0 + 1e-4).contains(&l),
659 "logit out of cosine range: {l}"
660 );
661 }
662 }
663
664 #[test]
667 fn ema_update_moves_teacher_toward_student() {
668 let mut teacher = make_head(10, 64);
669 let student = make_head(20, 64); let before = l2(&teacher.flatten(), &student.flatten());
671 assert!(before > 0.0, "teacher and student must start apart");
672 teacher.ema_update(&student, 0.9).expect("ema ok");
673 let after = l2(&teacher.flatten(), &student.flatten());
674 assert!(
675 after < before,
676 "EMA must reduce ‖θ_t − θ_s‖: before={before}, after={after}"
677 );
678 assert!(
680 (after - 0.9 * before).abs() < 1e-3 * before.max(1.0),
681 "EMA distance should scale by m=0.9: after={after}, 0.9·before={}",
682 0.9 * before
683 );
684 }
685
686 #[test]
687 fn ema_update_shape_mismatch_errors() {
688 let mut teacher = make_head(10, 64);
689 let other = make_head(11, 32); let r = teacher.ema_update(&other, 0.9);
691 assert!(matches!(r, Err(VisionError::Internal(_))));
692 }
693
694 #[test]
697 fn dino_loss_nonnegative() {
698 let mut rng = LcgRng::new(30);
699 for _ in 0..20 {
700 let mut sl = vec![0.0f32; 16];
701 let mut tl = vec![0.0f32; 16];
702 rng.fill_normal(&mut sl);
703 rng.fill_normal(&mut tl);
704 let l = dino_loss(&sl, &tl, &[], 0.1, 0.04).expect("ok");
705 assert!(l >= -1e-6, "DINO loss must be ≥ 0; got {l}");
706 }
707 }
708
709 #[test]
710 fn dino_loss_minimised_when_student_matches_teacher() {
711 let teacher_logits = vec![20.0f32, -20.0, -20.0, -20.0]; let student_logits = vec![20.0f32, -20.0, -20.0, -20.0];
716 let p_t = teacher_softmax(&teacher_logits, &[], 0.1).expect("ok");
718 let p_s = student_softmax(&student_logits, 0.1).expect("ok");
719 let h_self = cross_entropy(&p_t, &p_s).expect("ok");
720 assert!(
721 h_self < 1e-3,
722 "matched ~one-hot dists give ≈0 loss; got {h_self}"
723 );
724
725 let student_bad = vec![-20.0f32, 20.0, -20.0, -20.0]; let p_bad = student_softmax(&student_bad, 0.1).expect("ok");
728 let h_bad = cross_entropy(&p_t, &p_bad).expect("ok");
729 assert!(
730 h_bad > h_self + 1.0,
731 "mismatched student must raise the loss: self={h_self}, bad={h_bad}"
732 );
733 }
734
735 #[test]
736 fn cross_entropy_equals_entropy_at_self() {
737 let logits = vec![1.0f32, 0.3, -0.5, 2.0, -1.0];
739 let p = student_softmax(&logits, 1.0).expect("ok");
740 let ce = cross_entropy(&p, &p).expect("ok");
741 let ent = entropy(&p);
742 assert!(
743 (ce - ent).abs() < 1e-5,
744 "H(p,p) must equal entropy(p): {ce} vs {ent}"
745 );
746 }
747
748 #[test]
751 fn centering_drives_mean_near_zero() {
752 let dim = 8;
755 let mut buf = CenteringBuffer::new(dim, 0.9);
756 let base: Vec<f32> = (0..dim).map(|k| if k == 0 { 5.0 } else { 0.1 }).collect();
759 let batch = 4;
760 let mut flat = Vec::new();
761 for _ in 0..batch {
762 flat.extend_from_slice(&base);
763 }
764
765 for _ in 0..400 {
767 buf.update(&flat).expect("ok");
768 }
769 let centred_mean: f32 = base
771 .iter()
772 .zip(buf.center.iter())
773 .map(|(&g, &c)| (g - c).abs())
774 .sum::<f32>()
775 / dim as f32;
776 assert!(
777 centred_mean < 1e-2,
778 "centering should drive (g − c) mean ≈ 0; got {centred_mean}"
779 );
780 }
781
782 #[test]
783 fn centering_update_bad_shape_errors() {
784 let mut buf = CenteringBuffer::new(8, 0.9);
785 let r = buf.update(&[0.0f32; 7]); assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
787 }
788
789 #[test]
792 fn lower_teacher_temperature_sharpens_distribution() {
793 let logits = vec![2.0f32, 1.0, 0.5, -0.5, -1.0, 0.2];
794 let p_hot = teacher_softmax(&logits, &[], 0.04).expect("ok"); let p_soft = teacher_softmax(&logits, &[], 0.5).expect("ok"); let h_hot = entropy(&p_hot);
797 let h_soft = entropy(&p_soft);
798 assert!(
799 h_hot < h_soft,
800 "lower τ_t must lower entropy (sharper): H(0.04)={h_hot} vs H(0.5)={h_soft}"
801 );
802 let max_hot = p_hot.iter().cloned().fold(0.0f32, f32::max);
804 let max_soft = p_soft.iter().cloned().fold(0.0f32, f32::max);
805 assert!(max_hot > max_soft, "sharper dist must have a higher peak");
806 }
807
808 #[test]
811 fn nudging_student_toward_teacher_lowers_loss() {
812 let teacher_logits = vec![1.5f32, -0.5, 0.7, -1.2, 0.3, 0.9];
816 let student_before = vec![-1.0f32, 0.8, -0.3, 1.1, -0.6, 0.0];
817 let tau = 0.1;
818
819 let loss_before = dino_loss(&student_before, &teacher_logits, &[], tau, tau).expect("ok");
820
821 let alpha = 0.6f32;
823 let student_after: Vec<f32> = student_before
824 .iter()
825 .zip(teacher_logits.iter())
826 .map(|(&s, &t)| s + alpha * (t - s))
827 .collect();
828 let loss_after = dino_loss(&student_after, &teacher_logits, &[], tau, tau).expect("ok");
829
830 assert!(
831 loss_after < loss_before,
832 "moving the student toward the teacher must lower the loss: before={loss_before}, after={loss_after}"
833 );
834 }
835
836 #[test]
837 fn two_views_loss_decreases_when_student_aligns() {
838 let head = make_head(40, 32);
843 let mut rng = LcgRng::new(41);
844 let mut view_a = vec![0.0f32; 32];
845 let mut view_b = vec![0.0f32; 32];
846 rng.fill_normal(&mut view_a);
847 rng.fill_normal(&mut view_b);
848
849 let teacher_logits = head.forward(&view_a).expect("ok");
850 let student_logits = head.forward(&view_b).expect("ok");
851 let tau = 0.1;
852 let loss_before = dino_loss(&student_logits, &teacher_logits, &[], tau, tau).expect("ok");
853
854 let nudged: Vec<f32> = student_logits
855 .iter()
856 .zip(teacher_logits.iter())
857 .map(|(&s, &t)| s + 0.5 * (t - s))
858 .collect();
859 let loss_after = dino_loss(&nudged, &teacher_logits, &[], tau, tau).expect("ok");
860 assert!(
861 loss_after < loss_before,
862 "aligning student to teacher across views must lower loss: {loss_before} → {loss_after}"
863 );
864 }
865
866 #[test]
869 fn nonpositive_temperature_errors() {
870 let r = student_softmax(&[1.0, 2.0], 0.0);
871 assert!(matches!(r, Err(VisionError::NonPositiveTemperature(_))));
872 let r2 = teacher_softmax(&[1.0, 2.0], &[], -0.1);
873 assert!(matches!(r2, Err(VisionError::NonPositiveTemperature(_))));
874 }
875
876 #[test]
879 fn ibot_loss_only_counts_masked_patches() {
880 let n_patches = 4;
881 let n_proto = 6;
882 let mut rng = LcgRng::new(50);
883 let mut s = vec![0.0f32; n_patches * n_proto];
884 let mut t = vec![0.0f32; n_patches * n_proto];
885 rng.fill_normal(&mut s);
886 rng.fill_normal(&mut t);
887
888 let none = vec![false; n_patches];
890 let l0 = ibot_loss(&s, &t, &none, &[], n_proto, 0.1, 0.04).expect("ok");
891 assert_eq!(l0, 0.0, "no masked patches ⇒ zero iBOT loss");
892
893 let mut mask = vec![false; n_patches];
895 mask[0] = true;
896 mask[2] = true;
897 let l = ibot_loss(&s, &t, &mask, &[], n_proto, 0.1, 0.04).expect("ok");
898 let l_p0 = dino_loss(&s[0..n_proto], &t[0..n_proto], &[], 0.1, 0.04).expect("ok");
899 let l_p2 = dino_loss(
900 &s[2 * n_proto..3 * n_proto],
901 &t[2 * n_proto..3 * n_proto],
902 &[],
903 0.1,
904 0.04,
905 )
906 .expect("ok");
907 let expected = 0.5 * (l_p0 + l_p2);
908 assert!(
909 (l - expected).abs() < 1e-5,
910 "iBOT loss must average masked-patch losses: {l} vs {expected}"
911 );
912 assert!(l >= 0.0, "iBOT loss must be ≥ 0");
913 }
914
915 #[test]
916 fn ibot_loss_nudging_masked_student_lowers_loss() {
917 let n_proto = 5;
919 let teacher = vec![
920 1.2f32, -0.4, 0.6, -1.0, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1,
923 ];
924 let student = vec![
925 -1.0f32, 0.7, -0.2, 1.0, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0,
927 ];
928 let mask = vec![true, false];
929 let tau = 0.1;
930 let before = ibot_loss(&student, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
931
932 let mut nudged = student.clone();
933 for k in 0..n_proto {
934 nudged[k] += 0.6 * (teacher[k] - student[k]);
935 }
936 let after = ibot_loss(&nudged, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
937 assert!(
938 after < before,
939 "nudging masked student patch toward teacher must lower iBOT loss: {before} → {after}"
940 );
941 }
942
943 #[test]
944 fn ibot_loss_bad_shape_errors() {
945 let mask = vec![true, false];
946 let r = ibot_loss(&[0.0f32; 5], &[0.0f32; 10], &mask, &[], 5, 0.1, 0.04);
947 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
948 }
949
950 #[test]
953 fn head_forward_batch_matches_single() {
954 let head = make_head(60, 32);
955 let mut rng = LcgRng::new(61);
956 let batch = 3;
957 let mut x = vec![0.0f32; batch * 32];
958 rng.fill_normal(&mut x);
959 let all = head.forward_batch(&x).expect("ok");
960 let k = head.n_prototypes();
961 for b in 0..batch {
962 let single = head.forward(&x[b * 32..(b + 1) * 32]).expect("ok");
963 for (j, &v) in single.iter().enumerate() {
964 assert!(
965 (all[b * k + j] - v).abs() < 1e-6,
966 "batch vs single mismatch at b={b}, j={j}"
967 );
968 }
969 }
970 }
971
972 #[test]
973 fn head_dimension_mismatch_errors() {
974 let head = make_head(70, 32);
975 let r = head.forward(&[0.0f32; 31]);
976 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
977 }
978
979 #[test]
980 fn head_zero_prototypes_errors() {
981 let mut rng = LcgRng::new(80);
982 let r = DinoHead::new(32, 64, 16, 0, &mut rng);
983 assert!(matches!(r, Err(VisionError::InvalidProjDim(0))));
984 }
985}