use llama_cpp_sys_v3::{LlamaLib, LoadError};
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub mod backend;
pub mod downloader;
pub use backend::Backend;
#[derive(Debug, thiserror::Error)]
pub enum LlamaError {
#[error("Failed to load DLL: {0}")]
DllLoad(#[from] LoadError),
#[error("Failed to download backend: {0}")]
Download(#[from] downloader::DownloadError),
#[error("Failed to initialize backend")]
BackendInit,
#[error("Failed to load model from file")]
ModelLoad,
#[error("Failed to create context")]
ContextCreate,
#[error("Decode error with status code {0}")]
Decode(i32),
#[error("Missing or empty chat template")]
MissingChatTemplate,
}
pub struct LoadOptions<'a> {
pub backend: Backend,
pub app_name: &'a str,
pub version: Option<&'a str>, pub explicit_path: Option<&'a Path>, pub cache_dir: Option<PathBuf>, }
pub struct LlamaBackend {
pub lib: Arc<LlamaLib>,
}
impl Drop for LlamaBackend {
fn drop(&mut self) {
if Arc::strong_count(&self.lib) == 1 {
unsafe {
(self.lib.symbols.llama_backend_free)();
}
}
}
}
impl LlamaBackend {
pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
let dll_path = if let Some(path) = options.explicit_path {
path.to_path_buf()
} else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
PathBuf::from(env_path)
} else {
downloader::Downloader::ensure_dll(
options.backend,
options.app_name,
options.version,
options.cache_dir,
)?
};
if let Some(parent) = dll_path.parent() {
if let Some(path_ext) = std::env::var_os("PATH") {
let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
let parent_buf = parent.to_path_buf();
if !paths.contains(&parent_buf) {
paths.insert(0, parent_buf);
if let Ok(new_path) = std::env::join_paths(paths) {
std::env::set_var("PATH", new_path);
}
}
}
}
let lib = LlamaLib::open(&dll_path)?;
if let Some(parent) = dll_path.parent() {
let parent_str = parent.to_string_lossy().to_string();
let c_parent = std::ffi::CString::new(parent_str).unwrap();
unsafe {
(lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
}
} else {
unsafe {
(lib.symbols.ggml_backend_load_all)();
}
}
unsafe {
(lib.symbols.llama_backend_init)();
}
Ok(Self { lib: Arc::new(lib) })
}
}
pub struct LlamaModel {
pub backend: Arc<LlamaLib>,
pub handle: *mut llama_cpp_sys_v3::llama_model,
}
impl Drop for LlamaModel {
fn drop(&mut self) {
unsafe {
(self.backend.symbols.llama_model_free)(self.handle);
}
}
}
unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}
impl LlamaModel {
pub fn load_from_file(
backend: &LlamaBackend,
path: &str,
params: llama_cpp_sys_v3::llama_model_params,
) -> Result<Self, LlamaError> {
let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::ModelLoad)?;
let handle =
unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
if handle.is_null() {
return Err(LlamaError::ModelLoad);
}
Ok(Self {
backend: backend.lib.clone(),
handle,
})
}
pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
unsafe { (backend.lib.symbols.llama_model_default_params)() }
}
pub fn get_vocab(&self) -> LlamaVocab {
let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
LlamaVocab {
backend: self.backend.clone(),
handle,
}
}
pub fn tokenize(
&self,
text: &str,
add_special: bool,
parse_special: bool,
) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
let vocab = self.get_vocab();
let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::ModelLoad)?;
let n_tokens = unsafe {
(self.backend.symbols.llama_tokenize)(
vocab.handle,
c_text.as_ptr(),
text.len() as i32,
std::ptr::null_mut(),
0,
add_special,
parse_special,
)
};
if n_tokens < 0 {
let mut tokens = vec![0; (-n_tokens) as usize];
let actual_tokens = unsafe {
(self.backend.symbols.llama_tokenize)(
vocab.handle,
c_text.as_ptr(),
text.len() as i32,
tokens.as_mut_ptr(),
tokens.len() as i32,
add_special,
parse_special,
)
};
if actual_tokens < 0 {
return Err(LlamaError::Decode(actual_tokens));
}
tokens.truncate(actual_tokens as usize);
Ok(tokens)
} else {
let mut tokens = vec![0; n_tokens as usize];
let actual_tokens = unsafe {
(self.backend.symbols.llama_tokenize)(
vocab.handle,
c_text.as_ptr(),
text.len() as i32,
tokens.as_mut_ptr(),
tokens.len() as i32,
add_special,
parse_special,
)
};
if actual_tokens < 0 {
return Err(LlamaError::Decode(actual_tokens));
}
tokens.truncate(actual_tokens as usize);
Ok(tokens)
}
}
pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
let vocab = self.get_vocab();
let mut buf = vec![0u8; 128];
let n = unsafe {
(self.backend.symbols.llama_token_to_piece)(
vocab.handle,
token,
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len() as i32,
0,
true,
)
};
if n < 0 {
buf.resize((-n) as usize, 0);
unsafe {
(self.backend.symbols.llama_token_to_piece)(
vocab.handle,
token,
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len() as i32,
0,
true,
);
}
} else {
buf.truncate(n as usize);
}
String::from_utf8_lossy(&buf).to_string()
}
pub fn apply_chat_template(
&self,
tmpl: Option<&str>,
messages: &[ChatMessage],
add_ass: bool,
) -> Result<String, LlamaError> {
let resolved_tmpl = match tmpl {
Some(s) => s.to_string(),
None => self
.get_chat_template(None)
.ok_or(LlamaError::MissingChatTemplate)?,
};
if resolved_tmpl.trim().is_empty() {
return Err(LlamaError::MissingChatTemplate);
}
let c_tmpl = std::ffi::CString::new(resolved_tmpl).map_err(|_| LlamaError::ModelLoad)?;
let mut c_messages = Vec::with_capacity(messages.len());
let mut c_strings = Vec::with_capacity(messages.len() * 2);
for msg in messages {
let role =
std::ffi::CString::new(msg.role.as_str()).map_err(|_| LlamaError::ModelLoad)?;
let content =
std::ffi::CString::new(msg.content.as_str()).map_err(|_| LlamaError::ModelLoad)?;
let msg_struct = llama_cpp_sys_v3::llama_chat_message {
role: role.as_ptr(),
content: content.as_ptr(),
};
c_messages.push(msg_struct);
c_strings.push(role);
c_strings.push(content);
}
let n_chars = unsafe {
(self.backend.symbols.llama_chat_apply_template)(
c_tmpl.as_ptr(),
c_messages.as_ptr(),
c_messages.len(),
add_ass,
std::ptr::null_mut(),
0,
)
};
if n_chars < 0 {
return Err(LlamaError::Decode(n_chars));
}
let mut buf = vec![0u8; n_chars as usize + 1];
let actual_chars = unsafe {
(self.backend.symbols.llama_chat_apply_template)(
c_tmpl.as_ptr(),
c_messages.as_ptr(),
c_messages.len(),
add_ass,
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len() as i32,
)
};
if actual_chars < 0 {
return Err(LlamaError::Decode(actual_chars));
}
buf.truncate(actual_chars as usize);
Ok(String::from_utf8_lossy(&buf).to_string())
}
pub fn get_chat_template(&self, name: Option<&str>) -> Option<String> {
let c_name = name.map(|s| std::ffi::CString::new(s).ok()).flatten();
let name_ptr = c_name
.as_ref()
.map(|c| c.as_ptr())
.unwrap_or(std::ptr::null());
let mut buf = vec![0u8; 1024];
let n = unsafe {
(self.backend.symbols.llama_model_chat_template)(
self.handle,
name_ptr,
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len(),
)
};
if n < 0 {
return None;
}
if n as usize >= buf.len() {
buf.resize(n as usize + 1, 0);
unsafe {
(self.backend.symbols.llama_model_chat_template)(
self.handle,
name_ptr,
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len(),
);
}
}
buf.truncate(n as usize);
Some(String::from_utf8_lossy(&buf).to_string())
}
}
pub struct ChatMessage {
pub role: String,
pub content: String,
}
pub struct LlamaVocab {
pub backend: Arc<LlamaLib>,
pub handle: *const llama_cpp_sys_v3::llama_vocab,
}
impl LlamaVocab {
pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
}
pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
}
pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
}
}
pub struct LlamaSampler {
pub backend: Arc<LlamaLib>,
pub handle: *mut llama_cpp_sys_v3::llama_sampler,
}
impl Drop for LlamaSampler {
fn drop(&mut self) {
unsafe {
(self.backend.symbols.llama_sampler_free)(self.handle);
}
}
}
impl LlamaSampler {
pub fn new_chain(backend: Arc<LlamaLib>, no_perf: bool) -> Self {
let params = llama_cpp_sys_v3::llama_sampler_chain_params { no_perf };
let handle = unsafe { (backend.symbols.llama_sampler_chain_init)(params) };
Self { backend, handle }
}
pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
Self { backend, handle }
}
pub fn new_temp(backend: Arc<LlamaLib>, temp: f32) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_temp)(temp) };
Self { backend, handle }
}
pub fn new_top_k(backend: Arc<LlamaLib>, k: i32) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_top_k)(k) };
Self { backend, handle }
}
pub fn new_top_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_top_p)(p, min_keep) };
Self { backend, handle }
}
pub fn new_min_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_min_p)(p, min_keep) };
Self { backend, handle }
}
pub fn new_typical(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_typical)(p, min_keep) };
Self { backend, handle }
}
pub fn new_mirostat_v2(backend: Arc<LlamaLib>, seed: u32, tau: f32, eta: f32) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_mirostat_v2)(seed, tau, eta) };
Self { backend, handle }
}
pub fn new_penalties(
backend: Arc<LlamaLib>,
last_n: i32,
repeat: f32,
freq: f32,
present: f32,
) -> Self {
let handle = unsafe {
(backend.symbols.llama_sampler_init_penalties)(last_n, repeat, freq, present)
};
Self { backend, handle }
}
pub fn new_dist(backend: Arc<LlamaLib>, seed: u32) -> Self {
let handle = unsafe { (backend.symbols.llama_sampler_init_dist)(seed) };
Self { backend, handle }
}
pub fn add(&mut self, other: LlamaSampler) {
unsafe {
(self.backend.symbols.llama_sampler_chain_add)(self.handle, other.handle);
}
std::mem::forget(other);
}
pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
}
pub fn accept(&self, token: llama_cpp_sys_v3::llama_token) {
unsafe {
(self.backend.symbols.llama_sampler_accept)(self.handle, token);
}
}
}
pub struct LlamaContext {
pub backend: Arc<LlamaLib>,
pub handle: *mut llama_cpp_sys_v3::llama_context,
}
impl Drop for LlamaContext {
fn drop(&mut self) {
unsafe {
(self.backend.symbols.llama_free)(self.handle);
}
}
}
impl LlamaContext {
pub fn new(
model: &LlamaModel,
params: llama_cpp_sys_v3::llama_context_params,
) -> Result<Self, LlamaError> {
let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
if handle.is_null() {
return Err(LlamaError::ContextCreate);
}
Ok(Self {
backend: model.backend.clone(),
handle,
})
}
pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
unsafe { (model.backend.symbols.llama_context_default_params)() }
}
pub fn decode(&mut self, batch: &LlamaBatch) -> Result<(), LlamaError> {
let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch.handle) };
if res != 0 {
Err(LlamaError::Decode(res))
} else {
Ok(())
}
}
pub fn kv_cache_clear(&mut self) {
unsafe {
let memory = (self.backend.symbols.llama_get_memory)(self.handle);
(self.backend.symbols.llama_memory_clear)(memory, true);
}
}
}
pub struct LlamaBatch {
pub backend: Arc<LlamaLib>,
pub handle: llama_cpp_sys_v3::llama_batch,
}
impl Drop for LlamaBatch {
fn drop(&mut self) {
unsafe {
(self.backend.symbols.llama_batch_free)(self.handle);
}
}
}
impl LlamaBatch {
pub fn new(backend: Arc<LlamaLib>, n_tokens: i32, embd: i32, n_seq_max: i32) -> Self {
let handle = unsafe { (backend.symbols.llama_batch_init)(n_tokens, embd, n_seq_max) };
Self { backend, handle }
}
pub fn clear(&mut self) {
self.handle.n_tokens = 0;
}
pub fn add(
&mut self,
token: llama_cpp_sys_v3::llama_token,
pos: llama_cpp_sys_v3::llama_pos,
seq_ids: &[i32],
logits: bool,
) {
let n = self.handle.n_tokens as usize;
unsafe {
*self.handle.token.add(n) = token;
*self.handle.pos.add(n) = pos;
*self.handle.n_seq_id.add(n) = seq_ids.len() as i32;
for (j, &seq_id) in seq_ids.iter().enumerate() {
*(*self.handle.seq_id.add(n)).add(j) = seq_id;
}
*self.handle.logits.add(n) = if logits { 1 } else { 0 };
}
self.handle.n_tokens += 1;
}
}