pub mod buft_overrides;
pub mod kv_overrides;
pub mod params;
pub mod vocab;
use std::path::Path;
use std::ptr::NonNull;
use llama_crab_sys as sys;
use crate::backend::LlamaBackend;
use crate::context::{LlamaContext, LlamaContextParams};
use crate::error::{LlamaError, Result};
use crate::token::LlamaToken;
pub use buft_overrides::BufferTypeOverride;
pub use kv_overrides::ParamOverrideValue;
pub(crate) use vocab::VocabPtr;
#[derive(Debug)]
pub struct LlamaModel {
pub(crate) handle: NonNull<sys::llama_model>,
pub(crate) vocab: VocabPtr,
}
impl LlamaModel {
pub fn load_from_file(
_backend: &LlamaBackend,
path: &Path,
params: ¶ms::LlamaModelParams,
) -> Result<Self> {
let path_cstr = std::ffi::CString::new(path.display().to_string())?;
let raw_params = params.build();
let handle = unsafe { sys::llama_load_model_from_file(path_cstr.as_ptr(), raw_params) };
let handle = NonNull::new(handle).ok_or_else(|| {
LlamaError::ModelLoad(format!("llama_load_model_from_file: {}", path.display()))
})?;
let vocab_raw = unsafe { sys::llama_model_get_vocab(handle.as_ptr()) };
let vocab = NonNull::new(vocab_raw.cast_mut())
.ok_or_else(|| LlamaError::ModelLoad("llama_model_get_vocab returned null".into()))?;
Ok(Self {
handle,
vocab: VocabPtr(vocab),
})
}
pub(crate) fn raw(&self) -> *mut sys::llama_model {
self.handle.as_ptr()
}
pub(crate) fn vocab(&self) -> *const sys::llama_vocab {
self.vocab.0.as_ptr()
}
#[must_use]
pub fn n_layer(&self) -> u32 {
unsafe { sys::llama_model_n_layer(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn n_ctx_train(&self) -> u32 {
unsafe { sys::llama_model_n_ctx_train(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn n_embd(&self) -> i32 {
unsafe { sys::llama_model_n_embd(self.handle.as_ptr()) }
}
#[must_use]
pub fn n_vocab(&self) -> i32 {
unsafe { sys::llama_n_vocab(self.vocab()) }
}
#[must_use]
pub fn is_recurrent(&self) -> bool {
unsafe { sys::llama_model_is_recurrent(self.handle.as_ptr()) }
}
#[must_use]
pub fn is_hybrid(&self) -> bool {
unsafe { sys::llama_model_is_hybrid(self.handle.as_ptr()) }
}
#[must_use]
pub fn size(&self) -> u64 {
unsafe { sys::llama_model_size(self.handle.as_ptr()) }
}
#[must_use]
pub fn n_params(&self) -> u64 {
unsafe { sys::llama_model_n_params(self.handle.as_ptr()) }
}
#[must_use]
pub fn token_bos(&self) -> LlamaToken {
LlamaToken(unsafe { sys::llama_token_bos(self.vocab()) })
}
#[must_use]
pub fn token_eos(&self) -> LlamaToken {
LlamaToken(unsafe { sys::llama_token_eos(self.vocab()) })
}
pub fn tokenize(
&self,
text: &str,
add_bos: bool,
special: bool,
) -> Result<Vec<LlamaToken>> {
let bytes = text.as_bytes();
let mut buf: Vec<i32> = vec![0; bytes.len().saturating_add(8)];
let n = unsafe {
sys::llama_tokenize(
self.vocab(),
bytes.as_ptr().cast(),
bytes.len() as i32,
buf.as_mut_ptr(),
buf.len() as i32,
add_bos,
special,
)
};
if n < 0 {
return Err(LlamaError::Batch(format!("tokenize failed: {n}")));
}
if (n as usize) > buf.len() {
buf.resize(n as usize, 0);
let n2 = unsafe {
sys::llama_tokenize(
self.vocab(),
bytes.as_ptr().cast(),
bytes.len() as i32,
buf.as_mut_ptr(),
buf.len() as i32,
add_bos,
special,
)
};
if n2 < 0 {
return Err(LlamaError::Batch(format!("tokenize retry: {n2}")));
}
}
buf.truncate(n as usize);
Ok(buf.into_iter().map(LlamaToken).collect())
}
pub fn detokenize(&self, tokens: &[LlamaToken], special: bool) -> Result<String> {
let mut raw_buf: Vec<u8> = vec![0; 64];
let mut len = raw_buf.len() as i32;
let n = unsafe {
sys::llama_detokenize(
self.vocab(),
tokens.as_ptr().cast(),
tokens.len() as i32,
raw_buf.as_mut_ptr().cast(),
len,
special,
special,
)
};
if n < 0 {
return Err(LlamaError::Batch(format!("detokenize: {n}")));
}
if (n as usize) > raw_buf.len() {
raw_buf.resize(n as usize, 0);
len = n;
let n2 = unsafe {
sys::llama_detokenize(
self.vocab(),
tokens.as_ptr().cast(),
tokens.len() as i32,
raw_buf.as_mut_ptr().cast(),
len,
special,
special,
)
};
if n2 < 0 {
return Err(LlamaError::Batch(format!("detokenize: {n2}")));
}
}
raw_buf.truncate(n as usize);
Ok(String::from_utf8(raw_buf)?)
}
pub fn new_context(
&self,
_backend: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext<'_>> {
let raw_params = params.build();
let ctx = unsafe { sys::llama_new_context_with_model(self.handle.as_ptr(), raw_params) };
NonNull::new(ctx)
.map(|ctx| LlamaContext::from_raw(ctx, self))
.ok_or(LlamaError::ContextLoad(
"llama_new_context_with_model returned null".into(),
))
}
}
unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}
impl Drop for LlamaModel {
fn drop(&mut self) {
unsafe { sys::llama_free_model(self.handle.as_ptr()) };
}
}