use crate::{context::Context, error::MullamaError, model::Model, sys, token::TokenId};
use std::{ffi::CString, sync::Arc};
pub struct Sampler {
sampler_ptr: *mut sys::llama_sampler,
_model: Option<Arc<Model>>, }
impl Sampler {
fn from_ptr(
sampler_ptr: *mut sys::llama_sampler,
model: Option<Arc<Model>>,
name: &str,
) -> Result<Self, MullamaError> {
if sampler_ptr.is_null() {
Err(MullamaError::SamplingError(format!(
"Failed to create {} sampler",
name
)))
} else {
Ok(Self {
sampler_ptr,
_model: model,
})
}
}
pub fn new() -> Result<Self, MullamaError> {
Self::greedy()
}
pub fn greedy() -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_greedy() };
Self::from_ptr(sampler_ptr, None, "greedy")
}
pub fn dist(seed: u32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_dist(seed) };
Self::from_ptr(sampler_ptr, None, "distribution")
}
pub fn top_k(k: i32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_top_k(k) };
Self::from_ptr(sampler_ptr, None, "top-k")
}
pub fn top_p(p: f32, min_keep: usize) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_top_p(p, min_keep) };
Self::from_ptr(sampler_ptr, None, "top-p")
}
pub fn min_p(p: f32, min_keep: usize) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_min_p(p, min_keep) };
Self::from_ptr(sampler_ptr, None, "min-p")
}
pub fn typical(p: f32, min_keep: usize) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_typical(p, min_keep) };
Self::from_ptr(sampler_ptr, None, "typical")
}
pub fn temperature(temperature: f32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_temp(temperature) };
Self::from_ptr(sampler_ptr, None, "temperature")
}
pub fn temperature_ext(
temperature: f32,
delta: f32,
exponent: f32,
) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_temp_ext(temperature, delta, exponent) };
Self::from_ptr(sampler_ptr, None, "temperature-ext")
}
pub fn mirostat(
model: Arc<Model>,
seed: u32,
tau: f32,
eta: f32,
m: i32,
) -> Result<Self, MullamaError> {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(model.as_ptr()) };
let sampler_ptr = unsafe { sys::llama_sampler_init_mirostat(vocab_ptr, seed, tau, eta, m) };
Self::from_ptr(sampler_ptr, Some(model), "mirostat")
}
pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
Self::from_ptr(sampler_ptr, None, "mirostat-v2")
}
pub fn grammar(
model: Arc<Model>,
grammar_str: &str,
grammar_root: &str,
) -> Result<Self, MullamaError> {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(model.as_ptr()) };
let c_grammar_str = CString::new(grammar_str)
.map_err(|_| MullamaError::SamplingError("Invalid grammar string".to_string()))?;
let c_grammar_root = CString::new(grammar_root)
.map_err(|_| MullamaError::SamplingError("Invalid grammar root".to_string()))?;
let sampler_ptr = unsafe {
sys::llama_sampler_init_grammar(
vocab_ptr,
c_grammar_str.as_ptr(),
c_grammar_root.as_ptr(),
)
};
Self::from_ptr(sampler_ptr, Some(model), "grammar")
}
pub fn penalties(
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe {
sys::llama_sampler_init_penalties(
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
)
};
Self::from_ptr(sampler_ptr, None, "penalties")
}
pub fn logit_bias(n_vocab: i32, logit_biases: &[LogitBias]) -> Result<Self, MullamaError> {
let sys_biases: Vec<sys::llama_logit_bias> = logit_biases
.iter()
.map(|bias| sys::llama_logit_bias {
token: bias.token as sys::llama_token,
bias: bias.bias,
})
.collect();
let sampler_ptr = unsafe {
sys::llama_sampler_init_logit_bias(
n_vocab,
sys_biases.len() as i32,
sys_biases.as_ptr(),
)
};
Self::from_ptr(sampler_ptr, None, "logit-bias")
}
pub fn softmax() -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_softmax() };
Self::from_ptr(sampler_ptr, None, "softmax")
}
pub fn top_n_sigma(n: f32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_top_n_sigma(n) };
Self::from_ptr(sampler_ptr, None, "top-n-sigma")
}
pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Result<Self, MullamaError> {
let sampler_ptr = unsafe { sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
Self::from_ptr(sampler_ptr, None, "xtc")
}
pub fn dry(
model: Arc<Model>,
n_ctx_train: i32,
multiplier: f32,
base: f32,
allowed_length: i32,
penalty_last_n: i32,
seq_breakers: &[&str],
) -> Result<Self, MullamaError> {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(model.as_ptr()) };
let c_strings: Vec<CString> = seq_breakers
.iter()
.filter_map(|s| CString::new(*s).ok())
.collect();
let c_ptrs: Vec<*const i8> = c_strings.iter().map(|s| s.as_ptr()).collect();
let sampler_ptr = unsafe {
sys::llama_sampler_init_dry(
vocab_ptr,
n_ctx_train,
multiplier,
base,
allowed_length,
penalty_last_n,
c_ptrs.as_ptr(),
c_ptrs.len(),
)
};
Self::from_ptr(sampler_ptr, Some(model), "dry")
}
pub fn infill(model: Arc<Model>) -> Result<Self, MullamaError> {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(model.as_ptr()) };
let sampler_ptr = unsafe { sys::llama_sampler_init_infill(vocab_ptr) };
Self::from_ptr(sampler_ptr, Some(model), "infill")
}
pub fn sample(&mut self, context: &mut Context, idx: i32) -> TokenId {
let token = unsafe { sys::llama_sampler_sample(self.sampler_ptr, context.as_ptr(), idx) };
token as TokenId
}
pub fn accept(&mut self, token: TokenId) {
unsafe {
sys::llama_sampler_accept(self.sampler_ptr, token as sys::llama_token);
}
}
pub fn apply(&mut self, candidates: &mut TokenDataArray) {
unsafe {
sys::llama_sampler_apply(self.sampler_ptr, &mut candidates.inner);
}
}
pub fn reset(&mut self) {
unsafe {
sys::llama_sampler_reset(self.sampler_ptr);
}
}
pub fn try_clone(&self) -> Result<Self, MullamaError> {
let cloned_ptr = unsafe { sys::llama_sampler_clone(self.sampler_ptr) };
if cloned_ptr.is_null() {
return Err(MullamaError::SamplingError(
"Failed to clone sampler".to_string(),
));
}
Ok(Self {
sampler_ptr: cloned_ptr,
_model: self._model.clone(),
})
}
pub fn name(&self) -> String {
let name_ptr = unsafe { sys::llama_sampler_name(self.sampler_ptr) };
if name_ptr.is_null() {
return "unknown".to_string();
}
unsafe {
std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string()
}
}
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut sys::llama_sampler {
self.sampler_ptr
}
}
impl Drop for Sampler {
fn drop(&mut self) {
if !self.sampler_ptr.is_null() {
unsafe {
sys::llama_sampler_free(self.sampler_ptr);
}
}
}
}
pub struct SamplerChain {
chain_ptr: *mut sys::llama_sampler,
}
impl SamplerChain {
pub fn new(params: SamplerChainParams) -> Self {
let sys_params = sys::llama_sampler_chain_params {
no_perf: params.no_perf as sys::c_bool,
};
let chain_ptr = unsafe { sys::llama_sampler_chain_init(sys_params) };
Self { chain_ptr }
}
pub fn with_defaults() -> Self {
Self::new(SamplerChainParams::default())
}
pub fn add(&mut self, sampler: Sampler) {
unsafe {
sys::llama_sampler_chain_add(self.chain_ptr, sampler.sampler_ptr);
}
std::mem::forget(sampler);
}
pub fn get(&self, index: i32) -> Option<*mut sys::llama_sampler> {
let sampler_ptr = unsafe { sys::llama_sampler_chain_get(self.chain_ptr, index) };
if sampler_ptr.is_null() {
None
} else {
Some(sampler_ptr)
}
}
pub fn len(&self) -> i32 {
unsafe { sys::llama_sampler_chain_n(self.chain_ptr) }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn remove(&mut self, index: i32) -> Option<Sampler> {
let removed_ptr = unsafe { sys::llama_sampler_chain_remove(self.chain_ptr, index) };
if removed_ptr.is_null() {
None
} else {
Some(Sampler {
sampler_ptr: removed_ptr,
_model: None, })
}
}
pub fn sample(&mut self, context: &mut Context, idx: i32) -> TokenId {
let token = unsafe { sys::llama_sampler_sample(self.chain_ptr, context.as_ptr(), idx) };
token as TokenId
}
pub fn accept(&mut self, token: TokenId) {
unsafe {
sys::llama_sampler_accept(self.chain_ptr, token as sys::llama_token);
}
}
pub fn reset(&mut self) {
unsafe {
sys::llama_sampler_reset(self.chain_ptr);
}
}
pub fn get_seed(&self) -> u32 {
unsafe { sys::llama_sampler_get_seed(self.chain_ptr) }
}
pub fn perf_data(&self) -> SamplerPerfData {
let data = unsafe { sys::llama_perf_sampler(self.chain_ptr) };
SamplerPerfData {
t_sample_ms: data.t_sample_ms,
n_sample: data.n_sample,
}
}
pub fn perf_print(&self) {
unsafe {
sys::llama_perf_sampler_print(self.chain_ptr);
}
}
pub fn perf_reset(&mut self) {
unsafe {
sys::llama_perf_sampler_reset(self.chain_ptr);
}
}
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut sys::llama_sampler {
self.chain_ptr
}
}
impl Drop for SamplerChain {
fn drop(&mut self) {
if !self.chain_ptr.is_null() {
unsafe {
sys::llama_sampler_free(self.chain_ptr);
}
}
}
}
impl Default for SamplerChain {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone, Default)]
pub struct SamplerChainParams {
pub no_perf: bool,
}
#[derive(Debug, Clone)]
pub struct LogitBias {
pub token: TokenId,
pub bias: f32,
}
#[repr(C)]
#[derive(Debug, Clone)]
pub struct TokenData {
pub id: TokenId,
pub logit: f32,
pub p: f32,
}
#[repr(C, align(64))]
#[derive(Debug, Clone, Copy)]
pub struct AlignedTokenData {
pub id: TokenId,
pub logit: f32,
pub p: f32,
_padding: [u8; 52],
}
impl AlignedTokenData {
#[inline]
pub fn new(id: TokenId, logit: f32, p: f32) -> Self {
Self {
id,
logit,
p,
_padding: [0u8; 52],
}
}
#[inline]
pub fn from_token_data(data: &TokenData) -> Self {
Self::new(data.id, data.logit, data.p)
}
#[inline]
pub fn to_token_data(&self) -> TokenData {
TokenData {
id: self.id,
logit: self.logit,
p: self.p,
}
}
}
impl Default for AlignedTokenData {
fn default() -> Self {
Self::new(0, 0.0, 0.0)
}
}
impl From<TokenData> for AlignedTokenData {
fn from(data: TokenData) -> Self {
Self::from_token_data(&data)
}
}
impl From<AlignedTokenData> for TokenData {
fn from(data: AlignedTokenData) -> Self {
data.to_token_data()
}
}
#[repr(C, align(64))]
pub struct AlignedTokenDataArray {
data: Vec<AlignedTokenData>,
selected: i64,
sorted: bool,
}
impl AlignedTokenDataArray {
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
selected: -1,
sorted: false,
}
}
pub fn from_candidates(candidates: &[TokenData]) -> Self {
let data = candidates
.iter()
.map(AlignedTokenData::from_token_data)
.collect();
Self {
data,
selected: -1,
sorted: false,
}
}
#[inline]
pub fn push(&mut self, id: TokenId, logit: f32, p: f32) {
self.data.push(AlignedTokenData::new(id, logit, p));
self.sorted = false;
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn as_slice(&self) -> &[AlignedTokenData] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [AlignedTokenData] {
&mut self.data
}
pub fn sort_by_logit(&mut self) {
self.data.sort_by(|a, b| {
b.logit
.partial_cmp(&a.logit)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.sorted = true;
}
pub fn top_k(&self, k: usize) -> &[AlignedTokenData] {
let end = k.min(self.data.len());
&self.data[..end]
}
pub fn selected(&self) -> Option<&AlignedTokenData> {
if self.selected >= 0 && (self.selected as usize) < self.data.len() {
Some(&self.data[self.selected as usize])
} else {
None
}
}
pub fn set_selected(&mut self, index: usize) {
self.selected = index as i64;
}
pub fn to_token_data_vec(&self) -> Vec<TokenData> {
self.data.iter().map(|d| d.to_token_data()).collect()
}
}
pub struct TokenDataArray {
inner: sys::llama_token_data_array,
_data: Vec<sys::llama_token_data>, }
impl TokenDataArray {
pub fn new(candidates: Vec<TokenData>) -> Self {
let mut data: Vec<sys::llama_token_data> = candidates
.iter()
.map(|candidate| sys::llama_token_data {
id: candidate.id as sys::llama_token,
logit: candidate.logit,
p: candidate.p,
})
.collect();
let inner = sys::llama_token_data_array {
data: data.as_mut_ptr(),
size: data.len(),
selected: -1,
sorted: false as sys::c_bool,
};
Self { inner, _data: data }
}
pub fn len(&self) -> usize {
self.inner.size
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn selected(&self) -> Option<usize> {
if self.inner.selected >= 0 {
Some(self.inner.selected as usize)
} else {
None
}
}
pub fn is_sorted(&self) -> bool {
self.inner.sorted
}
pub fn candidates(&self) -> &[TokenData] {
unsafe { std::slice::from_raw_parts(self.inner.data as *const TokenData, self.inner.size) }
}
}
#[derive(Debug, Clone)]
pub struct SamplerPerfData {
pub t_sample_ms: f64,
pub n_sample: i32,
}
#[derive(Debug, Clone)]
pub struct SamplerParams {
pub temperature: f32,
pub top_k: i32,
pub top_p: f32,
pub min_p: f32,
pub typical_p: f32,
pub penalty_repeat: f32,
pub penalty_freq: f32,
pub penalty_present: f32,
pub penalty_last_n: i32,
pub penalize_nl: bool,
pub ignore_eos: bool,
pub seed: u32,
}
impl Default for SamplerParams {
fn default() -> Self {
Self {
temperature: 0.8,
top_k: 40,
top_p: 0.95,
min_p: 0.05,
typical_p: 1.0,
penalty_repeat: 1.1,
penalty_freq: 0.0,
penalty_present: 0.0,
penalty_last_n: 64,
penalize_nl: true,
ignore_eos: false,
seed: sys::LLAMA_DEFAULT_SEED,
}
}
}
impl SamplerParams {
pub fn build_chain(&self, _model: Arc<Model>) -> Result<SamplerChain, MullamaError> {
let mut chain = SamplerChain::default();
if self.penalty_repeat != 1.0 || self.penalty_freq != 0.0 || self.penalty_present != 0.0 {
let penalties = Sampler::penalties(
self.penalty_last_n,
self.penalty_repeat,
self.penalty_freq,
self.penalty_present,
)?;
chain.add(penalties);
}
if self.top_k > 0 {
chain.add(Sampler::top_k(self.top_k)?);
}
if self.typical_p < 1.0 && self.typical_p > 0.0 {
chain.add(Sampler::typical(self.typical_p, 1)?);
}
if self.top_p < 1.0 {
chain.add(Sampler::top_p(self.top_p, 1)?);
}
if self.min_p > 0.0 {
chain.add(Sampler::min_p(self.min_p, 1)?);
}
if self.temperature > 0.0 {
chain.add(Sampler::temperature(self.temperature)?);
}
chain.add(Sampler::dist(self.seed)?);
Ok(chain)
}
}
unsafe impl Send for Sampler {}
unsafe impl Sync for Sampler {}
unsafe impl Send for SamplerChain {}
unsafe impl Sync for SamplerChain {}