1use crate::error::{SslError, SslResult};
15
16#[derive(Debug, Clone)]
18pub struct MocoQueue {
19 pub capacity: usize,
21 pub dim: usize,
23 pub data: Vec<f32>,
25 pub head: usize,
27 pub len: usize,
29}
30
31impl MocoQueue {
32 pub fn new(capacity: usize, dim: usize) -> SslResult<Self> {
38 if capacity == 0 {
39 return Err(SslError::QueueCapacityTooSmall);
40 }
41 if dim == 0 {
42 return Err(SslError::InvalidFeatureDim);
43 }
44 Ok(Self {
45 capacity,
46 dim,
47 data: vec![0.0_f32; capacity * dim],
48 head: 0,
49 len: 0,
50 })
51 }
52
53 pub fn enqueue(&mut self, batch: &[f32]) -> SslResult<()> {
59 if batch.is_empty() {
60 return Ok(());
61 }
62 if batch.len() % self.dim != 0 {
63 return Err(SslError::DimensionMismatch {
64 expected: self.dim,
65 got: batch.len(),
66 });
67 }
68 let n = batch.len() / self.dim;
69 for i in 0..n {
70 let src = &batch[i * self.dim..(i + 1) * self.dim];
71 let dst = &mut self.data[self.head * self.dim..(self.head + 1) * self.dim];
72 dst.copy_from_slice(src);
73 self.head = (self.head + 1) % self.capacity;
74 if self.len < self.capacity {
75 self.len += 1;
76 }
77 }
78 Ok(())
79 }
80
81 #[must_use]
83 pub fn len(&self) -> usize {
84 self.len
85 }
86
87 #[must_use]
89 pub fn is_empty(&self) -> bool {
90 self.len == 0
91 }
92
93 #[must_use]
97 pub fn entries(&self) -> &[f32] {
98 &self.data[..self.len * self.dim]
99 }
100}
101
102pub fn moco_loss(
118 q: &[f32],
119 k_pos: &[f32],
120 batch: usize,
121 dim: usize,
122 queue: &MocoQueue,
123 temperature: f32,
124) -> SslResult<f32> {
125 if q.is_empty() || batch == 0 || dim == 0 {
126 return Err(SslError::EmptyInput);
127 }
128 if !(temperature.is_finite() && temperature > 0.0) {
129 return Err(SslError::InvalidTemperature { temp: temperature });
130 }
131 if q.len() != batch * dim {
132 return Err(SslError::DimensionMismatch {
133 expected: batch * dim,
134 got: q.len(),
135 });
136 }
137 if k_pos.len() != batch * dim {
138 return Err(SslError::DimensionMismatch {
139 expected: batch * dim,
140 got: k_pos.len(),
141 });
142 }
143 if queue.dim != dim {
144 return Err(SslError::DimensionMismatch {
145 expected: dim,
146 got: queue.dim,
147 });
148 }
149 if queue.is_empty() {
150 return Err(SslError::QueueEmpty);
151 }
152
153 let q_n = l2_normalize_clone(q, batch, dim);
154 let k_n = l2_normalize_clone(k_pos, batch, dim);
155 let neg = l2_normalize_clone(queue.entries(), queue.len(), dim);
158
159 let inv_t = 1.0_f32 / temperature;
160 let mut total_loss = 0.0_f64;
161 for i in 0..batch {
162 let q_row = &q_n[i * dim..(i + 1) * dim];
163 let k_row = &k_n[i * dim..(i + 1) * dim];
164 let mut pos = 0.0_f32;
166 for (a, b) in q_row.iter().zip(k_row.iter()) {
167 pos += a * b;
168 }
169 pos *= inv_t;
170 let mut max_v = pos;
172 let mut neg_logits: Vec<f32> = Vec::with_capacity(queue.len());
173 for k in 0..queue.len() {
174 let neg_row = &neg[k * dim..(k + 1) * dim];
175 let mut s = 0.0_f32;
176 for (a, b) in q_row.iter().zip(neg_row.iter()) {
177 s += a * b;
178 }
179 let l = s * inv_t;
180 if l > max_v {
181 max_v = l;
182 }
183 neg_logits.push(l);
184 }
185 let mut sum_exp = ((pos - max_v) as f64).exp();
187 for &l in &neg_logits {
188 sum_exp += ((l - max_v) as f64).exp();
189 }
190 let log_z = (max_v as f64) + sum_exp.ln();
191 total_loss += -((pos as f64) - log_z);
192 }
193 Ok((total_loss / batch as f64) as f32)
194}
195
196fn l2_normalize_clone(z: &[f32], n: usize, d: usize) -> Vec<f32> {
197 let mut out = z.to_vec();
198 for row in out.chunks_mut(d) {
199 let s: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
200 let inv = if s > 1e-12 { 1.0 / s } else { 1.0 };
201 for v in row.iter_mut() {
202 *v *= inv;
203 }
204 }
205 let _ = n;
206 out
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn queue_new_zero_capacity_errors() {
215 assert!(MocoQueue::new(0, 4).is_err());
216 }
217
218 #[test]
219 fn queue_new_zero_dim_errors() {
220 assert!(MocoQueue::new(4, 0).is_err());
221 }
222
223 #[test]
224 fn queue_enqueue_grows_until_capacity() {
225 let mut q = MocoQueue::new(4, 2).expect("new should succeed");
226 let batch = vec![1.0_f32, 0.0, 0.0, 1.0];
227 q.enqueue(&batch).expect("enqueue should succeed");
228 assert_eq!(q.len(), 2);
229 q.enqueue(&batch).expect("enqueue should succeed");
230 assert_eq!(q.len(), 4);
231 q.enqueue(&[0.5_f32, 0.5]).expect("enqueue should succeed");
233 assert_eq!(q.len(), 4);
234 }
235
236 #[test]
237 fn queue_enqueue_empty_batch_ok() {
238 let mut q = MocoQueue::new(4, 2).expect("new should succeed");
239 q.enqueue(&[]).expect("enqueue should succeed");
240 assert!(q.is_empty());
241 }
242
243 #[test]
244 fn queue_enqueue_rejects_misaligned() {
245 let mut q = MocoQueue::new(4, 3).expect("new should succeed");
246 let r = q.enqueue(&[1.0_f32, 2.0]);
247 assert!(r.is_err());
248 }
249
250 #[test]
251 fn moco_loss_perfect_positives_low() {
252 let mut q = MocoQueue::new(8, 4).expect("new should succeed");
253 let mut rng = 42u64;
255 let mut neg = vec![0.0_f32; 8 * 4];
256 for v in neg.iter_mut() {
257 rng = rng
258 .wrapping_mul(6_364_136_223_846_793_005)
259 .wrapping_add(1_442_695_040_888_963_407);
260 *v = ((rng >> 33) as f32 / (u32::MAX as f32 + 1.0)) - 0.5;
261 }
262 q.enqueue(&neg).expect("enqueue should succeed");
263 let pos = vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
265 let loss = moco_loss(&pos, &pos, 2, 4, &q, 0.1).expect("moco_loss should succeed");
266 assert!(loss.is_finite());
267 assert!(loss < 1.0);
268 }
269
270 #[test]
271 fn moco_loss_empty_queue_errors() {
272 let q = MocoQueue::new(4, 2).expect("new should succeed");
273 let pos = vec![1.0_f32, 0.0];
274 let r = moco_loss(&pos, &pos, 1, 2, &q, 0.1);
275 assert!(r.is_err());
276 }
277
278 #[test]
279 fn moco_loss_dim_mismatch_errors() {
280 let mut q = MocoQueue::new(2, 4).expect("new should succeed");
281 q.enqueue(&[1.0_f32; 8]).expect("enqueue should succeed");
282 let r = moco_loss(&[1.0_f32; 4], &[1.0_f32; 4], 1, 2, &q, 0.1);
283 assert!(r.is_err());
284 }
285
286 #[test]
287 fn moco_loss_temperature_must_be_positive() {
288 let mut q = MocoQueue::new(2, 2).expect("new should succeed");
289 q.enqueue(&[1.0_f32; 4]).expect("enqueue should succeed");
290 let r = moco_loss(&[1.0_f32, 0.0], &[1.0_f32, 0.0], 1, 2, &q, 0.0);
291 assert!(r.is_err());
292 }
293
294 #[test]
295 fn queue_entries_view_correct_length() {
296 let mut q = MocoQueue::new(4, 2).expect("new should succeed");
297 q.enqueue(&[1.0_f32, 2.0, 3.0, 4.0])
298 .expect("enqueue should succeed");
299 let entries = q.entries();
300 assert_eq!(entries.len(), 4);
301 }
302}