mod chain;
mod custom;
mod grammar;
mod strategies;
use llama_crab_sys as sys;
use crate::token::LlamaToken;
use crate::token_data::LlamaTokenDataArray;
pub use chain::SamplerChain;
#[cfg(feature = "common")]
pub use grammar::GrammarError;
#[derive(Debug)]
pub struct LlamaSampler {
handle: *mut sys::llama_sampler,
}
impl LlamaSampler {
#[allow(dead_code)]
pub(crate) unsafe fn from_raw(ptr: *mut sys::llama_sampler) -> Self {
Self { handle: ptr }
}
#[allow(dead_code)]
pub(crate) unsafe fn from_raw_borrowed(ptr: *mut sys::llama_sampler) -> Self {
Self { handle: ptr }
}
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut sys::llama_sampler {
self.handle
}
pub unsafe fn sample(&mut self, ctx: *mut sys::llama_context, idx: i32) -> LlamaToken {
let raw = unsafe { sys::llama_sampler_sample(self.handle, ctx, idx) };
LlamaToken(raw)
}
pub fn apply(&self, candidates: &mut LlamaTokenDataArray) {
unsafe {
sys::llama_sampler_apply(self.handle, candidates.as_mut_ptr());
}
}
pub fn reset(&mut self) {
unsafe { sys::llama_sampler_reset(self.handle) };
}
pub fn accept(&mut self, token: LlamaToken) {
unsafe { sys::llama_sampler_accept(self.handle, token.0) };
}
#[must_use]
pub fn get_seed(&self) -> u32 {
unsafe { sys::llama_sampler_get_seed(self.handle) }
}
#[must_use]
pub fn chain(samplers: Vec<LlamaSampler>, no_perf: bool) -> Option<Self> {
let mut chain_params = unsafe { sys::llama_sampler_chain_default_params() };
chain_params.no_perf = no_perf;
let chain = unsafe { sys::llama_sampler_chain_init(chain_params) };
if chain.is_null() {
return None;
}
for mut s in samplers {
unsafe { sys::llama_sampler_chain_add(chain, s.handle) };
s.handle = std::ptr::null_mut();
}
Some(unsafe { Self::from_raw(chain) })
}
}
impl Drop for LlamaSampler {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { sys::llama_sampler_free(self.handle) };
}
}
}