use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone)]
pub struct MsnConfig {
pub n_prototypes: usize,
pub tau_anchor: f32,
pub tau_target: f32,
pub lambda_reg: f32,
pub eps: f32,
}
impl Default for MsnConfig {
fn default() -> Self {
Self {
n_prototypes: 64,
tau_anchor: 0.10,
tau_target: 0.25,
lambda_reg: 1.0,
eps: 1e-8,
}
}
}
impl MsnConfig {
pub fn new(
n_prototypes: usize,
tau_anchor: f32,
tau_target: f32,
lambda_reg: f32,
eps: f32,
) -> SslResult<Self> {
if n_prototypes < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
for t in [tau_anchor, tau_target] {
if !(t.is_finite() && t > 0.0) {
return Err(SslError::InvalidTemperature { temp: t });
}
}
if !lambda_reg.is_finite() {
return Err(SslError::InvalidLossWeight { weight: lambda_reg });
}
let eps_val = eps.max(0.0);
Ok(Self {
n_prototypes,
tau_anchor,
tau_target,
lambda_reg,
eps: eps_val,
})
}
}
#[derive(Debug, Clone)]
pub struct MsnPrototypes {
pub weights: Vec<f32>,
pub n_prototypes: usize,
pub dim: usize,
}
impl MsnPrototypes {
#[inline]
#[must_use]
pub fn row(&self, k: usize) -> &[f32] {
let start = k * self.dim;
&self.weights[start..start + self.dim]
}
#[inline]
pub fn row_mut(&mut self, k: usize) -> &mut [f32] {
let start = k * self.dim;
&mut self.weights[start..start + self.dim]
}
}
#[derive(Debug, Clone)]
pub struct MsnResult {
pub loss: f32,
pub ce_loss: f32,
pub me_max_loss: f32,
pub mean_entropy: f32,
pub mean_prototype_util: f32,
}
#[inline]
fn row_softmax_t(scores: &[f32], n: usize, k: usize, temperature: f32) -> Vec<f32> {
let mut out = vec![0.0_f32; n * k];
for i in 0..n {
let row = &scores[i * k..(i + 1) * k];
let mut max_v = f32::NEG_INFINITY;
for &v in row {
let scaled = v / temperature;
if scaled > max_v {
max_v = scaled;
}
}
let mut sum_exp = 0.0_f64;
let mut exps = Vec::with_capacity(k);
for &v in row {
let e = ((v / temperature - max_v) as f64).exp();
exps.push(e);
sum_exp += e;
}
let inv = 1.0_f64 / sum_exp.max(1e-30_f64);
let out_row = &mut out[i * k..(i + 1) * k];
for (o, e) in out_row.iter_mut().zip(exps.iter()) {
*o = (*e * inv) as f32;
}
}
out
}
#[inline]
fn compute_scores(
features: &[f32],
prototypes: &MsnPrototypes,
batch_size: usize,
feat_dim: usize,
) -> Vec<f32> {
let k = prototypes.n_prototypes;
let mut scores = vec![0.0_f32; batch_size * k];
for b in 0..batch_size {
let f_row = &features[b * feat_dim..(b + 1) * feat_dim];
for proto_k in 0..k {
let p_row = prototypes.row(proto_k);
let dot: f32 = f_row
.iter()
.zip(p_row.iter())
.map(|(&a, &b_val)| a * b_val)
.sum();
scores[b * k + proto_k] = dot;
}
}
scores
}
#[inline]
fn l2_normalize_rows(weights: &mut [f32], n_rows: usize, dim: usize) {
for i in 0..n_rows {
let row = &mut weights[i * dim..(i + 1) * dim];
let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
if norm > 1e-12 {
let inv = 1.0 / norm;
for v in row.iter_mut() {
*v *= inv;
}
}
}
}
pub fn msn_loss(
anchor_features: &[f32],
target_features: &[f32],
prototypes: &MsnPrototypes,
batch_size: usize,
feat_dim: usize,
config: &MsnConfig,
) -> SslResult<MsnResult> {
if batch_size == 0 || feat_dim == 0 {
return Err(SslError::EmptyInput);
}
if config.n_prototypes < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
for t in [config.tau_anchor, config.tau_target] {
if !(t.is_finite() && t > 0.0) {
return Err(SslError::InvalidTemperature { temp: t });
}
}
let expected_feat = batch_size * feat_dim;
if anchor_features.len() != expected_feat {
return Err(SslError::DimensionMismatch {
expected: expected_feat,
got: anchor_features.len(),
});
}
if target_features.len() != expected_feat {
return Err(SslError::DimensionMismatch {
expected: expected_feat,
got: target_features.len(),
});
}
let k = prototypes.n_prototypes;
let expected_proto = k * feat_dim;
if prototypes.dim != feat_dim {
return Err(SslError::DimensionMismatch {
expected: feat_dim,
got: prototypes.dim,
});
}
if prototypes.weights.len() != expected_proto {
return Err(SslError::DimensionMismatch {
expected: expected_proto,
got: prototypes.weights.len(),
});
}
let anchor_scores = compute_scores(anchor_features, prototypes, batch_size, feat_dim);
let target_scores = compute_scores(target_features, prototypes, batch_size, feat_dim);
let q = row_softmax_t(&anchor_scores, batch_size, k, config.tau_anchor);
let p = row_softmax_t(&target_scores, batch_size, k, config.tau_target);
let mut ce_sum = 0.0_f64;
for b in 0..batch_size {
for ki in 0..k {
let q_bk = q[b * k + ki] as f64;
let p_bk = (p[b * k + ki] as f64).max(1e-30_f64);
ce_sum -= q_bk * p_bk.ln();
}
}
let ce_loss = (ce_sum / batch_size as f64) as f32;
let inv_b = 1.0_f64 / batch_size as f64;
let mut p_avg = vec![0.0_f64; k];
for b in 0..batch_size {
for ki in 0..k {
p_avg[ki] += q[b * k + ki] as f64;
}
}
for v in p_avg.iter_mut() {
*v *= inv_b;
}
let eps64 = config.eps as f64;
let mut me_max_sum = 0.0_f64;
for &pk in p_avg.iter() {
me_max_sum -= pk * (pk + eps64).ln();
}
let me_max_loss = me_max_sum as f32;
let total_loss = ce_loss - config.lambda_reg * me_max_loss;
if !total_loss.is_finite() {
return Err(SslError::NanEncountered {
location: "msn_loss: total",
});
}
let threshold = 1.0_f64 / k as f64;
let n_used = p_avg.iter().filter(|&&v| v > threshold).count();
let mean_prototype_util = n_used as f32 / k as f32;
Ok(MsnResult {
loss: total_loss,
ce_loss,
me_max_loss,
mean_entropy: me_max_loss,
mean_prototype_util,
})
}
pub fn msn_random_mask(n_tokens: usize, mask_ratio: f32, rng: &mut LcgRng) -> SslResult<Vec<bool>> {
if n_tokens == 0 {
return Err(SslError::EmptyInput);
}
if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
}
let n_masked = (n_tokens as f32 * mask_ratio) as usize;
let mut indices: Vec<usize> = (0..n_tokens).collect();
rng.shuffle(&mut indices);
let mut mask = vec![false; n_tokens];
for &idx in indices.iter().take(n_masked) {
mask[idx] = true;
}
Ok(mask)
}
#[must_use]
pub fn msn_prototype_init(n_prototypes: usize, dim: usize, rng: &mut LcgRng) -> MsnPrototypes {
let total = n_prototypes * dim;
let mut weights = vec![0.0_f32; total];
rng.fill_normal(&mut weights);
l2_normalize_rows(&mut weights, n_prototypes, dim);
MsnPrototypes {
weights,
n_prototypes,
dim,
}
}
pub fn msn_update_prototypes(
prototypes: &mut MsnPrototypes,
anchor_features: &[f32],
batch_size: usize,
feat_dim: usize,
lr: f32,
config: &MsnConfig,
) -> SslResult<()> {
if batch_size == 0 || feat_dim == 0 {
return Err(SslError::EmptyInput);
}
if config.n_prototypes < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
for t in [config.tau_anchor, config.tau_target] {
if !(t.is_finite() && t > 0.0) {
return Err(SslError::InvalidTemperature { temp: t });
}
}
let expected_feat = batch_size * feat_dim;
if anchor_features.len() != expected_feat {
return Err(SslError::DimensionMismatch {
expected: expected_feat,
got: anchor_features.len(),
});
}
if prototypes.dim != feat_dim {
return Err(SslError::DimensionMismatch {
expected: feat_dim,
got: prototypes.dim,
});
}
let k = prototypes.n_prototypes;
let anchor_scores = compute_scores(anchor_features, prototypes, batch_size, feat_dim);
let q = row_softmax_t(&anchor_scores, batch_size, k, config.tau_anchor);
let p = row_softmax_t(&anchor_scores, batch_size, k, config.tau_target);
let inv_b = 1.0_f32 / batch_size as f32;
let mut grad = vec![0.0_f32; k * feat_dim];
for b in 0..batch_size {
let feat_row = &anchor_features[b * feat_dim..(b + 1) * feat_dim];
for ki in 0..k {
let delta = (q[b * k + ki] - p[b * k + ki]) * inv_b;
let g_row = &mut grad[ki * feat_dim..(ki + 1) * feat_dim];
for (g, &f) in g_row.iter_mut().zip(feat_row.iter()) {
*g -= delta * f; }
}
}
for (w, g) in prototypes.weights.iter_mut().zip(grad.iter()) {
*w -= lr * g;
}
l2_normalize_rows(&mut prototypes.weights, k, feat_dim);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_features_normalised(batch_size: usize, dim: usize, seed_offset: f32) -> Vec<f32> {
let mut feats = Vec::with_capacity(batch_size * dim);
for b in 0..batch_size {
let mut row: Vec<f32> = (0..dim)
.map(|d| (b as f32 * 0.31 + d as f32 * 0.17 + seed_offset).sin())
.collect();
let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
for v in row.iter_mut() {
*v /= norm;
}
feats.extend_from_slice(&row);
}
feats
}
#[test]
fn loss_is_finite_and_non_negative() {
let mut rng = LcgRng::new(42);
let b = 4;
let d = 16;
let cfg = MsnConfig::default();
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let target = make_features_normalised(b, d, 1.0);
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert!(result.loss.is_finite(), "loss not finite: {}", result.loss);
assert!(
result.ce_loss >= 0.0,
"ce_loss negative: {}",
result.ce_loss
);
}
#[test]
fn equal_temperatures_identical_inputs_ce_bounded_by_ln_k() {
let mut rng = LcgRng::new(7);
let b = 8;
let d = 32;
let k = 16;
let cfg = MsnConfig {
n_prototypes: k,
tau_anchor: 0.15,
tau_target: 0.15, lambda_reg: 0.0, eps: 1e-8,
};
let protos = msn_prototype_init(k, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let result =
msn_loss(&anchor, &anchor, &protos, b, d, &cfg).expect("msn_loss should succeed");
let max_ce = (k as f32).ln(); assert!(
result.ce_loss <= max_ce + 1e-4,
"CE ({}) must be ≤ ln(K) = {}",
result.ce_loss,
max_ce
);
assert!(
result.ce_loss >= 0.0,
"CE must be non-negative, got {}",
result.ce_loss
);
}
#[test]
fn uniform_anchors_high_me_max_entropy() {
let b = 8;
let d = 8;
let k = 8;
let cfg = MsnConfig {
n_prototypes: k,
tau_anchor: 0.1,
tau_target: 0.25,
lambda_reg: 1.0,
eps: 1e-8,
};
let mut proto_weights = vec![0.0_f32; k * d];
for ki in 0..k {
proto_weights[ki * d + ki] = 1.0; }
let protos = MsnPrototypes {
weights: proto_weights,
n_prototypes: k,
dim: d,
};
let uniform_val = 1.0_f32 / (d as f32).sqrt();
let anchor: Vec<f32> = vec![uniform_val; b * d];
let target: Vec<f32> = vec![uniform_val; b * d];
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
let expected_h = (k as f32).ln();
assert!(
(result.me_max_loss - expected_h).abs() < 0.1,
"Expected entropy ≈ ln({}) ≈ {:.4}, got {:.4}",
k,
expected_h,
result.me_max_loss
);
}
#[test]
fn ce_loss_always_non_negative() {
let mut rng = LcgRng::new(123);
let b = 6;
let d = 24;
let cfg = MsnConfig::default();
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 2.5);
let target = make_features_normalised(b, d, 3.7);
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert!(
result.ce_loss >= 0.0,
"CE must be non-negative, got {}",
result.ce_loss
);
}
#[test]
fn invalid_n_prototypes_returns_error() {
let d = 8;
let anchor = make_features_normalised(2, d, 0.0);
let target = make_features_normalised(2, d, 1.0);
let cfg = MsnConfig {
n_prototypes: 1, ..MsnConfig::default()
};
let protos = MsnPrototypes {
weights: vec![1.0_f32; d],
n_prototypes: 1,
dim: d,
};
let result = msn_loss(&anchor, &target, &protos, 2, d, &cfg);
assert!(
result.is_err(),
"Expected error for n_prototypes < 2, got Ok"
);
}
#[test]
fn invalid_temperature_zero_returns_error() {
let cfg_result = MsnConfig::new(16, 0.0, 0.25, 1.0, 1e-8);
assert!(
cfg_result.is_err(),
"Expected InvalidTemperature for tau_anchor=0"
);
let cfg_result2 = MsnConfig::new(16, 0.1, 0.0, 1.0, 1e-8);
assert!(
cfg_result2.is_err(),
"Expected InvalidTemperature for tau_target=0"
);
}
#[test]
fn near_full_masking_loss_still_finite() {
let mut rng = LcgRng::new(99);
let b = 4;
let d = 16;
let cfg = MsnConfig::default();
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let target: Vec<f32> = vec![1.0_f32 / (d as f32).sqrt(); b * d]; let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert!(result.loss.is_finite(), "Expected finite loss, got NaN");
}
#[test]
fn single_sample_batch_works() {
let mut rng = LcgRng::new(55);
let b = 1;
let d = 8;
let cfg = MsnConfig {
n_prototypes: 4,
..MsnConfig::default()
};
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let target = make_features_normalised(b, d, 1.0);
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert!(result.loss.is_finite(), "Single-sample loss not finite");
}
#[test]
fn prototype_util_in_unit_interval() {
let mut rng = LcgRng::new(11);
let b = 8;
let d = 32;
let cfg = MsnConfig::default();
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let target = make_features_normalised(b, d, 1.0);
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert!(
(0.0..=1.0).contains(&result.mean_prototype_util),
"util out of [0,1]: {}",
result.mean_prototype_util
);
}
#[test]
fn prototype_init_rows_are_unit_norm() {
let mut rng = LcgRng::new(17);
let k = 32;
let d = 64;
let protos = msn_prototype_init(k, d, &mut rng);
for ki in 0..k {
let row = protos.row(ki);
let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Prototype row {ki} has norm {norm:.6}, expected ≈ 1.0"
);
}
}
#[test]
fn random_mask_exact_count() {
let mut rng = LcgRng::new(31);
let n_tokens = 196;
let mask_ratio = 0.75_f32;
let mask = msn_random_mask(n_tokens, mask_ratio, &mut rng)
.expect("msn_random_mask should succeed");
assert_eq!(mask.len(), n_tokens);
let n_masked = mask.iter().filter(|&&v| v).count();
let expected = (n_tokens as f32 * mask_ratio) as usize; assert_eq!(
n_masked, expected,
"Expected {expected} masked tokens, got {n_masked}"
);
}
#[test]
fn update_prototypes_keeps_unit_norm() {
let mut rng = LcgRng::new(77);
let b = 4;
let d = 16;
let cfg = MsnConfig {
n_prototypes: 8,
..MsnConfig::default()
};
let mut protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
msn_update_prototypes(&mut protos, &anchor, b, d, 0.01, &cfg)
.expect("msn_update_prototypes should succeed");
for ki in 0..cfg.n_prototypes {
let row = protos.row(ki);
let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"After update, prototype {ki} has norm {norm:.6}, expected ≈ 1.0"
);
}
}
#[test]
fn random_mask_rejects_invalid_ratio() {
let mut rng = LcgRng::new(42);
assert!(msn_random_mask(16, 1.0, &mut rng).is_err()); assert!(msn_random_mask(16, -0.1, &mut rng).is_err());
assert!(msn_random_mask(0, 0.5, &mut rng).is_err()); }
#[test]
fn loss_rejects_dim_mismatch() {
let mut rng = LcgRng::new(42);
let d = 16;
let protos = msn_prototype_init(8, d, &mut rng);
let anchor = vec![0.1_f32; 3 * d]; let target = vec![0.1_f32; 4 * d]; let cfg = MsnConfig {
n_prototypes: 8,
..MsnConfig::default()
};
let result = msn_loss(&anchor, &target, &protos, 3, d, &cfg);
assert!(result.is_err(), "Expected DimensionMismatch error");
}
#[test]
fn me_max_and_mean_entropy_are_equal() {
let mut rng = LcgRng::new(88);
let b = 4;
let d = 16;
let cfg = MsnConfig::default();
let protos = msn_prototype_init(cfg.n_prototypes, d, &mut rng);
let anchor = make_features_normalised(b, d, 0.0);
let target = make_features_normalised(b, d, 1.0);
let result =
msn_loss(&anchor, &target, &protos, b, d, &cfg).expect("msn_loss should succeed");
assert_eq!(
result.me_max_loss, result.mean_entropy,
"me_max_loss and mean_entropy must be identical diagnostic values"
);
}
}