use std::ptr::addr_of_mut;
use llama_cpp_sys::{
llama_context, llama_grammar_accept_token, llama_sample_entropy, llama_sample_grammar,
llama_sample_min_p, llama_sample_repetition_penalties, llama_sample_tail_free,
llama_sample_temp, llama_sample_token, llama_sample_token_greedy, llama_sample_token_mirostat,
llama_sample_token_mirostat_v2, llama_sample_top_k, llama_sample_top_p, llama_sample_typical,
llama_token, llama_token_data_array,
};
use crate::{grammar::LlamaGrammar, Sampler, Token};
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum SamplerStage {
Temperature(f32),
DynamicTemperature {
min_temp: f32,
max_temp: f32,
exponent_val: f32,
},
RepetitionPenalty {
repetition_penalty: f32,
frequency_penalty: f32,
presence_penalty: f32,
last_n: i32,
},
TopP(f32),
MinP(f32),
TopK(i32),
Typical(f32),
TailFree(f32),
Grammar(GrammarStage),
}
impl SamplerStage {
pub fn from_grammar(grammar: LlamaGrammar, start_position: Option<usize>) -> Self {
SamplerStage::Grammar(GrammarStage {
grammar,
accepted_up_to: start_position,
})
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn apply(
&mut self,
context: *mut llama_context,
tokens: &[Token],
mut candidates_p: llama_token_data_array,
min_keep: usize,
) -> llama_token_data_array {
let p_ptr = addr_of_mut!(candidates_p);
unsafe {
match self {
SamplerStage::RepetitionPenalty {
repetition_penalty,
frequency_penalty,
presence_penalty,
last_n,
} => {
let last_n = if *last_n < 0 {
tokens.len()
} else {
tokens.len().min(*last_n as usize)
};
llama_sample_repetition_penalties(
context,
p_ptr,
tokens[tokens.len() - last_n..].as_ptr() as *const llama_token,
last_n,
*repetition_penalty,
*frequency_penalty,
*presence_penalty,
);
}
SamplerStage::Temperature(temp) => {
if *temp == 0.0 {
llama_sample_top_k(context, p_ptr, 1, 1);
} else {
llama_sample_temp(context, p_ptr, *temp);
}
}
SamplerStage::DynamicTemperature {
min_temp,
max_temp,
exponent_val,
} => {
llama_sample_entropy(context, p_ptr, *min_temp, *max_temp, *exponent_val);
}
SamplerStage::TopP(top_p) => {
llama_sample_top_p(context, p_ptr, *top_p, min_keep);
}
SamplerStage::MinP(min_p) => {
llama_sample_min_p(context, p_ptr, *min_p, min_keep);
}
SamplerStage::TopK(top_k) => {
llama_sample_top_k(context, p_ptr, *top_k, min_keep);
}
SamplerStage::Typical(p) => {
llama_sample_typical(context, p_ptr, *p, min_keep);
}
SamplerStage::TailFree(z) => {
llama_sample_tail_free(context, p_ptr, *z, min_keep);
}
SamplerStage::Grammar(stage) => {
candidates_p = stage.apply(context, tokens, candidates_p, min_keep)
}
}
}
candidates_p
}
}
#[derive(Clone, Debug)]
pub struct GrammarStage {
grammar: LlamaGrammar,
accepted_up_to: Option<usize>,
}
impl GrammarStage {
fn apply(
&mut self,
context: *mut llama_context,
tokens: &[Token],
mut candidates_p: llama_token_data_array,
_min_keep: usize,
) -> llama_token_data_array {
let accepted_up_to = self.accepted_up_to.unwrap_or(tokens.len());
for token in &tokens[accepted_up_to..] {
unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) }
}
self.accepted_up_to = Some(tokens.len());
let p_ptr = addr_of_mut!(candidates_p);
unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) };
candidates_p
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
enum TokenSelector {
Softmax,
Greedy,
Mirostat { tau: f32, eta: f32, m: i32, mu: f32 },
MirostatV2 { tau: f32, eta: f32, mu: f32 },
}
impl TokenSelector {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn select(
&mut self,
context: *mut llama_context,
mut candidates_p: llama_token_data_array,
) -> Token {
unsafe {
let p_ptr = addr_of_mut!(candidates_p);
let id = match self {
TokenSelector::Softmax => llama_sample_token(context, p_ptr),
TokenSelector::Greedy => llama_sample_token_greedy(context, p_ptr),
TokenSelector::Mirostat { tau, eta, m, mu } => {
llama_sample_token_mirostat(context, p_ptr, *tau, *eta, *m, addr_of_mut!(*mu))
}
TokenSelector::MirostatV2 { tau, eta, mu } => {
llama_sample_token_mirostat_v2(context, p_ptr, *tau, *eta, addr_of_mut!(*mu))
}
};
Token(id)
}
}
}
#[derive(Clone, Debug)]
pub struct StandardSampler {
stages: Vec<SamplerStage>,
min_keep: usize,
token_selector: TokenSelector,
}
impl StandardSampler {
pub fn new_softmax(stages: Vec<SamplerStage>, min_keep: usize) -> StandardSampler {
StandardSampler {
stages,
min_keep,
token_selector: TokenSelector::Softmax,
}
}
pub fn new_greedy() -> StandardSampler {
StandardSampler {
stages: Vec::new(),
min_keep: 0,
token_selector: TokenSelector::Greedy,
}
}
pub fn new_mirostat(
stages: Vec<SamplerStage>,
min_keep: usize,
tau: f32,
eta: f32,
m: i32,
) -> StandardSampler {
StandardSampler {
stages,
min_keep,
token_selector: TokenSelector::Mirostat {
tau,
eta,
m,
mu: 2.0 * tau,
},
}
}
pub fn new_mirostat_v2(
stages: Vec<SamplerStage>,
min_keep: usize,
tau: f32,
eta: f32,
) -> StandardSampler {
StandardSampler {
stages,
min_keep,
token_selector: TokenSelector::MirostatV2 {
tau,
eta,
mu: 2.0 * tau,
},
}
}
}
impl Default for StandardSampler {
fn default() -> Self {
Self {
stages: vec![
SamplerStage::RepetitionPenalty {
repetition_penalty: 1.1,
frequency_penalty: 0.0,
presence_penalty: 0.0,
last_n: 64,
},
SamplerStage::TopK(40),
SamplerStage::TopP(0.95),
SamplerStage::MinP(0.05),
SamplerStage::Temperature(0.8),
],
min_keep: 1,
token_selector: TokenSelector::Softmax,
}
}
}
impl Sampler for StandardSampler {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn sample(
&mut self,
context: *mut llama_context,
tokens: &[Token],
mut candidates_p: llama_token_data_array,
) -> Token {
let min_keep = self.min_keep.max(1);
for stage in &mut self.stages {
candidates_p = stage.apply(context, tokens, candidates_p, min_keep);
}
self.token_selector.select(context, candidates_p)
}
}