use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
patch_embed::{PatchEmbed, PatchEmbedConfig, prepend_cls},
vit::vit_block::{gelu_exact, linear},
vit::{ViTConfig, ViTEncoder, ViTEncoderConfig},
};
#[derive(Debug, Clone)]
pub struct BackboneOutput {
pub cls: Vec<f32>,
pub patches: Vec<f32>,
pub n_patches: usize,
}
pub struct DinoBackbone {
pub config: ViTConfig,
patch_embed: PatchEmbed,
cls_token: Vec<f32>,
pos_embed: Vec<f32>, encoder: ViTEncoder,
}
impl DinoBackbone {
pub fn new(cfg: ViTConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let e = cfg.embed_dim;
let pe_cfg = PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, e)?;
let patch_embed = PatchEmbed::new(pe_cfg, rng);
let mut cls_token = vec![0.0f32; e];
rng.fill_normal(&mut cls_token);
for v in &mut cls_token {
*v *= 0.02;
}
let seq_len = cfg.n_patches() + 1;
let mut pos_embed = vec![0.0f32; seq_len * e];
rng.fill_normal(&mut pos_embed);
for v in &mut pos_embed {
*v *= 0.02;
}
let enc_cfg = ViTEncoderConfig::new(e, cfg.n_heads, cfg.mlp_ratio, cfg.depth)?;
let encoder = ViTEncoder::new(enc_cfg, rng)?;
Ok(Self {
config: cfg,
patch_embed,
cls_token,
pos_embed,
encoder,
})
}
pub fn forward(&self, image: &[f32]) -> VisionResult<BackboneOutput> {
let e = self.config.embed_dim;
let n_patches = self.config.n_patches();
let patch_tokens = self.patch_embed.forward(image)?;
let mut tokens = prepend_cls(&patch_tokens, &self.cls_token, e)?;
for (t, p) in tokens.iter_mut().zip(self.pos_embed.iter()) {
*t += p;
}
let seq_len = n_patches + 1;
let encoded = self.encoder.forward(&tokens, seq_len)?;
let cls = encoded[..e].to_vec();
let patches = encoded[e..].to_vec();
Ok(BackboneOutput {
cls,
patches,
n_patches,
})
}
}
#[derive(Clone)]
pub struct DinoHead {
in_dim: usize,
hidden_dim: usize,
bottleneck_dim: usize,
n_prototypes: usize,
w1: Vec<f32>,
b1: Vec<f32>,
w2: Vec<f32>,
b2: Vec<f32>,
w3: Vec<f32>,
b3: Vec<f32>,
prototypes: Vec<f32>,
gain: f32,
}
impl DinoHead {
pub fn new(
in_dim: usize,
hidden_dim: usize,
bottleneck_dim: usize,
n_prototypes: usize,
rng: &mut LcgRng,
) -> VisionResult<Self> {
if in_dim == 0 {
return Err(VisionError::InvalidEmbedDim(in_dim));
}
if hidden_dim == 0 {
return Err(VisionError::InvalidEmbedDim(hidden_dim));
}
if bottleneck_dim == 0 {
return Err(VisionError::InvalidEmbedDim(bottleneck_dim));
}
if n_prototypes == 0 {
return Err(VisionError::InvalidProjDim(n_prototypes));
}
let fill = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
let mut v = vec![0.0f32; n];
rng.fill_normal(&mut v);
for x in &mut v {
*x *= sc;
}
v
};
let w1 = fill(rng, hidden_dim * in_dim, 1.0 / (in_dim as f32).sqrt());
let b1 = vec![0.0f32; hidden_dim];
let w2 = fill(
rng,
hidden_dim * hidden_dim,
1.0 / (hidden_dim as f32).sqrt(),
);
let b2 = vec![0.0f32; hidden_dim];
let w3 = fill(
rng,
bottleneck_dim * hidden_dim,
1.0 / (hidden_dim as f32).sqrt(),
);
let b3 = vec![0.0f32; bottleneck_dim];
let prototypes = fill(
rng,
n_prototypes * bottleneck_dim,
1.0 / (bottleneck_dim as f32).sqrt(),
);
Ok(Self {
in_dim,
hidden_dim,
bottleneck_dim,
n_prototypes,
w1,
b1,
w2,
b2,
w3,
b3,
prototypes,
gain: 1.0,
})
}
#[must_use]
pub fn n_prototypes(&self) -> usize {
self.n_prototypes
}
pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
if x.len() != self.in_dim {
return Err(VisionError::DimensionMismatch {
expected: self.in_dim,
got: x.len(),
});
}
let h1 = linear(x, &self.w1, &self.b1, self.in_dim, self.hidden_dim);
let h1: Vec<f32> = h1.into_iter().map(gelu_exact).collect();
let h2 = linear(&h1, &self.w2, &self.b2, self.hidden_dim, self.hidden_dim);
let h2: Vec<f32> = h2.into_iter().map(gelu_exact).collect();
let mut z = linear(
&h2,
&self.w3,
&self.b3,
self.hidden_dim,
self.bottleneck_dim,
);
let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
let inv = 1.0 / norm.max(1e-12);
for v in &mut z {
*v *= inv;
}
let bd = self.bottleneck_dim;
let mut logits = vec![0.0f32; self.n_prototypes];
for (k, lk) in logits.iter_mut().enumerate() {
let proto = &self.prototypes[k * bd..(k + 1) * bd];
let pnorm: f32 = proto.iter().map(|&v| v * v).sum::<f32>().sqrt();
let pinv = 1.0 / pnorm.max(1e-12);
let dot: f32 = z.iter().zip(proto.iter()).map(|(&a, &b)| a * b).sum();
*lk = self.gain * dot * pinv;
}
Ok(logits)
}
pub fn forward_batch(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
if x.is_empty() || x.len() % self.in_dim != 0 {
return Err(VisionError::DimensionMismatch {
expected: self.in_dim,
got: x.len() % self.in_dim,
});
}
let batch = x.len() / self.in_dim;
let mut out = vec![0.0f32; batch * self.n_prototypes];
for b in 0..batch {
let row = self.forward(&x[b * self.in_dim..(b + 1) * self.in_dim])?;
out[b * self.n_prototypes..(b + 1) * self.n_prototypes].copy_from_slice(&row);
}
Ok(out)
}
fn num_params(&self) -> usize {
self.w1.len()
+ self.b1.len()
+ self.w2.len()
+ self.b2.len()
+ self.w3.len()
+ self.b3.len()
+ self.prototypes.len()
+ 1 }
#[cfg(test)]
fn flatten(&self) -> Vec<f32> {
let mut v = Vec::with_capacity(self.num_params());
v.extend_from_slice(&self.w1);
v.extend_from_slice(&self.b1);
v.extend_from_slice(&self.w2);
v.extend_from_slice(&self.b2);
v.extend_from_slice(&self.w3);
v.extend_from_slice(&self.b3);
v.extend_from_slice(&self.prototypes);
v.push(self.gain);
v
}
pub fn ema_update(&mut self, student: &DinoHead, momentum: f32) -> VisionResult<()> {
if self.num_params() != student.num_params()
|| self.w1.len() != student.w1.len()
|| self.prototypes.len() != student.prototypes.len()
{
return Err(VisionError::Internal(
"ema_update: teacher/student head shape mismatch".into(),
));
}
let m = momentum;
let lerp = |dst: &mut [f32], src: &[f32]| {
for (d, &s) in dst.iter_mut().zip(src.iter()) {
*d = m * *d + (1.0 - m) * s;
}
};
lerp(&mut self.w1, &student.w1);
lerp(&mut self.b1, &student.b1);
lerp(&mut self.w2, &student.w2);
lerp(&mut self.b2, &student.b2);
lerp(&mut self.w3, &student.w3);
lerp(&mut self.b3, &student.b3);
lerp(&mut self.prototypes, &student.prototypes);
self.gain = m * self.gain + (1.0 - m) * student.gain;
Ok(())
}
}
fn softmax_temp(logits: &[f32], center: &[f32], temperature: f32) -> Vec<f32> {
let n = logits.len();
let mut scaled = vec![0.0f32; n];
if center.is_empty() {
for i in 0..n {
scaled[i] = logits[i] / temperature;
}
} else {
for i in 0..n {
scaled[i] = (logits[i] - center[i]) / temperature;
}
}
let mx = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in &mut scaled {
*s = (*s - mx).exp();
sum += *s;
}
let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
for s in &mut scaled {
*s *= inv;
}
scaled
}
pub fn student_softmax(logits: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
if tau <= 0.0 {
return Err(VisionError::NonPositiveTemperature(tau));
}
Ok(softmax_temp(logits, &[], tau))
}
pub fn teacher_softmax(logits: &[f32], center: &[f32], tau: f32) -> VisionResult<Vec<f32>> {
if tau <= 0.0 {
return Err(VisionError::NonPositiveTemperature(tau));
}
if !center.is_empty() && center.len() != logits.len() {
return Err(VisionError::DimensionMismatch {
expected: logits.len(),
got: center.len(),
});
}
Ok(softmax_temp(logits, center, tau))
}
pub fn cross_entropy(p_teacher: &[f32], p_student: &[f32]) -> VisionResult<f32> {
if p_teacher.len() != p_student.len() {
return Err(VisionError::DimensionMismatch {
expected: p_teacher.len(),
got: p_student.len(),
});
}
let mut h = 0.0f32;
for (&pt, &ps) in p_teacher.iter().zip(p_student.iter()) {
if pt > 0.0 {
h -= pt * ps.max(1e-12).ln();
}
}
Ok(h)
}
pub fn dino_loss(
student_logits: &[f32],
teacher_logits: &[f32],
center: &[f32],
tau_student: f32,
tau_teacher: f32,
) -> VisionResult<f32> {
if student_logits.len() != teacher_logits.len() {
return Err(VisionError::DimensionMismatch {
expected: teacher_logits.len(),
got: student_logits.len(),
});
}
let p_t = teacher_softmax(teacher_logits, center, tau_teacher)?;
let p_s = student_softmax(student_logits, tau_student)?;
cross_entropy(&p_t, &p_s)
}
#[derive(Debug, Clone)]
pub struct CenteringBuffer {
pub center: Vec<f32>,
pub momentum: f32,
}
impl CenteringBuffer {
#[must_use]
pub fn new(dim: usize, momentum: f32) -> Self {
Self {
center: vec![0.0f32; dim],
momentum,
}
}
pub fn update(&mut self, batch_logits: &[f32]) -> VisionResult<()> {
let dim = self.center.len();
if dim == 0 || batch_logits.is_empty() || batch_logits.len() % dim != 0 {
return Err(VisionError::DimensionMismatch {
expected: dim,
got: batch_logits.len(),
});
}
let batch = batch_logits.len() / dim;
let mut mean = vec![0.0f32; dim];
for b in 0..batch {
for k in 0..dim {
mean[k] += batch_logits[b * dim + k];
}
}
let inv_b = 1.0 / batch as f32;
let lam = self.momentum;
for (c, m) in self.center.iter_mut().zip(mean.iter()) {
let batch_mean = m * inv_b;
*c = lam * *c + (1.0 - lam) * batch_mean;
}
Ok(())
}
}
pub fn ibot_loss(
student_patch_logits: &[f32],
teacher_patch_logits: &[f32],
mask: &[bool],
patch_center: &[f32],
n_proto: usize,
tau_student: f32,
tau_teacher: f32,
) -> VisionResult<f32> {
if n_proto == 0 {
return Err(VisionError::InvalidProjDim(n_proto));
}
let n_patches = mask.len();
if student_patch_logits.len() != n_patches * n_proto
|| teacher_patch_logits.len() != n_patches * n_proto
{
return Err(VisionError::DimensionMismatch {
expected: n_patches * n_proto,
got: student_patch_logits.len(),
});
}
let mut total = 0.0f32;
let mut count = 0usize;
for p in 0..n_patches {
if !mask[p] {
continue;
}
let s = &student_patch_logits[p * n_proto..(p + 1) * n_proto];
let t = &teacher_patch_logits[p * n_proto..(p + 1) * n_proto];
let l = dino_loss(s, t, patch_center, tau_student, tau_teacher)?;
total += l;
count += 1;
}
if count == 0 {
return Ok(0.0);
}
Ok(total / count as f32)
}
#[cfg(test)]
mod tests {
use super::*;
fn l2(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
fn entropy(p: &[f32]) -> f32 {
let mut h = 0.0f32;
for &v in p {
if v > 0.0 {
h -= v * v.ln();
}
}
h
}
fn make_head(seed: u64, k: usize) -> DinoHead {
let mut rng = LcgRng::new(seed);
DinoHead::new(32, 64, 16, k, &mut rng).expect("head ok")
}
#[test]
fn backbone_returns_cls_and_patches() {
let mut rng = LcgRng::new(1);
let cfg = ViTConfig::tiny();
let e = cfg.embed_dim;
let n_patches = cfg.n_patches();
let bb = DinoBackbone::new(cfg, &mut rng).expect("backbone ok");
let img = vec![0.3f32; 3 * 32 * 32];
let out = bb.forward(&img).expect("forward ok");
assert_eq!(out.cls.len(), e, "CLS must be [embed_dim]");
assert_eq!(
out.patches.len(),
n_patches * e,
"patches must be [n_patches, e]"
);
assert_eq!(out.n_patches, n_patches);
assert!(out.cls.iter().all(|v| v.is_finite()));
assert!(out.patches.iter().all(|v| v.is_finite()));
}
#[test]
fn head_prototype_logits_shape_and_softmax() {
let head = make_head(2, 128);
let mut rng = LcgRng::new(3);
let mut x = vec![0.0f32; 32];
rng.fill_normal(&mut x);
let logits = head.forward(&x).expect("ok");
assert_eq!(logits.len(), 128, "prototype logits must be [n_prototypes]");
let p = student_softmax(&logits, 0.1).expect("ok");
let sum: f32 = p.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "softmax must sum to 1; got {sum}");
for &l in &logits {
assert!(
(-1.0 - 1e-4..=1.0 + 1e-4).contains(&l),
"logit out of cosine range: {l}"
);
}
}
#[test]
fn ema_update_moves_teacher_toward_student() {
let mut teacher = make_head(10, 64);
let student = make_head(20, 64); let before = l2(&teacher.flatten(), &student.flatten());
assert!(before > 0.0, "teacher and student must start apart");
teacher.ema_update(&student, 0.9).expect("ema ok");
let after = l2(&teacher.flatten(), &student.flatten());
assert!(
after < before,
"EMA must reduce ‖θ_t − θ_s‖: before={before}, after={after}"
);
assert!(
(after - 0.9 * before).abs() < 1e-3 * before.max(1.0),
"EMA distance should scale by m=0.9: after={after}, 0.9·before={}",
0.9 * before
);
}
#[test]
fn ema_update_shape_mismatch_errors() {
let mut teacher = make_head(10, 64);
let other = make_head(11, 32); let r = teacher.ema_update(&other, 0.9);
assert!(matches!(r, Err(VisionError::Internal(_))));
}
#[test]
fn dino_loss_nonnegative() {
let mut rng = LcgRng::new(30);
for _ in 0..20 {
let mut sl = vec![0.0f32; 16];
let mut tl = vec![0.0f32; 16];
rng.fill_normal(&mut sl);
rng.fill_normal(&mut tl);
let l = dino_loss(&sl, &tl, &[], 0.1, 0.04).expect("ok");
assert!(l >= -1e-6, "DINO loss must be ≥ 0; got {l}");
}
}
#[test]
fn dino_loss_minimised_when_student_matches_teacher() {
let teacher_logits = vec![20.0f32, -20.0, -20.0, -20.0]; let student_logits = vec![20.0f32, -20.0, -20.0, -20.0];
let p_t = teacher_softmax(&teacher_logits, &[], 0.1).expect("ok");
let p_s = student_softmax(&student_logits, 0.1).expect("ok");
let h_self = cross_entropy(&p_t, &p_s).expect("ok");
assert!(
h_self < 1e-3,
"matched ~one-hot dists give ≈0 loss; got {h_self}"
);
let student_bad = vec![-20.0f32, 20.0, -20.0, -20.0]; let p_bad = student_softmax(&student_bad, 0.1).expect("ok");
let h_bad = cross_entropy(&p_t, &p_bad).expect("ok");
assert!(
h_bad > h_self + 1.0,
"mismatched student must raise the loss: self={h_self}, bad={h_bad}"
);
}
#[test]
fn cross_entropy_equals_entropy_at_self() {
let logits = vec![1.0f32, 0.3, -0.5, 2.0, -1.0];
let p = student_softmax(&logits, 1.0).expect("ok");
let ce = cross_entropy(&p, &p).expect("ok");
let ent = entropy(&p);
assert!(
(ce - ent).abs() < 1e-5,
"H(p,p) must equal entropy(p): {ce} vs {ent}"
);
}
#[test]
fn centering_drives_mean_near_zero() {
let dim = 8;
let mut buf = CenteringBuffer::new(dim, 0.9);
let base: Vec<f32> = (0..dim).map(|k| if k == 0 { 5.0 } else { 0.1 }).collect();
let batch = 4;
let mut flat = Vec::new();
for _ in 0..batch {
flat.extend_from_slice(&base);
}
for _ in 0..400 {
buf.update(&flat).expect("ok");
}
let centred_mean: f32 = base
.iter()
.zip(buf.center.iter())
.map(|(&g, &c)| (g - c).abs())
.sum::<f32>()
/ dim as f32;
assert!(
centred_mean < 1e-2,
"centering should drive (g − c) mean ≈ 0; got {centred_mean}"
);
}
#[test]
fn centering_update_bad_shape_errors() {
let mut buf = CenteringBuffer::new(8, 0.9);
let r = buf.update(&[0.0f32; 7]); assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn lower_teacher_temperature_sharpens_distribution() {
let logits = vec![2.0f32, 1.0, 0.5, -0.5, -1.0, 0.2];
let p_hot = teacher_softmax(&logits, &[], 0.04).expect("ok"); let p_soft = teacher_softmax(&logits, &[], 0.5).expect("ok"); let h_hot = entropy(&p_hot);
let h_soft = entropy(&p_soft);
assert!(
h_hot < h_soft,
"lower τ_t must lower entropy (sharper): H(0.04)={h_hot} vs H(0.5)={h_soft}"
);
let max_hot = p_hot.iter().cloned().fold(0.0f32, f32::max);
let max_soft = p_soft.iter().cloned().fold(0.0f32, f32::max);
assert!(max_hot > max_soft, "sharper dist must have a higher peak");
}
#[test]
fn nudging_student_toward_teacher_lowers_loss() {
let teacher_logits = vec![1.5f32, -0.5, 0.7, -1.2, 0.3, 0.9];
let student_before = vec![-1.0f32, 0.8, -0.3, 1.1, -0.6, 0.0];
let tau = 0.1;
let loss_before = dino_loss(&student_before, &teacher_logits, &[], tau, tau).expect("ok");
let alpha = 0.6f32;
let student_after: Vec<f32> = student_before
.iter()
.zip(teacher_logits.iter())
.map(|(&s, &t)| s + alpha * (t - s))
.collect();
let loss_after = dino_loss(&student_after, &teacher_logits, &[], tau, tau).expect("ok");
assert!(
loss_after < loss_before,
"moving the student toward the teacher must lower the loss: before={loss_before}, after={loss_after}"
);
}
#[test]
fn two_views_loss_decreases_when_student_aligns() {
let head = make_head(40, 32);
let mut rng = LcgRng::new(41);
let mut view_a = vec![0.0f32; 32];
let mut view_b = vec![0.0f32; 32];
rng.fill_normal(&mut view_a);
rng.fill_normal(&mut view_b);
let teacher_logits = head.forward(&view_a).expect("ok");
let student_logits = head.forward(&view_b).expect("ok");
let tau = 0.1;
let loss_before = dino_loss(&student_logits, &teacher_logits, &[], tau, tau).expect("ok");
let nudged: Vec<f32> = student_logits
.iter()
.zip(teacher_logits.iter())
.map(|(&s, &t)| s + 0.5 * (t - s))
.collect();
let loss_after = dino_loss(&nudged, &teacher_logits, &[], tau, tau).expect("ok");
assert!(
loss_after < loss_before,
"aligning student to teacher across views must lower loss: {loss_before} → {loss_after}"
);
}
#[test]
fn nonpositive_temperature_errors() {
let r = student_softmax(&[1.0, 2.0], 0.0);
assert!(matches!(r, Err(VisionError::NonPositiveTemperature(_))));
let r2 = teacher_softmax(&[1.0, 2.0], &[], -0.1);
assert!(matches!(r2, Err(VisionError::NonPositiveTemperature(_))));
}
#[test]
fn ibot_loss_only_counts_masked_patches() {
let n_patches = 4;
let n_proto = 6;
let mut rng = LcgRng::new(50);
let mut s = vec![0.0f32; n_patches * n_proto];
let mut t = vec![0.0f32; n_patches * n_proto];
rng.fill_normal(&mut s);
rng.fill_normal(&mut t);
let none = vec![false; n_patches];
let l0 = ibot_loss(&s, &t, &none, &[], n_proto, 0.1, 0.04).expect("ok");
assert_eq!(l0, 0.0, "no masked patches ⇒ zero iBOT loss");
let mut mask = vec![false; n_patches];
mask[0] = true;
mask[2] = true;
let l = ibot_loss(&s, &t, &mask, &[], n_proto, 0.1, 0.04).expect("ok");
let l_p0 = dino_loss(&s[0..n_proto], &t[0..n_proto], &[], 0.1, 0.04).expect("ok");
let l_p2 = dino_loss(
&s[2 * n_proto..3 * n_proto],
&t[2 * n_proto..3 * n_proto],
&[],
0.1,
0.04,
)
.expect("ok");
let expected = 0.5 * (l_p0 + l_p2);
assert!(
(l - expected).abs() < 1e-5,
"iBOT loss must average masked-patch losses: {l} vs {expected}"
);
assert!(l >= 0.0, "iBOT loss must be ≥ 0");
}
#[test]
fn ibot_loss_nudging_masked_student_lowers_loss() {
let n_proto = 5;
let teacher = vec![
1.2f32, -0.4, 0.6, -1.0, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1,
];
let student = vec![
-1.0f32, 0.7, -0.2, 1.0, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0,
];
let mask = vec![true, false];
let tau = 0.1;
let before = ibot_loss(&student, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
let mut nudged = student.clone();
for k in 0..n_proto {
nudged[k] += 0.6 * (teacher[k] - student[k]);
}
let after = ibot_loss(&nudged, &teacher, &mask, &[], n_proto, tau, tau).expect("ok");
assert!(
after < before,
"nudging masked student patch toward teacher must lower iBOT loss: {before} → {after}"
);
}
#[test]
fn ibot_loss_bad_shape_errors() {
let mask = vec![true, false];
let r = ibot_loss(&[0.0f32; 5], &[0.0f32; 10], &mask, &[], 5, 0.1, 0.04);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn head_forward_batch_matches_single() {
let head = make_head(60, 32);
let mut rng = LcgRng::new(61);
let batch = 3;
let mut x = vec![0.0f32; batch * 32];
rng.fill_normal(&mut x);
let all = head.forward_batch(&x).expect("ok");
let k = head.n_prototypes();
for b in 0..batch {
let single = head.forward(&x[b * 32..(b + 1) * 32]).expect("ok");
for (j, &v) in single.iter().enumerate() {
assert!(
(all[b * k + j] - v).abs() < 1e-6,
"batch vs single mismatch at b={b}, j={j}"
);
}
}
}
#[test]
fn head_dimension_mismatch_errors() {
let head = make_head(70, 32);
let r = head.forward(&[0.0f32; 31]);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn head_zero_prototypes_errors() {
let mut rng = LcgRng::new(80);
let r = DinoHead::new(32, 64, 16, 0, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidProjDim(0))));
}
}