use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
pub type SslRng = LcgRng;
#[derive(Debug, Clone)]
pub struct DinoV2Config {
pub n_register_tokens: usize,
pub d_model: usize,
pub n_prototypes: usize,
pub temp_student: f32,
pub temp_teacher: f32,
pub koleo_weight: f32,
pub center_momentum: f32,
}
impl Default for DinoV2Config {
fn default() -> Self {
Self {
n_register_tokens: 4,
d_model: 64,
n_prototypes: 8,
temp_student: 0.1,
temp_teacher: 0.04,
koleo_weight: 0.1,
center_momentum: 0.9,
}
}
}
#[derive(Debug)]
pub struct DinoV2 {
prototypes: Vec<f32>,
center: Vec<f32>,
config: DinoV2Config,
}
fn row_softmax_temp(scores: &[f32], n: usize, k: usize, t: f32) -> Vec<f32> {
let mut out = Vec::with_capacity(n * k);
for i in 0..n {
let row = &scores[i * k..(i + 1) * k];
let mut row_max = f32::NEG_INFINITY;
for &v in row {
let scaled = v / t;
if scaled > row_max {
row_max = scaled;
}
}
let mut exps = Vec::with_capacity(k);
let mut s = 0.0_f64;
for &v in row {
let e = ((v / t - row_max) as f64).exp();
exps.push(e);
s += e;
}
let inv = 1.0_f64 / s.max(1e-30_f64);
for e in exps {
out.push((e * inv) as f32);
}
}
out
}
#[inline]
fn l2_norm(v: &[f32]) -> f32 {
let sq: f32 = v.iter().map(|x| x * x).sum();
sq.sqrt()
}
#[inline]
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na = l2_norm(a);
let nb = l2_norm(b);
let denom = na * nb;
if denom < 1e-12 {
0.0
} else {
(dot / denom).clamp(-1.0, 1.0)
}
}
impl DinoV2 {
pub fn new(config: DinoV2Config, rng: &mut SslRng) -> SslResult<Self> {
if config.d_model == 0 {
return Err(SslError::InvalidFeatureDim);
}
if config.n_prototypes < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
if !(config.temp_student.is_finite() && config.temp_student > 0.0) {
return Err(SslError::InvalidTemperature {
temp: config.temp_student,
});
}
if !(config.temp_teacher.is_finite() && config.temp_teacher > 0.0) {
return Err(SslError::InvalidTemperature {
temp: config.temp_teacher,
});
}
if !(config.center_momentum.is_finite()
&& (0.0_f32..=1.0_f32).contains(&config.center_momentum))
{
return Err(SslError::InvalidMomentum {
momentum: config.center_momentum,
});
}
let p = config.n_prototypes;
let d = config.d_model;
let scale = 1.0_f32 / (d as f32).sqrt();
let mut prototypes = Vec::with_capacity(p * d);
for _ in 0..p * d {
prototypes.push((rng.next_f32() * 2.0 - 1.0) * scale);
}
let center = vec![0.0_f32; p];
Ok(Self {
prototypes,
center,
config,
})
}
#[must_use]
#[inline]
pub fn d_model(&self) -> usize {
self.config.d_model
}
#[must_use]
#[inline]
pub fn center(&self) -> &[f32] {
&self.center
}
pub fn compute_scores(&self, features: &[f32], n_samples: usize) -> SslResult<Vec<f32>> {
let d = self.config.d_model;
let p = self.config.n_prototypes;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * d;
if features.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: features.len(),
});
}
let mut scores = vec![0.0_f32; n_samples * p];
for i in 0..n_samples {
for k in 0..p {
let mut dot = 0.0_f32;
for j in 0..d {
dot += features[i * d + j] * self.prototypes[k * d + j];
}
scores[i * p + k] = dot;
}
}
Ok(scores)
}
pub fn student_probs(&self, scores: &[f32], n_samples: usize) -> SslResult<Vec<f32>> {
let p = self.config.n_prototypes;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * p;
if scores.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: scores.len(),
});
}
Ok(row_softmax_temp(
scores,
n_samples,
p,
self.config.temp_student,
))
}
pub fn teacher_probs(&self, scores: &[f32], n_samples: usize) -> SslResult<Vec<f32>> {
let p = self.config.n_prototypes;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * p;
if scores.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: scores.len(),
});
}
let mut centred = scores.to_vec();
for i in 0..n_samples {
for k in 0..p {
centred[i * p + k] -= self.center[k];
}
}
Ok(row_softmax_temp(
¢red,
n_samples,
p,
self.config.temp_teacher,
))
}
pub fn dino_v2_loss(
&self,
student_scores: &[f32],
teacher_scores: &[f32],
n_samples: usize,
) -> SslResult<f32> {
let p = self.config.n_prototypes;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * p;
if student_scores.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: student_scores.len(),
});
}
if teacher_scores.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: teacher_scores.len(),
});
}
let p_s = self.student_probs(student_scores, n_samples)?;
let p_t = self.teacher_probs(teacher_scores, n_samples)?;
let mut total = 0.0_f64;
for i in 0..n_samples {
for k in 0..p {
let log_ps = ((p_s[i * p + k] as f64) + 1e-8_f64).ln();
total -= (p_t[i * p + k] as f64) * log_ps;
}
}
Ok((total / n_samples as f64) as f32)
}
pub fn update_center(&mut self, teacher_scores: &[f32], n_samples: usize) -> SslResult<()> {
let p = self.config.n_prototypes;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * p;
if teacher_scores.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: teacher_scores.len(),
});
}
let p_t = self.teacher_probs(teacher_scores, n_samples)?;
let m = self.config.center_momentum;
let inv_n = 1.0_f32 / n_samples as f32;
for k in 0..p {
let mut mean_k = 0.0_f32;
for i in 0..n_samples {
mean_k += p_t[i * p + k];
}
mean_k *= inv_n;
self.center[k] = m * self.center[k] + (1.0 - m) * mean_k;
}
Ok(())
}
pub fn koleo_loss(&self, features: &[f32], n_samples: usize) -> SslResult<f32> {
let d = self.config.d_model;
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_samples * d;
if features.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: features.len(),
});
}
if n_samples < 2 {
return Ok(0.0);
}
let mut total = 0.0_f64;
for i in 0..n_samples {
let fi = &features[i * d..(i + 1) * d];
let mut min_sim = f32::INFINITY;
for j in 0..n_samples {
if j == i {
continue;
}
let fj = &features[j * d..(j + 1) * d];
let s = cosine_sim(fi, fj);
if s < min_sim {
min_sim = s;
}
}
let sim_shifted = (min_sim + 1.0).max(0.0);
total -= (sim_shifted as f64 + 1e-8_f64).ln();
}
Ok((total / n_samples as f64) as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> DinoV2Config {
DinoV2Config {
n_register_tokens: 4,
d_model: 8,
n_prototypes: 4,
temp_student: 0.1,
temp_teacher: 0.04,
koleo_weight: 0.1,
center_momentum: 0.9,
}
}
fn make_rng() -> SslRng {
LcgRng::new(99)
}
#[test]
fn compute_scores_shape() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let d = model.d_model();
let n = 6_usize;
let feats = vec![0.3_f32; n * d];
let scores = model
.compute_scores(&feats, n)
.expect("compute_scores should succeed");
assert_eq!(scores.len(), n * default_config().n_prototypes);
}
#[test]
fn student_probs_sum_to_one() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let p = default_config().n_prototypes;
let n = 5_usize;
let scores: Vec<f32> = (0..n * p).map(|i| (i as f32 * 0.17).sin()).collect();
let probs = model
.student_probs(&scores, n)
.expect("student_probs should succeed");
for i in 0..n {
let row_sum: f32 = probs[i * p..(i + 1) * p].iter().sum();
assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum = {row_sum}");
}
}
#[test]
fn teacher_probs_sum_to_one() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let p = default_config().n_prototypes;
let n = 5_usize;
let scores: Vec<f32> = (0..n * p).map(|i| (i as f32 * 0.23).cos()).collect();
let probs = model
.teacher_probs(&scores, n)
.expect("teacher_probs should succeed");
for i in 0..n {
let row_sum: f32 = probs[i * p..(i + 1) * p].iter().sum();
assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum = {row_sum}");
}
}
#[test]
fn dino_v2_loss_finite() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let p = default_config().n_prototypes;
let n = 8_usize;
let student_scores: Vec<f32> = (0..n * p).map(|i| (i as f32 * 0.09).sin()).collect();
let teacher_scores: Vec<f32> = (0..n * p).map(|i| (i as f32 * 0.11).cos()).collect();
let loss = model
.dino_v2_loss(&student_scores, &teacher_scores, n)
.expect("value should be present");
assert!(loss.is_finite(), "loss must be finite, got {loss}");
assert!(loss > 0.0, "cross-entropy must be positive, got {loss}");
}
#[test]
fn update_center_changes_center() {
let mut rng = make_rng();
let mut model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let p = default_config().n_prototypes;
let n = 4_usize;
let scores: Vec<f32> = (0..n * p).map(|i| (i as f32 * 0.31).sin() + 1.0).collect();
let center_before = model.center().to_vec();
model
.update_center(&scores, n)
.expect("update_center should succeed");
let center_after = model.center().to_vec();
let changed = center_before
.iter()
.zip(center_after.iter())
.any(|(a, b)| (a - b).abs() > 1e-7);
assert!(changed, "center should change after update");
}
#[test]
fn d_model_zero_error() {
let mut rng = make_rng();
let mut cfg = default_config();
cfg.d_model = 0;
let result = DinoV2::new(cfg, &mut rng);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), SslError::InvalidFeatureDim));
}
#[test]
fn n_prototypes_too_small_error() {
let mut rng = make_rng();
let mut cfg = default_config();
cfg.n_prototypes = 1;
let result = DinoV2::new(cfg, &mut rng);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SslError::NumPrototypesTooSmall
));
}
#[test]
fn koleo_loss_nonneg() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let d = model.d_model();
let n = 6_usize;
let feats: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.19).sin()).collect();
let loss = model
.koleo_loss(&feats, n)
.expect("koleo_loss should succeed");
assert!(loss >= 0.0, "KoLeo loss must be non-negative, got {loss}");
assert!(loss.is_finite(), "KoLeo loss must be finite, got {loss}");
}
#[test]
fn invalid_temperature_error() {
let mut rng = make_rng();
let mut cfg = default_config();
cfg.temp_student = 0.0;
let result = DinoV2::new(cfg, &mut rng);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SslError::InvalidTemperature { .. }
));
}
#[test]
fn invalid_momentum_error() {
let mut rng = make_rng();
let mut cfg = default_config();
cfg.center_momentum = 1.5;
let result = DinoV2::new(cfg, &mut rng);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SslError::InvalidMomentum { .. }
));
}
#[test]
fn koleo_single_sample_returns_zero() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let d = model.d_model();
let feats = vec![0.5_f32; d];
let loss = model
.koleo_loss(&feats, 1)
.expect("koleo_loss should succeed");
assert_eq!(loss, 0.0);
}
#[test]
fn compute_scores_dim_mismatch_error() {
let mut rng = make_rng();
let model = DinoV2::new(default_config(), &mut rng).expect("value should be present");
let feats = vec![0.1_f32; 3]; let result = model.compute_scores(&feats, 2);
assert!(matches!(
result.unwrap_err(),
SslError::DimensionMismatch { .. }
));
}
}