1use crate::error::{SeqError, SeqResult};
44
45#[derive(Debug, Clone, Copy)]
54pub struct ContrastiveConfig {
55 pub k: usize,
57 pub alpha: f32,
59 pub max_len: usize,
61}
62
63impl Default for ContrastiveConfig {
64 fn default() -> Self {
65 Self {
66 k: 5,
67 alpha: 0.6,
68 max_len: 50,
69 }
70 }
71}
72
73impl ContrastiveConfig {
74 fn validate(&self) -> SeqResult<()> {
76 if self.k == 0 {
77 return Err(SeqError::InvalidConfiguration(
78 "contrastive: k must be >= 1".to_string(),
79 ));
80 }
81 if !self.alpha.is_finite() || self.alpha < 0.0 || self.alpha > 1.0 {
82 return Err(SeqError::InvalidConfiguration(format!(
83 "contrastive: alpha must be in [0, 1], got {}",
84 self.alpha
85 )));
86 }
87 Ok(())
88 }
89}
90
91pub struct ContrastiveSearcher;
98
99impl ContrastiveSearcher {
100 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> SeqResult<f32> {
110 if a.is_empty() || b.is_empty() {
111 return Err(SeqError::EmptyInput);
112 }
113 if a.len() != b.len() {
114 return Err(SeqError::LengthMismatch {
115 a: a.len(),
116 b: b.len(),
117 });
118 }
119 let mut dot = 0.0_f32;
120 let mut norm_a = 0.0_f32;
121 let mut norm_b = 0.0_f32;
122 for (x, y) in a.iter().zip(b.iter()) {
123 dot += x * y;
124 norm_a += x * x;
125 norm_b += y * y;
126 }
127 let denom = norm_a.sqrt() * norm_b.sqrt() + 1e-12_f32;
128 Ok(dot / denom)
129 }
130
131 pub fn degeneration_penalty(
152 context_hiddens: &[f32],
153 n_context: usize,
154 candidate_hidden: &[f32],
155 hidden_dim: usize,
156 ) -> SeqResult<f32> {
157 if hidden_dim == 0 {
158 return Err(SeqError::EmptyInput);
159 }
160 if context_hiddens.len() != n_context * hidden_dim {
161 return Err(SeqError::ShapeMismatch {
162 expected: n_context * hidden_dim,
163 got: context_hiddens.len(),
164 });
165 }
166 if candidate_hidden.len() != hidden_dim {
167 return Err(SeqError::LengthMismatch {
168 a: candidate_hidden.len(),
169 b: hidden_dim,
170 });
171 }
172 if n_context == 0 {
173 return Ok(0.0);
174 }
175
176 let mut max_sim = f32::NEG_INFINITY;
177 for t in 0..n_context {
178 let ctx_slice = &context_hiddens[t * hidden_dim..(t + 1) * hidden_dim];
179 let sim = Self::cosine_similarity(ctx_slice, candidate_hidden)?;
180 if sim > max_sim {
181 max_sim = sim;
182 }
183 }
184 Ok(max_sim)
185 }
186
187 pub fn top_k_candidates(logits: &[f32], k: usize) -> SeqResult<Vec<(usize, f32)>> {
199 if logits.is_empty() {
200 return Err(SeqError::EmptyInput);
201 }
202 if k == 0 {
203 return Err(SeqError::InvalidConfiguration(
204 "contrastive: k must be >= 1".to_string(),
205 ));
206 }
207 let vocab = logits.len();
208 let k_eff = k.min(vocab);
209
210 let mut indices: Vec<usize> = (0..vocab).collect();
212 indices.sort_by(|&a, &b| {
213 logits[b]
214 .partial_cmp(&logits[a])
215 .unwrap_or(std::cmp::Ordering::Equal)
216 });
217 indices.truncate(k_eff);
218
219 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
221 let mut exps = vec![0.0_f32; vocab];
222 let mut sum = 0.0_f32;
223 for (i, &l) in logits.iter().enumerate() {
224 let e = (l - max_l).exp();
225 exps[i] = e;
226 sum += e;
227 }
228 let sum_safe = if sum > 0.0 && sum.is_finite() {
230 sum
231 } else {
232 1.0
233 };
234
235 let mut candidates: Vec<(usize, f32)> = indices
237 .iter()
238 .map(|&idx| (idx, exps[idx] / sum_safe))
239 .collect();
240
241 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
243
244 Ok(candidates)
245 }
246
247 #[inline]
254 pub fn contrastive_score(prob: f32, degen_penalty: f32, alpha: f32) -> f32 {
255 (1.0 - alpha) * prob - alpha * degen_penalty
256 }
257
258 pub fn decode<F>(
285 initial_logits: &[f32],
286 initial_hiddens: &[f32],
287 vocab_size: usize,
288 hidden_dim: usize,
289 step_fn: F,
290 cfg: &ContrastiveConfig,
291 ) -> SeqResult<Vec<usize>>
292 where
293 F: Fn(usize, &[f32]) -> (Vec<f32>, Vec<f32>),
294 {
295 cfg.validate()?;
296 if vocab_size == 0 || hidden_dim == 0 {
297 return Err(SeqError::EmptyInput);
298 }
299 if initial_logits.len() != vocab_size {
300 return Err(SeqError::ShapeMismatch {
301 expected: vocab_size,
302 got: initial_logits.len(),
303 });
304 }
305 if initial_hiddens.len() != vocab_size * hidden_dim {
306 return Err(SeqError::ShapeMismatch {
307 expected: vocab_size * hidden_dim,
308 got: initial_hiddens.len(),
309 });
310 }
311
312 let mut generated: Vec<usize> = Vec::with_capacity(cfg.max_len);
313 let mut context_hiddens: Vec<f32> = Vec::new();
315
316 let candidates_0 = Self::top_k_candidates(initial_logits, cfg.k)?;
318
319 let mut best_score = f32::NEG_INFINITY;
320 let mut best_token = candidates_0[0].0;
321 let mut best_hidden: Vec<f32> =
322 initial_hiddens[best_token * hidden_dim..(best_token + 1) * hidden_dim].to_vec();
323
324 for (tok, prob) in &candidates_0 {
325 let score = Self::contrastive_score(*prob, 0.0, cfg.alpha);
327 if score > best_score {
328 best_score = score;
329 best_token = *tok;
330 best_hidden = initial_hiddens[tok * hidden_dim..(tok + 1) * hidden_dim].to_vec();
331 }
332 }
333
334 generated.push(best_token);
335 context_hiddens.extend_from_slice(&best_hidden);
336 let mut last_hidden = best_hidden;
337
338 for _step in 1..cfg.max_len {
340 let (next_logits, next_hidden) = step_fn(generated[generated.len() - 1], &last_hidden);
341
342 if next_logits.len() != vocab_size {
343 return Err(SeqError::ShapeMismatch {
344 expected: vocab_size,
345 got: next_logits.len(),
346 });
347 }
348 if next_hidden.len() != hidden_dim {
349 return Err(SeqError::ShapeMismatch {
350 expected: hidden_dim,
351 got: next_hidden.len(),
352 });
353 }
354
355 let candidates = Self::top_k_candidates(&next_logits, cfg.k)?;
356 let n_ctx = context_hiddens.len() / hidden_dim;
357
358 let mut step_best_score = f32::NEG_INFINITY;
359 let mut step_best_token = candidates[0].0;
360
361 for (tok, prob) in &candidates {
362 let degen =
367 Self::degeneration_penalty(&context_hiddens, n_ctx, &next_hidden, hidden_dim)?;
368 let score = Self::contrastive_score(*prob, degen, cfg.alpha);
369 if score > step_best_score {
370 step_best_score = score;
371 step_best_token = *tok;
372 }
373 }
374
375 generated.push(step_best_token);
376 context_hiddens.extend_from_slice(&next_hidden);
377 last_hidden = next_hidden;
378 }
379
380 Ok(generated)
381 }
382
383 pub fn decode_logits_only<F>(
406 initial_logits: &[f32],
407 step_fn: F,
408 cfg: &ContrastiveConfig,
409 ) -> SeqResult<Vec<usize>>
410 where
411 F: Fn(usize) -> Vec<f32>,
412 {
413 cfg.validate()?;
414 if initial_logits.is_empty() {
415 return Err(SeqError::EmptyInput);
416 }
417
418 let vocab_size = initial_logits.len();
419 let mut generated: Vec<usize> = Vec::with_capacity(cfg.max_len);
420 let mut context_logits_flat: Vec<f32> = Vec::new();
423
424 let candidates_0 = Self::top_k_candidates(initial_logits, cfg.k)?;
426
427 let mut best_score = f32::NEG_INFINITY;
429 let mut best_token = candidates_0[0].0;
430 for (tok, prob) in &candidates_0 {
431 let score = Self::contrastive_score(*prob, 0.0, cfg.alpha);
432 if score > best_score {
433 best_score = score;
434 best_token = *tok;
435 }
436 }
437
438 generated.push(best_token);
439 context_logits_flat.extend_from_slice(initial_logits);
441
442 for _step in 1..cfg.max_len {
444 let next_logits = step_fn(generated[generated.len() - 1]);
445
446 if next_logits.is_empty() {
447 return Err(SeqError::EmptyInput);
448 }
449 let cur_vocab = next_logits.len();
450 let dim = cur_vocab.min(vocab_size);
454
455 let candidates = Self::top_k_candidates(&next_logits, cfg.k)?;
456 let n_ctx = context_logits_flat.len() / vocab_size;
457
458 let mut degen = 0.0_f32;
462 for t in 0..n_ctx {
463 let ctx_slice = &context_logits_flat[t * vocab_size..t * vocab_size + dim];
464 let cand_slice = &next_logits[..dim];
465 let sim = Self::cosine_similarity(ctx_slice, cand_slice)?;
466 if sim > degen {
467 degen = sim;
468 }
469 }
470
471 let mut step_best_score = f32::NEG_INFINITY;
472 let mut step_best_token = candidates[0].0;
473 for (tok, prob) in &candidates {
474 let score = Self::contrastive_score(*prob, degen, cfg.alpha);
475 if score > step_best_score {
476 step_best_score = score;
477 step_best_token = *tok;
478 }
479 }
480
481 generated.push(step_best_token);
482 let mut entry = next_logits.clone();
484 entry.resize(vocab_size, 0.0);
485 context_logits_flat.extend_from_slice(&entry);
486 }
487
488 Ok(generated)
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
501 fn cosine_similarity_identical_vectors_is_one() {
502 let v = vec![1.0_f32, 2.0, 3.0];
503 let sim = ContrastiveSearcher::cosine_similarity(&v, &v).expect("ok");
504 assert!((sim - 1.0).abs() < 1e-5, "got {sim}");
505 }
506
507 #[test]
508 fn cosine_similarity_orthogonal_is_zero() {
509 let a = vec![1.0_f32, 0.0];
510 let b = vec![0.0_f32, 1.0];
511 let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
512 assert!(sim.abs() < 1e-6, "got {sim}");
513 }
514
515 #[test]
516 fn cosine_similarity_zero_vector_is_zero_not_nan() {
517 let a = vec![0.0_f32, 0.0, 0.0];
518 let b = vec![1.0_f32, 2.0, 3.0];
519 let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
520 assert!(!sim.is_nan(), "must not be NaN");
521 assert!(sim.abs() < 1e-6, "got {sim}");
522 }
523
524 #[test]
525 fn cosine_similarity_length_mismatch_error() {
526 let a = vec![1.0_f32, 2.0];
527 let b = vec![1.0_f32, 2.0, 3.0];
528 let err = ContrastiveSearcher::cosine_similarity(&a, &b).unwrap_err();
529 assert!(matches!(err, SeqError::LengthMismatch { .. }));
530 }
531
532 #[test]
533 fn cosine_similarity_empty_error() {
534 let err = ContrastiveSearcher::cosine_similarity(&[], &[]).unwrap_err();
535 assert!(matches!(err, SeqError::EmptyInput));
536 }
537
538 #[test]
539 fn cosine_similarity_negative_vectors() {
540 let a = vec![1.0_f32, 0.0];
542 let b = vec![-1.0_f32, 0.0];
543 let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
544 assert!((sim + 1.0).abs() < 1e-5, "got {sim}");
545 }
546
547 #[test]
552 fn degeneration_penalty_no_context_is_zero() {
553 let candidate = vec![1.0_f32, 2.0, 3.0];
554 let pen = ContrastiveSearcher::degeneration_penalty(&[], 0, &candidate, 3).expect("ok");
555 assert!(pen.abs() < 1e-6, "got {pen}");
556 }
557
558 #[test]
559 fn degeneration_penalty_identical_context_is_one() {
560 let hidden = vec![1.0_f32, 0.0, 0.0];
562 let context = hidden.clone();
563 let pen = ContrastiveSearcher::degeneration_penalty(&context, 1, &hidden, 3).expect("ok");
564 assert!((pen - 1.0).abs() < 1e-5, "got {pen}");
565 }
566
567 #[test]
568 fn degeneration_penalty_orthogonal_context_is_zero() {
569 let context = vec![1.0_f32, 0.0];
570 let candidate = vec![0.0_f32, 1.0];
571 let pen =
572 ContrastiveSearcher::degeneration_penalty(&context, 1, &candidate, 2).expect("ok");
573 assert!(pen.abs() < 1e-6, "got {pen}");
574 }
575
576 #[test]
577 fn degeneration_penalty_multiple_context_returns_max() {
578 let dim = 2usize;
580 let mut ctx = vec![1.0_f32, 0.0]; ctx.extend_from_slice(&[0.0, 1.0]); let candidate = vec![0.0_f32, 1.0];
583 let pen = ContrastiveSearcher::degeneration_penalty(&ctx, 2, &candidate, dim).expect("ok");
584 assert!((pen - 1.0).abs() < 1e-5, "got {pen}");
586 }
587
588 #[test]
589 fn degeneration_penalty_shape_mismatch_error() {
590 let err = ContrastiveSearcher::degeneration_penalty(
591 &[1.0, 2.0],
592 2, &[1.0, 2.0, 3.0],
594 3,
595 )
596 .unwrap_err();
597 assert!(matches!(err, SeqError::ShapeMismatch { .. }));
598 }
599
600 #[test]
605 fn top_k_k_equals_one_returns_argmax() {
606 let logits = vec![-1.0_f32, 5.0, 2.0, 0.5];
607 let cands = ContrastiveSearcher::top_k_candidates(&logits, 1).expect("ok");
608 assert_eq!(cands.len(), 1);
609 assert_eq!(cands[0].0, 1, "argmax should be token 1");
610 }
611
612 #[test]
613 fn top_k_k_ge_vocab_returns_all() {
614 let logits = vec![1.0_f32, 2.0, 0.5];
615 let cands = ContrastiveSearcher::top_k_candidates(&logits, 100).expect("ok");
616 assert_eq!(cands.len(), 3, "should return all 3 tokens");
617 }
618
619 #[test]
620 fn top_k_probs_are_valid_softmax() {
621 let logits = vec![1.0_f32, 2.0, 0.5, -1.0, 3.0];
622 let cands = ContrastiveSearcher::top_k_candidates(&logits, 3).expect("ok");
623 for (_, prob) in &cands {
625 assert!(*prob > 0.0, "prob must be positive");
626 }
627 let partial_sum: f32 = cands.iter().map(|(_, p)| p).sum();
629 assert!(partial_sum <= 1.0 + 1e-5, "partial sum {partial_sum} > 1");
630 }
631
632 #[test]
633 fn top_k_sorted_descending_by_prob() {
634 let logits = vec![1.0_f32, 3.0, 2.0, 0.5];
635 let cands = ContrastiveSearcher::top_k_candidates(&logits, 4).expect("ok");
636 for i in 1..cands.len() {
637 assert!(
638 cands[i - 1].1 >= cands[i].1,
639 "probs should be non-increasing: {:?}",
640 cands
641 );
642 }
643 }
644
645 #[test]
646 fn top_k_empty_logits_error() {
647 let err = ContrastiveSearcher::top_k_candidates(&[], 3).unwrap_err();
648 assert!(matches!(err, SeqError::EmptyInput));
649 }
650
651 #[test]
652 fn top_k_k_zero_error() {
653 let err = ContrastiveSearcher::top_k_candidates(&[1.0, 2.0], 0).unwrap_err();
654 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
655 }
656
657 #[test]
662 fn contrastive_score_alpha_zero_equals_prob() {
663 let score = ContrastiveSearcher::contrastive_score(0.7, 0.9, 0.0);
664 assert!((score - 0.7).abs() < 1e-6, "got {score}");
665 }
666
667 #[test]
668 fn contrastive_score_alpha_one_equals_neg_degen() {
669 let score = ContrastiveSearcher::contrastive_score(0.7, 0.4, 1.0);
670 assert!((score + 0.4).abs() < 1e-6, "got {score}");
671 }
672
673 #[test]
674 fn contrastive_score_midpoint() {
675 let score = ContrastiveSearcher::contrastive_score(0.8, 0.5, 0.5);
676 assert!((score - 0.15).abs() < 1e-6, "got {score}");
678 }
679
680 #[test]
685 fn decode_logits_only_length_matches_max_len() {
686 let initial = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
687 let cfg = ContrastiveConfig {
688 k: 3,
689 alpha: 0.5,
690 max_len: 5,
691 };
692 let seq = ContrastiveSearcher::decode_logits_only(
693 &initial,
694 |_tok| vec![1.0_f32, 2.0, 3.0, 4.0, 5.0],
695 &cfg,
696 )
697 .expect("ok");
698 assert_eq!(seq.len(), 5);
699 }
700
701 #[test]
702 fn decode_logits_only_constant_step_fn_valid_tokens() {
703 let initial = vec![0.0_f32, 1.0, -1.0, 2.0];
704 let cfg = ContrastiveConfig {
705 k: 2,
706 alpha: 0.4,
707 max_len: 10,
708 };
709 let seq = ContrastiveSearcher::decode_logits_only(
710 &initial,
711 |_tok| vec![0.0_f32, 1.0, -1.0, 2.0],
712 &cfg,
713 )
714 .expect("ok");
715 assert_eq!(seq.len(), 10);
716 for tok in &seq {
717 assert!(*tok < 4, "token {tok} out of vocab");
718 }
719 }
720
721 #[test]
722 fn decode_logits_only_can_produce_repetition() {
723 let initial = vec![0.0_f32, 5.0, 1.0];
728 let cfg = ContrastiveConfig {
729 k: 1,
730 alpha: 0.0,
731 max_len: 5,
732 };
733 let seq =
734 ContrastiveSearcher::decode_logits_only(&initial, |_tok| vec![0.0_f32, 5.0, 1.0], &cfg)
735 .expect("ok");
736 for tok in &seq {
738 assert_eq!(*tok, 1);
739 }
740 }
741
742 #[test]
743 fn decode_logits_only_alpha_reduces_repetition() {
744 let vocab = 8usize;
749 let initial: Vec<f32> = (0..vocab).map(|i| i as f32).collect();
750 let cfg = ContrastiveConfig {
751 k: 4,
752 alpha: 0.8,
753 max_len: 20,
754 };
755 let seq = ContrastiveSearcher::decode_logits_only(
756 &initial,
757 |_tok| (0..vocab).map(|i| i as f32).collect(),
758 &cfg,
759 )
760 .expect("ok");
761 assert_eq!(seq.len(), 20);
762 for tok in &seq {
763 assert!(*tok < vocab);
764 }
765 }
766
767 #[test]
768 fn decode_logits_only_k_zero_error() {
769 let cfg = ContrastiveConfig {
770 k: 0,
771 alpha: 0.5,
772 max_len: 5,
773 };
774 let err = ContrastiveSearcher::decode_logits_only(&[1.0, 2.0], |_| vec![1.0, 2.0], &cfg)
775 .unwrap_err();
776 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
777 }
778
779 #[test]
780 fn decode_logits_only_alpha_above_one_error() {
781 let cfg = ContrastiveConfig {
782 k: 3,
783 alpha: 1.5,
784 max_len: 5,
785 };
786 let err = ContrastiveSearcher::decode_logits_only(&[1.0, 2.0], |_| vec![1.0, 2.0], &cfg)
787 .unwrap_err();
788 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
789 }
790
791 #[test]
792 fn decode_logits_only_empty_logits_error() {
793 let cfg = ContrastiveConfig::default();
794 let err = ContrastiveSearcher::decode_logits_only(&[], |_| vec![], &cfg).unwrap_err();
795 assert!(matches!(err, SeqError::EmptyInput));
796 }
797
798 #[test]
803 fn decode_with_hidden_states_length_matches_max_len() {
804 let vocab = 4usize;
805 let hidden_dim = 3usize;
806 let initial_logits = vec![1.0_f32, 2.0, 3.0, 0.5];
807 let initial_hiddens: Vec<f32> = (0..vocab * hidden_dim).map(|i| i as f32 * 0.1).collect();
809 let cfg = ContrastiveConfig {
810 k: 2,
811 alpha: 0.5,
812 max_len: 7,
813 };
814
815 let seq = ContrastiveSearcher::decode(
816 &initial_logits,
817 &initial_hiddens,
818 vocab,
819 hidden_dim,
820 |_tok, _last| {
821 let logits = vec![0.5_f32, 1.5, 2.5, 0.2];
822 let hidden = vec![0.1_f32, 0.2, 0.3];
823 (logits, hidden)
824 },
825 &cfg,
826 )
827 .expect("ok");
828 assert_eq!(seq.len(), 7);
829 }
830
831 #[test]
832 fn decode_with_hidden_states_valid_token_ids() {
833 let vocab = 5usize;
834 let hidden_dim = 4usize;
835 let initial_logits: Vec<f32> = vec![1.0, 2.0, 3.0, 0.5, 1.5];
836 let initial_hiddens: Vec<f32> = (0..vocab * hidden_dim).map(|i| (i as f32).sin()).collect();
837 let cfg = ContrastiveConfig {
838 k: 3,
839 alpha: 0.6,
840 max_len: 12,
841 };
842 let seq = ContrastiveSearcher::decode(
843 &initial_logits,
844 &initial_hiddens,
845 vocab,
846 hidden_dim,
847 |_tok, _last| {
848 let logits: Vec<f32> = vec![0.1, 0.5, 2.0, 1.0, 0.3];
849 let hidden: Vec<f32> = vec![0.5, -0.5, 0.3, -0.3];
850 (logits, hidden)
851 },
852 &cfg,
853 )
854 .expect("ok");
855 for tok in &seq {
856 assert!(*tok < vocab, "token {tok} out of range");
857 }
858 }
859
860 #[test]
861 fn decode_empty_vocab_error() {
862 let cfg = ContrastiveConfig::default();
863 let err = ContrastiveSearcher::decode(&[], &[], 0, 4, |_tok, _h| (vec![], vec![]), &cfg)
864 .unwrap_err();
865 assert!(matches!(err, SeqError::EmptyInput));
866 }
867}