use crate::context::params::LlamaContextParams;
use crate::context::LlamaContext;
use crate::llama_backend::LlamaBackend;
use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::LlamaTokenType;
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
use std::ffi::CString;
use std::os::raw::c_int;
use std::path::Path;
use std::ptr::NonNull;
pub mod params;
#[derive(Debug)]
#[repr(transparent)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaModel {
pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddBos {
Always,
Never,
}
unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}
impl LlamaModel {
#[must_use]
pub fn n_ctx_train(&self) -> u16 {
let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
u16::try_from(n_ctx_train).expect("n_ctx_train fits into an u16")
}
pub fn tokens(
&self,
) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
(0..self.n_vocab())
.map(LlamaToken::new)
.map(|llama_token| (llama_token, self.token_to_str(llama_token)))
}
#[must_use]
pub fn token_bos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn token_eos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn token_nl(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) };
LlamaToken(token)
}
pub fn token_to_str(&self, token: LlamaToken) -> Result<String, TokenToStringError> {
self.token_to_str_with_size(token, 32)
}
pub fn tokens_to_str(&self, tokens: &[LlamaToken]) -> Result<String, TokenToStringError> {
let mut builder = String::with_capacity(tokens.len() * 4);
for str in tokens.iter().copied().map(|t| self.token_to_str(t)) {
builder += &str?;
}
Ok(builder)
}
pub fn str_to_token(
&self,
str: &str,
add_bos: AddBos,
) -> Result<Vec<LlamaToken>, StringToTokenError> {
let add_bos = match add_bos {
AddBos::Always => true,
AddBos::Never => false
};
let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
let mut buffer = Vec::with_capacity(tokens_estimation);
let c_string = CString::new(str)?;
let buffer_capacity =
c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
let size = unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr(),
buffer_capacity,
add_bos,
true,
)
};
let size = if size.is_negative() {
buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr(),
-size,
add_bos,
true,
)
}
} else {
size
};
let size = usize::try_from(size).expect("size is positive and usize ");
unsafe { buffer.set_len(size) }
Ok(buffer.into_iter().map(LlamaToken).collect())
}
#[must_use]
pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType {
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) };
LlamaTokenType::try_from(token_type).expect("token type is valid")
}
pub fn token_to_str_with_size(
&self,
token: LlamaToken,
buffer_size: usize,
) -> Result<String, TokenToStringError> {
if token == self.token_nl() {
return Ok(String::from("\n"));
}
match self.token_type(token) {
LlamaTokenType::Normal => {}
LlamaTokenType::Control => {
if token == self.token_bos() || token == self.token_eos() {
return Ok(String::new());
}
}
LlamaTokenType::Unknown
| LlamaTokenType::Undefined
| LlamaTokenType::Byte
| LlamaTokenType::UserDefined
| LlamaTokenType::Unused => {
return Ok(String::new());
}
}
let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
let len = string.as_bytes().len();
let len = c_int::try_from(len).expect("length fits into c_int");
let buf = string.into_raw();
let size = unsafe {
llama_cpp_sys_2::llama_token_to_piece(self.model.as_ptr(), token.0, buf, len)
};
match size {
0 => Err(TokenToStringError::UnknownTokenType),
i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
size => {
let string = unsafe { CString::from_raw(buf) };
let mut bytes = string.into_bytes();
let len = usize::try_from(size).expect("size is positive and fits into usize");
bytes.truncate(len);
Ok(String::from_utf8(bytes)?)
}
}
}
#[must_use]
pub fn n_vocab(&self) -> i32 {
unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) }
}
#[must_use]
pub fn vocab_type(&self) -> VocabType {
let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) };
VocabType::try_from(vocab_type).expect("invalid vocab type")
}
#[must_use]
pub fn n_embd(&self) -> c_int {
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
}
#[tracing::instrument(skip_all)]
pub fn load_from_file(
_: &LlamaBackend,
path: impl AsRef<Path>,
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let path = path.as_ref();
debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
let path = path
.to_str()
.ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
let llama_model = unsafe {
println!("{:?}", params.params);
llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params)
};
let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
println!("Loaded {path:?}");
Ok(LlamaModel { model })
}
pub fn new_context(
&self,
_: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext, LlamaContextLoadError> {
let context_params = params.context_params;
let context = unsafe {
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
};
let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
Ok(LlamaContext::new(self, context))
}
}
impl Drop for LlamaModel {
fn drop(&mut self) {
unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
}
}
#[repr(u32)]
#[derive(Debug, Eq, Copy, Clone, PartialEq)]
pub enum VocabType {
BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE,
SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM,
}
#[derive(thiserror::Error, Debug, Eq, PartialEq)]
pub enum LlamaTokenTypeFromIntError {
#[error("Unknown Value {0}")]
UnknownValue(llama_cpp_sys_2::llama_vocab_type),
}
impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
type Error = LlamaTokenTypeFromIntError;
fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
match value {
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
}
}
}