pub struct ContrastiveSearcher;Expand description
Stateless struct providing all contrastive search decoding primitives.
All methods are free functions (associated functions) taking only explicit
inputs; there is no mutable state. The generation loops in Self::decode and
Self::decode_logits_only do maintain local state internally but expose only a
pure interface.
Implementations§
Source§impl ContrastiveSearcher
impl ContrastiveSearcher
Sourcepub fn cosine_similarity(a: &[f32], b: &[f32]) -> SeqResult<f32>
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> SeqResult<f32>
Cosine similarity between two equal-length, non-empty vectors.
Returns (a·b) / (‖a‖ · ‖b‖ + ε) where ε = 1e-12. If either
vector is the zero vector the function returns 0.0 (not NaN).
§Errors
SeqError::EmptyInputif either slice is empty.SeqError::LengthMismatchifa.len() != b.len().
Sourcepub fn degeneration_penalty(
context_hiddens: &[f32],
n_context: usize,
candidate_hidden: &[f32],
hidden_dim: usize,
) -> SeqResult<f32>
pub fn degeneration_penalty( context_hiddens: &[f32], n_context: usize, candidate_hidden: &[f32], hidden_dim: usize, ) -> SeqResult<f32>
Degeneration penalty for a candidate token given context hidden states.
The penalty is the maximum cosine similarity between
candidate_hidden and any of the n_context previously-seen context
hidden states packed row-major in context_hiddens.
When n_context == 0 (no prior context), the penalty is 0.0.
§Parameters
context_hiddens— flat[n_context × hidden_dim]buffer.n_context— number of prior context tokens.candidate_hidden—[hidden_dim]hidden state for this candidate.hidden_dim— dimension of each hidden-state vector.
§Errors
SeqError::ShapeMismatchifcontext_hiddens.len() != n_context * hidden_dim.SeqError::LengthMismatchifcandidate_hidden.len() != hidden_dim.SeqError::EmptyInputifhidden_dim == 0.
Sourcepub fn top_k_candidates(
logits: &[f32],
k: usize,
) -> SeqResult<Vec<(usize, f32)>>
pub fn top_k_candidates( logits: &[f32], k: usize, ) -> SeqResult<Vec<(usize, f32)>>
Select the top-k tokens by logit value, returning (token_id, prob)
pairs sorted by softmax probability in descending order.
Softmax is computed over all logits for correct probability values;
only the top-k are returned. If k >= vocab_size, all tokens are
returned.
§Errors
SeqError::EmptyInputiflogitsis empty.SeqError::InvalidConfigurationifk == 0.
Sourcepub fn contrastive_score(prob: f32, degen_penalty: f32, alpha: f32) -> f32
pub fn contrastive_score(prob: f32, degen_penalty: f32, alpha: f32) -> f32
Combine model probability and degeneration penalty into a contrastive score.
score = (1 − alpha) · prob − alpha · degen_penaltySourcepub fn decode<F>(
initial_logits: &[f32],
initial_hiddens: &[f32],
vocab_size: usize,
hidden_dim: usize,
step_fn: F,
cfg: &ContrastiveConfig,
) -> SeqResult<Vec<usize>>
pub fn decode<F>( initial_logits: &[f32], initial_hiddens: &[f32], vocab_size: usize, hidden_dim: usize, step_fn: F, cfg: &ContrastiveConfig, ) -> SeqResult<Vec<usize>>
Run contrastive search decoding for cfg.max_len steps using a
caller-provided step function that also returns hidden states.
§Step function contract
step_fn(selected_token_id: usize, last_hidden: &[f32])
-> (logits: Vec<f32>, next_hidden: Vec<f32>)logits must have length vocab_size; next_hidden must have length
hidden_dim and represents the hidden state for the selected token.
§Initial hidden states
initial_hiddens must be a flat [vocab_size × hidden_dim] buffer
providing one hidden state per vocabulary token for the very first step.
If your model produces a single shared hidden state at step 0 rather
than one per token, replicate it vocab_size times before calling.
§Errors
SeqError::InvalidConfigurationifk == 0oralpha ∉ [0, 1].SeqError::ShapeMismatchifinitial_logits.len() != vocab_sizeorinitial_hiddens.len() != vocab_size * hidden_dim.SeqError::EmptyInputifvocab_size == 0orhidden_dim == 0.
Sourcepub fn decode_logits_only<F>(
initial_logits: &[f32],
step_fn: F,
cfg: &ContrastiveConfig,
) -> SeqResult<Vec<usize>>
pub fn decode_logits_only<F>( initial_logits: &[f32], step_fn: F, cfg: &ContrastiveConfig, ) -> SeqResult<Vec<usize>>
Simplified contrastive search that works with a logit-only step function.
Because hidden states are unavailable, the past logit vectors serve
as proxy hidden states: the degeneration penalty for a candidate at step
t is the maximum cosine similarity between the current logit vector
and each of the past logit vectors stored in the context.
The degeneration is therefore shared across all candidates at each step (it measures how similar the new distribution is to past ones), which implicitly penalises repetitive distributions.
§Step function contract
step_fn(selected_token_id: usize) -> logits: Vec<f32> [vocab_size]§Errors
SeqError::InvalidConfigurationifk == 0oralpha ∉ [0, 1].SeqError::EmptyInputifinitial_logitsis empty.