Skip to main content

ContrastiveSearcher

Struct ContrastiveSearcher 

Source
pub struct ContrastiveSearcher;
Expand description

Stateless struct providing all contrastive search decoding primitives.

All methods are free functions (associated functions) taking only explicit inputs; there is no mutable state. The generation loops in Self::decode and Self::decode_logits_only do maintain local state internally but expose only a pure interface.

Implementations§

Source§

impl ContrastiveSearcher

Source

pub fn cosine_similarity(a: &[f32], b: &[f32]) -> SeqResult<f32>

Cosine similarity between two equal-length, non-empty vectors.

Returns (a·b) / (‖a‖ · ‖b‖ + ε) where ε = 1e-12. If either vector is the zero vector the function returns 0.0 (not NaN).

§Errors
Source

pub fn degeneration_penalty( context_hiddens: &[f32], n_context: usize, candidate_hidden: &[f32], hidden_dim: usize, ) -> SeqResult<f32>

Degeneration penalty for a candidate token given context hidden states.

The penalty is the maximum cosine similarity between candidate_hidden and any of the n_context previously-seen context hidden states packed row-major in context_hiddens.

When n_context == 0 (no prior context), the penalty is 0.0.

§Parameters
  • context_hiddens — flat [n_context × hidden_dim] buffer.
  • n_context — number of prior context tokens.
  • candidate_hidden[hidden_dim] hidden state for this candidate.
  • hidden_dim — dimension of each hidden-state vector.
§Errors
Source

pub fn top_k_candidates( logits: &[f32], k: usize, ) -> SeqResult<Vec<(usize, f32)>>

Select the top-k tokens by logit value, returning (token_id, prob) pairs sorted by softmax probability in descending order.

Softmax is computed over all logits for correct probability values; only the top-k are returned. If k >= vocab_size, all tokens are returned.

§Errors
Source

pub fn contrastive_score(prob: f32, degen_penalty: f32, alpha: f32) -> f32

Combine model probability and degeneration penalty into a contrastive score.

score = (1 − alpha) · prob − alpha · degen_penalty
Source

pub fn decode<F>( initial_logits: &[f32], initial_hiddens: &[f32], vocab_size: usize, hidden_dim: usize, step_fn: F, cfg: &ContrastiveConfig, ) -> SeqResult<Vec<usize>>
where F: Fn(usize, &[f32]) -> (Vec<f32>, Vec<f32>),

Run contrastive search decoding for cfg.max_len steps using a caller-provided step function that also returns hidden states.

§Step function contract
step_fn(selected_token_id: usize, last_hidden: &[f32])
    -> (logits: Vec<f32>, next_hidden: Vec<f32>)

logits must have length vocab_size; next_hidden must have length hidden_dim and represents the hidden state for the selected token.

§Initial hidden states

initial_hiddens must be a flat [vocab_size × hidden_dim] buffer providing one hidden state per vocabulary token for the very first step. If your model produces a single shared hidden state at step 0 rather than one per token, replicate it vocab_size times before calling.

§Errors
Source

pub fn decode_logits_only<F>( initial_logits: &[f32], step_fn: F, cfg: &ContrastiveConfig, ) -> SeqResult<Vec<usize>>
where F: Fn(usize) -> Vec<f32>,

Simplified contrastive search that works with a logit-only step function.

Because hidden states are unavailable, the past logit vectors serve as proxy hidden states: the degeneration penalty for a candidate at step t is the maximum cosine similarity between the current logit vector and each of the past logit vectors stored in the context.

The degeneration is therefore shared across all candidates at each step (it measures how similar the new distribution is to past ones), which implicitly penalises repetitive distributions.

§Step function contract
step_fn(selected_token_id: usize) -> logits: Vec<f32>   [vocab_size]
§Errors

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.