use std::ffi::CString;
use std::os::raw::{c_char, c_int};
use std::sync::{Arc, OnceLock};
use slm_inference::errors::{FfiError, StringToTokenError, TokenToStringError};
use slm_inference::{SlmSimpleTokEnv, SlmToken, SlmVocab, SlmTokEnv};
use crate::batch::Token;
pub struct Vocab {
tokenv: OnceLock<toktrie::TokEnv>,
vocab_ptr: *const slm_ikllama_sys::llama_vocab,
model_ptr: *const slm_ikllama_sys::llama_model,
}
impl Vocab {
#[inline(never)]
pub fn new(model_ptr: *const slm_ikllama_sys::llama_model, vocab_ptr: *const slm_ikllama_sys::llama_vocab) -> Self {
Self {
tokenv: OnceLock::new(),
vocab_ptr,
model_ptr,
}
}
}
impl SlmVocab for Vocab {
type Token = Token;
#[inline(never)]
fn vocab_size(&self) -> usize {
unsafe { slm_ikllama_sys::llama_vocab_n_tokens(self.vocab_ptr) as usize }
}
#[inline(never)]
fn token_to_bytes(&self, token: Self::Token, special: bool, left_strip: Option<usize>) -> Result<Vec<u8>, TokenToStringError> {
let mut buf: Vec<u8> = vec![0u8; 128];
let len = buf.len() as c_int;
let lstrip = left_strip
.map_or(Ok(0), i32::try_from)
.map_err(|_| TokenToStringError::InvalidLstrip)?;
let size = unsafe {
slm_ikllama_sys::llama_token_to_piece(
self.model_ptr,
token.as_i32() as slm_ikllama_sys::llama_token,
buf.as_mut_ptr() as *mut c_char,
len,
lstrip,
special,
)
};
match size {
0 => Ok(vec![]),
i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
size => {
let len = usize::try_from(size).expect("size is positive and fits into usize");
buf.truncate(len);
Ok(buf)
}
}
}
#[inline(never)]
fn str_to_tokens(&self, str: &str, add_special: bool, parse_special: bool) -> Result<Vec<Self::Token>, StringToTokenError> {
let add_bos = match add_special {
true => 1,
_ => 0,
};
let tokens_estimation = std::cmp::max(8, (str.len() / 2) + add_bos);
let mut buffer: Vec<Self::Token> = Vec::with_capacity(tokens_estimation);
let c_string = CString::new(str).map_err(|_| FfiError::CstAllocationError)?;
let buffer_capacity =
c_int::try_from(buffer.capacity()).map_err(|_| FfiError::CintConversionError)?;
let text_len = c_int::try_from(c_string.as_bytes().len())
.map_err(|_| FfiError::CintConversionError)?;
let size = unsafe {
slm_ikllama_sys::llama_tokenize(
self.model_ptr,
c_string.as_ptr(),
text_len,
buffer.as_mut_ptr().cast::<slm_ikllama_sys::llama_token>(),
buffer_capacity,
add_special,
parse_special,
)
};
let size = if size.is_negative() {
buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
unsafe {
slm_ikllama_sys::llama_tokenize(
self.model_ptr,
c_string.as_ptr(),
text_len,
buffer.as_mut_ptr().cast::<slm_ikllama_sys::llama_token>(),
-size,
add_special,
parse_special,
)
}
} else {
size
};
let size = usize::try_from(size).expect("size is positive and usize ");
unsafe { buffer.set_len(size) }
Ok(buffer)
}
#[inline(never)]
fn tok_env(&self) -> &SlmTokEnv {
self.tokenv.get_or_init(|| {
let vocab_size = self.vocab_size();
let mut last_used = 0;
let tok_eos = unsafe { slm_ikllama_sys::llama_vocab_eos(self.vocab_ptr) } as u32;
let mut words = Vec::with_capacity(vocab_size);
for i in 0..vocab_size {
let k = self.token_to_bytes(Token::from_i32(i as i32),true, None).unwrap();
if k.starts_with(b"<unused") {
words.push(vec![]);
} else {
words.push(k);
last_used = i;
}
}
words.truncate(last_used+1);
Arc::new(SlmSimpleTokEnv::new(tok_eos, &words))
})
}
}