Skip to main content

oxicuda_seq/decoders/
contrastive.rs

1//! Contrastive search decoding (Su et al. 2022 ACL).
2//!
3//! Reference: Su, Y., Lan, T., Wang, Y., Yatbaz, H. Y., & Xu, X. (2022).
4//! *A Contrastive Framework for Neural Text Generation*. ACL 2022.
5//! <https://aclanthology.org/2022.acl-long.365/>.
6//!
7//! # Background
8//!
9//! Standard stochastic decoders (nucleus, top-k) address the *degeneration
10//! problem* (repetition, incoherence) only partially; at low temperatures they
11//! tend to loop, while at high temperatures they produce incoherent text.
12//!
13//! Contrastive search is **deterministic** and addresses degeneration through
14//! an explicit penalty on the cosine similarity between the candidate token's
15//! hidden state and any of the previous context hidden states:
16//!
17//! ```text
18//! score(v, t) = (1 − α) · model_prob(v | context)
19//!             − α · max_{j < t} cos_sim(h_v, h_{context_j})
20//! ```
21//!
22//! The first term rewards tokens with high model probability; the second term
23//! penalises tokens whose hidden-state representation is too similar to those
24//! already present in the context (i.e., tokens that would cause degeneration).
25//! The hyperparameter `α ∈ [0, 1]` controls the trade-off.
26//!
27//! # Provided API
28//!
29//! * [`ContrastiveConfig`] — configuration struct (`k`, `alpha`, `max_len`).
30//! * [`ContrastiveSearcher`] — stateless struct with all algorithm steps as
31//!   associated functions.
32//! * [`ContrastiveSearcher::cosine_similarity`] — `(a·b) / (‖a‖·‖b‖ + ε)`.
33//! * [`ContrastiveSearcher::degeneration_penalty`] — max cos-sim to any
34//!   previous context hidden state.
35//! * [`ContrastiveSearcher::top_k_candidates`] — select top-k logit indices,
36//!   return `(token_id, softmax_prob)` pairs sorted descending by prob.
37//! * [`ContrastiveSearcher::contrastive_score`] — scalar score combination.
38//! * [`ContrastiveSearcher::decode`] — full generation loop with explicit
39//!   hidden states from a step function.
40//! * [`ContrastiveSearcher::decode_logits_only`] — generation with logit-only
41//!   step function using past logit vectors as proxy hidden states.
42
43use crate::error::{SeqError, SeqResult};
44
45/// Configuration for contrastive search decoding.
46///
47/// # Fields
48///
49/// * `k` — top-k candidates to consider at each step (`≥ 1`).
50/// * `alpha` — degeneration penalty weight `∈ [0, 1]`.  `α = 0` reduces to
51///   greedy decoding; `α = 1` ignores model probability entirely.
52/// * `max_len` — maximum number of tokens to generate.
53#[derive(Debug, Clone, Copy)]
54pub struct ContrastiveConfig {
55    /// Number of top-probability candidates to consider at each step.
56    pub k: usize,
57    /// Degeneration penalty weight ∈ [0, 1].
58    pub alpha: f32,
59    /// Maximum generation length (number of tokens produced).
60    pub max_len: usize,
61}
62
63impl Default for ContrastiveConfig {
64    fn default() -> Self {
65        Self {
66            k: 5,
67            alpha: 0.6,
68            max_len: 50,
69        }
70    }
71}
72
73impl ContrastiveConfig {
74    /// Validate the configuration fields.
75    fn validate(&self) -> SeqResult<()> {
76        if self.k == 0 {
77            return Err(SeqError::InvalidConfiguration(
78                "contrastive: k must be >= 1".to_string(),
79            ));
80        }
81        if !self.alpha.is_finite() || self.alpha < 0.0 || self.alpha > 1.0 {
82            return Err(SeqError::InvalidConfiguration(format!(
83                "contrastive: alpha must be in [0, 1], got {}",
84                self.alpha
85            )));
86        }
87        Ok(())
88    }
89}
90
91/// Stateless struct providing all contrastive search decoding primitives.
92///
93/// All methods are free functions (associated functions) taking only explicit
94/// inputs; there is no mutable state.  The generation loops in [`Self::decode`] and
95/// [`Self::decode_logits_only`] do maintain local state internally but expose only a
96/// pure interface.
97pub struct ContrastiveSearcher;
98
99impl ContrastiveSearcher {
100    /// Cosine similarity between two equal-length, non-empty vectors.
101    ///
102    /// Returns `(a·b) / (‖a‖ · ‖b‖ + ε)` where `ε = 1e-12`.  If either
103    /// vector is the zero vector the function returns `0.0` (not `NaN`).
104    ///
105    /// # Errors
106    ///
107    /// * [`SeqError::EmptyInput`] if either slice is empty.
108    /// * [`SeqError::LengthMismatch`] if `a.len() != b.len()`.
109    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> SeqResult<f32> {
110        if a.is_empty() || b.is_empty() {
111            return Err(SeqError::EmptyInput);
112        }
113        if a.len() != b.len() {
114            return Err(SeqError::LengthMismatch {
115                a: a.len(),
116                b: b.len(),
117            });
118        }
119        let mut dot = 0.0_f32;
120        let mut norm_a = 0.0_f32;
121        let mut norm_b = 0.0_f32;
122        for (x, y) in a.iter().zip(b.iter()) {
123            dot += x * y;
124            norm_a += x * x;
125            norm_b += y * y;
126        }
127        let denom = norm_a.sqrt() * norm_b.sqrt() + 1e-12_f32;
128        Ok(dot / denom)
129    }
130
131    /// Degeneration penalty for a candidate token given context hidden states.
132    ///
133    /// The penalty is the **maximum** cosine similarity between
134    /// `candidate_hidden` and any of the `n_context` previously-seen context
135    /// hidden states packed row-major in `context_hiddens`.
136    ///
137    /// When `n_context == 0` (no prior context), the penalty is `0.0`.
138    ///
139    /// # Parameters
140    ///
141    /// * `context_hiddens` — flat `[n_context × hidden_dim]` buffer.
142    /// * `n_context` — number of prior context tokens.
143    /// * `candidate_hidden` — `[hidden_dim]` hidden state for this candidate.
144    /// * `hidden_dim` — dimension of each hidden-state vector.
145    ///
146    /// # Errors
147    ///
148    /// * [`SeqError::ShapeMismatch`] if `context_hiddens.len() != n_context * hidden_dim`.
149    /// * [`SeqError::LengthMismatch`] if `candidate_hidden.len() != hidden_dim`.
150    /// * [`SeqError::EmptyInput`] if `hidden_dim == 0`.
151    pub fn degeneration_penalty(
152        context_hiddens: &[f32],
153        n_context: usize,
154        candidate_hidden: &[f32],
155        hidden_dim: usize,
156    ) -> SeqResult<f32> {
157        if hidden_dim == 0 {
158            return Err(SeqError::EmptyInput);
159        }
160        if context_hiddens.len() != n_context * hidden_dim {
161            return Err(SeqError::ShapeMismatch {
162                expected: n_context * hidden_dim,
163                got: context_hiddens.len(),
164            });
165        }
166        if candidate_hidden.len() != hidden_dim {
167            return Err(SeqError::LengthMismatch {
168                a: candidate_hidden.len(),
169                b: hidden_dim,
170            });
171        }
172        if n_context == 0 {
173            return Ok(0.0);
174        }
175
176        let mut max_sim = f32::NEG_INFINITY;
177        for t in 0..n_context {
178            let ctx_slice = &context_hiddens[t * hidden_dim..(t + 1) * hidden_dim];
179            let sim = Self::cosine_similarity(ctx_slice, candidate_hidden)?;
180            if sim > max_sim {
181                max_sim = sim;
182            }
183        }
184        Ok(max_sim)
185    }
186
187    /// Select the top-k tokens by logit value, returning `(token_id, prob)`
188    /// pairs sorted by softmax probability in **descending** order.
189    ///
190    /// Softmax is computed over **all** logits for correct probability values;
191    /// only the top-k are returned.  If `k >= vocab_size`, all tokens are
192    /// returned.
193    ///
194    /// # Errors
195    ///
196    /// * [`SeqError::EmptyInput`] if `logits` is empty.
197    /// * [`SeqError::InvalidConfiguration`] if `k == 0`.
198    pub fn top_k_candidates(logits: &[f32], k: usize) -> SeqResult<Vec<(usize, f32)>> {
199        if logits.is_empty() {
200            return Err(SeqError::EmptyInput);
201        }
202        if k == 0 {
203            return Err(SeqError::InvalidConfiguration(
204                "contrastive: k must be >= 1".to_string(),
205            ));
206        }
207        let vocab = logits.len();
208        let k_eff = k.min(vocab);
209
210        // Partial-sort: find the top-k indices by raw logit (descending).
211        let mut indices: Vec<usize> = (0..vocab).collect();
212        indices.sort_by(|&a, &b| {
213            logits[b]
214                .partial_cmp(&logits[a])
215                .unwrap_or(std::cmp::Ordering::Equal)
216        });
217        indices.truncate(k_eff);
218
219        // Full-vocabulary numerically-stable softmax.
220        let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
221        let mut exps = vec![0.0_f32; vocab];
222        let mut sum = 0.0_f32;
223        for (i, &l) in logits.iter().enumerate() {
224            let e = (l - max_l).exp();
225            exps[i] = e;
226            sum += e;
227        }
228        // Guard against degenerate all-NEG_INFINITY logits.
229        let sum_safe = if sum > 0.0 && sum.is_finite() {
230            sum
231        } else {
232            1.0
233        };
234
235        // Collect (token_id, prob) for top-k candidates.
236        let mut candidates: Vec<(usize, f32)> = indices
237            .iter()
238            .map(|&idx| (idx, exps[idx] / sum_safe))
239            .collect();
240
241        // Sort by probability descending.
242        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
243
244        Ok(candidates)
245    }
246
247    /// Combine model probability and degeneration penalty into a contrastive
248    /// score.
249    ///
250    /// ```text
251    /// score = (1 − alpha) · prob − alpha · degen_penalty
252    /// ```
253    #[inline]
254    pub fn contrastive_score(prob: f32, degen_penalty: f32, alpha: f32) -> f32 {
255        (1.0 - alpha) * prob - alpha * degen_penalty
256    }
257
258    /// Run contrastive search decoding for `cfg.max_len` steps using a
259    /// caller-provided step function that also returns hidden states.
260    ///
261    /// # Step function contract
262    ///
263    /// ```text
264    /// step_fn(selected_token_id: usize, last_hidden: &[f32])
265    ///     -> (logits: Vec<f32>, next_hidden: Vec<f32>)
266    /// ```
267    ///
268    /// `logits` must have length `vocab_size`; `next_hidden` must have length
269    /// `hidden_dim` and represents the hidden state **for the selected token**.
270    ///
271    /// # Initial hidden states
272    ///
273    /// `initial_hiddens` must be a flat `[vocab_size × hidden_dim]` buffer
274    /// providing one hidden state per vocabulary token for the very first step.
275    /// If your model produces a single shared hidden state at step 0 rather
276    /// than one per token, replicate it `vocab_size` times before calling.
277    ///
278    /// # Errors
279    ///
280    /// * [`SeqError::InvalidConfiguration`] if `k == 0` or `alpha ∉ [0, 1]`.
281    /// * [`SeqError::ShapeMismatch`] if `initial_logits.len() != vocab_size` or
282    ///   `initial_hiddens.len() != vocab_size * hidden_dim`.
283    /// * [`SeqError::EmptyInput`] if `vocab_size == 0` or `hidden_dim == 0`.
284    pub fn decode<F>(
285        initial_logits: &[f32],
286        initial_hiddens: &[f32],
287        vocab_size: usize,
288        hidden_dim: usize,
289        step_fn: F,
290        cfg: &ContrastiveConfig,
291    ) -> SeqResult<Vec<usize>>
292    where
293        F: Fn(usize, &[f32]) -> (Vec<f32>, Vec<f32>),
294    {
295        cfg.validate()?;
296        if vocab_size == 0 || hidden_dim == 0 {
297            return Err(SeqError::EmptyInput);
298        }
299        if initial_logits.len() != vocab_size {
300            return Err(SeqError::ShapeMismatch {
301                expected: vocab_size,
302                got: initial_logits.len(),
303            });
304        }
305        if initial_hiddens.len() != vocab_size * hidden_dim {
306            return Err(SeqError::ShapeMismatch {
307                expected: vocab_size * hidden_dim,
308                got: initial_hiddens.len(),
309            });
310        }
311
312        let mut generated: Vec<usize> = Vec::with_capacity(cfg.max_len);
313        // Accumulate context hidden states: [t × hidden_dim]
314        let mut context_hiddens: Vec<f32> = Vec::new();
315
316        // ---- Step 0 ------------------------------------------------
317        let candidates_0 = Self::top_k_candidates(initial_logits, cfg.k)?;
318
319        let mut best_score = f32::NEG_INFINITY;
320        let mut best_token = candidates_0[0].0;
321        let mut best_hidden: Vec<f32> =
322            initial_hiddens[best_token * hidden_dim..(best_token + 1) * hidden_dim].to_vec();
323
324        for (tok, prob) in &candidates_0 {
325            // At step 0 there is no context yet: degeneration penalty is 0.
326            let score = Self::contrastive_score(*prob, 0.0, cfg.alpha);
327            if score > best_score {
328                best_score = score;
329                best_token = *tok;
330                best_hidden = initial_hiddens[tok * hidden_dim..(tok + 1) * hidden_dim].to_vec();
331            }
332        }
333
334        generated.push(best_token);
335        context_hiddens.extend_from_slice(&best_hidden);
336        let mut last_hidden = best_hidden;
337
338        // ---- Steps 1..max_len ----------------------------------------
339        for _step in 1..cfg.max_len {
340            let (next_logits, next_hidden) = step_fn(generated[generated.len() - 1], &last_hidden);
341
342            if next_logits.len() != vocab_size {
343                return Err(SeqError::ShapeMismatch {
344                    expected: vocab_size,
345                    got: next_logits.len(),
346                });
347            }
348            if next_hidden.len() != hidden_dim {
349                return Err(SeqError::ShapeMismatch {
350                    expected: hidden_dim,
351                    got: next_hidden.len(),
352                });
353            }
354
355            let candidates = Self::top_k_candidates(&next_logits, cfg.k)?;
356            let n_ctx = context_hiddens.len() / hidden_dim;
357
358            let mut step_best_score = f32::NEG_INFINITY;
359            let mut step_best_token = candidates[0].0;
360
361            for (tok, prob) in &candidates {
362                // Use next_hidden as the proxy hidden state for all candidates.
363                // In a true implementation each candidate would have its own
364                // forward pass; here we use the single next_hidden (for the
365                // selected token) as the representative for the candidate set.
366                let degen =
367                    Self::degeneration_penalty(&context_hiddens, n_ctx, &next_hidden, hidden_dim)?;
368                let score = Self::contrastive_score(*prob, degen, cfg.alpha);
369                if score > step_best_score {
370                    step_best_score = score;
371                    step_best_token = *tok;
372                }
373            }
374
375            generated.push(step_best_token);
376            context_hiddens.extend_from_slice(&next_hidden);
377            last_hidden = next_hidden;
378        }
379
380        Ok(generated)
381    }
382
383    /// Simplified contrastive search that works with a logit-only step
384    /// function.
385    ///
386    /// Because hidden states are unavailable, the past **logit vectors** serve
387    /// as proxy hidden states: the degeneration penalty for a candidate at step
388    /// `t` is the maximum cosine similarity between the *current* logit vector
389    /// and each of the past logit vectors stored in the context.
390    ///
391    /// The degeneration is therefore shared across all candidates at each step
392    /// (it measures how similar the new distribution is to past ones), which
393    /// implicitly penalises repetitive distributions.
394    ///
395    /// # Step function contract
396    ///
397    /// ```text
398    /// step_fn(selected_token_id: usize) -> logits: Vec<f32>   [vocab_size]
399    /// ```
400    ///
401    /// # Errors
402    ///
403    /// * [`SeqError::InvalidConfiguration`] if `k == 0` or `alpha ∉ [0, 1]`.
404    /// * [`SeqError::EmptyInput`] if `initial_logits` is empty.
405    pub fn decode_logits_only<F>(
406        initial_logits: &[f32],
407        step_fn: F,
408        cfg: &ContrastiveConfig,
409    ) -> SeqResult<Vec<usize>>
410    where
411        F: Fn(usize) -> Vec<f32>,
412    {
413        cfg.validate()?;
414        if initial_logits.is_empty() {
415            return Err(SeqError::EmptyInput);
416        }
417
418        let vocab_size = initial_logits.len();
419        let mut generated: Vec<usize> = Vec::with_capacity(cfg.max_len);
420        // Context logit vectors act as proxy hidden states.
421        // Each entry is one logit vector [vocab_size].
422        let mut context_logits_flat: Vec<f32> = Vec::new();
423
424        // ---- Step 0 ------------------------------------------------
425        let candidates_0 = Self::top_k_candidates(initial_logits, cfg.k)?;
426
427        // At step 0 there is no context; degeneration = 0 for all candidates.
428        let mut best_score = f32::NEG_INFINITY;
429        let mut best_token = candidates_0[0].0;
430        for (tok, prob) in &candidates_0 {
431            let score = Self::contrastive_score(*prob, 0.0, cfg.alpha);
432            if score > best_score {
433                best_score = score;
434                best_token = *tok;
435            }
436        }
437
438        generated.push(best_token);
439        // Store initial_logits as the first context entry.
440        context_logits_flat.extend_from_slice(initial_logits);
441
442        // ---- Steps 1..max_len ----------------------------------------
443        for _step in 1..cfg.max_len {
444            let next_logits = step_fn(generated[generated.len() - 1]);
445
446            if next_logits.is_empty() {
447                return Err(SeqError::EmptyInput);
448            }
449            let cur_vocab = next_logits.len();
450            // Allow vocab to grow/shrink within a step (though unusual), but
451            // we need a consistent dimension for cosine similarity.  Use
452            // the minimum dimension when comparing with past logit vectors.
453            let dim = cur_vocab.min(vocab_size);
454
455            let candidates = Self::top_k_candidates(&next_logits, cfg.k)?;
456            let n_ctx = context_logits_flat.len() / vocab_size;
457
458            // Compute degeneration penalty using the current logit vector as
459            // the "candidate hidden state", compared against past logit vectors.
460            // All candidates share this penalty (distribution-level penalty).
461            let mut degen = 0.0_f32;
462            for t in 0..n_ctx {
463                let ctx_slice = &context_logits_flat[t * vocab_size..t * vocab_size + dim];
464                let cand_slice = &next_logits[..dim];
465                let sim = Self::cosine_similarity(ctx_slice, cand_slice)?;
466                if sim > degen {
467                    degen = sim;
468                }
469            }
470
471            let mut step_best_score = f32::NEG_INFINITY;
472            let mut step_best_token = candidates[0].0;
473            for (tok, prob) in &candidates {
474                let score = Self::contrastive_score(*prob, degen, cfg.alpha);
475                if score > step_best_score {
476                    step_best_score = score;
477                    step_best_token = *tok;
478                }
479            }
480
481            generated.push(step_best_token);
482            // Extend context with the current logit vector (padded/truncated to vocab_size).
483            let mut entry = next_logits.clone();
484            entry.resize(vocab_size, 0.0);
485            context_logits_flat.extend_from_slice(&entry);
486        }
487
488        Ok(generated)
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    // -----------------------------------------------------------------------
497    // cosine_similarity tests
498    // -----------------------------------------------------------------------
499
500    #[test]
501    fn cosine_similarity_identical_vectors_is_one() {
502        let v = vec![1.0_f32, 2.0, 3.0];
503        let sim = ContrastiveSearcher::cosine_similarity(&v, &v).expect("ok");
504        assert!((sim - 1.0).abs() < 1e-5, "got {sim}");
505    }
506
507    #[test]
508    fn cosine_similarity_orthogonal_is_zero() {
509        let a = vec![1.0_f32, 0.0];
510        let b = vec![0.0_f32, 1.0];
511        let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
512        assert!(sim.abs() < 1e-6, "got {sim}");
513    }
514
515    #[test]
516    fn cosine_similarity_zero_vector_is_zero_not_nan() {
517        let a = vec![0.0_f32, 0.0, 0.0];
518        let b = vec![1.0_f32, 2.0, 3.0];
519        let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
520        assert!(!sim.is_nan(), "must not be NaN");
521        assert!(sim.abs() < 1e-6, "got {sim}");
522    }
523
524    #[test]
525    fn cosine_similarity_length_mismatch_error() {
526        let a = vec![1.0_f32, 2.0];
527        let b = vec![1.0_f32, 2.0, 3.0];
528        let err = ContrastiveSearcher::cosine_similarity(&a, &b).unwrap_err();
529        assert!(matches!(err, SeqError::LengthMismatch { .. }));
530    }
531
532    #[test]
533    fn cosine_similarity_empty_error() {
534        let err = ContrastiveSearcher::cosine_similarity(&[], &[]).unwrap_err();
535        assert!(matches!(err, SeqError::EmptyInput));
536    }
537
538    #[test]
539    fn cosine_similarity_negative_vectors() {
540        // Antiparallel vectors should give -1.
541        let a = vec![1.0_f32, 0.0];
542        let b = vec![-1.0_f32, 0.0];
543        let sim = ContrastiveSearcher::cosine_similarity(&a, &b).expect("ok");
544        assert!((sim + 1.0).abs() < 1e-5, "got {sim}");
545    }
546
547    // -----------------------------------------------------------------------
548    // degeneration_penalty tests
549    // -----------------------------------------------------------------------
550
551    #[test]
552    fn degeneration_penalty_no_context_is_zero() {
553        let candidate = vec![1.0_f32, 2.0, 3.0];
554        let pen = ContrastiveSearcher::degeneration_penalty(&[], 0, &candidate, 3).expect("ok");
555        assert!(pen.abs() < 1e-6, "got {pen}");
556    }
557
558    #[test]
559    fn degeneration_penalty_identical_context_is_one() {
560        // If the candidate is identical to a context token, penalty = 1.0.
561        let hidden = vec![1.0_f32, 0.0, 0.0];
562        let context = hidden.clone();
563        let pen = ContrastiveSearcher::degeneration_penalty(&context, 1, &hidden, 3).expect("ok");
564        assert!((pen - 1.0).abs() < 1e-5, "got {pen}");
565    }
566
567    #[test]
568    fn degeneration_penalty_orthogonal_context_is_zero() {
569        let context = vec![1.0_f32, 0.0];
570        let candidate = vec![0.0_f32, 1.0];
571        let pen =
572            ContrastiveSearcher::degeneration_penalty(&context, 1, &candidate, 2).expect("ok");
573        assert!(pen.abs() < 1e-6, "got {pen}");
574    }
575
576    #[test]
577    fn degeneration_penalty_multiple_context_returns_max() {
578        // Two context vectors: first orthogonal, second identical to candidate.
579        let dim = 2usize;
580        let mut ctx = vec![1.0_f32, 0.0]; // orthogonal to candidate
581        ctx.extend_from_slice(&[0.0, 1.0]); // identical to candidate
582        let candidate = vec![0.0_f32, 1.0];
583        let pen = ContrastiveSearcher::degeneration_penalty(&ctx, 2, &candidate, dim).expect("ok");
584        // Max should be ~1.0 (the identical pair).
585        assert!((pen - 1.0).abs() < 1e-5, "got {pen}");
586    }
587
588    #[test]
589    fn degeneration_penalty_shape_mismatch_error() {
590        let err = ContrastiveSearcher::degeneration_penalty(
591            &[1.0, 2.0],
592            2, // expects 2 * 3 = 6 floats
593            &[1.0, 2.0, 3.0],
594            3,
595        )
596        .unwrap_err();
597        assert!(matches!(err, SeqError::ShapeMismatch { .. }));
598    }
599
600    // -----------------------------------------------------------------------
601    // top_k_candidates tests
602    // -----------------------------------------------------------------------
603
604    #[test]
605    fn top_k_k_equals_one_returns_argmax() {
606        let logits = vec![-1.0_f32, 5.0, 2.0, 0.5];
607        let cands = ContrastiveSearcher::top_k_candidates(&logits, 1).expect("ok");
608        assert_eq!(cands.len(), 1);
609        assert_eq!(cands[0].0, 1, "argmax should be token 1");
610    }
611
612    #[test]
613    fn top_k_k_ge_vocab_returns_all() {
614        let logits = vec![1.0_f32, 2.0, 0.5];
615        let cands = ContrastiveSearcher::top_k_candidates(&logits, 100).expect("ok");
616        assert_eq!(cands.len(), 3, "should return all 3 tokens");
617    }
618
619    #[test]
620    fn top_k_probs_are_valid_softmax() {
621        let logits = vec![1.0_f32, 2.0, 0.5, -1.0, 3.0];
622        let cands = ContrastiveSearcher::top_k_candidates(&logits, 3).expect("ok");
623        // Probs should be positive.
624        for (_, prob) in &cands {
625            assert!(*prob > 0.0, "prob must be positive");
626        }
627        // Sum of all vocab probs ≈ 1; we only have top-k, so sum ≤ 1.
628        let partial_sum: f32 = cands.iter().map(|(_, p)| p).sum();
629        assert!(partial_sum <= 1.0 + 1e-5, "partial sum {partial_sum} > 1");
630    }
631
632    #[test]
633    fn top_k_sorted_descending_by_prob() {
634        let logits = vec![1.0_f32, 3.0, 2.0, 0.5];
635        let cands = ContrastiveSearcher::top_k_candidates(&logits, 4).expect("ok");
636        for i in 1..cands.len() {
637            assert!(
638                cands[i - 1].1 >= cands[i].1,
639                "probs should be non-increasing: {:?}",
640                cands
641            );
642        }
643    }
644
645    #[test]
646    fn top_k_empty_logits_error() {
647        let err = ContrastiveSearcher::top_k_candidates(&[], 3).unwrap_err();
648        assert!(matches!(err, SeqError::EmptyInput));
649    }
650
651    #[test]
652    fn top_k_k_zero_error() {
653        let err = ContrastiveSearcher::top_k_candidates(&[1.0, 2.0], 0).unwrap_err();
654        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
655    }
656
657    // -----------------------------------------------------------------------
658    // contrastive_score tests
659    // -----------------------------------------------------------------------
660
661    #[test]
662    fn contrastive_score_alpha_zero_equals_prob() {
663        let score = ContrastiveSearcher::contrastive_score(0.7, 0.9, 0.0);
664        assert!((score - 0.7).abs() < 1e-6, "got {score}");
665    }
666
667    #[test]
668    fn contrastive_score_alpha_one_equals_neg_degen() {
669        let score = ContrastiveSearcher::contrastive_score(0.7, 0.4, 1.0);
670        assert!((score + 0.4).abs() < 1e-6, "got {score}");
671    }
672
673    #[test]
674    fn contrastive_score_midpoint() {
675        let score = ContrastiveSearcher::contrastive_score(0.8, 0.5, 0.5);
676        // (1-0.5)*0.8 - 0.5*0.5 = 0.4 - 0.25 = 0.15
677        assert!((score - 0.15).abs() < 1e-6, "got {score}");
678    }
679
680    // -----------------------------------------------------------------------
681    // decode_logits_only tests
682    // -----------------------------------------------------------------------
683
684    #[test]
685    fn decode_logits_only_length_matches_max_len() {
686        let initial = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
687        let cfg = ContrastiveConfig {
688            k: 3,
689            alpha: 0.5,
690            max_len: 5,
691        };
692        let seq = ContrastiveSearcher::decode_logits_only(
693            &initial,
694            |_tok| vec![1.0_f32, 2.0, 3.0, 4.0, 5.0],
695            &cfg,
696        )
697        .expect("ok");
698        assert_eq!(seq.len(), 5);
699    }
700
701    #[test]
702    fn decode_logits_only_constant_step_fn_valid_tokens() {
703        let initial = vec![0.0_f32, 1.0, -1.0, 2.0];
704        let cfg = ContrastiveConfig {
705            k: 2,
706            alpha: 0.4,
707            max_len: 10,
708        };
709        let seq = ContrastiveSearcher::decode_logits_only(
710            &initial,
711            |_tok| vec![0.0_f32, 1.0, -1.0, 2.0],
712            &cfg,
713        )
714        .expect("ok");
715        assert_eq!(seq.len(), 10);
716        for tok in &seq {
717            assert!(*tok < 4, "token {tok} out of vocab");
718        }
719    }
720
721    #[test]
722    fn decode_logits_only_can_produce_repetition() {
723        // When the step function always returns the same logits and alpha=0
724        // (no degeneration penalty), the argmax (greedy) token is always
725        // the same token — demonstrating that contrastive search CAN repeat
726        // when α=0.
727        let initial = vec![0.0_f32, 5.0, 1.0];
728        let cfg = ContrastiveConfig {
729            k: 1,
730            alpha: 0.0,
731            max_len: 5,
732        };
733        let seq =
734            ContrastiveSearcher::decode_logits_only(&initial, |_tok| vec![0.0_f32, 5.0, 1.0], &cfg)
735                .expect("ok");
736        // With k=1 and alpha=0, must always pick token 1 (greedy).
737        for tok in &seq {
738            assert_eq!(*tok, 1);
739        }
740    }
741
742    #[test]
743    fn decode_logits_only_alpha_reduces_repetition() {
744        // With a high degeneration penalty (alpha close to 1), even a constant
745        // step function should eventually switch to a different token because
746        // distributional similarity accumulates.  We just verify no panic
747        // and that token IDs are within vocab.
748        let vocab = 8usize;
749        let initial: Vec<f32> = (0..vocab).map(|i| i as f32).collect();
750        let cfg = ContrastiveConfig {
751            k: 4,
752            alpha: 0.8,
753            max_len: 20,
754        };
755        let seq = ContrastiveSearcher::decode_logits_only(
756            &initial,
757            |_tok| (0..vocab).map(|i| i as f32).collect(),
758            &cfg,
759        )
760        .expect("ok");
761        assert_eq!(seq.len(), 20);
762        for tok in &seq {
763            assert!(*tok < vocab);
764        }
765    }
766
767    #[test]
768    fn decode_logits_only_k_zero_error() {
769        let cfg = ContrastiveConfig {
770            k: 0,
771            alpha: 0.5,
772            max_len: 5,
773        };
774        let err = ContrastiveSearcher::decode_logits_only(&[1.0, 2.0], |_| vec![1.0, 2.0], &cfg)
775            .unwrap_err();
776        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
777    }
778
779    #[test]
780    fn decode_logits_only_alpha_above_one_error() {
781        let cfg = ContrastiveConfig {
782            k: 3,
783            alpha: 1.5,
784            max_len: 5,
785        };
786        let err = ContrastiveSearcher::decode_logits_only(&[1.0, 2.0], |_| vec![1.0, 2.0], &cfg)
787            .unwrap_err();
788        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
789    }
790
791    #[test]
792    fn decode_logits_only_empty_logits_error() {
793        let cfg = ContrastiveConfig::default();
794        let err = ContrastiveSearcher::decode_logits_only(&[], |_| vec![], &cfg).unwrap_err();
795        assert!(matches!(err, SeqError::EmptyInput));
796    }
797
798    // -----------------------------------------------------------------------
799    // decode (with hidden states) tests
800    // -----------------------------------------------------------------------
801
802    #[test]
803    fn decode_with_hidden_states_length_matches_max_len() {
804        let vocab = 4usize;
805        let hidden_dim = 3usize;
806        let initial_logits = vec![1.0_f32, 2.0, 3.0, 0.5];
807        // One [hidden_dim] vector per vocabulary token.
808        let initial_hiddens: Vec<f32> = (0..vocab * hidden_dim).map(|i| i as f32 * 0.1).collect();
809        let cfg = ContrastiveConfig {
810            k: 2,
811            alpha: 0.5,
812            max_len: 7,
813        };
814
815        let seq = ContrastiveSearcher::decode(
816            &initial_logits,
817            &initial_hiddens,
818            vocab,
819            hidden_dim,
820            |_tok, _last| {
821                let logits = vec![0.5_f32, 1.5, 2.5, 0.2];
822                let hidden = vec![0.1_f32, 0.2, 0.3];
823                (logits, hidden)
824            },
825            &cfg,
826        )
827        .expect("ok");
828        assert_eq!(seq.len(), 7);
829    }
830
831    #[test]
832    fn decode_with_hidden_states_valid_token_ids() {
833        let vocab = 5usize;
834        let hidden_dim = 4usize;
835        let initial_logits: Vec<f32> = vec![1.0, 2.0, 3.0, 0.5, 1.5];
836        let initial_hiddens: Vec<f32> = (0..vocab * hidden_dim).map(|i| (i as f32).sin()).collect();
837        let cfg = ContrastiveConfig {
838            k: 3,
839            alpha: 0.6,
840            max_len: 12,
841        };
842        let seq = ContrastiveSearcher::decode(
843            &initial_logits,
844            &initial_hiddens,
845            vocab,
846            hidden_dim,
847            |_tok, _last| {
848                let logits: Vec<f32> = vec![0.1, 0.5, 2.0, 1.0, 0.3];
849                let hidden: Vec<f32> = vec![0.5, -0.5, 0.3, -0.3];
850                (logits, hidden)
851            },
852            &cfg,
853        )
854        .expect("ok");
855        for tok in &seq {
856            assert!(*tok < vocab, "token {tok} out of range");
857        }
858    }
859
860    #[test]
861    fn decode_empty_vocab_error() {
862        let cfg = ContrastiveConfig::default();
863        let err = ContrastiveSearcher::decode(&[], &[], 0, 4, |_tok, _h| (vec![], vec![]), &cfg)
864            .unwrap_err();
865        assert!(matches!(err, SeqError::EmptyInput));
866    }
867}