1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ContrastiveConfig {
23 pub temperature: f64,
25 pub margin: f64,
27 pub num_negatives: usize,
29 pub mining_strategy: NegativeMiningStrategy,
31 pub use_cosine: bool,
33 pub label_smoothing: f64,
35}
36
37impl Default for ContrastiveConfig {
38 fn default() -> Self {
39 Self {
40 temperature: 0.07,
41 margin: 1.0,
42 num_negatives: 128,
43 mining_strategy: NegativeMiningStrategy::SemiHard,
44 use_cosine: true,
45 label_smoothing: 0.0,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52pub enum NegativeMiningStrategy {
53 Random,
55 SemiHard,
57 Hard,
59 Mixed,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ContrastiveLossResult {
70 pub loss: f64,
72 pub per_sample_losses: Vec<f64>,
74 pub avg_positive_similarity: f64,
76 pub avg_negative_similarity: f64,
78 pub batch_size: usize,
80 pub hard_negatives_count: usize,
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct ContrastiveTrainingStats {
87 pub batches_processed: u64,
89 pub avg_loss: f64,
91 pub min_loss: f64,
93 pub max_loss: f64,
95 pub avg_similarity_gap: f64,
97 pub total_samples: u64,
99}
100
101pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
107 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
108 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
109 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
110 if norm_a < 1e-30 || norm_b < 1e-30 {
111 return 0.0;
112 }
113 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
114}
115
116pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
118 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
119}
120
121pub fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
123 a.iter()
124 .zip(b.iter())
125 .map(|(x, y)| (x - y) * (x - y))
126 .sum::<f64>()
127 .sqrt()
128}
129
130pub struct ContrastiveLossEngine {
136 config: ContrastiveConfig,
137 stats: ContrastiveTrainingStats,
138}
139
140impl ContrastiveLossEngine {
141 pub fn new(config: ContrastiveConfig) -> Self {
143 Self {
144 config,
145 stats: ContrastiveTrainingStats {
146 min_loss: f64::MAX,
147 ..Default::default()
148 },
149 }
150 }
151
152 pub fn with_defaults() -> Self {
154 Self::new(ContrastiveConfig::default())
155 }
156
157 pub fn info_nce_loss(
161 &mut self,
162 anchors: &[Vec<f64>],
163 positives: &[Vec<f64>],
164 negatives: &[Vec<f64>],
165 ) -> ContrastiveLossResult {
166 let batch_size = anchors.len().min(positives.len());
167 let tau = self.config.temperature;
168 let mut per_sample_losses = Vec::with_capacity(batch_size);
169 let mut total_pos_sim = 0.0;
170 let mut total_neg_sim = 0.0;
171 let mut hard_count = 0;
172
173 for i in 0..batch_size {
174 let pos_sim = self.similarity(&anchors[i], &positives[i]) / tau;
175 total_pos_sim += pos_sim * tau;
176
177 let mut log_sum_exp = pos_sim.exp();
178 let mut max_neg_sim = f64::NEG_INFINITY;
179
180 for neg in negatives.iter() {
181 let neg_sim = self.similarity(&anchors[i], neg) / tau;
182 total_neg_sim += neg_sim * tau;
183 log_sum_exp += neg_sim.exp();
184 if neg_sim > max_neg_sim {
185 max_neg_sim = neg_sim;
186 }
187 }
188
189 if max_neg_sim * tau > pos_sim * tau - self.config.margin {
190 hard_count += 1;
191 }
192
193 let loss = -pos_sim + log_sum_exp.ln();
194 per_sample_losses.push(loss);
195 }
196
197 let total_loss: f64 = per_sample_losses.iter().sum();
198 let avg_loss = if batch_size > 0 {
199 total_loss / batch_size as f64
200 } else {
201 0.0
202 };
203
204 let neg_count = negatives.len().max(1) * batch_size;
205 let result = ContrastiveLossResult {
206 loss: avg_loss,
207 per_sample_losses,
208 avg_positive_similarity: if batch_size > 0 {
209 total_pos_sim / batch_size as f64
210 } else {
211 0.0
212 },
213 avg_negative_similarity: if neg_count > 0 {
214 total_neg_sim / neg_count as f64
215 } else {
216 0.0
217 },
218 batch_size,
219 hard_negatives_count: hard_count,
220 };
221
222 self.update_stats(&result);
223 result
224 }
225
226 pub fn triplet_loss(
230 &mut self,
231 anchors: &[Vec<f64>],
232 positives: &[Vec<f64>],
233 negatives: &[Vec<f64>],
234 ) -> ContrastiveLossResult {
235 let batch_size = anchors.len().min(positives.len()).min(negatives.len());
236 let margin = self.config.margin;
237 let mut per_sample_losses = Vec::with_capacity(batch_size);
238 let mut total_pos_dist = 0.0;
239 let mut total_neg_dist = 0.0;
240 let mut hard_count = 0;
241
242 for i in 0..batch_size {
243 let pos_dist = l2_distance(&anchors[i], &positives[i]);
244 let neg_dist = l2_distance(&anchors[i], &negatives[i]);
245
246 total_pos_dist += pos_dist;
247 total_neg_dist += neg_dist;
248
249 let loss = (pos_dist - neg_dist + margin).max(0.0);
250 if loss > 0.0 {
251 hard_count += 1;
252 }
253 per_sample_losses.push(loss);
254 }
255
256 let total_loss: f64 = per_sample_losses.iter().sum();
257 let avg_loss = if batch_size > 0 {
258 total_loss / batch_size as f64
259 } else {
260 0.0
261 };
262
263 let result = ContrastiveLossResult {
264 loss: avg_loss,
265 per_sample_losses,
266 avg_positive_similarity: if batch_size > 0 {
267 -(total_pos_dist / batch_size as f64)
268 } else {
269 0.0
270 },
271 avg_negative_similarity: if batch_size > 0 {
272 -(total_neg_dist / batch_size as f64)
273 } else {
274 0.0
275 },
276 batch_size,
277 hard_negatives_count: hard_count,
278 };
279
280 self.update_stats(&result);
281 result
282 }
283
284 pub fn nt_xent_loss(
288 &mut self,
289 embeddings_a: &[Vec<f64>],
290 embeddings_b: &[Vec<f64>],
291 ) -> ContrastiveLossResult {
292 let batch_size = embeddings_a.len().min(embeddings_b.len());
293 let tau = self.config.temperature;
294 let mut per_sample_losses = Vec::with_capacity(batch_size);
295 let mut total_pos_sim = 0.0;
296 let mut total_neg_sim = 0.0;
297 let mut neg_count = 0usize;
298
299 for i in 0..batch_size {
300 let pos_sim = self.similarity(&embeddings_a[i], &embeddings_b[i]) / tau;
301 total_pos_sim += pos_sim * tau;
302
303 let mut log_sum = 0.0f64;
304 for j in 0..batch_size {
305 if j != i {
306 let sim_aj = self.similarity(&embeddings_a[i], &embeddings_b[j]) / tau;
307 let sim_ai = self.similarity(&embeddings_a[i], &embeddings_a[j]) / tau;
308 total_neg_sim += sim_aj * tau + sim_ai * tau;
309 neg_count += 2;
310 log_sum += sim_aj.exp() + sim_ai.exp();
311 }
312 }
313 log_sum += pos_sim.exp();
314
315 let loss = -pos_sim + log_sum.ln();
316 per_sample_losses.push(loss);
317 }
318
319 let total_loss: f64 = per_sample_losses.iter().sum();
320 let avg_loss = if batch_size > 0 {
321 total_loss / batch_size as f64
322 } else {
323 0.0
324 };
325
326 let result = ContrastiveLossResult {
327 loss: avg_loss,
328 per_sample_losses,
329 avg_positive_similarity: if batch_size > 0 {
330 total_pos_sim / batch_size as f64
331 } else {
332 0.0
333 },
334 avg_negative_similarity: if neg_count > 0 {
335 total_neg_sim / neg_count as f64
336 } else {
337 0.0
338 },
339 batch_size,
340 hard_negatives_count: 0,
341 };
342
343 self.update_stats(&result);
344 result
345 }
346
347 pub fn mine_semi_hard(
352 &self,
353 anchor: &[f64],
354 positive: &[f64],
355 negative_pool: &[Vec<f64>],
356 ) -> Vec<usize> {
357 let pos_dist = l2_distance(anchor, positive);
358 let margin = self.config.margin;
359
360 negative_pool
361 .iter()
362 .enumerate()
363 .filter_map(|(i, neg)| {
364 let neg_dist = l2_distance(anchor, neg);
365 if neg_dist > pos_dist && neg_dist < pos_dist + margin {
366 Some(i)
367 } else {
368 None
369 }
370 })
371 .collect()
372 }
373
374 pub fn mine_hardest(&self, anchor: &[f64], negative_pool: &[Vec<f64>]) -> Option<usize> {
376 negative_pool
377 .iter()
378 .enumerate()
379 .min_by(|(_, a), (_, b)| {
380 let da = l2_distance(anchor, a);
381 let db = l2_distance(anchor, b);
382 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
383 })
384 .map(|(i, _)| i)
385 }
386
387 pub fn stats(&self) -> &ContrastiveTrainingStats {
389 &self.stats
390 }
391
392 pub fn reset_stats(&mut self) {
394 self.stats = ContrastiveTrainingStats {
395 min_loss: f64::MAX,
396 ..Default::default()
397 };
398 }
399
400 pub fn config(&self) -> &ContrastiveConfig {
402 &self.config
403 }
404
405 fn similarity(&self, a: &[f64], b: &[f64]) -> f64 {
408 if self.config.use_cosine {
409 cosine_similarity(a, b)
410 } else {
411 dot_product(a, b)
412 }
413 }
414
415 fn update_stats(&mut self, result: &ContrastiveLossResult) {
416 self.stats.batches_processed += 1;
417 self.stats.total_samples += result.batch_size as u64;
418
419 let n = self.stats.batches_processed as f64;
420 self.stats.avg_loss = self.stats.avg_loss * (n - 1.0) / n + result.loss / n;
421
422 if result.loss < self.stats.min_loss {
423 self.stats.min_loss = result.loss;
424 }
425 if result.loss > self.stats.max_loss {
426 self.stats.max_loss = result.loss;
427 }
428
429 let gap = result.avg_positive_similarity - result.avg_negative_similarity;
430 self.stats.avg_similarity_gap = self.stats.avg_similarity_gap * (n - 1.0) / n + gap / n;
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn sample_vector(seed: f64, dim: usize) -> Vec<f64> {
443 (0..dim).map(|i| (seed + i as f64 * 0.1).sin()).collect()
444 }
445
446 fn unit_vector(dim: usize, idx: usize) -> Vec<f64> {
447 let mut v = vec![0.0; dim];
448 if idx < dim {
449 v[idx] = 1.0;
450 }
451 v
452 }
453
454 #[test]
455 fn test_cosine_similarity_identical() {
456 let v = vec![1.0, 2.0, 3.0];
457 let sim = cosine_similarity(&v, &v);
458 assert!((sim - 1.0).abs() < 1e-10);
459 }
460
461 #[test]
462 fn test_cosine_similarity_orthogonal() {
463 let a = vec![1.0, 0.0];
464 let b = vec![0.0, 1.0];
465 let sim = cosine_similarity(&a, &b);
466 assert!(sim.abs() < 1e-10);
467 }
468
469 #[test]
470 fn test_cosine_similarity_opposite() {
471 let a = vec![1.0, 0.0];
472 let b = vec![-1.0, 0.0];
473 let sim = cosine_similarity(&a, &b);
474 assert!((sim - (-1.0)).abs() < 1e-10);
475 }
476
477 #[test]
478 fn test_cosine_similarity_zero_vector() {
479 let a = vec![1.0, 2.0];
480 let b = vec![0.0, 0.0];
481 assert_eq!(cosine_similarity(&a, &b), 0.0);
482 }
483
484 #[test]
485 fn test_dot_product_simple() {
486 let a = vec![1.0, 2.0, 3.0];
487 let b = vec![4.0, 5.0, 6.0];
488 assert!((dot_product(&a, &b) - 32.0).abs() < 1e-10);
489 }
490
491 #[test]
492 fn test_l2_distance_same() {
493 let v = vec![1.0, 2.0, 3.0];
494 assert!(l2_distance(&v, &v) < 1e-10);
495 }
496
497 #[test]
498 fn test_l2_distance_known() {
499 let a = vec![0.0, 0.0];
500 let b = vec![3.0, 4.0];
501 assert!((l2_distance(&a, &b) - 5.0).abs() < 1e-10);
502 }
503
504 #[test]
505 fn test_default_config() {
506 let config = ContrastiveConfig::default();
507 assert!((config.temperature - 0.07).abs() < 1e-10);
508 assert!((config.margin - 1.0).abs() < 1e-10);
509 assert_eq!(config.num_negatives, 128);
510 assert!(config.use_cosine);
511 }
512
513 #[test]
514 fn test_info_nce_basic() {
515 let mut engine = ContrastiveLossEngine::with_defaults();
516 let anchors = vec![sample_vector(1.0, 8)];
517 let positives = vec![sample_vector(1.1, 8)]; let negatives = vec![sample_vector(5.0, 8), sample_vector(10.0, 8)];
519
520 let result = engine.info_nce_loss(&anchors, &positives, &negatives);
521 assert!(result.loss.is_finite());
522 assert_eq!(result.batch_size, 1);
523 assert_eq!(result.per_sample_losses.len(), 1);
524 }
525
526 #[test]
527 fn test_info_nce_positive_higher_similarity() {
528 let mut engine = ContrastiveLossEngine::with_defaults();
529 let anchor = vec![1.0, 0.0, 0.0, 0.0];
530 let positive = vec![0.9, 0.1, 0.0, 0.0]; let negatives = vec![vec![0.0, 1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0]];
532
533 let result = engine.info_nce_loss(&[anchor], &[positive], &negatives);
534 assert!(result.avg_positive_similarity > result.avg_negative_similarity);
535 }
536
537 #[test]
538 fn test_triplet_loss_zero_when_separated() {
539 let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
540 margin: 1.0,
541 ..Default::default()
542 });
543 let anchor = vec![0.0, 0.0];
544 let positive = vec![0.1, 0.0]; let negative = vec![10.0, 10.0]; let result = engine.triplet_loss(&[anchor], &[positive], &[negative]);
548 assert!(
549 result.loss < 1e-10,
550 "Loss should be 0 when negative is far away"
551 );
552 }
553
554 #[test]
555 fn test_triplet_loss_positive_when_close() {
556 let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
557 margin: 1.0,
558 ..Default::default()
559 });
560 let anchor = vec![0.0, 0.0];
561 let positive = vec![2.0, 0.0]; let negative = vec![1.5, 0.0]; let result = engine.triplet_loss(&[anchor], &[positive], &[negative]);
565 assert!(
566 result.loss > 0.0,
567 "Loss should be positive when negative is closer"
568 );
569 }
570
571 #[test]
572 fn test_nt_xent_basic() {
573 let mut engine = ContrastiveLossEngine::with_defaults();
574 let a = vec![sample_vector(1.0, 8), sample_vector(2.0, 8)];
575 let b = vec![sample_vector(1.1, 8), sample_vector(2.1, 8)];
576
577 let result = engine.nt_xent_loss(&a, &b);
578 assert!(result.loss.is_finite());
579 assert_eq!(result.batch_size, 2);
580 }
581
582 #[test]
583 fn test_mine_semi_hard() {
584 let engine = ContrastiveLossEngine::new(ContrastiveConfig {
585 margin: 2.0,
586 ..Default::default()
587 });
588 let anchor = vec![0.0, 0.0];
589 let positive = vec![1.0, 0.0]; let pool = vec![
591 vec![0.5, 0.0], vec![1.5, 0.0], vec![2.5, 0.0], vec![10.0, 0.0], ];
596
597 let indices = engine.mine_semi_hard(&anchor, &positive, &pool);
598 assert!(indices.contains(&1));
599 assert!(indices.contains(&2));
600 }
601
602 #[test]
603 fn test_mine_hardest() {
604 let engine = ContrastiveLossEngine::with_defaults();
605 let anchor = vec![0.0, 0.0];
606 let pool = vec![
607 vec![10.0, 0.0], vec![2.0, 0.0], vec![5.0, 0.0], ];
611
612 let idx = engine.mine_hardest(&anchor, &pool);
613 assert_eq!(idx, Some(1)); }
615
616 #[test]
617 fn test_mine_hardest_empty() {
618 let engine = ContrastiveLossEngine::with_defaults();
619 let anchor = vec![0.0, 0.0];
620 assert!(engine.mine_hardest(&anchor, &[]).is_none());
621 }
622
623 #[test]
624 fn test_stats_tracking() {
625 let mut engine = ContrastiveLossEngine::with_defaults();
626 let a = vec![sample_vector(1.0, 4)];
627 let p = vec![sample_vector(1.1, 4)];
628 let n = vec![sample_vector(5.0, 4)];
629
630 engine.info_nce_loss(&a, &p, &n);
631 engine.info_nce_loss(&a, &p, &n);
632
633 assert_eq!(engine.stats().batches_processed, 2);
634 assert_eq!(engine.stats().total_samples, 2);
635 }
636
637 #[test]
638 fn test_stats_reset() {
639 let mut engine = ContrastiveLossEngine::with_defaults();
640 let a = vec![sample_vector(1.0, 4)];
641 let p = vec![sample_vector(1.1, 4)];
642 let n = vec![sample_vector(5.0, 4)];
643 engine.info_nce_loss(&a, &p, &n);
644
645 engine.reset_stats();
646 assert_eq!(engine.stats().batches_processed, 0);
647 }
648
649 #[test]
650 fn test_dot_product_mode() {
651 let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
652 use_cosine: false,
653 ..Default::default()
654 });
655 let a = vec![vec![1.0, 0.0]];
656 let p = vec![vec![0.9, 0.1]];
657 let n = vec![vec![0.0, 1.0]];
658
659 let result = engine.info_nce_loss(&a, &p, &n);
660 assert!(result.loss.is_finite());
661 }
662
663 #[test]
664 fn test_empty_batch() {
665 let mut engine = ContrastiveLossEngine::with_defaults();
666 let result = engine.info_nce_loss(&[], &[], &[]);
667 assert_eq!(result.batch_size, 0);
668 assert!((result.loss).abs() < 1e-10);
669 }
670
671 #[test]
672 fn test_triplet_empty_batch() {
673 let mut engine = ContrastiveLossEngine::with_defaults();
674 let result = engine.triplet_loss(&[], &[], &[]);
675 assert_eq!(result.batch_size, 0);
676 }
677
678 #[test]
679 fn test_nt_xent_single_sample() {
680 let mut engine = ContrastiveLossEngine::with_defaults();
681 let a = vec![sample_vector(1.0, 4)];
682 let b = vec![sample_vector(1.1, 4)];
683 let result = engine.nt_xent_loss(&a, &b);
684 assert!(result.loss.is_finite());
685 }
686
687 #[test]
688 fn test_config_serialization() {
689 let config = ContrastiveConfig::default();
690 let json = serde_json::to_string(&config).expect("serialize failed");
691 let deser: ContrastiveConfig = serde_json::from_str(&json).expect("deser failed");
692 assert!((deser.temperature - config.temperature).abs() < 1e-10);
693 }
694
695 #[test]
696 fn test_result_serialization() {
697 let result = ContrastiveLossResult {
698 loss: 0.5,
699 per_sample_losses: vec![0.5],
700 avg_positive_similarity: 0.8,
701 avg_negative_similarity: 0.2,
702 batch_size: 1,
703 hard_negatives_count: 0,
704 };
705 let json = serde_json::to_string(&result).expect("serialize failed");
706 assert!(json.contains("loss"));
707 }
708
709 #[test]
710 fn test_stats_serialization() {
711 let stats = ContrastiveTrainingStats::default();
712 let json = serde_json::to_string(&stats).expect("serialize failed");
713 assert!(json.contains("batches_processed"));
714 }
715
716 #[test]
717 fn test_mining_strategy_serde() {
718 let s = NegativeMiningStrategy::SemiHard;
719 let json = serde_json::to_string(&s).expect("serialize failed");
720 let deser: NegativeMiningStrategy = serde_json::from_str(&json).expect("deser failed");
721 assert_eq!(deser, s);
722 }
723
724 #[test]
725 fn test_large_batch() {
726 let mut engine = ContrastiveLossEngine::with_defaults();
727 let dim = 32;
728 let batch: Vec<Vec<f64>> = (0..16).map(|i| sample_vector(i as f64, dim)).collect();
729 let pos: Vec<Vec<f64>> = (0..16)
730 .map(|i| sample_vector(i as f64 + 0.01, dim))
731 .collect();
732 let neg: Vec<Vec<f64>> = (0..8)
733 .map(|i| sample_vector(i as f64 + 100.0, dim))
734 .collect();
735
736 let result = engine.info_nce_loss(&batch, &pos, &neg);
737 assert_eq!(result.batch_size, 16);
738 assert!(result.loss.is_finite());
739 }
740
741 #[test]
742 fn test_hard_negatives_count() {
743 let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
744 margin: 0.5,
745 ..Default::default()
746 });
747 let anchor = vec![1.0, 0.0, 0.0, 0.0];
748 let positive = vec![0.9, 0.1, 0.0, 0.0];
749 let negatives = vec![vec![0.95, 0.05, 0.0, 0.0]];
751
752 let result = engine.info_nce_loss(&[anchor], &[positive], &negatives);
753 assert!(result.hard_negatives_count <= 1);
755 }
756
757 #[test]
758 fn test_min_max_loss_tracking() {
759 let mut engine = ContrastiveLossEngine::with_defaults();
760 let a1 = vec![sample_vector(1.0, 4)];
761 let p1 = vec![sample_vector(1.1, 4)];
762 let n1 = vec![sample_vector(5.0, 4)];
763 engine.info_nce_loss(&a1, &p1, &n1);
764
765 let a2 = vec![sample_vector(1.0, 4)];
766 let p2 = vec![sample_vector(100.0, 4)]; let n2 = vec![sample_vector(1.01, 4)]; engine.info_nce_loss(&a2, &p2, &n2);
769
770 assert!(engine.stats().min_loss <= engine.stats().max_loss);
771 }
772}