use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct MocoQueue {
pub capacity: usize,
pub dim: usize,
pub data: Vec<f32>,
pub head: usize,
pub len: usize,
}
impl MocoQueue {
pub fn new(capacity: usize, dim: usize) -> SslResult<Self> {
if capacity == 0 {
return Err(SslError::QueueCapacityTooSmall);
}
if dim == 0 {
return Err(SslError::InvalidFeatureDim);
}
Ok(Self {
capacity,
dim,
data: vec![0.0_f32; capacity * dim],
head: 0,
len: 0,
})
}
pub fn enqueue(&mut self, batch: &[f32]) -> SslResult<()> {
if batch.is_empty() {
return Ok(());
}
if batch.len() % self.dim != 0 {
return Err(SslError::DimensionMismatch {
expected: self.dim,
got: batch.len(),
});
}
let n = batch.len() / self.dim;
for i in 0..n {
let src = &batch[i * self.dim..(i + 1) * self.dim];
let dst = &mut self.data[self.head * self.dim..(self.head + 1) * self.dim];
dst.copy_from_slice(src);
self.head = (self.head + 1) % self.capacity;
if self.len < self.capacity {
self.len += 1;
}
}
Ok(())
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn entries(&self) -> &[f32] {
&self.data[..self.len * self.dim]
}
}
pub fn moco_loss(
q: &[f32],
k_pos: &[f32],
batch: usize,
dim: usize,
queue: &MocoQueue,
temperature: f32,
) -> SslResult<f32> {
if q.is_empty() || batch == 0 || dim == 0 {
return Err(SslError::EmptyInput);
}
if !(temperature.is_finite() && temperature > 0.0) {
return Err(SslError::InvalidTemperature { temp: temperature });
}
if q.len() != batch * dim {
return Err(SslError::DimensionMismatch {
expected: batch * dim,
got: q.len(),
});
}
if k_pos.len() != batch * dim {
return Err(SslError::DimensionMismatch {
expected: batch * dim,
got: k_pos.len(),
});
}
if queue.dim != dim {
return Err(SslError::DimensionMismatch {
expected: dim,
got: queue.dim,
});
}
if queue.is_empty() {
return Err(SslError::QueueEmpty);
}
let q_n = l2_normalize_clone(q, batch, dim);
let k_n = l2_normalize_clone(k_pos, batch, dim);
let neg = l2_normalize_clone(queue.entries(), queue.len(), dim);
let inv_t = 1.0_f32 / temperature;
let mut total_loss = 0.0_f64;
for i in 0..batch {
let q_row = &q_n[i * dim..(i + 1) * dim];
let k_row = &k_n[i * dim..(i + 1) * dim];
let mut pos = 0.0_f32;
for (a, b) in q_row.iter().zip(k_row.iter()) {
pos += a * b;
}
pos *= inv_t;
let mut max_v = pos;
let mut neg_logits: Vec<f32> = Vec::with_capacity(queue.len());
for k in 0..queue.len() {
let neg_row = &neg[k * dim..(k + 1) * dim];
let mut s = 0.0_f32;
for (a, b) in q_row.iter().zip(neg_row.iter()) {
s += a * b;
}
let l = s * inv_t;
if l > max_v {
max_v = l;
}
neg_logits.push(l);
}
let mut sum_exp = ((pos - max_v) as f64).exp();
for &l in &neg_logits {
sum_exp += ((l - max_v) as f64).exp();
}
let log_z = (max_v as f64) + sum_exp.ln();
total_loss += -((pos as f64) - log_z);
}
Ok((total_loss / batch as f64) as f32)
}
fn l2_normalize_clone(z: &[f32], n: usize, d: usize) -> Vec<f32> {
let mut out = z.to_vec();
for row in out.chunks_mut(d) {
let s: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
let inv = if s > 1e-12 { 1.0 / s } else { 1.0 };
for v in row.iter_mut() {
*v *= inv;
}
}
let _ = n;
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn queue_new_zero_capacity_errors() {
assert!(MocoQueue::new(0, 4).is_err());
}
#[test]
fn queue_new_zero_dim_errors() {
assert!(MocoQueue::new(4, 0).is_err());
}
#[test]
fn queue_enqueue_grows_until_capacity() {
let mut q = MocoQueue::new(4, 2).expect("new should succeed");
let batch = vec![1.0_f32, 0.0, 0.0, 1.0];
q.enqueue(&batch).expect("enqueue should succeed");
assert_eq!(q.len(), 2);
q.enqueue(&batch).expect("enqueue should succeed");
assert_eq!(q.len(), 4);
q.enqueue(&[0.5_f32, 0.5]).expect("enqueue should succeed");
assert_eq!(q.len(), 4);
}
#[test]
fn queue_enqueue_empty_batch_ok() {
let mut q = MocoQueue::new(4, 2).expect("new should succeed");
q.enqueue(&[]).expect("enqueue should succeed");
assert!(q.is_empty());
}
#[test]
fn queue_enqueue_rejects_misaligned() {
let mut q = MocoQueue::new(4, 3).expect("new should succeed");
let r = q.enqueue(&[1.0_f32, 2.0]);
assert!(r.is_err());
}
#[test]
fn moco_loss_perfect_positives_low() {
let mut q = MocoQueue::new(8, 4).expect("new should succeed");
let mut rng = 42u64;
let mut neg = vec![0.0_f32; 8 * 4];
for v in neg.iter_mut() {
rng = rng
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
*v = ((rng >> 33) as f32 / (u32::MAX as f32 + 1.0)) - 0.5;
}
q.enqueue(&neg).expect("enqueue should succeed");
let pos = vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let loss = moco_loss(&pos, &pos, 2, 4, &q, 0.1).expect("moco_loss should succeed");
assert!(loss.is_finite());
assert!(loss < 1.0);
}
#[test]
fn moco_loss_empty_queue_errors() {
let q = MocoQueue::new(4, 2).expect("new should succeed");
let pos = vec![1.0_f32, 0.0];
let r = moco_loss(&pos, &pos, 1, 2, &q, 0.1);
assert!(r.is_err());
}
#[test]
fn moco_loss_dim_mismatch_errors() {
let mut q = MocoQueue::new(2, 4).expect("new should succeed");
q.enqueue(&[1.0_f32; 8]).expect("enqueue should succeed");
let r = moco_loss(&[1.0_f32; 4], &[1.0_f32; 4], 1, 2, &q, 0.1);
assert!(r.is_err());
}
#[test]
fn moco_loss_temperature_must_be_positive() {
let mut q = MocoQueue::new(2, 2).expect("new should succeed");
q.enqueue(&[1.0_f32; 4]).expect("enqueue should succeed");
let r = moco_loss(&[1.0_f32, 0.0], &[1.0_f32, 0.0], 1, 2, &q, 0.0);
assert!(r.is_err());
}
#[test]
fn queue_entries_view_correct_length() {
let mut q = MocoQueue::new(4, 2).expect("new should succeed");
q.enqueue(&[1.0_f32, 2.0, 3.0, 4.0])
.expect("enqueue should succeed");
let entries = q.entries();
assert_eq!(entries.len(), 4);
}
}