use crate::batch::{Batch, Token};
use crate::model::Model;
use crate::vocab::Vocab;
use slm_inference::core::shared_ptr::{Free, SharedPtr};
use slm_inference::errors::{
BatchError, ContextBuilderError, ContextError, DecodeError, SamplingError,
};
use slm_inference::{SlmConstraint, SlmContext, SlmContextBuilder, SlmEditLevel, SlmKvType, SlmPos};
#[derive(Clone)]
struct LlamaContextFree;
impl Free<slm_ikllama_sys::llama_context> for LlamaContextFree {
#[inline(never)]
unsafe fn free(ptr: *mut slm_ikllama_sys::llama_context) {
unsafe { slm_ikllama_sys::llama_free(ptr) };
}
}
#[derive(Clone)]
pub struct Context {
vocab_ptr: *const slm_ikllama_sys::llama_vocab,
n_batch: u32,
n_vocab: usize,
ctx: SharedPtr<slm_ikllama_sys::llama_context, LlamaContextFree>,
edit_level: SlmEditLevel,
temperature: f32,
top_k: i32,
top_p: f32,
#[allow(dead_code)]
model: Model,
}
impl SlmContext for Context {
type Token = Token;
type Batch = Batch;
type Vocab = Vocab;
fn vocab(&self) -> &Self::Vocab {
self.model.vocab()
}
fn new_batch(&self, tokens: usize, sequences: usize) -> Result<Batch, BatchError> {
Batch::new(tokens, sequences)
}
fn max_batch_len(&self) -> usize {
self.n_batch as usize
}
#[inline(never)]
fn decode(&mut self, batch: &mut Batch) -> Result<(), DecodeError> {
let result =
unsafe { slm_ikllama_sys::llama_decode(self.ctx.get_ptr(), batch.llama_batch) };
if result != 0 {
return Err(DecodeError::from(result));
}
Ok(())
}
#[inline(never)]
fn sample_with_constraint(&mut self, logit_idx: usize, constraint: Option<&mut dyn SlmConstraint>) -> Result<Option<Self::Token>, SamplingError> {
unsafe {
let ctx = self.ctx.get_ptr();
let logits_ptr = slm_ikllama_sys::llama_get_logits_ith(ctx, logit_idx as i32);
let logits = std::slice::from_raw_parts_mut(logits_ptr, self.n_vocab);
if let Some(c) = constraint {
if !c.mask(logits)? {
return Ok(None);
}
}
let mut candidates_vec: Vec<slm_ikllama_sys::llama_token_data> = (0..self.n_vocab)
.map(|id| slm_ikllama_sys::llama_token_data {
id: id as slm_ikllama_sys::llama_token,
logit: logits[id],
p: 0.0,
})
.collect();
let mut candidates_array = slm_ikllama_sys::llama_token_data_array {
data: candidates_vec.as_mut_ptr(),
size: self.n_vocab,
selected: 0,
sorted: false,
};
let token = if self.temperature <= 0.0 {
slm_ikllama_sys::llama_sample_token_greedy(ctx, &mut candidates_array)
} else {
slm_ikllama_sys::llama_sample_top_k(ctx, &mut candidates_array, self.top_k, 1);
slm_ikllama_sys::llama_sample_temp(ctx, &mut candidates_array, self.temperature);
slm_ikllama_sys::llama_sample_softmax(ctx, &mut candidates_array);
slm_ikllama_sys::llama_sample_top_p(ctx, &mut candidates_array, self.top_p, 1);
slm_ikllama_sys::llama_sample_token(ctx, &mut candidates_array)
};
if slm_ikllama_sys::llama_vocab_is_eog(self.vocab_ptr, token) {
Ok(None)
} else {
Ok(Some(token.into()))
}
}
}
#[inline(never)]
fn clear(&mut self) -> Result<(), ContextError> {
let ctx = self.ctx.get_ptr();
unsafe { slm_ikllama_sys::llama_kv_cache_clear(ctx) };
Ok(())
}
fn drop(&mut self, fork_id: usize) -> Result<(), ContextError> {
let ctx = self.ctx.get_ptr();
unsafe { slm_ikllama_sys::llama_kv_cache_seq_rm(ctx, fork_id as i32, -1, -1) };
Ok(())
}
#[inline(never)]
fn truncate(&mut self, pos: &SlmPos) -> Result<SlmPos, ContextError> {
let SlmPos { token_pos, fork_id } = *pos;
let ctx = self.ctx.get_ptr();
unsafe {
slm_ikllama_sys::llama_kv_cache_seq_rm(ctx, fork_id as i32, token_pos as i32, -1)
};
Ok(SlmPos::new(token_pos, fork_id))
}
#[inline(never)]
fn cut(&mut self, start_pos: &SlmPos, end_pos: &SlmPos) -> Result<SlmPos, ContextError> {
if start_pos.fork_id != end_pos.fork_id {
return Err(ContextError::Error(
"positions must have the same fork_id".to_string(),
));
}
if start_pos.token_pos < end_pos.token_pos {
return Err(ContextError::Error(
"start_pos must be before end_pos".to_string(),
));
}
let ctx = self.ctx.get_ptr();
unsafe {
slm_ikllama_sys::llama_kv_cache_seq_rm(
ctx,
start_pos.fork_id as i32,
start_pos.token_pos as i32,
end_pos.token_pos as i32 - 1,
);
}
let pos_n = end_pos.token_pos - start_pos.token_pos;
unsafe {
slm_ikllama_sys::llama_kv_cache_seq_add(
ctx,
start_pos.fork_id as i32,
end_pos.token_pos as i32,
-1,
pos_n as i32,
);
}
let next_pos = unsafe {
slm_ikllama_sys::llama_kv_cache_seq_pos_max(ctx, start_pos.fork_id as i32) + 1
};
Ok(SlmPos::new(next_pos as usize, start_pos.fork_id))
}
fn dump(&mut self) -> Result<Vec<u8>, ContextError> {
todo!()
}
fn restore(&mut self, _data: Vec<u8>) -> Result<(), ContextError> {
todo!()
}
fn edit_level(&self) -> SlmEditLevel {
self.edit_level
}
}
pub struct Builder {
model: Model,
params: slm_ikllama_sys::llama_context_params,
temperature: f32,
top_k: i32,
top_p: f32,
}
#[repr(u32)]
#[allow(dead_code)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum KVType {
Q4_0 = slm_ikllama_sys::GGML_TYPE_Q4_0,
Q5_0 = slm_ikllama_sys::GGML_TYPE_Q5_0,
Q6_0 = slm_ikllama_sys::GGML_TYPE_Q6_0,
Q8_0 = slm_ikllama_sys::GGML_TYPE_Q8_0,
F16 = slm_ikllama_sys::GGML_TYPE_F16,
F32 = slm_ikllama_sys::GGML_TYPE_F32,
}
impl KVType {
pub fn from(t: SlmKvType) -> Option<(KVType, bool)> {
match t {
SlmKvType::Q4 => Some((KVType::Q4_0, true)),
SlmKvType::Q5 => Some((KVType::Q5_0, true)),
SlmKvType::Q6 => Some((KVType::Q6_0, true)),
SlmKvType::Q8 => Some((KVType::Q8_0, true)),
SlmKvType::RawQ8 => Some((KVType::Q8_0, false)),
SlmKvType::F16 => Some((KVType::F16, false)),
SlmKvType::F32 => Some((KVType::F32, false)),
}
}
}
impl Builder {
#[inline(never)]
pub fn new(model: Model) -> Self {
Self {
model,
params: unsafe { slm_ikllama_sys::llama_context_default_params() },
temperature: 0.0,
top_k: 0,
top_p: 0.0,
}
}
#[allow(dead_code)]
pub fn with_flash_attn(mut self) -> Self {
self.params.flash_attn = true;
self
}
#[allow(dead_code)]
pub fn with_kv_hadamard(mut self, k: bool, v: bool) -> Self {
self.params.flash_attn = true;
self.params.k_cache_hadamard = k;
self.params.v_cache_hadamard = v;
self
}
#[allow(dead_code)]
pub fn with_type_kv(mut self, type_k: KVType, type_v: KVType) -> Self {
self.params.flash_attn = true;
if type_k == KVType::Q4_0 || type_k == KVType::Q5_0 {
self.params.k_cache_hadamard = true;
}
if type_k == KVType::Q4_0 || type_k == KVType::Q5_0 {
self.params.v_cache_hadamard = true;
}
self.params.type_k = type_k as u32;
self.params.type_v = type_v as u32;
self
}
}
impl SlmContextBuilder<Context> for Builder {
#[inline(never)]
fn build(mut self) -> Result<Context, ContextBuilderError> {
let ctx =
unsafe { slm_ikllama_sys::llama_init_from_model(self.model.get_ptr()?, self.params) };
let model_ptr = self.model.get_const_ptr()?;
let vocab_ptr = unsafe { slm_ikllama_sys::llama_model_get_vocab(model_ptr) };
let n_vocab = unsafe { slm_ikllama_sys::llama_n_vocab(model_ptr) } as usize;
let edit_level = SlmEditLevel::Cut;
Ok(Context {
ctx: SharedPtr::new(ctx),
vocab_ptr,
n_batch: self.params.n_batch,
n_vocab,
edit_level,
temperature: self.temperature,
top_k: self.top_k,
top_p: self.top_p,
model: self.model,
})
}
#[inline(never)]
fn with_sampler(mut self, temperature: f32, top_k: i32, top_p: f32) -> Self {
self.temperature = temperature;
self.top_k = top_k;
self.top_p = top_p;
self
}
#[inline(never)]
fn with_n_ctx(mut self, n_ctx: usize) -> Self {
self.params.n_ctx = n_ctx as u32;
self
}
#[inline(never)]
fn with_n_batch(mut self, n_batch: usize) -> Self {
self.params.n_batch = n_batch as u32;
self
}
#[inline(never)]
fn with_gen_type_kv(mut self, k: SlmKvType, v: SlmKvType) -> Self {
let (k, kh) = KVType::from(k).unwrap();
let (v, vh) = KVType::from(v).unwrap();
self.params.flash_attn = true;
self.params.type_k = k as u32;
self.params.k_cache_hadamard = kh;
self.params.type_v = v as u32;
self.params.v_cache_hadamard = vh;
self
}
#[inline(never)]
fn with_flash_attn(mut self, enable: bool) -> Self {
self.params.flash_attn = enable;
self
}
}