Skip to main content

SamplingOps

Trait SamplingOps 

Source
pub trait SamplingOps: Send + Sync {
    // Required methods
    fn sample_token(
        &self,
        logits: &TensorRef,
        params: &SamplingParams,
    ) -> Result<u32>;
    fn argmax(&self, logits: &TensorRef) -> Result<u32>;
}
Expand description

Token sampling operations (GPU-side when possible).

Required Methods§

Source

fn sample_token( &self, logits: &TensorRef, params: &SamplingParams, ) -> Result<u32>

Sample a single token from logits using the full sampling pipeline.

logits shape: [vocab_size] or [1, vocab_size] (last-token logits).

Source

fn argmax(&self, logits: &TensorRef) -> Result<u32>

Greedy argmax over the last dimension.

Implementors§