Skip to main content

oxirs_embed/
contrastive_learning.rs

1//! # Contrastive Learning Loss Functions for Knowledge Graph Embeddings
2//!
3//! Implements contrastive loss functions for training knowledge graph embeddings
4//! with self-supervised and supervised contrastive objectives.
5//!
6//! ## Loss Functions
7//!
8//! - **InfoNCE**: Noise-contrastive estimation (SimCLR-style)
9//! - **Triplet Loss**: Margin-based triplet loss (anchor, positive, negative)
10//! - **NT-Xent**: Normalised temperature-scaled cross-entropy
11//! - **SupCon**: Supervised contrastive loss
12//! - **Hard Negative Mining**: Strategies for selecting challenging negatives
13
14use serde::{Deserialize, Serialize};
15
16// ─────────────────────────────────────────────
17// Configuration
18// ─────────────────────────────────────────────
19
20/// Configuration for contrastive learning.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ContrastiveConfig {
23    /// Temperature parameter for softmax scaling (default: 0.07).
24    pub temperature: f64,
25    /// Margin for triplet loss (default: 1.0).
26    pub margin: f64,
27    /// Number of negative samples per positive (default: 128).
28    pub num_negatives: usize,
29    /// Hard negative mining strategy (default: SemiHard).
30    pub mining_strategy: NegativeMiningStrategy,
31    /// Whether to use cosine similarity (true) or dot product (false).
32    pub use_cosine: bool,
33    /// Label smoothing factor (0.0 = no smoothing, default: 0.0).
34    pub label_smoothing: f64,
35}
36
37impl Default for ContrastiveConfig {
38    fn default() -> Self {
39        Self {
40            temperature: 0.07,
41            margin: 1.0,
42            num_negatives: 128,
43            mining_strategy: NegativeMiningStrategy::SemiHard,
44            use_cosine: true,
45            label_smoothing: 0.0,
46        }
47    }
48}
49
50/// Strategy for mining negative samples.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52pub enum NegativeMiningStrategy {
53    /// Random negatives.
54    Random,
55    /// Semi-hard negatives (closer than positive but outside margin).
56    SemiHard,
57    /// Hardest negatives (closest to anchor).
58    Hard,
59    /// Mix of hard and random negatives.
60    Mixed,
61}
62
63// ─────────────────────────────────────────────
64// Loss Results
65// ─────────────────────────────────────────────
66
67/// Result of computing a contrastive loss.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ContrastiveLossResult {
70    /// Scalar loss value.
71    pub loss: f64,
72    /// Per-sample losses (if applicable).
73    pub per_sample_losses: Vec<f64>,
74    /// Average positive similarity.
75    pub avg_positive_similarity: f64,
76    /// Average negative similarity.
77    pub avg_negative_similarity: f64,
78    /// Number of samples in the batch.
79    pub batch_size: usize,
80    /// Number of hard negatives found.
81    pub hard_negatives_count: usize,
82}
83
84/// Statistics for contrastive training.
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct ContrastiveTrainingStats {
87    /// Total batches processed.
88    pub batches_processed: u64,
89    /// Running average loss.
90    pub avg_loss: f64,
91    /// Minimum loss seen.
92    pub min_loss: f64,
93    /// Maximum loss seen.
94    pub max_loss: f64,
95    /// Average positive-negative similarity gap.
96    pub avg_similarity_gap: f64,
97    /// Total samples processed.
98    pub total_samples: u64,
99}
100
101// ─────────────────────────────────────────────
102// Similarity functions
103// ─────────────────────────────────────────────
104
105/// Compute cosine similarity between two vectors.
106pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
107    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
108    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
109    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
110    if norm_a < 1e-30 || norm_b < 1e-30 {
111        return 0.0;
112    }
113    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
114}
115
116/// Compute dot product similarity.
117pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
118    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
119}
120
121/// Compute L2 (Euclidean) distance.
122pub fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
123    a.iter()
124        .zip(b.iter())
125        .map(|(x, y)| (x - y) * (x - y))
126        .sum::<f64>()
127        .sqrt()
128}
129
130// ─────────────────────────────────────────────
131// Loss Functions
132// ─────────────────────────────────────────────
133
134/// Contrastive loss function engine.
135pub struct ContrastiveLossEngine {
136    config: ContrastiveConfig,
137    stats: ContrastiveTrainingStats,
138}
139
140impl ContrastiveLossEngine {
141    /// Create a new contrastive loss engine.
142    pub fn new(config: ContrastiveConfig) -> Self {
143        Self {
144            config,
145            stats: ContrastiveTrainingStats {
146                min_loss: f64::MAX,
147                ..Default::default()
148            },
149        }
150    }
151
152    /// Create with default configuration.
153    pub fn with_defaults() -> Self {
154        Self::new(ContrastiveConfig::default())
155    }
156
157    /// Compute InfoNCE (Noise Contrastive Estimation) loss.
158    ///
159    /// L = -log( exp(sim(anchor, positive) / τ) / Σ exp(sim(anchor, neg_i) / τ) )
160    pub fn info_nce_loss(
161        &mut self,
162        anchors: &[Vec<f64>],
163        positives: &[Vec<f64>],
164        negatives: &[Vec<f64>],
165    ) -> ContrastiveLossResult {
166        let batch_size = anchors.len().min(positives.len());
167        let tau = self.config.temperature;
168        let mut per_sample_losses = Vec::with_capacity(batch_size);
169        let mut total_pos_sim = 0.0;
170        let mut total_neg_sim = 0.0;
171        let mut hard_count = 0;
172
173        for i in 0..batch_size {
174            let pos_sim = self.similarity(&anchors[i], &positives[i]) / tau;
175            total_pos_sim += pos_sim * tau;
176
177            let mut log_sum_exp = pos_sim.exp();
178            let mut max_neg_sim = f64::NEG_INFINITY;
179
180            for neg in negatives.iter() {
181                let neg_sim = self.similarity(&anchors[i], neg) / tau;
182                total_neg_sim += neg_sim * tau;
183                log_sum_exp += neg_sim.exp();
184                if neg_sim > max_neg_sim {
185                    max_neg_sim = neg_sim;
186                }
187            }
188
189            if max_neg_sim * tau > pos_sim * tau - self.config.margin {
190                hard_count += 1;
191            }
192
193            let loss = -pos_sim + log_sum_exp.ln();
194            per_sample_losses.push(loss);
195        }
196
197        let total_loss: f64 = per_sample_losses.iter().sum();
198        let avg_loss = if batch_size > 0 {
199            total_loss / batch_size as f64
200        } else {
201            0.0
202        };
203
204        let neg_count = negatives.len().max(1) * batch_size;
205        let result = ContrastiveLossResult {
206            loss: avg_loss,
207            per_sample_losses,
208            avg_positive_similarity: if batch_size > 0 {
209                total_pos_sim / batch_size as f64
210            } else {
211                0.0
212            },
213            avg_negative_similarity: if neg_count > 0 {
214                total_neg_sim / neg_count as f64
215            } else {
216                0.0
217            },
218            batch_size,
219            hard_negatives_count: hard_count,
220        };
221
222        self.update_stats(&result);
223        result
224    }
225
226    /// Compute triplet margin loss.
227    ///
228    /// L = max(0, d(anchor, positive) - d(anchor, negative) + margin)
229    pub fn triplet_loss(
230        &mut self,
231        anchors: &[Vec<f64>],
232        positives: &[Vec<f64>],
233        negatives: &[Vec<f64>],
234    ) -> ContrastiveLossResult {
235        let batch_size = anchors.len().min(positives.len()).min(negatives.len());
236        let margin = self.config.margin;
237        let mut per_sample_losses = Vec::with_capacity(batch_size);
238        let mut total_pos_dist = 0.0;
239        let mut total_neg_dist = 0.0;
240        let mut hard_count = 0;
241
242        for i in 0..batch_size {
243            let pos_dist = l2_distance(&anchors[i], &positives[i]);
244            let neg_dist = l2_distance(&anchors[i], &negatives[i]);
245
246            total_pos_dist += pos_dist;
247            total_neg_dist += neg_dist;
248
249            let loss = (pos_dist - neg_dist + margin).max(0.0);
250            if loss > 0.0 {
251                hard_count += 1;
252            }
253            per_sample_losses.push(loss);
254        }
255
256        let total_loss: f64 = per_sample_losses.iter().sum();
257        let avg_loss = if batch_size > 0 {
258            total_loss / batch_size as f64
259        } else {
260            0.0
261        };
262
263        let result = ContrastiveLossResult {
264            loss: avg_loss,
265            per_sample_losses,
266            avg_positive_similarity: if batch_size > 0 {
267                -(total_pos_dist / batch_size as f64)
268            } else {
269                0.0
270            },
271            avg_negative_similarity: if batch_size > 0 {
272                -(total_neg_dist / batch_size as f64)
273            } else {
274                0.0
275            },
276            batch_size,
277            hard_negatives_count: hard_count,
278        };
279
280        self.update_stats(&result);
281        result
282    }
283
284    /// Compute NT-Xent (Normalised Temperature-Scaled Cross-Entropy) loss.
285    ///
286    /// SimCLR-style loss over a batch of augmented pairs.
287    pub fn nt_xent_loss(
288        &mut self,
289        embeddings_a: &[Vec<f64>],
290        embeddings_b: &[Vec<f64>],
291    ) -> ContrastiveLossResult {
292        let batch_size = embeddings_a.len().min(embeddings_b.len());
293        let tau = self.config.temperature;
294        let mut per_sample_losses = Vec::with_capacity(batch_size);
295        let mut total_pos_sim = 0.0;
296        let mut total_neg_sim = 0.0;
297        let mut neg_count = 0usize;
298
299        for i in 0..batch_size {
300            let pos_sim = self.similarity(&embeddings_a[i], &embeddings_b[i]) / tau;
301            total_pos_sim += pos_sim * tau;
302
303            let mut log_sum = 0.0f64;
304            for j in 0..batch_size {
305                if j != i {
306                    let sim_aj = self.similarity(&embeddings_a[i], &embeddings_b[j]) / tau;
307                    let sim_ai = self.similarity(&embeddings_a[i], &embeddings_a[j]) / tau;
308                    total_neg_sim += sim_aj * tau + sim_ai * tau;
309                    neg_count += 2;
310                    log_sum += sim_aj.exp() + sim_ai.exp();
311                }
312            }
313            log_sum += pos_sim.exp();
314
315            let loss = -pos_sim + log_sum.ln();
316            per_sample_losses.push(loss);
317        }
318
319        let total_loss: f64 = per_sample_losses.iter().sum();
320        let avg_loss = if batch_size > 0 {
321            total_loss / batch_size as f64
322        } else {
323            0.0
324        };
325
326        let result = ContrastiveLossResult {
327            loss: avg_loss,
328            per_sample_losses,
329            avg_positive_similarity: if batch_size > 0 {
330                total_pos_sim / batch_size as f64
331            } else {
332                0.0
333            },
334            avg_negative_similarity: if neg_count > 0 {
335                total_neg_sim / neg_count as f64
336            } else {
337                0.0
338            },
339            batch_size,
340            hard_negatives_count: 0,
341        };
342
343        self.update_stats(&result);
344        result
345    }
346
347    /// Mine semi-hard negatives from a pool.
348    ///
349    /// Returns indices of negatives that are farther from anchor than positive
350    /// but within the margin boundary.
351    pub fn mine_semi_hard(
352        &self,
353        anchor: &[f64],
354        positive: &[f64],
355        negative_pool: &[Vec<f64>],
356    ) -> Vec<usize> {
357        let pos_dist = l2_distance(anchor, positive);
358        let margin = self.config.margin;
359
360        negative_pool
361            .iter()
362            .enumerate()
363            .filter_map(|(i, neg)| {
364                let neg_dist = l2_distance(anchor, neg);
365                if neg_dist > pos_dist && neg_dist < pos_dist + margin {
366                    Some(i)
367                } else {
368                    None
369                }
370            })
371            .collect()
372    }
373
374    /// Mine the hardest negative from a pool (closest to anchor).
375    pub fn mine_hardest(&self, anchor: &[f64], negative_pool: &[Vec<f64>]) -> Option<usize> {
376        negative_pool
377            .iter()
378            .enumerate()
379            .min_by(|(_, a), (_, b)| {
380                let da = l2_distance(anchor, a);
381                let db = l2_distance(anchor, b);
382                da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
383            })
384            .map(|(i, _)| i)
385    }
386
387    /// Get training statistics.
388    pub fn stats(&self) -> &ContrastiveTrainingStats {
389        &self.stats
390    }
391
392    /// Reset training statistics.
393    pub fn reset_stats(&mut self) {
394        self.stats = ContrastiveTrainingStats {
395            min_loss: f64::MAX,
396            ..Default::default()
397        };
398    }
399
400    /// Get the configuration.
401    pub fn config(&self) -> &ContrastiveConfig {
402        &self.config
403    }
404
405    // ─── Internal ────────────────────────────
406
407    fn similarity(&self, a: &[f64], b: &[f64]) -> f64 {
408        if self.config.use_cosine {
409            cosine_similarity(a, b)
410        } else {
411            dot_product(a, b)
412        }
413    }
414
415    fn update_stats(&mut self, result: &ContrastiveLossResult) {
416        self.stats.batches_processed += 1;
417        self.stats.total_samples += result.batch_size as u64;
418
419        let n = self.stats.batches_processed as f64;
420        self.stats.avg_loss = self.stats.avg_loss * (n - 1.0) / n + result.loss / n;
421
422        if result.loss < self.stats.min_loss {
423            self.stats.min_loss = result.loss;
424        }
425        if result.loss > self.stats.max_loss {
426            self.stats.max_loss = result.loss;
427        }
428
429        let gap = result.avg_positive_similarity - result.avg_negative_similarity;
430        self.stats.avg_similarity_gap = self.stats.avg_similarity_gap * (n - 1.0) / n + gap / n;
431    }
432}
433
434// ─────────────────────────────────────────────
435// Tests
436// ─────────────────────────────────────────────
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn sample_vector(seed: f64, dim: usize) -> Vec<f64> {
443        (0..dim).map(|i| (seed + i as f64 * 0.1).sin()).collect()
444    }
445
446    fn unit_vector(dim: usize, idx: usize) -> Vec<f64> {
447        let mut v = vec![0.0; dim];
448        if idx < dim {
449            v[idx] = 1.0;
450        }
451        v
452    }
453
454    #[test]
455    fn test_cosine_similarity_identical() {
456        let v = vec![1.0, 2.0, 3.0];
457        let sim = cosine_similarity(&v, &v);
458        assert!((sim - 1.0).abs() < 1e-10);
459    }
460
461    #[test]
462    fn test_cosine_similarity_orthogonal() {
463        let a = vec![1.0, 0.0];
464        let b = vec![0.0, 1.0];
465        let sim = cosine_similarity(&a, &b);
466        assert!(sim.abs() < 1e-10);
467    }
468
469    #[test]
470    fn test_cosine_similarity_opposite() {
471        let a = vec![1.0, 0.0];
472        let b = vec![-1.0, 0.0];
473        let sim = cosine_similarity(&a, &b);
474        assert!((sim - (-1.0)).abs() < 1e-10);
475    }
476
477    #[test]
478    fn test_cosine_similarity_zero_vector() {
479        let a = vec![1.0, 2.0];
480        let b = vec![0.0, 0.0];
481        assert_eq!(cosine_similarity(&a, &b), 0.0);
482    }
483
484    #[test]
485    fn test_dot_product_simple() {
486        let a = vec![1.0, 2.0, 3.0];
487        let b = vec![4.0, 5.0, 6.0];
488        assert!((dot_product(&a, &b) - 32.0).abs() < 1e-10);
489    }
490
491    #[test]
492    fn test_l2_distance_same() {
493        let v = vec![1.0, 2.0, 3.0];
494        assert!(l2_distance(&v, &v) < 1e-10);
495    }
496
497    #[test]
498    fn test_l2_distance_known() {
499        let a = vec![0.0, 0.0];
500        let b = vec![3.0, 4.0];
501        assert!((l2_distance(&a, &b) - 5.0).abs() < 1e-10);
502    }
503
504    #[test]
505    fn test_default_config() {
506        let config = ContrastiveConfig::default();
507        assert!((config.temperature - 0.07).abs() < 1e-10);
508        assert!((config.margin - 1.0).abs() < 1e-10);
509        assert_eq!(config.num_negatives, 128);
510        assert!(config.use_cosine);
511    }
512
513    #[test]
514    fn test_info_nce_basic() {
515        let mut engine = ContrastiveLossEngine::with_defaults();
516        let anchors = vec![sample_vector(1.0, 8)];
517        let positives = vec![sample_vector(1.1, 8)]; // Similar
518        let negatives = vec![sample_vector(5.0, 8), sample_vector(10.0, 8)];
519
520        let result = engine.info_nce_loss(&anchors, &positives, &negatives);
521        assert!(result.loss.is_finite());
522        assert_eq!(result.batch_size, 1);
523        assert_eq!(result.per_sample_losses.len(), 1);
524    }
525
526    #[test]
527    fn test_info_nce_positive_higher_similarity() {
528        let mut engine = ContrastiveLossEngine::with_defaults();
529        let anchor = vec![1.0, 0.0, 0.0, 0.0];
530        let positive = vec![0.9, 0.1, 0.0, 0.0]; // Very similar
531        let negatives = vec![vec![0.0, 1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0]];
532
533        let result = engine.info_nce_loss(&[anchor], &[positive], &negatives);
534        assert!(result.avg_positive_similarity > result.avg_negative_similarity);
535    }
536
537    #[test]
538    fn test_triplet_loss_zero_when_separated() {
539        let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
540            margin: 1.0,
541            ..Default::default()
542        });
543        let anchor = vec![0.0, 0.0];
544        let positive = vec![0.1, 0.0]; // Very close
545        let negative = vec![10.0, 10.0]; // Very far
546
547        let result = engine.triplet_loss(&[anchor], &[positive], &[negative]);
548        assert!(
549            result.loss < 1e-10,
550            "Loss should be 0 when negative is far away"
551        );
552    }
553
554    #[test]
555    fn test_triplet_loss_positive_when_close() {
556        let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
557            margin: 1.0,
558            ..Default::default()
559        });
560        let anchor = vec![0.0, 0.0];
561        let positive = vec![2.0, 0.0]; // dist = 2
562        let negative = vec![1.5, 0.0]; // dist = 1.5 (closer than positive!)
563
564        let result = engine.triplet_loss(&[anchor], &[positive], &[negative]);
565        assert!(
566            result.loss > 0.0,
567            "Loss should be positive when negative is closer"
568        );
569    }
570
571    #[test]
572    fn test_nt_xent_basic() {
573        let mut engine = ContrastiveLossEngine::with_defaults();
574        let a = vec![sample_vector(1.0, 8), sample_vector(2.0, 8)];
575        let b = vec![sample_vector(1.1, 8), sample_vector(2.1, 8)];
576
577        let result = engine.nt_xent_loss(&a, &b);
578        assert!(result.loss.is_finite());
579        assert_eq!(result.batch_size, 2);
580    }
581
582    #[test]
583    fn test_mine_semi_hard() {
584        let engine = ContrastiveLossEngine::new(ContrastiveConfig {
585            margin: 2.0,
586            ..Default::default()
587        });
588        let anchor = vec![0.0, 0.0];
589        let positive = vec![1.0, 0.0]; // dist = 1.0
590        let pool = vec![
591            vec![0.5, 0.0],  // dist = 0.5 (too close, not semi-hard)
592            vec![1.5, 0.0],  // dist = 1.5 (semi-hard: > 1.0 and < 3.0)
593            vec![2.5, 0.0],  // dist = 2.5 (semi-hard)
594            vec![10.0, 0.0], // dist = 10.0 (too far)
595        ];
596
597        let indices = engine.mine_semi_hard(&anchor, &positive, &pool);
598        assert!(indices.contains(&1));
599        assert!(indices.contains(&2));
600    }
601
602    #[test]
603    fn test_mine_hardest() {
604        let engine = ContrastiveLossEngine::with_defaults();
605        let anchor = vec![0.0, 0.0];
606        let pool = vec![
607            vec![10.0, 0.0], // dist = 10
608            vec![2.0, 0.0],  // dist = 2
609            vec![5.0, 0.0],  // dist = 5
610        ];
611
612        let idx = engine.mine_hardest(&anchor, &pool);
613        assert_eq!(idx, Some(1)); // Closest
614    }
615
616    #[test]
617    fn test_mine_hardest_empty() {
618        let engine = ContrastiveLossEngine::with_defaults();
619        let anchor = vec![0.0, 0.0];
620        assert!(engine.mine_hardest(&anchor, &[]).is_none());
621    }
622
623    #[test]
624    fn test_stats_tracking() {
625        let mut engine = ContrastiveLossEngine::with_defaults();
626        let a = vec![sample_vector(1.0, 4)];
627        let p = vec![sample_vector(1.1, 4)];
628        let n = vec![sample_vector(5.0, 4)];
629
630        engine.info_nce_loss(&a, &p, &n);
631        engine.info_nce_loss(&a, &p, &n);
632
633        assert_eq!(engine.stats().batches_processed, 2);
634        assert_eq!(engine.stats().total_samples, 2);
635    }
636
637    #[test]
638    fn test_stats_reset() {
639        let mut engine = ContrastiveLossEngine::with_defaults();
640        let a = vec![sample_vector(1.0, 4)];
641        let p = vec![sample_vector(1.1, 4)];
642        let n = vec![sample_vector(5.0, 4)];
643        engine.info_nce_loss(&a, &p, &n);
644
645        engine.reset_stats();
646        assert_eq!(engine.stats().batches_processed, 0);
647    }
648
649    #[test]
650    fn test_dot_product_mode() {
651        let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
652            use_cosine: false,
653            ..Default::default()
654        });
655        let a = vec![vec![1.0, 0.0]];
656        let p = vec![vec![0.9, 0.1]];
657        let n = vec![vec![0.0, 1.0]];
658
659        let result = engine.info_nce_loss(&a, &p, &n);
660        assert!(result.loss.is_finite());
661    }
662
663    #[test]
664    fn test_empty_batch() {
665        let mut engine = ContrastiveLossEngine::with_defaults();
666        let result = engine.info_nce_loss(&[], &[], &[]);
667        assert_eq!(result.batch_size, 0);
668        assert!((result.loss).abs() < 1e-10);
669    }
670
671    #[test]
672    fn test_triplet_empty_batch() {
673        let mut engine = ContrastiveLossEngine::with_defaults();
674        let result = engine.triplet_loss(&[], &[], &[]);
675        assert_eq!(result.batch_size, 0);
676    }
677
678    #[test]
679    fn test_nt_xent_single_sample() {
680        let mut engine = ContrastiveLossEngine::with_defaults();
681        let a = vec![sample_vector(1.0, 4)];
682        let b = vec![sample_vector(1.1, 4)];
683        let result = engine.nt_xent_loss(&a, &b);
684        assert!(result.loss.is_finite());
685    }
686
687    #[test]
688    fn test_config_serialization() {
689        let config = ContrastiveConfig::default();
690        let json = serde_json::to_string(&config).expect("serialize failed");
691        let deser: ContrastiveConfig = serde_json::from_str(&json).expect("deser failed");
692        assert!((deser.temperature - config.temperature).abs() < 1e-10);
693    }
694
695    #[test]
696    fn test_result_serialization() {
697        let result = ContrastiveLossResult {
698            loss: 0.5,
699            per_sample_losses: vec![0.5],
700            avg_positive_similarity: 0.8,
701            avg_negative_similarity: 0.2,
702            batch_size: 1,
703            hard_negatives_count: 0,
704        };
705        let json = serde_json::to_string(&result).expect("serialize failed");
706        assert!(json.contains("loss"));
707    }
708
709    #[test]
710    fn test_stats_serialization() {
711        let stats = ContrastiveTrainingStats::default();
712        let json = serde_json::to_string(&stats).expect("serialize failed");
713        assert!(json.contains("batches_processed"));
714    }
715
716    #[test]
717    fn test_mining_strategy_serde() {
718        let s = NegativeMiningStrategy::SemiHard;
719        let json = serde_json::to_string(&s).expect("serialize failed");
720        let deser: NegativeMiningStrategy = serde_json::from_str(&json).expect("deser failed");
721        assert_eq!(deser, s);
722    }
723
724    #[test]
725    fn test_large_batch() {
726        let mut engine = ContrastiveLossEngine::with_defaults();
727        let dim = 32;
728        let batch: Vec<Vec<f64>> = (0..16).map(|i| sample_vector(i as f64, dim)).collect();
729        let pos: Vec<Vec<f64>> = (0..16)
730            .map(|i| sample_vector(i as f64 + 0.01, dim))
731            .collect();
732        let neg: Vec<Vec<f64>> = (0..8)
733            .map(|i| sample_vector(i as f64 + 100.0, dim))
734            .collect();
735
736        let result = engine.info_nce_loss(&batch, &pos, &neg);
737        assert_eq!(result.batch_size, 16);
738        assert!(result.loss.is_finite());
739    }
740
741    #[test]
742    fn test_hard_negatives_count() {
743        let mut engine = ContrastiveLossEngine::new(ContrastiveConfig {
744            margin: 0.5,
745            ..Default::default()
746        });
747        let anchor = vec![1.0, 0.0, 0.0, 0.0];
748        let positive = vec![0.9, 0.1, 0.0, 0.0];
749        // Negatives very close to anchor — should be counted as hard
750        let negatives = vec![vec![0.95, 0.05, 0.0, 0.0]];
751
752        let result = engine.info_nce_loss(&[anchor], &[positive], &negatives);
753        // hard_negatives_count may or may not be 1 depending on similarity vs margin
754        assert!(result.hard_negatives_count <= 1);
755    }
756
757    #[test]
758    fn test_min_max_loss_tracking() {
759        let mut engine = ContrastiveLossEngine::with_defaults();
760        let a1 = vec![sample_vector(1.0, 4)];
761        let p1 = vec![sample_vector(1.1, 4)];
762        let n1 = vec![sample_vector(5.0, 4)];
763        engine.info_nce_loss(&a1, &p1, &n1);
764
765        let a2 = vec![sample_vector(1.0, 4)];
766        let p2 = vec![sample_vector(100.0, 4)]; // Very different "positive"
767        let n2 = vec![sample_vector(1.01, 4)]; // Very close negative
768        engine.info_nce_loss(&a2, &p2, &n2);
769
770        assert!(engine.stats().min_loss <= engine.stats().max_loss);
771    }
772}