Skip to main content

oxicuda_ssl/contrastive/
moco.rs

1//! MoCo — He et al. 2020 — Momentum Contrast with a memory queue.
2//!
3//! MoCo decouples the size of the negative-key set from the batch size by
4//! maintaining a fixed-size FIFO queue of D-dimensional embeddings (computed
5//! by a momentum-updated key encoder). The contrastive loss becomes
6//!
7//! ```text
8//!   L = -log( exp(q·k_+/τ) / [ exp(q·k_+/τ) + Σ_{k_n ∈ Q} exp(q·k_n/τ) ] )
9//! ```
10//!
11//! The queue is updated with the current batch of `k_+` embeddings after each
12//! step, evicting the oldest entries.
13
14use crate::error::{SslError, SslResult};
15
16/// FIFO circular queue of D-dimensional embedding vectors.
17#[derive(Debug, Clone)]
18pub struct MocoQueue {
19    /// Capacity (max number of stored vectors).
20    pub capacity: usize,
21    /// Embedding dimensionality.
22    pub dim: usize,
23    /// Backing storage: `[capacity × dim]` flat row-major buffer.
24    pub data: Vec<f32>,
25    /// Index where the next enqueue writes (mod capacity).
26    pub head: usize,
27    /// Number of valid entries currently in the queue (`<= capacity`).
28    pub len: usize,
29}
30
31impl MocoQueue {
32    /// Create an empty queue with given capacity and embedding dim.
33    ///
34    /// # Errors
35    /// - [`SslError::QueueCapacityTooSmall`] if `capacity == 0`.
36    /// - [`SslError::InvalidFeatureDim`] if `dim == 0`.
37    pub fn new(capacity: usize, dim: usize) -> SslResult<Self> {
38        if capacity == 0 {
39            return Err(SslError::QueueCapacityTooSmall);
40        }
41        if dim == 0 {
42            return Err(SslError::InvalidFeatureDim);
43        }
44        Ok(Self {
45            capacity,
46            dim,
47            data: vec![0.0_f32; capacity * dim],
48            head: 0,
49            len: 0,
50        })
51    }
52
53    /// Enqueue a batch of `[batch × dim]` row-major key embeddings.
54    /// Old entries are evicted FIFO if the queue overflows.
55    ///
56    /// # Errors
57    /// - [`SslError::DimensionMismatch`] when `batch.len() % self.dim != 0`.
58    pub fn enqueue(&mut self, batch: &[f32]) -> SslResult<()> {
59        if batch.is_empty() {
60            return Ok(());
61        }
62        if batch.len() % self.dim != 0 {
63            return Err(SslError::DimensionMismatch {
64                expected: self.dim,
65                got: batch.len(),
66            });
67        }
68        let n = batch.len() / self.dim;
69        for i in 0..n {
70            let src = &batch[i * self.dim..(i + 1) * self.dim];
71            let dst = &mut self.data[self.head * self.dim..(self.head + 1) * self.dim];
72            dst.copy_from_slice(src);
73            self.head = (self.head + 1) % self.capacity;
74            if self.len < self.capacity {
75                self.len += 1;
76            }
77        }
78        Ok(())
79    }
80
81    /// Number of currently stored entries (≤ capacity).
82    #[must_use]
83    pub fn len(&self) -> usize {
84        self.len
85    }
86
87    /// True if the queue is empty.
88    #[must_use]
89    pub fn is_empty(&self) -> bool {
90        self.len == 0
91    }
92
93    /// Snapshot of the currently stored entries in the order they appear in the
94    /// backing buffer (i.e., **not** strictly chronological — both directions
95    /// are negatives, so order does not matter).
96    #[must_use]
97    pub fn entries(&self) -> &[f32] {
98        &self.data[..self.len * self.dim]
99    }
100}
101
102/// MoCo contrastive loss for a batch of `B` query embeddings against a single
103/// positive key per query plus a fixed memory bank.
104///
105/// `q` is `[B × D]`, `k_pos` is `[B × D]` (the *current* batch of positives),
106/// `queue` is the [`MocoQueue`] of historical negatives.
107///
108/// All inputs should be L2-normalised on the host (we re-normalise defensively).
109///
110/// Returns the average per-query loss `(1/B) Σ_i L_i`.
111///
112/// # Errors
113/// - [`SslError::DimensionMismatch`] if shapes disagree.
114/// - [`SslError::EmptyInput`] if `q` or `k_pos` is empty.
115/// - [`SslError::QueueEmpty`] if the queue has no negatives.
116/// - [`SslError::InvalidTemperature`] if `temperature <= 0` or non-finite.
117pub fn moco_loss(
118    q: &[f32],
119    k_pos: &[f32],
120    batch: usize,
121    dim: usize,
122    queue: &MocoQueue,
123    temperature: f32,
124) -> SslResult<f32> {
125    if q.is_empty() || batch == 0 || dim == 0 {
126        return Err(SslError::EmptyInput);
127    }
128    if !(temperature.is_finite() && temperature > 0.0) {
129        return Err(SslError::InvalidTemperature { temp: temperature });
130    }
131    if q.len() != batch * dim {
132        return Err(SslError::DimensionMismatch {
133            expected: batch * dim,
134            got: q.len(),
135        });
136    }
137    if k_pos.len() != batch * dim {
138        return Err(SslError::DimensionMismatch {
139            expected: batch * dim,
140            got: k_pos.len(),
141        });
142    }
143    if queue.dim != dim {
144        return Err(SslError::DimensionMismatch {
145            expected: dim,
146            got: queue.dim,
147        });
148    }
149    if queue.is_empty() {
150        return Err(SslError::QueueEmpty);
151    }
152
153    let q_n = l2_normalize_clone(q, batch, dim);
154    let k_n = l2_normalize_clone(k_pos, batch, dim);
155    // Negative bank is already accumulated by the caller; defensive
156    // re-normalisation costs O(K·D) but improves stability.
157    let neg = l2_normalize_clone(queue.entries(), queue.len(), dim);
158
159    let inv_t = 1.0_f32 / temperature;
160    let mut total_loss = 0.0_f64;
161    for i in 0..batch {
162        let q_row = &q_n[i * dim..(i + 1) * dim];
163        let k_row = &k_n[i * dim..(i + 1) * dim];
164        // Positive logit
165        let mut pos = 0.0_f32;
166        for (a, b) in q_row.iter().zip(k_row.iter()) {
167            pos += a * b;
168        }
169        pos *= inv_t;
170        // Negative logits
171        let mut max_v = pos;
172        let mut neg_logits: Vec<f32> = Vec::with_capacity(queue.len());
173        for k in 0..queue.len() {
174            let neg_row = &neg[k * dim..(k + 1) * dim];
175            let mut s = 0.0_f32;
176            for (a, b) in q_row.iter().zip(neg_row.iter()) {
177                s += a * b;
178            }
179            let l = s * inv_t;
180            if l > max_v {
181                max_v = l;
182            }
183            neg_logits.push(l);
184        }
185        // log-sum-exp denom
186        let mut sum_exp = ((pos - max_v) as f64).exp();
187        for &l in &neg_logits {
188            sum_exp += ((l - max_v) as f64).exp();
189        }
190        let log_z = (max_v as f64) + sum_exp.ln();
191        total_loss += -((pos as f64) - log_z);
192    }
193    Ok((total_loss / batch as f64) as f32)
194}
195
196fn l2_normalize_clone(z: &[f32], n: usize, d: usize) -> Vec<f32> {
197    let mut out = z.to_vec();
198    for row in out.chunks_mut(d) {
199        let s: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
200        let inv = if s > 1e-12 { 1.0 / s } else { 1.0 };
201        for v in row.iter_mut() {
202            *v *= inv;
203        }
204    }
205    let _ = n;
206    out
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn queue_new_zero_capacity_errors() {
215        assert!(MocoQueue::new(0, 4).is_err());
216    }
217
218    #[test]
219    fn queue_new_zero_dim_errors() {
220        assert!(MocoQueue::new(4, 0).is_err());
221    }
222
223    #[test]
224    fn queue_enqueue_grows_until_capacity() {
225        let mut q = MocoQueue::new(4, 2).expect("new should succeed");
226        let batch = vec![1.0_f32, 0.0, 0.0, 1.0];
227        q.enqueue(&batch).expect("enqueue should succeed");
228        assert_eq!(q.len(), 2);
229        q.enqueue(&batch).expect("enqueue should succeed");
230        assert_eq!(q.len(), 4);
231        // Overflow: existing entries are evicted FIFO.
232        q.enqueue(&[0.5_f32, 0.5]).expect("enqueue should succeed");
233        assert_eq!(q.len(), 4);
234    }
235
236    #[test]
237    fn queue_enqueue_empty_batch_ok() {
238        let mut q = MocoQueue::new(4, 2).expect("new should succeed");
239        q.enqueue(&[]).expect("enqueue should succeed");
240        assert!(q.is_empty());
241    }
242
243    #[test]
244    fn queue_enqueue_rejects_misaligned() {
245        let mut q = MocoQueue::new(4, 3).expect("new should succeed");
246        let r = q.enqueue(&[1.0_f32, 2.0]);
247        assert!(r.is_err());
248    }
249
250    #[test]
251    fn moco_loss_perfect_positives_low() {
252        let mut q = MocoQueue::new(8, 4).expect("new should succeed");
253        // Random negatives: small inner products with the positives.
254        let mut rng = 42u64;
255        let mut neg = vec![0.0_f32; 8 * 4];
256        for v in neg.iter_mut() {
257            rng = rng
258                .wrapping_mul(6_364_136_223_846_793_005)
259                .wrapping_add(1_442_695_040_888_963_407);
260            *v = ((rng >> 33) as f32 / (u32::MAX as f32 + 1.0)) - 0.5;
261        }
262        q.enqueue(&neg).expect("enqueue should succeed");
263        // Identical query and key positive → max similarity, so loss is small.
264        let pos = vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
265        let loss = moco_loss(&pos, &pos, 2, 4, &q, 0.1).expect("moco_loss should succeed");
266        assert!(loss.is_finite());
267        assert!(loss < 1.0);
268    }
269
270    #[test]
271    fn moco_loss_empty_queue_errors() {
272        let q = MocoQueue::new(4, 2).expect("new should succeed");
273        let pos = vec![1.0_f32, 0.0];
274        let r = moco_loss(&pos, &pos, 1, 2, &q, 0.1);
275        assert!(r.is_err());
276    }
277
278    #[test]
279    fn moco_loss_dim_mismatch_errors() {
280        let mut q = MocoQueue::new(2, 4).expect("new should succeed");
281        q.enqueue(&[1.0_f32; 8]).expect("enqueue should succeed");
282        let r = moco_loss(&[1.0_f32; 4], &[1.0_f32; 4], 1, 2, &q, 0.1);
283        assert!(r.is_err());
284    }
285
286    #[test]
287    fn moco_loss_temperature_must_be_positive() {
288        let mut q = MocoQueue::new(2, 2).expect("new should succeed");
289        q.enqueue(&[1.0_f32; 4]).expect("enqueue should succeed");
290        let r = moco_loss(&[1.0_f32, 0.0], &[1.0_f32, 0.0], 1, 2, &q, 0.0);
291        assert!(r.is_err());
292    }
293
294    #[test]
295    fn queue_entries_view_correct_length() {
296        let mut q = MocoQueue::new(4, 2).expect("new should succeed");
297        q.enqueue(&[1.0_f32, 2.0, 3.0, 4.0])
298            .expect("enqueue should succeed");
299        let entries = q.entries();
300        assert_eq!(entries.len(), 4);
301    }
302}