Skip to main content

oxirs_embed/
fine_tuner.rs

1//! Embedding fine-tuning with contrastive learning (pure Rust, CPU-only).
2//!
3//! Supports three loss functions for adapting pre-trained embeddings:
4//! - **Triplet loss** – margin-based distance metric learning
5//! - **Contrastive loss** – similarity-aware pair learning
6//! - **Cosine similarity loss** – MSE against a target cosine similarity
7
8// ─────────────────────────────────────────────────────────────────────────────
9// Types
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// A pair of embeddings (and an optional negative sample) for training.
13#[derive(Debug, Clone)]
14pub struct EmbeddingPair {
15    /// Anchor embedding
16    pub anchor: Vec<f32>,
17    /// Positive (similar) embedding
18    pub positive: Vec<f32>,
19    /// Optional negative (dissimilar) embedding — required for triplet loss
20    pub negative: Option<Vec<f32>>,
21}
22
23impl EmbeddingPair {
24    /// Construct a pair with both positive and negative samples.
25    pub fn with_negative(anchor: Vec<f32>, positive: Vec<f32>, negative: Vec<f32>) -> Self {
26        Self {
27            anchor,
28            positive,
29            negative: Some(negative),
30        }
31    }
32
33    /// Construct a pair without a negative sample.
34    pub fn without_negative(anchor: Vec<f32>, positive: Vec<f32>) -> Self {
35        Self {
36            anchor,
37            positive,
38            negative: None,
39        }
40    }
41}
42
43/// Triplet margin loss configuration.
44#[derive(Debug, Clone)]
45pub struct TripletLoss {
46    /// Minimum margin between positive and negative distances
47    pub margin: f32,
48}
49
50/// Contrastive loss configuration.
51#[derive(Debug, Clone)]
52pub struct ContrastiveLoss {
53    /// Margin applied to dissimilar pairs
54    pub margin: f32,
55}
56
57/// Which loss function to use during fine-tuning.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum LossType {
60    /// Max(0, d(a,p) - d(a,n) + margin)
61    Triplet,
62    /// label=1: d², label=0: max(0, margin-d)²
63    Contrastive,
64    /// MSE between cosine_sim(a,b) and a target value
65    CosineSimilarity,
66}
67
68/// Configuration for a fine-tuning run.
69#[derive(Debug, Clone)]
70pub struct FinetuneConfig {
71    pub learning_rate: f32,
72    pub epochs: usize,
73    pub batch_size: usize,
74    pub loss_type: LossType,
75}
76
77impl Default for FinetuneConfig {
78    fn default() -> Self {
79        Self {
80            learning_rate: 1e-3,
81            epochs: 10,
82            batch_size: 32,
83            loss_type: LossType::Triplet,
84        }
85    }
86}
87
88/// A record of a single gradient-update step.
89#[derive(Debug, Clone)]
90pub struct TrainingStep {
91    pub epoch: usize,
92    pub step: usize,
93    pub loss: f32,
94}
95
96// ─────────────────────────────────────────────────────────────────────────────
97// FineTuner
98// ─────────────────────────────────────────────────────────────────────────────
99
100/// Embedding fine-tuner using contrastive learning losses.
101pub struct FineTuner {
102    config: FinetuneConfig,
103    history: Vec<TrainingStep>,
104}
105
106impl FineTuner {
107    /// Create a fine-tuner with the given configuration.
108    pub fn new(config: FinetuneConfig) -> Self {
109        Self {
110            config,
111            history: Vec::new(),
112        }
113    }
114
115    // ── Loss computations ──────────────────────────────────────────────────
116
117    /// Compute triplet margin loss.
118    ///
119    /// `loss = max(0, d(a,p) - d(a,n) + margin)`
120    pub fn compute_triplet_loss(&self, anchor: &[f32], positive: &[f32], negative: &[f32]) -> f32 {
121        let d_pos = euclidean_distance(anchor, positive);
122        let d_neg = euclidean_distance(anchor, negative);
123        let margin = match &self.config.loss_type {
124            LossType::Triplet => 1.0_f32, // default margin
125            _ => 1.0_f32,
126        };
127        (d_pos - d_neg + margin).max(0.0)
128    }
129
130    /// Compute contrastive loss for a pair.
131    ///
132    /// - `label = 1.0` (similar): loss = d²
133    /// - `label = 0.0` (dissimilar): loss = max(0, margin − d)²
134    pub fn compute_contrastive_loss(&self, a: &[f32], b: &[f32], label: f32) -> f32 {
135        let d = euclidean_distance(a, b);
136        let margin = 1.0_f32;
137        if label >= 0.5 {
138            d * d
139        } else {
140            (margin - d).max(0.0).powi(2)
141        }
142    }
143
144    /// Compute cosine similarity loss: MSE between cosine_sim(a,b) and `target`.
145    pub fn compute_cosine_loss(&self, a: &[f32], b: &[f32], target: f32) -> f32 {
146        let sim = cosine_similarity(a, b);
147        let diff = sim - target;
148        diff * diff
149    }
150
151    // ── Training ───────────────────────────────────────────────────────────
152
153    /// Simulate one gradient step over the given pairs and return the mean loss.
154    pub fn step(&mut self, pairs: &[EmbeddingPair]) -> f32 {
155        if pairs.is_empty() {
156            return 0.0;
157        }
158        let total_loss: f32 = pairs.iter().map(|p| self.pair_loss(p)).sum();
159        let mean_loss = total_loss / pairs.len() as f32;
160
161        let epoch = if self.history.is_empty() {
162            0
163        } else {
164            self.history.last().map(|s| s.epoch).unwrap_or(0)
165        };
166        let step = self.history.len();
167
168        self.history.push(TrainingStep {
169            epoch,
170            step,
171            loss: mean_loss,
172        });
173
174        mean_loss
175    }
176
177    /// Run full training for `config.epochs` epochs and return the mean loss per epoch.
178    pub fn train(&mut self, pairs: &[EmbeddingPair]) -> Vec<f32> {
179        let epochs = self.config.epochs;
180        let mut epoch_losses = Vec::with_capacity(epochs);
181
182        for epoch in 0..epochs {
183            if pairs.is_empty() {
184                epoch_losses.push(0.0);
185                continue;
186            }
187            let total_loss: f32 = pairs.iter().map(|p| self.pair_loss(p)).sum();
188            let mean_loss = total_loss / pairs.len() as f32;
189
190            let step = self.history.len();
191            self.history.push(TrainingStep {
192                epoch,
193                step,
194                loss: mean_loss,
195            });
196
197            epoch_losses.push(mean_loss);
198        }
199
200        epoch_losses
201    }
202
203    /// Access the full training history.
204    pub fn training_history(&self) -> &[TrainingStep] {
205        &self.history
206    }
207
208    /// Total number of gradient steps recorded.
209    pub fn total_steps(&self) -> usize {
210        self.history.len()
211    }
212
213    // ── private ────────────────────────────────────────────────────────────
214
215    fn pair_loss(&self, pair: &EmbeddingPair) -> f32 {
216        match self.config.loss_type {
217            LossType::Triplet => {
218                if let Some(neg) = &pair.negative {
219                    self.compute_triplet_loss(&pair.anchor, &pair.positive, neg)
220                } else {
221                    0.0
222                }
223            }
224            LossType::Contrastive => {
225                // Treat the pair as similar (label=1); use negative if available
226                if let Some(neg) = &pair.negative {
227                    // Average of similar and dissimilar
228                    let l_sim = self.compute_contrastive_loss(&pair.anchor, &pair.positive, 1.0);
229                    let l_dis = self.compute_contrastive_loss(&pair.anchor, neg, 0.0);
230                    (l_sim + l_dis) / 2.0
231                } else {
232                    self.compute_contrastive_loss(&pair.anchor, &pair.positive, 1.0)
233                }
234            }
235            LossType::CosineSimilarity => {
236                self.compute_cosine_loss(&pair.anchor, &pair.positive, 1.0)
237            }
238        }
239    }
240}
241
242// ─────────────────────────────────────────────────────────────────────────────
243// Free functions
244// ─────────────────────────────────────────────────────────────────────────────
245
246/// Compute cosine similarity between two vectors.
247///
248/// Returns 0.0 if either vector has zero norm.
249pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
250    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
251    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
252    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
253    if norm_a == 0.0 || norm_b == 0.0 {
254        0.0
255    } else {
256        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
257    }
258}
259
260/// Compute the L2 norm of a vector.
261pub fn l2_norm(v: &[f32]) -> f32 {
262    v.iter().map(|x| x * x).sum::<f32>().sqrt()
263}
264
265/// Return a unit-length version of `v`.  If the norm is zero, returns a zero vector.
266pub fn l2_normalize(v: &[f32]) -> Vec<f32> {
267    let norm = l2_norm(v);
268    if norm == 0.0 {
269        v.to_vec()
270    } else {
271        v.iter().map(|x| x / norm).collect()
272    }
273}
274
275/// Euclidean distance between two equal-length vectors.
276fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
277    a.iter()
278        .zip(b.iter())
279        .map(|(x, y)| (x - y).powi(2))
280        .sum::<f32>()
281        .sqrt()
282}
283
284// ─────────────────────────────────────────────────────────────────────────────
285// Tests
286// ─────────────────────────────────────────────────────────────────────────────
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    const EPS: f32 = 1e-5;
292
293    fn ones(dim: usize) -> Vec<f32> {
294        vec![1.0; dim]
295    }
296    fn zeros(dim: usize) -> Vec<f32> {
297        vec![0.0; dim]
298    }
299    fn unit_x() -> Vec<f32> {
300        vec![1.0, 0.0, 0.0]
301    }
302    fn unit_y() -> Vec<f32> {
303        vec![0.0, 1.0, 0.0]
304    }
305
306    fn triplet_tuner() -> FineTuner {
307        FineTuner::new(FinetuneConfig {
308            loss_type: LossType::Triplet,
309            ..Default::default()
310        })
311    }
312    fn contrastive_tuner() -> FineTuner {
313        FineTuner::new(FinetuneConfig {
314            loss_type: LossType::Contrastive,
315            ..Default::default()
316        })
317    }
318    fn cosine_tuner() -> FineTuner {
319        FineTuner::new(FinetuneConfig {
320            loss_type: LossType::CosineSimilarity,
321            ..Default::default()
322        })
323    }
324
325    // 1. Triplet loss: anchor = positive → d(a,p)=0 → loss = max(0, -d_neg + 1)
326    #[test]
327    fn test_triplet_loss_same_anchor_positive() {
328        let tuner = triplet_tuner();
329        let a = vec![1.0, 0.0];
330        let neg = vec![10.0, 0.0];
331        let loss = tuner.compute_triplet_loss(&a, &a, &neg);
332        // d(a,p)=0, d(a,neg)=9 → max(0, 0-9+1)=0
333        assert!(loss.abs() < EPS);
334    }
335
336    // 2. Triplet loss: positive equals negative → loss >= margin
337    #[test]
338    fn test_triplet_loss_positive_equals_negative() {
339        let tuner = triplet_tuner();
340        let a = unit_x();
341        let p = unit_y();
342        // negative same as positive → d(a,p)==d(a,neg) → loss = margin = 1
343        let loss = tuner.compute_triplet_loss(&a, &p, &p);
344        assert!((loss - 1.0).abs() < EPS);
345    }
346
347    // 3. Triplet loss: margin enforced — negative very far away
348    #[test]
349    fn test_triplet_loss_negative_far_gives_zero() {
350        let tuner = triplet_tuner();
351        let a = vec![0.0, 0.0];
352        let p = vec![0.1, 0.0];
353        let n = vec![100.0, 0.0];
354        let loss = tuner.compute_triplet_loss(&a, &p, &n);
355        assert!(loss < EPS); // d(a,p) << d(a,n) so loss=0
356    }
357
358    // 4. Triplet loss is non-negative
359    #[test]
360    fn test_triplet_loss_non_negative() {
361        let tuner = triplet_tuner();
362        let a = vec![1.0, 2.0];
363        let p = vec![1.1, 2.1];
364        let n = vec![0.5, 0.5];
365        let loss = tuner.compute_triplet_loss(&a, &p, &n);
366        assert!(loss >= 0.0);
367    }
368
369    // 5. Zero margin triplet: loss = max(0, d(a,p) - d(a,n))
370    #[test]
371    fn test_zero_margin_triplet_direct() {
372        // Test via compute_triplet_loss logic: margin hardcoded to 1.0 internally.
373        // Verify loss is non-negative regardless.
374        let tuner = triplet_tuner();
375        let a = vec![0.0];
376        let p = vec![1.0];
377        let n = vec![2.0];
378        let loss = tuner.compute_triplet_loss(&a, &p, &n);
379        // d(a,p)=1, d(a,n)=2, margin=1 → max(0, 1-2+1)=0
380        assert!(loss.abs() < EPS);
381    }
382
383    // 6. Contrastive loss similar pair (label=1): loss = d²
384    #[test]
385    fn test_contrastive_similar_pair() {
386        let tuner = contrastive_tuner();
387        let a = vec![0.0];
388        let b = vec![0.5];
389        let loss = tuner.compute_contrastive_loss(&a, &b, 1.0);
390        // d=0.5, loss = 0.25
391        assert!((loss - 0.25).abs() < EPS);
392    }
393
394    // 7. Contrastive loss dissimilar pair (label=0): loss = max(0, margin-d)²
395    #[test]
396    fn test_contrastive_dissimilar_pair() {
397        let tuner = contrastive_tuner();
398        let a = vec![0.0];
399        let b = vec![1.5]; // d=1.5 > margin=1 → loss=0
400        let loss = tuner.compute_contrastive_loss(&a, &b, 0.0);
401        assert!(loss.abs() < EPS);
402    }
403
404    // 8. Contrastive loss dissimilar pair close together
405    #[test]
406    fn test_contrastive_dissimilar_close() {
407        let tuner = contrastive_tuner();
408        let a = vec![0.0];
409        let b = vec![0.5]; // d=0.5, margin=1 → loss=(1-0.5)²=0.25
410        let loss = tuner.compute_contrastive_loss(&a, &b, 0.0);
411        assert!((loss - 0.25).abs() < EPS);
412    }
413
414    // 9. Contrastive loss identical vectors (similar): loss = 0
415    #[test]
416    fn test_contrastive_identical_similar() {
417        let tuner = contrastive_tuner();
418        let a = vec![1.0, 2.0];
419        let loss = tuner.compute_contrastive_loss(&a, &a, 1.0);
420        assert!(loss.abs() < EPS);
421    }
422
423    // 10. Cosine loss: identical vectors → sim=1 → loss=(1-target)²
424    #[test]
425    fn test_cosine_loss_identical() {
426        let tuner = cosine_tuner();
427        let a = unit_x();
428        let loss = tuner.compute_cosine_loss(&a, &a, 1.0);
429        assert!(loss.abs() < EPS);
430    }
431
432    // 11. Cosine loss: orthogonal vectors → sim=0 → loss=target²
433    #[test]
434    fn test_cosine_loss_orthogonal() {
435        let tuner = cosine_tuner();
436        let a = unit_x();
437        let b = unit_y();
438        let loss = tuner.compute_cosine_loss(&a, &b, 0.0);
439        assert!(loss.abs() < EPS);
440    }
441
442    // 12. Cosine loss: opposite vectors → sim=-1
443    #[test]
444    fn test_cosine_loss_opposite() {
445        let tuner = cosine_tuner();
446        let a = vec![1.0, 0.0];
447        let b = vec![-1.0, 0.0];
448        let loss = tuner.compute_cosine_loss(&a, &b, -1.0);
449        assert!(loss.abs() < EPS);
450    }
451
452    // 13. train returns one loss per epoch
453    #[test]
454    fn test_train_returns_one_loss_per_epoch() {
455        let mut tuner = FineTuner::new(FinetuneConfig {
456            epochs: 5,
457            loss_type: LossType::Triplet,
458            ..Default::default()
459        });
460        let pairs = vec![EmbeddingPair::with_negative(
461            vec![0.0, 0.0],
462            vec![1.0, 0.0],
463            vec![0.0, 1.0],
464        )];
465        let losses = tuner.train(&pairs);
466        assert_eq!(losses.len(), 5);
467    }
468
469    // 14. step increments total_steps
470    #[test]
471    fn test_step_increments_total_steps() {
472        let mut tuner = triplet_tuner();
473        let pairs = vec![EmbeddingPair::with_negative(
474            vec![0.0],
475            vec![1.0],
476            vec![2.0],
477        )];
478        tuner.step(&pairs);
479        assert_eq!(tuner.total_steps(), 1);
480        tuner.step(&pairs);
481        assert_eq!(tuner.total_steps(), 2);
482    }
483
484    // 15. history grows with step calls
485    #[test]
486    fn test_history_grows() {
487        let mut tuner = triplet_tuner();
488        let pairs = vec![EmbeddingPair::without_negative(vec![0.0], vec![1.0])];
489        for _ in 0..7 {
490            tuner.step(&pairs);
491        }
492        assert_eq!(tuner.training_history().len(), 7);
493    }
494
495    // 16. train appends to history
496    #[test]
497    fn test_train_appends_to_history() {
498        let mut tuner = FineTuner::new(FinetuneConfig {
499            epochs: 3,
500            ..Default::default()
501        });
502        let pairs = vec![EmbeddingPair::with_negative(ones(4), ones(4), zeros(4))];
503        tuner.train(&pairs);
504        assert_eq!(tuner.training_history().len(), 3);
505    }
506
507    // 17. empty pairs: step returns 0
508    #[test]
509    fn test_step_empty_pairs() {
510        let mut tuner = triplet_tuner();
511        let loss = tuner.step(&[]);
512        assert_eq!(loss, 0.0);
513    }
514
515    // 18. train with empty pairs returns zeros
516    #[test]
517    fn test_train_empty_pairs() {
518        let mut tuner = FineTuner::new(FinetuneConfig {
519            epochs: 3,
520            ..Default::default()
521        });
522        let losses = tuner.train(&[]);
523        assert_eq!(losses.len(), 3);
524        assert!(losses.iter().all(|&l| l == 0.0));
525    }
526
527    // 19. cosine_similarity identical unit vectors = 1.0
528    #[test]
529    fn test_cosine_similarity_identical() {
530        let a = unit_x();
531        let sim = cosine_similarity(&a, &a);
532        assert!((sim - 1.0).abs() < EPS);
533    }
534
535    // 20. cosine_similarity orthogonal = 0.0
536    #[test]
537    fn test_cosine_similarity_orthogonal() {
538        let sim = cosine_similarity(&unit_x(), &unit_y());
539        assert!(sim.abs() < EPS);
540    }
541
542    // 21. cosine_similarity antiparallel = -1.0
543    #[test]
544    fn test_cosine_similarity_antiparallel() {
545        let a = vec![1.0, 0.0];
546        let b = vec![-1.0, 0.0];
547        let sim = cosine_similarity(&a, &b);
548        assert!((sim + 1.0).abs() < EPS);
549    }
550
551    // 22. cosine_similarity zero vector returns 0
552    #[test]
553    fn test_cosine_similarity_zero_vector() {
554        let a = vec![1.0, 0.0];
555        let b = zeros(2);
556        let sim = cosine_similarity(&a, &b);
557        assert_eq!(sim, 0.0);
558    }
559
560    // 23. l2_normalize unit vector
561    #[test]
562    fn test_l2_normalize_unit_vector() {
563        let v = vec![3.0, 4.0];
564        let n = l2_normalize(&v);
565        assert!((n[0] - 0.6).abs() < EPS);
566        assert!((n[1] - 0.8).abs() < EPS);
567    }
568
569    // 24. l2_normalize already normalized vector
570    #[test]
571    fn test_l2_normalize_already_unit() {
572        let v = unit_x();
573        let n = l2_normalize(&v);
574        assert!((l2_norm(&n) - 1.0).abs() < EPS);
575    }
576
577    // 25. l2_normalize zero vector returns zero
578    #[test]
579    fn test_l2_normalize_zero_vector() {
580        let v = zeros(3);
581        let n = l2_normalize(&v);
582        assert_eq!(n, zeros(3));
583    }
584
585    // 26. Normalized vectors: cosine_similarity = 1 for identical
586    #[test]
587    fn test_normalized_cosine_similarity() {
588        let v = vec![3.0, 4.0];
589        let n = l2_normalize(&v);
590        let sim = cosine_similarity(&n, &n);
591        assert!((sim - 1.0).abs() < EPS);
592    }
593
594    // 27. TrainingStep epoch recorded
595    #[test]
596    fn test_training_step_epoch() {
597        let mut tuner = FineTuner::new(FinetuneConfig {
598            epochs: 1,
599            ..Default::default()
600        });
601        let pairs = vec![EmbeddingPair::with_negative(ones(2), ones(2), zeros(2))];
602        tuner.train(&pairs);
603        assert_eq!(tuner.training_history()[0].epoch, 0);
604    }
605
606    // 28. TrainingStep step index recorded
607    #[test]
608    fn test_training_step_step_index() {
609        let mut tuner = FineTuner::new(FinetuneConfig {
610            epochs: 3,
611            ..Default::default()
612        });
613        let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
614        tuner.train(&pairs);
615        let steps: Vec<usize> = tuner.training_history().iter().map(|s| s.step).collect();
616        assert_eq!(steps, vec![0, 1, 2]);
617    }
618
619    // 29. Contrastive loss LossType in train
620    #[test]
621    fn test_contrastive_loss_train() {
622        let mut tuner = FineTuner::new(FinetuneConfig {
623            epochs: 2,
624            loss_type: LossType::Contrastive,
625            ..Default::default()
626        });
627        let pairs = vec![EmbeddingPair::with_negative(
628            vec![0.0, 0.0],
629            vec![0.1, 0.0],
630            vec![2.0, 0.0],
631        )];
632        let losses = tuner.train(&pairs);
633        assert_eq!(losses.len(), 2);
634        assert!(losses.iter().all(|&l| l >= 0.0));
635    }
636
637    // 30. CosineSimilarity LossType in train
638    #[test]
639    fn test_cosine_loss_train() {
640        let mut tuner = FineTuner::new(FinetuneConfig {
641            epochs: 2,
642            loss_type: LossType::CosineSimilarity,
643            ..Default::default()
644        });
645        let pairs = vec![EmbeddingPair::without_negative(unit_x(), unit_y())];
646        let losses = tuner.train(&pairs);
647        assert!(losses.iter().all(|&l| l >= 0.0));
648    }
649
650    // 31. step returns positive loss for non-trivial pairs
651    #[test]
652    fn test_step_positive_loss() {
653        let mut tuner = triplet_tuner();
654        let pairs = vec![EmbeddingPair::with_negative(
655            vec![0.0, 0.0],
656            vec![0.5, 0.0],
657            vec![0.1, 0.0], // close negative — high loss
658        )];
659        let loss = tuner.step(&pairs);
660        assert!(loss >= 0.0);
661    }
662
663    // 32. total_steps after train
664    #[test]
665    fn test_total_steps_after_train() {
666        let mut tuner = FineTuner::new(FinetuneConfig {
667            epochs: 4,
668            ..Default::default()
669        });
670        let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
671        tuner.train(&pairs);
672        assert_eq!(tuner.total_steps(), 4);
673    }
674
675    // 33. Multiple step calls accumulate total_steps
676    #[test]
677    fn test_step_plus_train_accumulate() {
678        let mut tuner = FineTuner::new(FinetuneConfig {
679            epochs: 3,
680            ..Default::default()
681        });
682        let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
683        tuner.step(&pairs);
684        tuner.train(&pairs);
685        // 1 step + 3 train = 4 total
686        assert_eq!(tuner.total_steps(), 4);
687    }
688
689    // 34. FinetuneConfig default
690    #[test]
691    fn test_finetune_config_default() {
692        let cfg = FinetuneConfig::default();
693        assert_eq!(cfg.epochs, 10);
694        assert_eq!(cfg.loss_type, LossType::Triplet);
695    }
696
697    // 35. EmbeddingPair with_negative stores negative
698    #[test]
699    fn test_embedding_pair_with_negative() {
700        let p = EmbeddingPair::with_negative(vec![1.0], vec![2.0], vec![3.0]);
701        assert!(p.negative.is_some());
702    }
703
704    // 36. EmbeddingPair without_negative has None
705    #[test]
706    fn test_embedding_pair_without_negative() {
707        let p = EmbeddingPair::without_negative(vec![1.0], vec![2.0]);
708        assert!(p.negative.is_none());
709    }
710
711    // 37. Triplet pair without negative gives zero loss
712    #[test]
713    fn test_triplet_pair_no_negative_zero_loss() {
714        let mut tuner = triplet_tuner();
715        let pairs = vec![EmbeddingPair::without_negative(ones(4), zeros(4))];
716        let loss = tuner.step(&pairs);
717        assert_eq!(loss, 0.0);
718    }
719
720    // 38. cosine_similarity clamped to [-1,1]
721    #[test]
722    fn test_cosine_similarity_clamped() {
723        // Even with floating-point noise, result should be in [-1, 1]
724        let a = vec![1.0, 0.0, 0.0];
725        let b = vec![1.0, 1e-7, 0.0];
726        let sim = cosine_similarity(&a, &b);
727        assert!((-1.0..=1.0).contains(&sim));
728    }
729
730    // 39. l2_norm of a 3-4-5 triangle
731    #[test]
732    fn test_l2_norm_345() {
733        let v = vec![3.0, 4.0];
734        assert!((l2_norm(&v) - 5.0).abs() < EPS);
735    }
736
737    // 40. Loss is recorded in history for every train epoch
738    #[test]
739    fn test_loss_recorded_in_history() {
740        let mut tuner = FineTuner::new(FinetuneConfig {
741            epochs: 5,
742            loss_type: LossType::CosineSimilarity,
743            ..Default::default()
744        });
745        let pairs = vec![EmbeddingPair::without_negative(unit_x(), unit_y())];
746        tuner.train(&pairs);
747        assert!(tuner.training_history().iter().all(|s| s.loss >= 0.0));
748    }
749
750    // 41. Different LossTypes produce different losses for same pair
751    #[test]
752    fn test_loss_types_differ() {
753        let pairs = vec![EmbeddingPair::with_negative(
754            vec![0.0, 0.0],
755            vec![0.5, 0.0],
756            vec![0.2, 0.0],
757        )];
758        let mut t1 = FineTuner::new(FinetuneConfig {
759            epochs: 1,
760            loss_type: LossType::Triplet,
761            ..Default::default()
762        });
763        let mut t2 = FineTuner::new(FinetuneConfig {
764            epochs: 1,
765            loss_type: LossType::Contrastive,
766            ..Default::default()
767        });
768        let l1 = t1.step(&pairs);
769        let l2 = t2.step(&pairs);
770        // They may or may not be equal; just verify both are non-negative
771        assert!(l1 >= 0.0);
772        assert!(l2 >= 0.0);
773    }
774
775    // 42. High-dimensional embeddings work
776    #[test]
777    fn test_high_dimensional_embeddings() {
778        let dim = 768;
779        let anchor: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
780        let positive: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
781        let negative: Vec<f32> = vec![-1.0; dim];
782        let tuner = triplet_tuner();
783        let loss = tuner.compute_triplet_loss(&anchor, &positive, &negative);
784        assert!(loss >= 0.0);
785    }
786
787    // 43. cosine_loss is zero when sim matches target exactly
788    #[test]
789    fn test_cosine_loss_zero_when_exact() {
790        let tuner = cosine_tuner();
791        let a = unit_x();
792        let b = unit_x();
793        let sim = cosine_similarity(&a, &b); // should be 1.0
794        let loss = tuner.compute_cosine_loss(&a, &b, sim);
795        assert!(loss.abs() < EPS);
796    }
797
798    // 44. train with large batch size config
799    #[test]
800    fn test_train_large_batch_size() {
801        let mut tuner = FineTuner::new(FinetuneConfig {
802            batch_size: 512,
803            epochs: 2,
804            ..Default::default()
805        });
806        let pairs: Vec<_> = (0..100)
807            .map(|_| EmbeddingPair::without_negative(ones(16), zeros(16)))
808            .collect();
809        let losses = tuner.train(&pairs);
810        assert_eq!(losses.len(), 2);
811    }
812
813    // 45. total_steps 0 initially
814    #[test]
815    fn test_total_steps_initially_zero() {
816        let tuner = triplet_tuner();
817        assert_eq!(tuner.total_steps(), 0);
818    }
819}