use std::ffi::CStr;
use std::ffi::CString;
use std::fmt;
use std::num::NonZeroU16;
use std::os::raw::{c_char, c_int};
use std::path::Path;
use std::ptr::NonNull;
use std::slice;
use llama_cpp_sys_4::{
llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
llama_model_free, llama_model_get_device, llama_model_get_vocab, llama_model_has_decoder,
llama_model_has_encoder, llama_model_is_diffusion, llama_model_is_hybrid,
llama_model_is_recurrent, llama_model_load_from_file, llama_model_load_from_splits,
llama_model_meta_count, llama_model_meta_key_by_index, llama_model_meta_val_str,
llama_model_meta_val_str_by_index, llama_model_n_cls_out, llama_model_n_ctx_train,
llama_model_n_devices, llama_model_n_embd, llama_model_n_embd_inp, llama_model_n_embd_out,
llama_model_n_expert, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
llama_model_n_layer_nextn, llama_model_n_params, llama_model_n_swa,
llama_model_rope_freq_scale_train, llama_model_rope_type, llama_model_save_to_file,
llama_model_size, llama_model_target_layer_ids, llama_model_target_layer_ids_n,
llama_split_path, llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab,
llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
};
use crate::context::params::LlamaContextParams;
use crate::context::LlamaContext;
use crate::llama_backend::LlamaBackend;
use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
use crate::{
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
TokenToStringError,
};
pub mod params;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct LlamaBackendDevice {
pub(crate) dev: llama_cpp_sys_4::ggml_backend_dev_t,
}
#[repr(i32)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaBackendDeviceType {
Cpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU.cast_signed(),
Gpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU.cast_signed(),
IntegratedGpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU.cast_signed(),
Accel = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL.cast_signed(),
Meta = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_META.cast_signed(),
}
impl From<llama_cpp_sys_4::ggml_backend_dev_type> for LlamaBackendDeviceType {
fn from(value: llama_cpp_sys_4::ggml_backend_dev_type) -> Self {
match value {
llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU => Self::Cpu,
llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU => Self::Gpu,
llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU => Self::IntegratedGpu,
llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL => Self::Accel,
_ => Self::Meta,
}
}
}
impl LlamaBackendDevice {
pub fn name(&self) -> Result<&str, StringFromModelError> {
let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_name(self.dev) };
if ptr.is_null() {
return Err(StringFromModelError::ReturnedError(-1));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
cstr.to_str().map_err(StringFromModelError::Utf8Error)
}
pub fn description(&self) -> Result<&str, StringFromModelError> {
let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_description(self.dev) };
if ptr.is_null() {
return Err(StringFromModelError::ReturnedError(-1));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
cstr.to_str().map_err(StringFromModelError::Utf8Error)
}
#[must_use]
pub fn device_type(&self) -> LlamaBackendDeviceType {
unsafe { llama_cpp_sys_4::ggml_backend_dev_type(self.dev).into() }
}
#[must_use]
pub fn memory(&self) -> (usize, usize) {
let mut free = 0usize;
let mut total = 0usize;
unsafe {
llama_cpp_sys_4::ggml_backend_dev_memory(self.dev, &raw mut free, &raw mut total);
}
(free, total)
}
}
#[derive(Debug, Clone, Copy)]
pub struct LlamaBackendDevices<'a> {
model: &'a LlamaModel,
next: i32,
}
#[allow(clippy::copy_iterator)]
impl Iterator for LlamaBackendDevices<'_> {
type Item = LlamaBackendDevice;
fn next(&mut self) -> Option<Self::Item> {
let dev = self.model.get_device(self.next)?;
self.next += 1;
Some(dev)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = usize::try_from((self.model.n_devices() - self.next).max(0)).unwrap_or(0);
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for LlamaBackendDevices<'_> {}
#[derive(Debug)]
#[repr(transparent)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaModel {
pub(crate) model: NonNull<llama_model>,
}
#[derive(Debug)]
#[repr(transparent)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaVocab {
pub(crate) vocab: NonNull<llama_vocab>,
}
impl LlamaVocab {
#[must_use]
pub fn n_tokens(&self) -> i32 {
unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
}
#[must_use]
pub fn vocab_type(&self) -> u32 {
unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()) as u32 }
}
#[must_use]
pub fn bos(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
}
#[must_use]
pub fn eos(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
}
#[must_use]
pub fn eot(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
}
#[must_use]
pub fn cls(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
}
#[must_use]
pub fn sep(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
}
#[must_use]
pub fn nl(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
}
#[must_use]
pub fn pad(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_pre(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_suf(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_mid(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_pad(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_rep(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
}
#[must_use]
pub fn fim_sep(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
}
#[must_use]
pub fn get_add_bos(&self) -> bool {
unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
}
#[must_use]
pub fn get_add_eos(&self) -> bool {
unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
}
#[must_use]
pub fn get_add_sep(&self) -> bool {
unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
}
pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
if ptr.is_null() {
return Err(StringFromModelError::ReturnedError(-1));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
cstr.to_str().map_err(StringFromModelError::Utf8Error)
}
#[must_use]
pub fn get_score(&self, token: LlamaToken) -> f32 {
unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
}
#[must_use]
pub fn get_attr(&self, token: LlamaToken) -> u32 {
unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0) as u32 }
}
#[must_use]
pub fn is_control(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
}
#[must_use]
pub fn is_eog(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
}
#[must_use]
pub fn mask(&self) -> LlamaToken {
LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
}
}
#[derive(Debug)]
#[repr(transparent)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
}
impl LlamaLoraAdapter {
#[must_use]
pub fn meta_count(&self) -> i32 {
unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_key_by_index(
&self,
index: i32,
buf_size: usize,
) -> Result<String, StringFromModelError> {
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_cpp_sys_4::llama_adapter_meta_key_by_index(
self.lora_adapter.as_ptr(),
index,
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_cpp_sys_4::llama_adapter_meta_val_str(
self.lora_adapter.as_ptr(),
c_key.as_ptr(),
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_val_str_by_index(
&self,
index: i32,
buf_size: usize,
) -> Result<String, StringFromModelError> {
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
self.lora_adapter.as_ptr(),
index,
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
let count = self.meta_count();
let mut result = Vec::with_capacity(count as usize);
for i in 0..count {
let key = self.meta_key_by_index(i, 256)?;
let val = self.meta_val_str_by_index(i, 4096)?;
result.push((key, val));
}
Ok(result)
}
#[must_use]
pub fn n_invocation_tokens(&self) -> u64 {
unsafe {
llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(self.lora_adapter.as_ptr())
}
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn invocation_tokens(&self) -> &[LlamaToken] {
let n = self.n_invocation_tokens() as usize;
if n == 0 {
return &[];
}
let ptr = unsafe {
llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(self.lora_adapter.as_ptr())
};
if ptr.is_null() {
return &[];
}
unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
}
}
impl Drop for LlamaLoraAdapter {
fn drop(&mut self) {
unsafe {
llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
role: CString,
content: CString,
}
impl LlamaChatMessage {
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
Ok(Self {
role: CString::new(role)?,
content: CString::new(content)?,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddBos {
Always,
Never,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Special {
Tokenize,
Plaintext,
}
unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}
impl LlamaModel {
#[must_use]
pub fn get_vocab(&self) -> LlamaVocab {
let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
LlamaVocab {
vocab: NonNull::new(llama_vocab).unwrap(),
}
}
#[must_use]
pub fn n_ctx_train(&self) -> u32 {
let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
}
pub fn tokens(
&self,
special: Special,
) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
(0..self.n_vocab())
.map(LlamaToken::new)
.map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
}
#[must_use]
pub fn token_bos(&self) -> LlamaToken {
self.get_vocab().bos()
}
#[must_use]
pub fn token_eos(&self) -> LlamaToken {
self.get_vocab().eos()
}
#[must_use]
pub fn token_nl(&self) -> LlamaToken {
self.get_vocab().nl()
}
#[must_use]
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
self.get_vocab().is_eog(token)
}
#[must_use]
pub fn token_cls(&self) -> LlamaToken {
self.get_vocab().cls()
}
#[must_use]
pub fn token_eot(&self) -> LlamaToken {
self.get_vocab().eot()
}
#[must_use]
pub fn token_pad(&self) -> LlamaToken {
self.get_vocab().pad()
}
#[must_use]
pub fn token_sep(&self) -> LlamaToken {
self.get_vocab().sep()
}
#[must_use]
pub fn token_fim_pre(&self) -> LlamaToken {
self.get_vocab().fim_pre()
}
#[must_use]
pub fn token_fim_suf(&self) -> LlamaToken {
self.get_vocab().fim_suf()
}
#[must_use]
pub fn token_fim_mid(&self) -> LlamaToken {
self.get_vocab().fim_mid()
}
#[must_use]
pub fn token_fim_pad(&self) -> LlamaToken {
self.get_vocab().fim_pad()
}
#[must_use]
pub fn token_fim_rep(&self) -> LlamaToken {
self.get_vocab().fim_rep()
}
#[must_use]
pub fn token_fim_sep(&self) -> LlamaToken {
self.get_vocab().fim_sep()
}
#[must_use]
pub fn token_is_control(&self, token: LlamaToken) -> bool {
self.get_vocab().is_control(token)
}
#[must_use]
pub fn token_get_score(&self, token: LlamaToken) -> f32 {
self.get_vocab().get_score(token)
}
pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
let ptr = unsafe {
llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
};
if ptr.is_null() {
return Err(StringFromModelError::ReturnedError(-1));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
cstr.to_str().map_err(StringFromModelError::Utf8Error)
}
#[must_use]
pub fn add_bos_token(&self) -> bool {
self.get_vocab().get_add_bos()
}
#[must_use]
pub fn add_eos_token(&self) -> bool {
self.get_vocab().get_add_eos()
}
#[must_use]
pub fn decode_start_token(&self) -> LlamaToken {
let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
LlamaToken(token)
}
pub fn token_to_str(
&self,
token: LlamaToken,
special: Special,
) -> Result<String, TokenToStringError> {
self.token_to_str_with_size(token, 32, special)
}
pub fn token_to_bytes(
&self,
token: LlamaToken,
special: Special,
) -> Result<Vec<u8>, TokenToStringError> {
self.token_to_bytes_with_size(token, 32, special, None)
}
pub fn tokens_to_str(
&self,
tokens: &[LlamaToken],
special: Special,
) -> Result<String, TokenToStringError> {
let mut builder = String::with_capacity(tokens.len() * 4);
for str in tokens
.iter()
.copied()
.map(|t| self.token_to_str(t, special))
{
builder += &str?;
}
Ok(builder)
}
pub fn str_to_token(
&self,
str: &str,
add_bos: AddBos,
) -> Result<Vec<LlamaToken>, StringToTokenError> {
let add_bos = match add_bos {
AddBos::Always => true,
AddBos::Never => false,
};
let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
let mut buffer = Vec::with_capacity(tokens_estimation);
let c_string = CString::new(str)?;
let buffer_capacity =
c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
let size = unsafe {
llama_tokenize(
self.get_vocab().vocab.as_ref(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr(),
buffer_capacity,
add_bos,
true,
)
};
let size = if size.is_negative() {
buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
unsafe {
llama_tokenize(
self.get_vocab().vocab.as_ref(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr(),
-size,
add_bos,
true,
)
}
} else {
size
};
let size = usize::try_from(size).expect("size is positive and usize ");
unsafe { buffer.set_len(size) }
Ok(buffer.into_iter().map(LlamaToken).collect())
}
#[must_use]
pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
let token_type =
unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
pub fn detokenize(
&self,
tokens: &[LlamaToken],
remove_special: bool,
unparse_special: bool,
) -> Result<String, StringFromModelError> {
let n_tokens = tokens.len() as i32;
let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
let needed = unsafe {
llama_detokenize(
self.get_vocab().vocab.as_ref(),
token_ptr,
n_tokens,
std::ptr::null_mut(),
0,
remove_special,
unparse_special,
)
};
let buf_size = if needed < 0 {
(-needed) as usize
} else {
needed as usize
};
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_detokenize(
self.get_vocab().vocab.as_ref(),
token_ptr,
n_tokens,
buf.as_mut_ptr().cast::<c_char>(),
buf_size as i32,
remove_special,
unparse_special,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
pub fn token_to_str_with_size(
&self,
token: LlamaToken,
buffer_size: usize,
special: Special,
) -> Result<String, TokenToStringError> {
let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
Ok(String::from_utf8(bytes)?)
}
pub fn token_to_bytes_with_size(
&self,
token: LlamaToken,
buffer_size: usize,
special: Special,
lstrip: Option<NonZeroU16>,
) -> Result<Vec<u8>, TokenToStringError> {
if token == self.token_nl() {
return Ok(String::from("\n").into_bytes());
}
let attrs = self.token_attr(token);
if (attrs.contains(LlamaTokenAttr::Control)
&& (token == self.token_bos() || token == self.token_eos()))
|| attrs.is_empty()
|| attrs
.intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
{
return Ok(Vec::new());
}
let special = match special {
Special::Tokenize => true,
Special::Plaintext => false,
};
let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
let len = string.as_bytes().len();
let len = c_int::try_from(len).expect("length fits into c_int");
let buf = string.into_raw();
let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
let size = unsafe {
llama_token_to_piece(
self.get_vocab().vocab.as_ref(),
token.0,
buf,
len,
lstrip,
special,
)
};
match size {
0 => Err(TokenToStringError::UnknownTokenType),
i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
size => {
let string = unsafe { CString::from_raw(buf) };
let mut bytes = string.into_bytes();
let len = usize::try_from(size).expect("size is positive and fits into usize");
bytes.truncate(len);
Ok(bytes)
}
}
}
#[must_use]
pub fn n_vocab(&self) -> i32 {
self.get_vocab().n_tokens()
}
#[must_use]
pub fn vocab_type(&self) -> VocabType {
let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
VocabType::try_from(vocab_type).expect("invalid vocab type")
}
#[must_use]
pub fn n_embd(&self) -> c_int {
unsafe { llama_model_n_embd(self.model.as_ptr()) }
}
#[must_use]
pub fn n_layer(&self) -> c_int {
unsafe { llama_model_n_layer(self.model.as_ptr()) }
}
#[must_use]
pub fn n_layer_nextn(&self) -> c_int {
unsafe { llama_model_n_layer_nextn(self.model.as_ptr()) }
}
#[must_use]
pub fn n_expert(&self) -> c_int {
unsafe { llama_model_n_expert(self.model.as_ptr()) }
}
#[must_use]
pub fn n_devices(&self) -> c_int {
unsafe { llama_model_n_devices(self.model.as_ptr()) }
}
#[must_use]
pub fn get_device(&self, index: i32) -> Option<LlamaBackendDevice> {
if index < 0 || index >= self.n_devices() {
return None;
}
let dev = unsafe { llama_model_get_device(self.model.as_ptr(), index) };
if dev.is_null() {
None
} else {
Some(LlamaBackendDevice { dev })
}
}
#[must_use]
pub fn devices(&self) -> LlamaBackendDevices<'_> {
LlamaBackendDevices {
model: self,
next: 0,
}
}
#[must_use]
pub fn target_layer_ids(&self) -> &[i32] {
let n = unsafe { llama_model_target_layer_ids_n(self.model.as_ptr()) };
if n == 0 {
return &[];
}
let ptr = unsafe { llama_model_target_layer_ids(self.model.as_ptr()) };
if ptr.is_null() {
&[]
} else {
unsafe { slice::from_raw_parts(ptr, n as usize) }
}
}
#[must_use]
pub fn n_head(&self) -> c_int {
unsafe { llama_model_n_head(self.model.as_ptr()) }
}
#[must_use]
pub fn n_head_kv(&self) -> c_int {
unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
}
#[must_use]
pub fn n_embd_inp(&self) -> c_int {
unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
}
#[must_use]
pub fn n_embd_out(&self) -> c_int {
unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
}
#[must_use]
pub fn n_swa(&self) -> c_int {
unsafe { llama_model_n_swa(self.model.as_ptr()) }
}
#[must_use]
pub fn rope_type(&self) -> i32 {
unsafe { llama_model_rope_type(self.model.as_ptr()) }
}
#[must_use]
pub fn rope_freq_scale_train(&self) -> f32 {
unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
}
#[must_use]
pub fn model_size(&self) -> u64 {
unsafe { llama_model_size(self.model.as_ptr()) }
}
#[must_use]
pub fn n_params(&self) -> u64 {
unsafe { llama_model_n_params(self.model.as_ptr()) }
}
#[must_use]
pub fn n_cls_out(&self) -> u32 {
unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
}
pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
if ptr.is_null() {
return Err(StringFromModelError::ReturnedError(-1));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
cstr.to_str().map_err(StringFromModelError::Utf8Error)
}
#[must_use]
pub fn meta_count(&self) -> c_int {
unsafe { llama_model_meta_count(self.model.as_ptr()) }
}
#[allow(clippy::cast_sign_loss)]
pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_model_desc(
self.model.as_ptr(),
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_key_by_index(
&self,
index: i32,
buf_size: usize,
) -> Result<String, StringFromModelError> {
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_model_meta_key_by_index(
self.model.as_ptr(),
index,
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_val_str_by_index(
&self,
index: i32,
buf_size: usize,
) -> Result<String, StringFromModelError> {
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_model_meta_val_str_by_index(
self.model.as_ptr(),
index,
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
let mut buf = vec![0u8; buf_size];
let ret = unsafe {
llama_model_meta_val_str(
self.model.as_ptr(),
c_key.as_ptr(),
buf.as_mut_ptr().cast::<c_char>(),
buf_size,
)
};
if ret < 0 {
return Err(StringFromModelError::ReturnedError(ret));
}
let len = ret as usize;
let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
Ok(s.to_owned())
}
#[allow(clippy::cast_sign_loss)]
pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
let count = self.meta_count();
let mut result = Vec::with_capacity(count as usize);
for i in 0..count {
let key = self.meta_key_by_index(i, 256)?;
let val = self.meta_val_str_by_index(i, 4096)?;
result.push((key, val));
}
Ok(result)
}
#[must_use]
pub fn has_encoder(&self) -> bool {
unsafe { llama_model_has_encoder(self.model.as_ptr()) }
}
#[must_use]
pub fn has_decoder(&self) -> bool {
unsafe { llama_model_has_decoder(self.model.as_ptr()) }
}
#[must_use]
pub fn is_recurrent(&self) -> bool {
unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
}
#[must_use]
pub fn is_hybrid(&self) -> bool {
unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
}
#[must_use]
pub fn is_diffusion(&self) -> bool {
unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
}
#[allow(clippy::missing_panics_doc)] pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
let ret = unsafe {
llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
};
if ret < 0 {
return Err(ChatTemplateError::MissingTemplate(ret));
}
let template_c = unsafe { CString::from_raw(chat_ptr) };
let template = template_c.to_str()?;
let ret: usize = ret.try_into().unwrap();
if template.len() < ret {
return Err(ChatTemplateError::BuffSizeError(ret + 1));
}
Ok(template.to_owned())
}
#[tracing::instrument(skip_all, fields(params))]
pub fn load_from_file(
_: &LlamaBackend,
path: impl AsRef<Path>,
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let path = path.as_ref();
debug_assert!(
Path::new(path).exists(),
"{} does not exist",
path.display()
);
let path = path
.to_str()
.ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
tracing::debug!(?path, "Loaded model");
Ok(LlamaModel { model })
}
#[tracing::instrument(skip_all)]
pub fn load_from_splits(
_: &LlamaBackend,
paths: &[impl AsRef<Path>],
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let c_strings: Vec<CString> = paths
.iter()
.map(|p| {
let path = p.as_ref();
debug_assert!(path.exists(), "{} does not exist", path.display());
let path_str = path
.to_str()
.ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
CString::new(path_str).map_err(LlamaModelLoadError::from)
})
.collect::<Result<Vec<_>, _>>()?;
let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
let llama_model = unsafe {
llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
};
let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
tracing::debug!("Loaded model from {} splits", paths.len());
Ok(LlamaModel { model })
}
pub unsafe fn load_from_file_ptr(
file: *mut llama_cpp_sys_4::FILE,
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
Ok(LlamaModel { model })
}
pub unsafe fn init_from_user(
metadata: *mut llama_cpp_sys_4::gguf_context,
set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
set_tensor_data_ud: *mut std::ffi::c_void,
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let model = llama_cpp_sys_4::llama_model_init_from_user(
metadata,
set_tensor_data,
set_tensor_data_ud,
params.params,
);
let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
Ok(LlamaModel { model })
}
pub fn save_to_file(&self, path: impl AsRef<Path>) {
let path = path.as_ref();
let path_str = path.to_str().expect("path is not valid UTF-8");
let c_path = CString::new(path_str).expect("path contains null bytes");
unsafe {
llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
}
}
#[allow(clippy::cast_sign_loss)]
#[must_use]
pub fn chat_builtin_templates() -> Vec<String> {
let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
if count <= 0 {
return Vec::new();
}
let count = count as usize;
let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
unsafe {
llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
}
ptrs.iter()
.map(|&p| {
let cstr = unsafe { CStr::from_ptr(p) };
cstr.to_str()
.expect("template name is not valid UTF-8")
.to_owned()
})
.collect()
}
pub fn lora_adapter_init(
&self,
path: impl AsRef<Path>,
) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
let path = path.as_ref();
debug_assert!(
Path::new(path).exists(),
"{} does not exist",
path.display()
);
let path = path
.to_str()
.ok_or(LlamaLoraAdapterInitError::PathToStrError(
path.to_path_buf(),
))?;
let cstr = CString::new(path)?;
let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
tracing::debug!(?path, "Initialized lora adapter");
Ok(LlamaLoraAdapter {
lora_adapter: adapter,
})
}
#[allow(clippy::needless_pass_by_value)]
pub fn new_context(
&self,
_: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
if params.attn_rot_disabled {
#[allow(unused_unsafe)]
unsafe {
std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
}
} else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
}
let context_params = params.context_params;
let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
#[allow(unused_unsafe)]
match prev_rot_var {
Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
None if params.attn_rot_disabled => unsafe {
std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
},
None => {}
}
let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
Ok(LlamaContext::new(self, context, params.embeddings()))
}
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<&str>,
chat: &[LlamaChatMessage],
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
let message_length = chat.iter().fold(0usize, |acc, c| {
acc + c.role.to_bytes().len() + c.content.to_bytes().len()
});
let chat_sys: Vec<llama_chat_message> = chat
.iter()
.map(|c| llama_chat_message {
role: c.role.as_ptr(),
content: c.content.as_ptr(),
})
.collect();
let tmpl_cstring = tmpl.map(CString::new).transpose()?;
let tmpl_ptr = tmpl_cstring
.as_ref()
.map_or(std::ptr::null(), |s| s.as_ptr());
let mut buf_size = message_length.saturating_mul(4).max(4096);
for _ in 0..2 {
let mut buff = vec![0u8; buf_size];
let res = unsafe {
llama_chat_apply_template(
tmpl_ptr,
chat_sys.as_ptr(),
chat_sys.len(),
add_ass,
buff.as_mut_ptr().cast(),
i32::try_from(buff.len()).expect("buffer length fits in i32"),
)
};
if res < 0 {
return Err(ApplyChatTemplateError::BuffSizeError);
}
#[allow(clippy::cast_sign_loss)]
let needed = res as usize;
if needed > buf_size {
buf_size = needed + 1; continue;
}
let formatted = unsafe {
CStr::from_ptr(buff.as_ptr().cast())
.to_string_lossy()
.into_owned()
};
return Ok(formatted);
}
Err(ApplyChatTemplateError::BuffSizeError)
}
#[must_use]
pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
let mut buffer = vec![0u8; 1024];
let len = unsafe {
llama_split_path(
buffer.as_mut_ptr().cast::<c_char>(),
buffer.len(),
CString::new(path_prefix).unwrap().as_ptr(),
split_no,
split_count,
)
};
let len = usize::try_from(len).expect("split_path length fits in usize");
buffer.truncate(len);
String::from_utf8(buffer).unwrap_or_default()
}
#[must_use]
pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
let mut buffer = vec![0u8; 1024];
let len = unsafe {
llama_split_prefix(
buffer.as_mut_ptr().cast::<c_char>(),
buffer.len(),
CString::new(split_path).unwrap().as_ptr(),
split_no,
split_count,
)
};
if len > 0 {
let len = usize::try_from(len).expect("split_prefix length fits in usize");
buffer.truncate(len);
String::from_utf8(buffer).ok()
} else {
None
}
}
}
#[allow(clippy::cast_precision_loss)]
impl fmt::Display for LlamaModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
write!(
f,
"{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
layers = self.n_layer(),
heads = self.n_head(),
embd = self.n_embd(),
params = self.n_params(),
size = self.model_size() as f64 / (1024.0 * 1024.0),
)
}
}
impl Drop for LlamaModel {
fn drop(&mut self) {
unsafe { llama_model_free(self.model.as_ptr()) }
}
}
#[repr(u32)]
#[derive(Debug, Eq, Copy, Clone, PartialEq)]
pub enum VocabType {
BPE = LLAMA_VOCAB_TYPE_BPE as _,
SPM = LLAMA_VOCAB_TYPE_SPM as _,
}
#[derive(thiserror::Error, Debug, Eq, PartialEq)]
pub enum LlamaTokenTypeFromIntError {
#[error("Unknown Value {0}")]
UnknownValue(llama_vocab_type),
}
impl TryFrom<llama_vocab_type> for VocabType {
type Error = LlamaTokenTypeFromIntError;
fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
match value {
LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
}
}
}