1#[derive(Debug, Clone)]
14pub struct EmbeddingPair {
15 pub anchor: Vec<f32>,
17 pub positive: Vec<f32>,
19 pub negative: Option<Vec<f32>>,
21}
22
23impl EmbeddingPair {
24 pub fn with_negative(anchor: Vec<f32>, positive: Vec<f32>, negative: Vec<f32>) -> Self {
26 Self {
27 anchor,
28 positive,
29 negative: Some(negative),
30 }
31 }
32
33 pub fn without_negative(anchor: Vec<f32>, positive: Vec<f32>) -> Self {
35 Self {
36 anchor,
37 positive,
38 negative: None,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct TripletLoss {
46 pub margin: f32,
48}
49
50#[derive(Debug, Clone)]
52pub struct ContrastiveLoss {
53 pub margin: f32,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum LossType {
60 Triplet,
62 Contrastive,
64 CosineSimilarity,
66}
67
68#[derive(Debug, Clone)]
70pub struct FinetuneConfig {
71 pub learning_rate: f32,
72 pub epochs: usize,
73 pub batch_size: usize,
74 pub loss_type: LossType,
75}
76
77impl Default for FinetuneConfig {
78 fn default() -> Self {
79 Self {
80 learning_rate: 1e-3,
81 epochs: 10,
82 batch_size: 32,
83 loss_type: LossType::Triplet,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct TrainingStep {
91 pub epoch: usize,
92 pub step: usize,
93 pub loss: f32,
94}
95
96pub struct FineTuner {
102 config: FinetuneConfig,
103 history: Vec<TrainingStep>,
104}
105
106impl FineTuner {
107 pub fn new(config: FinetuneConfig) -> Self {
109 Self {
110 config,
111 history: Vec::new(),
112 }
113 }
114
115 pub fn compute_triplet_loss(&self, anchor: &[f32], positive: &[f32], negative: &[f32]) -> f32 {
121 let d_pos = euclidean_distance(anchor, positive);
122 let d_neg = euclidean_distance(anchor, negative);
123 let margin = match &self.config.loss_type {
124 LossType::Triplet => 1.0_f32, _ => 1.0_f32,
126 };
127 (d_pos - d_neg + margin).max(0.0)
128 }
129
130 pub fn compute_contrastive_loss(&self, a: &[f32], b: &[f32], label: f32) -> f32 {
135 let d = euclidean_distance(a, b);
136 let margin = 1.0_f32;
137 if label >= 0.5 {
138 d * d
139 } else {
140 (margin - d).max(0.0).powi(2)
141 }
142 }
143
144 pub fn compute_cosine_loss(&self, a: &[f32], b: &[f32], target: f32) -> f32 {
146 let sim = cosine_similarity(a, b);
147 let diff = sim - target;
148 diff * diff
149 }
150
151 pub fn step(&mut self, pairs: &[EmbeddingPair]) -> f32 {
155 if pairs.is_empty() {
156 return 0.0;
157 }
158 let total_loss: f32 = pairs.iter().map(|p| self.pair_loss(p)).sum();
159 let mean_loss = total_loss / pairs.len() as f32;
160
161 let epoch = if self.history.is_empty() {
162 0
163 } else {
164 self.history.last().map(|s| s.epoch).unwrap_or(0)
165 };
166 let step = self.history.len();
167
168 self.history.push(TrainingStep {
169 epoch,
170 step,
171 loss: mean_loss,
172 });
173
174 mean_loss
175 }
176
177 pub fn train(&mut self, pairs: &[EmbeddingPair]) -> Vec<f32> {
179 let epochs = self.config.epochs;
180 let mut epoch_losses = Vec::with_capacity(epochs);
181
182 for epoch in 0..epochs {
183 if pairs.is_empty() {
184 epoch_losses.push(0.0);
185 continue;
186 }
187 let total_loss: f32 = pairs.iter().map(|p| self.pair_loss(p)).sum();
188 let mean_loss = total_loss / pairs.len() as f32;
189
190 let step = self.history.len();
191 self.history.push(TrainingStep {
192 epoch,
193 step,
194 loss: mean_loss,
195 });
196
197 epoch_losses.push(mean_loss);
198 }
199
200 epoch_losses
201 }
202
203 pub fn training_history(&self) -> &[TrainingStep] {
205 &self.history
206 }
207
208 pub fn total_steps(&self) -> usize {
210 self.history.len()
211 }
212
213 fn pair_loss(&self, pair: &EmbeddingPair) -> f32 {
216 match self.config.loss_type {
217 LossType::Triplet => {
218 if let Some(neg) = &pair.negative {
219 self.compute_triplet_loss(&pair.anchor, &pair.positive, neg)
220 } else {
221 0.0
222 }
223 }
224 LossType::Contrastive => {
225 if let Some(neg) = &pair.negative {
227 let l_sim = self.compute_contrastive_loss(&pair.anchor, &pair.positive, 1.0);
229 let l_dis = self.compute_contrastive_loss(&pair.anchor, neg, 0.0);
230 (l_sim + l_dis) / 2.0
231 } else {
232 self.compute_contrastive_loss(&pair.anchor, &pair.positive, 1.0)
233 }
234 }
235 LossType::CosineSimilarity => {
236 self.compute_cosine_loss(&pair.anchor, &pair.positive, 1.0)
237 }
238 }
239 }
240}
241
242pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
250 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
251 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
252 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
253 if norm_a == 0.0 || norm_b == 0.0 {
254 0.0
255 } else {
256 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
257 }
258}
259
260pub fn l2_norm(v: &[f32]) -> f32 {
262 v.iter().map(|x| x * x).sum::<f32>().sqrt()
263}
264
265pub fn l2_normalize(v: &[f32]) -> Vec<f32> {
267 let norm = l2_norm(v);
268 if norm == 0.0 {
269 v.to_vec()
270 } else {
271 v.iter().map(|x| x / norm).collect()
272 }
273}
274
275fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
277 a.iter()
278 .zip(b.iter())
279 .map(|(x, y)| (x - y).powi(2))
280 .sum::<f32>()
281 .sqrt()
282}
283
284#[cfg(test)]
288mod tests {
289 use super::*;
290
291 const EPS: f32 = 1e-5;
292
293 fn ones(dim: usize) -> Vec<f32> {
294 vec![1.0; dim]
295 }
296 fn zeros(dim: usize) -> Vec<f32> {
297 vec![0.0; dim]
298 }
299 fn unit_x() -> Vec<f32> {
300 vec![1.0, 0.0, 0.0]
301 }
302 fn unit_y() -> Vec<f32> {
303 vec![0.0, 1.0, 0.0]
304 }
305
306 fn triplet_tuner() -> FineTuner {
307 FineTuner::new(FinetuneConfig {
308 loss_type: LossType::Triplet,
309 ..Default::default()
310 })
311 }
312 fn contrastive_tuner() -> FineTuner {
313 FineTuner::new(FinetuneConfig {
314 loss_type: LossType::Contrastive,
315 ..Default::default()
316 })
317 }
318 fn cosine_tuner() -> FineTuner {
319 FineTuner::new(FinetuneConfig {
320 loss_type: LossType::CosineSimilarity,
321 ..Default::default()
322 })
323 }
324
325 #[test]
327 fn test_triplet_loss_same_anchor_positive() {
328 let tuner = triplet_tuner();
329 let a = vec![1.0, 0.0];
330 let neg = vec![10.0, 0.0];
331 let loss = tuner.compute_triplet_loss(&a, &a, &neg);
332 assert!(loss.abs() < EPS);
334 }
335
336 #[test]
338 fn test_triplet_loss_positive_equals_negative() {
339 let tuner = triplet_tuner();
340 let a = unit_x();
341 let p = unit_y();
342 let loss = tuner.compute_triplet_loss(&a, &p, &p);
344 assert!((loss - 1.0).abs() < EPS);
345 }
346
347 #[test]
349 fn test_triplet_loss_negative_far_gives_zero() {
350 let tuner = triplet_tuner();
351 let a = vec![0.0, 0.0];
352 let p = vec![0.1, 0.0];
353 let n = vec![100.0, 0.0];
354 let loss = tuner.compute_triplet_loss(&a, &p, &n);
355 assert!(loss < EPS); }
357
358 #[test]
360 fn test_triplet_loss_non_negative() {
361 let tuner = triplet_tuner();
362 let a = vec![1.0, 2.0];
363 let p = vec![1.1, 2.1];
364 let n = vec![0.5, 0.5];
365 let loss = tuner.compute_triplet_loss(&a, &p, &n);
366 assert!(loss >= 0.0);
367 }
368
369 #[test]
371 fn test_zero_margin_triplet_direct() {
372 let tuner = triplet_tuner();
375 let a = vec![0.0];
376 let p = vec![1.0];
377 let n = vec![2.0];
378 let loss = tuner.compute_triplet_loss(&a, &p, &n);
379 assert!(loss.abs() < EPS);
381 }
382
383 #[test]
385 fn test_contrastive_similar_pair() {
386 let tuner = contrastive_tuner();
387 let a = vec![0.0];
388 let b = vec![0.5];
389 let loss = tuner.compute_contrastive_loss(&a, &b, 1.0);
390 assert!((loss - 0.25).abs() < EPS);
392 }
393
394 #[test]
396 fn test_contrastive_dissimilar_pair() {
397 let tuner = contrastive_tuner();
398 let a = vec![0.0];
399 let b = vec![1.5]; let loss = tuner.compute_contrastive_loss(&a, &b, 0.0);
401 assert!(loss.abs() < EPS);
402 }
403
404 #[test]
406 fn test_contrastive_dissimilar_close() {
407 let tuner = contrastive_tuner();
408 let a = vec![0.0];
409 let b = vec![0.5]; let loss = tuner.compute_contrastive_loss(&a, &b, 0.0);
411 assert!((loss - 0.25).abs() < EPS);
412 }
413
414 #[test]
416 fn test_contrastive_identical_similar() {
417 let tuner = contrastive_tuner();
418 let a = vec![1.0, 2.0];
419 let loss = tuner.compute_contrastive_loss(&a, &a, 1.0);
420 assert!(loss.abs() < EPS);
421 }
422
423 #[test]
425 fn test_cosine_loss_identical() {
426 let tuner = cosine_tuner();
427 let a = unit_x();
428 let loss = tuner.compute_cosine_loss(&a, &a, 1.0);
429 assert!(loss.abs() < EPS);
430 }
431
432 #[test]
434 fn test_cosine_loss_orthogonal() {
435 let tuner = cosine_tuner();
436 let a = unit_x();
437 let b = unit_y();
438 let loss = tuner.compute_cosine_loss(&a, &b, 0.0);
439 assert!(loss.abs() < EPS);
440 }
441
442 #[test]
444 fn test_cosine_loss_opposite() {
445 let tuner = cosine_tuner();
446 let a = vec![1.0, 0.0];
447 let b = vec![-1.0, 0.0];
448 let loss = tuner.compute_cosine_loss(&a, &b, -1.0);
449 assert!(loss.abs() < EPS);
450 }
451
452 #[test]
454 fn test_train_returns_one_loss_per_epoch() {
455 let mut tuner = FineTuner::new(FinetuneConfig {
456 epochs: 5,
457 loss_type: LossType::Triplet,
458 ..Default::default()
459 });
460 let pairs = vec![EmbeddingPair::with_negative(
461 vec![0.0, 0.0],
462 vec![1.0, 0.0],
463 vec![0.0, 1.0],
464 )];
465 let losses = tuner.train(&pairs);
466 assert_eq!(losses.len(), 5);
467 }
468
469 #[test]
471 fn test_step_increments_total_steps() {
472 let mut tuner = triplet_tuner();
473 let pairs = vec![EmbeddingPair::with_negative(
474 vec![0.0],
475 vec![1.0],
476 vec![2.0],
477 )];
478 tuner.step(&pairs);
479 assert_eq!(tuner.total_steps(), 1);
480 tuner.step(&pairs);
481 assert_eq!(tuner.total_steps(), 2);
482 }
483
484 #[test]
486 fn test_history_grows() {
487 let mut tuner = triplet_tuner();
488 let pairs = vec![EmbeddingPair::without_negative(vec![0.0], vec![1.0])];
489 for _ in 0..7 {
490 tuner.step(&pairs);
491 }
492 assert_eq!(tuner.training_history().len(), 7);
493 }
494
495 #[test]
497 fn test_train_appends_to_history() {
498 let mut tuner = FineTuner::new(FinetuneConfig {
499 epochs: 3,
500 ..Default::default()
501 });
502 let pairs = vec![EmbeddingPair::with_negative(ones(4), ones(4), zeros(4))];
503 tuner.train(&pairs);
504 assert_eq!(tuner.training_history().len(), 3);
505 }
506
507 #[test]
509 fn test_step_empty_pairs() {
510 let mut tuner = triplet_tuner();
511 let loss = tuner.step(&[]);
512 assert_eq!(loss, 0.0);
513 }
514
515 #[test]
517 fn test_train_empty_pairs() {
518 let mut tuner = FineTuner::new(FinetuneConfig {
519 epochs: 3,
520 ..Default::default()
521 });
522 let losses = tuner.train(&[]);
523 assert_eq!(losses.len(), 3);
524 assert!(losses.iter().all(|&l| l == 0.0));
525 }
526
527 #[test]
529 fn test_cosine_similarity_identical() {
530 let a = unit_x();
531 let sim = cosine_similarity(&a, &a);
532 assert!((sim - 1.0).abs() < EPS);
533 }
534
535 #[test]
537 fn test_cosine_similarity_orthogonal() {
538 let sim = cosine_similarity(&unit_x(), &unit_y());
539 assert!(sim.abs() < EPS);
540 }
541
542 #[test]
544 fn test_cosine_similarity_antiparallel() {
545 let a = vec![1.0, 0.0];
546 let b = vec![-1.0, 0.0];
547 let sim = cosine_similarity(&a, &b);
548 assert!((sim + 1.0).abs() < EPS);
549 }
550
551 #[test]
553 fn test_cosine_similarity_zero_vector() {
554 let a = vec![1.0, 0.0];
555 let b = zeros(2);
556 let sim = cosine_similarity(&a, &b);
557 assert_eq!(sim, 0.0);
558 }
559
560 #[test]
562 fn test_l2_normalize_unit_vector() {
563 let v = vec![3.0, 4.0];
564 let n = l2_normalize(&v);
565 assert!((n[0] - 0.6).abs() < EPS);
566 assert!((n[1] - 0.8).abs() < EPS);
567 }
568
569 #[test]
571 fn test_l2_normalize_already_unit() {
572 let v = unit_x();
573 let n = l2_normalize(&v);
574 assert!((l2_norm(&n) - 1.0).abs() < EPS);
575 }
576
577 #[test]
579 fn test_l2_normalize_zero_vector() {
580 let v = zeros(3);
581 let n = l2_normalize(&v);
582 assert_eq!(n, zeros(3));
583 }
584
585 #[test]
587 fn test_normalized_cosine_similarity() {
588 let v = vec![3.0, 4.0];
589 let n = l2_normalize(&v);
590 let sim = cosine_similarity(&n, &n);
591 assert!((sim - 1.0).abs() < EPS);
592 }
593
594 #[test]
596 fn test_training_step_epoch() {
597 let mut tuner = FineTuner::new(FinetuneConfig {
598 epochs: 1,
599 ..Default::default()
600 });
601 let pairs = vec![EmbeddingPair::with_negative(ones(2), ones(2), zeros(2))];
602 tuner.train(&pairs);
603 assert_eq!(tuner.training_history()[0].epoch, 0);
604 }
605
606 #[test]
608 fn test_training_step_step_index() {
609 let mut tuner = FineTuner::new(FinetuneConfig {
610 epochs: 3,
611 ..Default::default()
612 });
613 let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
614 tuner.train(&pairs);
615 let steps: Vec<usize> = tuner.training_history().iter().map(|s| s.step).collect();
616 assert_eq!(steps, vec![0, 1, 2]);
617 }
618
619 #[test]
621 fn test_contrastive_loss_train() {
622 let mut tuner = FineTuner::new(FinetuneConfig {
623 epochs: 2,
624 loss_type: LossType::Contrastive,
625 ..Default::default()
626 });
627 let pairs = vec![EmbeddingPair::with_negative(
628 vec![0.0, 0.0],
629 vec![0.1, 0.0],
630 vec![2.0, 0.0],
631 )];
632 let losses = tuner.train(&pairs);
633 assert_eq!(losses.len(), 2);
634 assert!(losses.iter().all(|&l| l >= 0.0));
635 }
636
637 #[test]
639 fn test_cosine_loss_train() {
640 let mut tuner = FineTuner::new(FinetuneConfig {
641 epochs: 2,
642 loss_type: LossType::CosineSimilarity,
643 ..Default::default()
644 });
645 let pairs = vec![EmbeddingPair::without_negative(unit_x(), unit_y())];
646 let losses = tuner.train(&pairs);
647 assert!(losses.iter().all(|&l| l >= 0.0));
648 }
649
650 #[test]
652 fn test_step_positive_loss() {
653 let mut tuner = triplet_tuner();
654 let pairs = vec![EmbeddingPair::with_negative(
655 vec![0.0, 0.0],
656 vec![0.5, 0.0],
657 vec![0.1, 0.0], )];
659 let loss = tuner.step(&pairs);
660 assert!(loss >= 0.0);
661 }
662
663 #[test]
665 fn test_total_steps_after_train() {
666 let mut tuner = FineTuner::new(FinetuneConfig {
667 epochs: 4,
668 ..Default::default()
669 });
670 let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
671 tuner.train(&pairs);
672 assert_eq!(tuner.total_steps(), 4);
673 }
674
675 #[test]
677 fn test_step_plus_train_accumulate() {
678 let mut tuner = FineTuner::new(FinetuneConfig {
679 epochs: 3,
680 ..Default::default()
681 });
682 let pairs = vec![EmbeddingPair::without_negative(ones(2), zeros(2))];
683 tuner.step(&pairs);
684 tuner.train(&pairs);
685 assert_eq!(tuner.total_steps(), 4);
687 }
688
689 #[test]
691 fn test_finetune_config_default() {
692 let cfg = FinetuneConfig::default();
693 assert_eq!(cfg.epochs, 10);
694 assert_eq!(cfg.loss_type, LossType::Triplet);
695 }
696
697 #[test]
699 fn test_embedding_pair_with_negative() {
700 let p = EmbeddingPair::with_negative(vec![1.0], vec![2.0], vec![3.0]);
701 assert!(p.negative.is_some());
702 }
703
704 #[test]
706 fn test_embedding_pair_without_negative() {
707 let p = EmbeddingPair::without_negative(vec![1.0], vec![2.0]);
708 assert!(p.negative.is_none());
709 }
710
711 #[test]
713 fn test_triplet_pair_no_negative_zero_loss() {
714 let mut tuner = triplet_tuner();
715 let pairs = vec![EmbeddingPair::without_negative(ones(4), zeros(4))];
716 let loss = tuner.step(&pairs);
717 assert_eq!(loss, 0.0);
718 }
719
720 #[test]
722 fn test_cosine_similarity_clamped() {
723 let a = vec![1.0, 0.0, 0.0];
725 let b = vec![1.0, 1e-7, 0.0];
726 let sim = cosine_similarity(&a, &b);
727 assert!((-1.0..=1.0).contains(&sim));
728 }
729
730 #[test]
732 fn test_l2_norm_345() {
733 let v = vec![3.0, 4.0];
734 assert!((l2_norm(&v) - 5.0).abs() < EPS);
735 }
736
737 #[test]
739 fn test_loss_recorded_in_history() {
740 let mut tuner = FineTuner::new(FinetuneConfig {
741 epochs: 5,
742 loss_type: LossType::CosineSimilarity,
743 ..Default::default()
744 });
745 let pairs = vec![EmbeddingPair::without_negative(unit_x(), unit_y())];
746 tuner.train(&pairs);
747 assert!(tuner.training_history().iter().all(|s| s.loss >= 0.0));
748 }
749
750 #[test]
752 fn test_loss_types_differ() {
753 let pairs = vec![EmbeddingPair::with_negative(
754 vec![0.0, 0.0],
755 vec![0.5, 0.0],
756 vec![0.2, 0.0],
757 )];
758 let mut t1 = FineTuner::new(FinetuneConfig {
759 epochs: 1,
760 loss_type: LossType::Triplet,
761 ..Default::default()
762 });
763 let mut t2 = FineTuner::new(FinetuneConfig {
764 epochs: 1,
765 loss_type: LossType::Contrastive,
766 ..Default::default()
767 });
768 let l1 = t1.step(&pairs);
769 let l2 = t2.step(&pairs);
770 assert!(l1 >= 0.0);
772 assert!(l2 >= 0.0);
773 }
774
775 #[test]
777 fn test_high_dimensional_embeddings() {
778 let dim = 768;
779 let anchor: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
780 let positive: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
781 let negative: Vec<f32> = vec![-1.0; dim];
782 let tuner = triplet_tuner();
783 let loss = tuner.compute_triplet_loss(&anchor, &positive, &negative);
784 assert!(loss >= 0.0);
785 }
786
787 #[test]
789 fn test_cosine_loss_zero_when_exact() {
790 let tuner = cosine_tuner();
791 let a = unit_x();
792 let b = unit_x();
793 let sim = cosine_similarity(&a, &b); let loss = tuner.compute_cosine_loss(&a, &b, sim);
795 assert!(loss.abs() < EPS);
796 }
797
798 #[test]
800 fn test_train_large_batch_size() {
801 let mut tuner = FineTuner::new(FinetuneConfig {
802 batch_size: 512,
803 epochs: 2,
804 ..Default::default()
805 });
806 let pairs: Vec<_> = (0..100)
807 .map(|_| EmbeddingPair::without_negative(ones(16), zeros(16)))
808 .collect();
809 let losses = tuner.train(&pairs);
810 assert_eq!(losses.len(), 2);
811 }
812
813 #[test]
815 fn test_total_steps_initially_zero() {
816 let tuner = triplet_tuner();
817 assert_eq!(tuner.total_steps(), 0);
818 }
819}