use crate::{
context::{Context, ContextParams},
error::MullamaError,
sys,
token::TokenId,
};
use std::os::raw::c_char;
use std::{ffi::CString, path::Path, ptr, sync::Arc};
type ProgressCallbackFn = fn(f32) -> bool;
struct ProgressCallbackData {
callback: ProgressCallbackFn,
}
unsafe extern "C" fn progress_callback_wrapper(
progress: std::os::raw::c_float,
user_data: *mut std::os::raw::c_void,
) -> bool {
let data = &*(user_data as *const ProgressCallbackData);
(data.callback)(progress)
}
#[derive(Debug)]
struct ModelInner {
model_ptr: *mut sys::llama_model,
vocab_ptr: *const sys::llama_vocab,
}
impl Drop for ModelInner {
fn drop(&mut self) {
if !self.model_ptr.is_null() {
unsafe {
sys::llama_model_free(self.model_ptr);
}
}
}
}
unsafe impl Send for ModelInner {}
unsafe impl Sync for ModelInner {}
#[derive(Debug, Clone)]
pub struct Model {
inner: Arc<ModelInner>,
}
impl Model {
#[allow(dead_code)]
pub(crate) fn model_ptr(&self) -> *mut sys::llama_model {
self.inner.model_ptr
}
}
#[derive(Debug, Clone)]
pub struct ModelParams {
pub n_gpu_layers: i32,
pub split_mode: sys::llama_split_mode,
pub main_gpu: i32,
pub tensor_split: Vec<f32>,
pub vocab_only: bool,
pub use_mmap: bool,
pub use_mlock: bool,
pub check_tensors: bool,
pub use_extra_bufts: bool,
pub kv_overrides: Vec<ModelKvOverride>,
pub progress_callback: Option<fn(f32) -> bool>,
}
#[derive(Debug, Clone)]
pub struct ModelKvOverride {
pub key: String,
pub value: ModelKvOverrideValue,
}
#[derive(Debug, Clone)]
pub enum ModelKvOverrideValue {
Int(i64),
Float(f64),
Bool(bool),
Str(String),
}
impl Default for ModelParams {
fn default() -> Self {
Self {
n_gpu_layers: 0,
split_mode: sys::llama_split_mode::LLAMA_SPLIT_MODE_NONE,
main_gpu: 0,
tensor_split: Vec::new(),
vocab_only: false,
use_mmap: true,
use_mlock: false,
check_tensors: true,
use_extra_bufts: false,
kv_overrides: Vec::new(),
progress_callback: None,
}
}
}
impl Model {
pub fn load(path: impl AsRef<Path>) -> Result<Self, MullamaError> {
Self::load_with_params(path, ModelParams::default())
}
pub fn load_with_params(
path: impl AsRef<Path>,
params: ModelParams,
) -> Result<Self, MullamaError> {
let path = path.as_ref();
if !path.exists() {
return Err(MullamaError::ModelLoadError(format!(
"Model file not found: {}",
path.display()
)));
}
let c_path = CString::new(path.to_string_lossy().as_bytes())
.map_err(|_| MullamaError::ModelLoadError("Invalid path".to_string()))?;
unsafe {
sys::llama_backend_init();
}
let mut llama_params = unsafe { sys::llama_model_default_params() };
llama_params.n_gpu_layers = params.n_gpu_layers;
if params.n_gpu_layers > 0 {
llama_params.split_mode = params.split_mode;
llama_params.main_gpu = params.main_gpu;
}
llama_params.vocab_only = params.vocab_only as sys::c_bool;
llama_params.use_mmap = params.use_mmap as sys::c_bool;
llama_params.use_mlock = params.use_mlock as sys::c_bool;
llama_params.check_tensors = params.check_tensors as sys::c_bool;
llama_params.use_extra_bufts = params.use_extra_bufts as sys::c_bool;
if !params.tensor_split.is_empty() {
llama_params.tensor_split = params.tensor_split.as_ptr();
} else {
llama_params.tensor_split = ptr::null();
}
let kv_overrides: Vec<sys::llama_model_kv_override> = params
.kv_overrides
.iter()
.map(Self::convert_kv_override)
.collect::<Result<Vec<_>, _>>()?;
if !kv_overrides.is_empty() {
llama_params.kv_overrides = kv_overrides.as_ptr();
} else {
llama_params.kv_overrides = ptr::null();
}
let _callback_data: Option<Box<ProgressCallbackData>>;
if let Some(cb) = params.progress_callback {
let data = Box::new(ProgressCallbackData { callback: cb });
llama_params.progress_callback = Some(
progress_callback_wrapper
as unsafe extern "C" fn(
std::os::raw::c_float,
*mut std::os::raw::c_void,
) -> bool,
);
llama_params.progress_callback_user_data =
&*data as *const ProgressCallbackData as *mut std::os::raw::c_void;
_callback_data = Some(data);
} else {
llama_params.progress_callback = None;
llama_params.progress_callback_user_data = ptr::null_mut();
_callback_data = None;
}
llama_params.devices = ptr::null_mut();
llama_params.tensor_buft_overrides = ptr::null();
let model_ptr = unsafe { sys::llama_model_load_from_file(c_path.as_ptr(), llama_params) };
if model_ptr.is_null() {
return Err(MullamaError::ModelLoadError(
"Failed to load model - check file format and parameters".to_string(),
));
}
let vocab_ptr = unsafe { sys::llama_model_get_vocab(model_ptr) };
Ok(Model {
inner: Arc::new(ModelInner {
model_ptr,
vocab_ptr,
}),
})
}
pub fn create_context(&self, params: ContextParams) -> Result<Context, MullamaError> {
Context::new(Arc::new(self.clone()), params)
}
pub fn tokenize(
&self,
text: &str,
add_bos: bool,
special: bool,
) -> Result<Vec<TokenId>, MullamaError> {
let c_text = CString::new(text)
.map_err(|_| MullamaError::TokenizationError("Invalid text".to_string()))?;
let vocab = unsafe { sys::llama_model_get_vocab(self.inner.model_ptr) };
if vocab.is_null() {
return Err(MullamaError::TokenizationError(
"Failed to get vocabulary".to_string(),
));
}
let result = unsafe {
sys::llama_tokenize(
vocab,
c_text.as_ptr(),
text.len() as i32,
ptr::null_mut(),
0,
add_bos as sys::c_bool,
special as sys::c_bool,
)
};
let max_tokens = if result < 0 { -result } else { result };
if max_tokens == 0 {
return Ok(Vec::new());
}
let mut tokens = vec![0i32; max_tokens as usize];
let actual_tokens = unsafe {
sys::llama_tokenize(
vocab,
c_text.as_ptr(),
text.len() as i32,
tokens.as_mut_ptr(),
max_tokens,
add_bos as sys::c_bool,
special as sys::c_bool,
)
};
if actual_tokens < 0 {
return Err(MullamaError::TokenizationError(format!(
"Tokenization failed with code: {}",
actual_tokens
)));
}
Ok(tokens
.into_iter()
.take(actual_tokens as usize)
.map(|t| t as TokenId)
.collect())
}
pub fn detokenize(
&self,
tokens: &[TokenId],
remove_special: bool,
unparse_special: bool,
) -> Result<String, MullamaError> {
if tokens.is_empty() {
return Ok(String::new());
}
let mut result = String::new();
let mut is_first = true;
for &token in tokens {
if remove_special && self.token_is_control(token) {
continue;
}
let lstrip = if is_first { 1 } else { 0 };
let piece = self.token_to_str(token, lstrip, unparse_special)?;
result.push_str(&piece);
is_first = false;
}
Ok(result)
}
pub fn token_to_str(
&self,
token: TokenId,
lstrip: i32,
special: bool,
) -> Result<String, MullamaError> {
let mut buf = vec![0u8; 128]; let vocab = self.vocab();
let n_chars = unsafe {
sys::llama_token_to_piece(
vocab,
token as sys::llama_token,
buf.as_mut_ptr() as *mut c_char,
buf.len() as i32,
lstrip,
special as sys::c_bool,
)
};
if n_chars < 0 {
return Err(MullamaError::TokenizationError(format!(
"Failed to convert token to string: {}",
n_chars
)));
}
if n_chars as usize > buf.len() {
buf.resize(n_chars as usize + 1, 0);
let n_chars_retry = unsafe {
sys::llama_token_to_piece(
vocab,
token as sys::llama_token,
buf.as_mut_ptr() as *mut c_char,
buf.len() as i32,
lstrip,
special as sys::c_bool,
)
};
if n_chars_retry < 0 {
return Err(MullamaError::TokenizationError(format!(
"Failed to convert token to string on retry: {}",
n_chars_retry
)));
}
}
let result_bytes = &buf[..n_chars as usize];
Ok(String::from_utf8_lossy(result_bytes).into_owned())
}
pub fn n_ctx_train(&self) -> i32 {
unsafe { sys::llama_model_n_ctx_train(self.inner.model_ptr) as i32 }
}
pub fn n_embd(&self) -> i32 {
unsafe { sys::llama_model_n_embd(self.inner.model_ptr) }
}
pub fn n_layer(&self) -> i32 {
unsafe { sys::llama_model_n_layer(self.inner.model_ptr) }
}
pub fn n_head(&self) -> i32 {
unsafe { sys::llama_model_n_head(self.inner.model_ptr) }
}
pub fn n_head_kv(&self) -> i32 {
unsafe { sys::llama_model_n_head_kv(self.inner.model_ptr) }
}
pub fn n_swa(&self) -> i32 {
unsafe { sys::llama_model_n_swa(self.inner.model_ptr) }
}
pub fn rope_freq_scale_train(&self) -> f32 {
unsafe { sys::llama_model_rope_freq_scale_train(self.inner.model_ptr) }
}
pub fn vocab_type(&self) -> sys::llama_vocab_type {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.inner.model_ptr) };
unsafe { sys::llama_vocab_type(vocab_ptr) }
}
pub fn vocab_size(&self) -> i32 {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.inner.model_ptr) };
unsafe { sys::llama_vocab_n_tokens(vocab_ptr) }
}
pub fn rope_type(&self) -> sys::llama_rope_type {
unsafe { sys::llama_model_rope_type(self.inner.model_ptr) }
}
pub fn as_ptr(&self) -> *mut sys::llama_model {
self.inner.model_ptr
}
}
#[derive(Debug, Clone)]
pub struct Token {
pub id: TokenId,
pub text: String,
pub score: f32,
pub attr: sys::llama_token_attr,
}
impl Model {
#[inline]
fn vocab(&self) -> *const sys::llama_vocab {
self.inner.vocab_ptr
}
pub fn get_token_info(&self, token: TokenId) -> Result<Token, MullamaError> {
let vocab = self.vocab();
let text_ptr = unsafe { sys::llama_vocab_get_text(vocab, token as sys::llama_token) };
if text_ptr.is_null() {
return Err(MullamaError::TokenizationError(
"Token not found".to_string(),
));
}
let text = unsafe {
std::ffi::CStr::from_ptr(text_ptr)
.to_string_lossy()
.to_string()
};
let score = unsafe { sys::llama_vocab_get_score(vocab, token as sys::llama_token) };
let attr = unsafe { sys::llama_vocab_get_attr(vocab, token as sys::llama_token) };
Ok(Token {
id: token,
text,
score,
attr,
})
}
pub fn token_is_eog(&self, token: TokenId) -> bool {
unsafe { sys::llama_vocab_is_eog(self.vocab(), token as sys::llama_token) as bool }
}
pub fn token_is_control(&self, token: TokenId) -> bool {
unsafe { sys::llama_vocab_is_control(self.vocab(), token as sys::llama_token) as bool }
}
pub fn token_bos(&self) -> TokenId {
unsafe { sys::llama_vocab_bos(self.vocab()) as TokenId }
}
pub fn token_eos(&self) -> TokenId {
unsafe { sys::llama_vocab_eos(self.vocab()) as TokenId }
}
pub fn token_sep(&self) -> TokenId {
unsafe { sys::llama_vocab_sep(self.vocab()) as TokenId }
}
pub fn token_nl(&self) -> TokenId {
unsafe { sys::llama_vocab_nl(self.vocab()) as TokenId }
}
pub fn token_pad(&self) -> TokenId {
unsafe { sys::llama_vocab_pad(self.vocab()) as TokenId }
}
pub fn token_eot(&self) -> TokenId {
unsafe { sys::llama_vocab_eot(self.vocab()) as TokenId }
}
pub fn add_bos_token(&self) -> bool {
unsafe { sys::llama_vocab_get_add_bos(self.vocab()) as bool }
}
pub fn add_eos_token(&self) -> bool {
unsafe { sys::llama_vocab_get_add_eos(self.vocab()) as bool }
}
}
impl Model {
pub fn architecture(&self) -> Option<String> {
self.meta_val("general.architecture")
}
pub fn name(&self) -> Option<String> {
self.meta_val("general.name")
}
pub fn desc(&self) -> String {
let mut buf = vec![0u8; 256];
let len = unsafe {
sys::llama_model_desc(
self.inner.model_ptr,
buf.as_mut_ptr() as *mut c_char,
buf.len(),
)
};
if len > 0 {
buf.truncate(len as usize);
String::from_utf8_lossy(&buf).into_owned()
} else {
String::new()
}
}
pub fn size(&self) -> u64 {
unsafe { sys::llama_model_size(self.inner.model_ptr) }
}
pub fn n_params(&self) -> u64 {
unsafe { sys::llama_model_n_params(self.inner.model_ptr) }
}
pub fn n_vocab(&self) -> i32 {
unsafe {
let vocab = sys::llama_model_get_vocab(self.inner.model_ptr);
sys::llama_vocab_n_tokens(vocab)
}
}
pub fn n_cls_out(&self) -> u32 {
unsafe { sys::llama_model_n_cls_out(self.inner.model_ptr) }
}
pub fn has_encoder(&self) -> bool {
unsafe { sys::llama_model_has_encoder(self.inner.model_ptr) }
}
pub fn has_decoder(&self) -> bool {
unsafe { sys::llama_model_has_decoder(self.inner.model_ptr) }
}
pub fn is_recurrent(&self) -> bool {
unsafe { sys::llama_model_is_recurrent(self.inner.model_ptr) }
}
pub fn is_diffusion(&self) -> bool {
unsafe { sys::llama_model_is_diffusion(self.inner.model_ptr) }
}
pub fn decoder_start_token(&self) -> TokenId {
unsafe { sys::llama_model_decoder_start_token(self.inner.model_ptr) as TokenId }
}
pub fn chat_template(&self) -> Option<String> {
let template_ptr =
unsafe { sys::llama_model_chat_template(self.inner.model_ptr, std::ptr::null()) };
if template_ptr.is_null() {
None
} else {
let cstr = unsafe { std::ffi::CStr::from_ptr(template_ptr) };
Some(cstr.to_string_lossy().into_owned())
}
}
pub fn get_chat_stop_sequences(&self) -> Vec<String> {
let mut stops = Vec::new();
if let Some(arch) = self.architecture() {
let arch_lower = arch.to_lowercase();
match arch_lower.as_str() {
"qwen" | "qwen2" | "qwen2moe" | "qwen2vl" | "qwen3" => {
stops.push("<|im_end|>".to_string());
stops.push("<|endoftext|>".to_string());
}
"llama" => {
stops.push("<|eot_id|>".to_string());
stops.push("<|eom_id|>".to_string());
stops.push("<|start_header_id|>".to_string());
}
"gemma" | "gemma2" | "gemma3" => {
stops.push("<end_of_turn>".to_string());
}
"phi" | "phi2" | "phi3" | "phi4" => {
stops.push("<|end|>".to_string());
stops.push("<|im_end|>".to_string());
}
"mistral" | "mixtral" => {
stops.push("</s>".to_string());
}
"deepseek" | "deepseek2" => {
stops.push("<|end▁of▁sentence|>".to_string());
stops.push("<|end▁of▁sentence|>".to_string());
}
_ => {}
}
}
if let Some(template) = self.chat_template() {
if template.contains("<|im_end|>") && !stops.contains(&"<|im_end|>".to_string()) {
stops.push("<|im_end|>".to_string());
stops.push("|im_end|".to_string());
}
if template.contains("<|eot_id|>") && !stops.contains(&"<|eot_id|>".to_string()) {
stops.push("<|eot_id|>".to_string());
}
if template.contains("<|eom_id|>") && !stops.contains(&"<|eom_id|>".to_string()) {
stops.push("<|eom_id|>".to_string());
}
if template.contains("<|end|>") && !stops.contains(&"<|end|>".to_string()) {
stops.push("<|end|>".to_string());
}
if template.contains("<|endoftext|>") && !stops.contains(&"<|endoftext|>".to_string()) {
stops.push("<|endoftext|>".to_string());
}
if template.contains("<end_of_turn>") && !stops.contains(&"<end_of_turn>".to_string()) {
stops.push("<end_of_turn>".to_string());
}
if template.contains("[/INST]")
&& template.contains("</s>")
&& !stops.contains(&"</s>".to_string())
{
stops.push("</s>".to_string());
}
}
stops
}
pub fn apply_chat_template(
&self,
template: Option<&str>,
messages: &[(&str, &str)],
add_generation_prompt: bool,
) -> Result<String, MullamaError> {
let mut owned_messages = Vec::with_capacity(messages.len());
for (role, content) in messages {
let role_cstr = CString::new(*role)
.map_err(|_| MullamaError::InvalidInput("Role contains null byte".to_string()))?;
let content_cstr = CString::new(*content).map_err(|_| {
MullamaError::InvalidInput("Content contains null byte".to_string())
})?;
owned_messages.push((role_cstr, content_cstr));
}
let chat_messages: Vec<sys::llama_chat_message> = owned_messages
.iter()
.map(|(role, content)| sys::llama_chat_message {
role: role.as_ptr(),
content: content.as_ptr(),
})
.collect();
let template_cstr = match template {
Some(tpl) => Some(CString::new(tpl).map_err(|_| {
MullamaError::InvalidInput("Template contains null byte".to_string())
})?),
None => {
let model_template_ptr = unsafe {
sys::llama_model_chat_template(self.inner.model_ptr, std::ptr::null())
};
if !model_template_ptr.is_null() {
let cstr = unsafe { std::ffi::CStr::from_ptr(model_template_ptr) };
Some(CString::new(cstr.to_bytes()).unwrap())
} else {
None
}
}
};
let template_ptr = template_cstr
.as_ref()
.map_or(std::ptr::null(), |t| t.as_ptr());
let required = unsafe {
sys::llama_chat_apply_template(
template_ptr,
chat_messages.as_ptr(),
chat_messages.len(),
add_generation_prompt,
std::ptr::null_mut(),
0,
)
};
if required < 0 {
return Err(MullamaError::InvalidInput(
"Failed to apply chat template".to_string(),
));
}
let mut buffer = vec![0u8; required as usize + 1];
let written = unsafe {
sys::llama_chat_apply_template(
template_ptr,
chat_messages.as_ptr(),
chat_messages.len(),
add_generation_prompt,
buffer.as_mut_ptr() as *mut c_char,
buffer.len() as i32,
)
};
if written < 0 {
return Err(MullamaError::InvalidInput(
"Failed to apply chat template".to_string(),
));
}
buffer.truncate(written as usize);
Ok(String::from_utf8_lossy(&buffer).into_owned())
}
pub fn meta_count(&self) -> i32 {
unsafe { sys::llama_model_meta_count(self.inner.model_ptr) }
}
pub fn meta_key(&self, index: i32) -> Option<String> {
let mut buf = vec![0u8; 256];
let len = unsafe {
sys::llama_model_meta_key_by_index(
self.inner.model_ptr,
index,
buf.as_mut_ptr() as *mut c_char,
buf.len(),
)
};
if len > 0 {
buf.truncate(len as usize);
Some(String::from_utf8_lossy(&buf).into_owned())
} else {
None
}
}
pub fn meta_val(&self, key: &str) -> Option<String> {
let key_cstr = CString::new(key).ok()?;
let mut buf = vec![0u8; 1024];
let len = unsafe {
sys::llama_model_meta_val_str(
self.inner.model_ptr,
key_cstr.as_ptr(),
buf.as_mut_ptr() as *mut c_char,
buf.len(),
)
};
if len > 0 {
buf.truncate(len as usize);
Some(String::from_utf8_lossy(&buf).into_owned())
} else {
None
}
}
pub fn meta_val_by_index(&self, index: i32) -> Option<String> {
let mut buf = vec![0u8; 1024];
let len = unsafe {
sys::llama_model_meta_val_str_by_index(
self.inner.model_ptr,
index,
buf.as_mut_ptr() as *mut c_char,
buf.len(),
)
};
if len > 0 {
buf.truncate(len as usize);
Some(String::from_utf8_lossy(&buf).into_owned())
} else {
None
}
}
pub fn metadata(&self) -> std::collections::HashMap<String, String> {
let mut map = std::collections::HashMap::new();
let count = self.meta_count();
for i in 0..count {
if let (Some(key), Some(val)) = (self.meta_key(i), self.meta_val_by_index(i)) {
map.insert(key, val);
}
}
map
}
pub fn token_fim_pre(&self) -> TokenId {
unsafe { sys::llama_token_fim_pre(self.inner.model_ptr) as TokenId }
}
pub fn token_fim_suf(&self) -> TokenId {
unsafe { sys::llama_token_fim_suf(self.inner.model_ptr) as TokenId }
}
pub fn token_fim_mid(&self) -> TokenId {
unsafe { sys::llama_token_fim_mid(self.inner.model_ptr) as TokenId }
}
pub fn token_fim_pad(&self) -> TokenId {
unsafe { sys::llama_token_fim_pad(self.inner.model_ptr) as TokenId }
}
pub fn token_fim_rep(&self) -> TokenId {
unsafe { sys::llama_token_fim_rep(self.inner.model_ptr) as TokenId }
}
pub fn token_fim_sep(&self) -> TokenId {
unsafe { sys::llama_token_fim_sep(self.inner.model_ptr) as TokenId }
}
pub fn save(&self, path: &str) -> Result<(), MullamaError> {
let c_path = CString::new(path)
.map_err(|_| MullamaError::InvalidInput("Invalid path".to_string()))?;
unsafe {
sys::llama_model_save_to_file(self.inner.model_ptr, c_path.as_ptr());
}
if !Path::new(path).exists() {
return Err(MullamaError::IoError(std::io::Error::other(format!(
"Model save failed: {}",
path
))));
}
Ok(())
}
}
impl Model {
fn convert_kv_override(
override_: &ModelKvOverride,
) -> Result<sys::llama_model_kv_override, MullamaError> {
let key_bytes = override_.key.as_bytes();
if key_bytes.len() >= 128 {
return Err(MullamaError::ModelLoadError(
"KV override key too long".to_string(),
));
}
let mut key = [0i8; 128];
for (i, &byte) in key_bytes.iter().enumerate() {
key[i] = byte as i8;
}
let (tag, value) = match &override_.value {
ModelKvOverrideValue::Int(v) => {
let val = sys::llama_model_kv_override_value { val_i64: *v };
(
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_INT,
val,
)
}
ModelKvOverrideValue::Float(v) => {
let val = sys::llama_model_kv_override_value { val_f64: *v };
(
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
val,
)
}
ModelKvOverrideValue::Bool(v) => {
let val = sys::llama_model_kv_override_value {
val_bool: *v as sys::c_bool,
};
(
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_BOOL,
val,
)
}
ModelKvOverrideValue::Str(s) => {
if s.len() >= 128 {
return Err(MullamaError::ModelLoadError(
"KV override string value too long".to_string(),
));
}
let mut val_str = [0i8; 128];
for (i, &byte) in s.as_bytes().iter().enumerate() {
val_str[i] = byte as i8;
}
let val = sys::llama_model_kv_override_value { val_str };
(
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_STR,
val,
)
}
};
Ok(sys::llama_model_kv_override { tag, key, value })
}
}