1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
//! Create a sampler struct to encapsulate the sampling process. This allows passing all the possible
//! sampling parameters around as a single struct, and also allow late binding of expensive context
//! like [`crate::context::LlamaContext`] or token history to the sampler.
//!
//! # Example
//!
//! **Llama.cpp default sampler**
//!
//! ```rust
//! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep};
//! use llama_cpp_2::token::data::LlamaTokenData;
//! use llama_cpp_2::token::data_array::LlamaTokenDataArray;
//! use llama_cpp_2::token::LlamaToken;
//!
//! // Sample a token greedily and add to the history.
//! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec<LlamaToken>| {
//! canidates.sample_softmax(None);
//! let token = canidates.data[0];
//! history.push(token.id());
//! vec![token]
//! };
//!
//! let mut history = vec![];
//! let mut sampler = Sampler::new(finalizer);
//!
//! sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0));
//! sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1));
//! sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1));
//! sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1));
//! sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1));
//! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1));
//! sampler.push_step(&|c, _| c.sample_temp(None, 0.5));
//!
//! // random candidates
//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false);
//!
//! for _ in 0..10 {
//! let tokens = sampler.sample(&mut history, candidates.clone());
//! assert_eq!(tokens.len(), 1);
//! }
//!
//! assert_eq!(history.len(), 10);
//! ```
use crate::token::data::LlamaTokenData;
use crate::token::data_array::LlamaTokenDataArray;
use std::fmt::{Debug, Formatter};
/// A single step to sample tokens from the remaining candidates.
pub type SampleStep<C> = dyn Fn(&mut LlamaTokenDataArray, &mut C);
/// The final step to select tokens from the remaining candidates.
pub type SampleFinalizer<C> = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec<LlamaTokenData>;
/// A series of sampling steps that will produce a vector of token data.
///
/// `C` is dynamic context that will be passed to the sampling functions. Some sampling steps may
/// require state to be maintained across multiple samples, and this context can be used to store
/// that state. For example, [`LlamaTokenDataArray::sample_token_mirostat_v2`] requires a `mu` to be
/// shared across multiple samples.
pub struct Sampler<'a, C> {
/// The steps to take when sampling.
pub steps: Vec<&'a SampleStep<C>>,
/// The final step to select one or more tokens from the remaining candidates.
pub finalizer: &'a SampleFinalizer<C>,
}
impl<T> Debug for Sampler<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sampler")
.field(
"steps",
&format!(
"{} steps of Box<dyn FnMut(&mut LlamaTokenDataArray) -> ()>",
&self.steps.len()
),
)
.field(
"finalizer",
&"Box<dyn FnMut(LlamaTokenDataArray) -> Vec<LlamaTokenData>>",
)
.finish()
}
}
impl<'a, T> Sampler<'a, T> {
/// Create a new sampler with a given finalizer.
pub fn new(finalizer: &'a SampleFinalizer<T>) -> Self {
Self {
steps: vec![],
finalizer,
}
}
/// Adds a step to the sampler.
pub fn push_step(&mut self, step: &'a SampleStep<T>) {
self.steps.push(step);
}
/// Sample a token from the given candidates.
#[must_use]
pub fn sample(
&mut self,
context: &mut T,
mut candidates: LlamaTokenDataArray,
) -> Vec<LlamaTokenData> {
for step in &self.steps {
step(&mut candidates, context);
}
(self.finalizer)(candidates, context)
}
}