use std::ffi::{c_char, c_float, c_int, c_void, CString};
use std::path::{Path, PathBuf};
use std::ptr;
use std::sync::{Arc, Mutex, OnceLock};
use kbolt_types::KboltError;
use crate::Result;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
const LLAMA_DYLIB_ENV: &str = "KBOLT_LLAMA_DYLIB";
#[repr(C)]
struct LlamaModel {
_private: [u8; 0],
}
#[repr(C)]
struct LlamaVocab {
_private: [u8; 0],
}
#[repr(C)]
#[derive(Clone, Copy)]
struct LlamaModelParams {
devices: *mut *mut c_void,
tensor_buft_overrides: *const c_void,
n_gpu_layers: i32,
split_mode: i32,
main_gpu: i32,
tensor_split: *const c_float,
progress_callback: Option<unsafe extern "C" fn(c_float, *mut c_void) -> bool>,
progress_callback_user_data: *mut c_void,
kv_overrides: *const c_void,
vocab_only: bool,
use_mmap: bool,
use_direct_io: bool,
use_mlock: bool,
check_tensors: bool,
use_extra_bufts: bool,
no_host: bool,
no_alloc: bool,
}
type LlamaModelDefaultParams = unsafe extern "C" fn() -> LlamaModelParams;
type LlamaModelLoadFromFile =
unsafe extern "C" fn(*const c_char, LlamaModelParams) -> *mut LlamaModel;
type LlamaModelFree = unsafe extern "C" fn(*mut LlamaModel);
type LlamaModelGetVocab = unsafe extern "C" fn(*const LlamaModel) -> *const LlamaVocab;
type LlamaLogSet = unsafe extern "C" fn(
Option<unsafe extern "C" fn(c_int, *const c_char, *mut c_void)>,
*mut c_void,
);
type LlamaTokenize = unsafe extern "C" fn(
*const LlamaVocab,
*const c_char,
c_int,
*mut i32,
c_int,
bool,
bool,
) -> c_int;
struct LlamaLibrary {
handle: *mut c_void,
model_default_params: LlamaModelDefaultParams,
model_load_from_file: LlamaModelLoadFromFile,
model_free: LlamaModelFree,
model_get_vocab: LlamaModelGetVocab,
log_set: LlamaLogSet,
tokenize: LlamaTokenize,
source: String,
}
pub(super) struct LlamaCppVocabOnlyTokenizer {
library: Arc<LlamaLibrary>,
model: *mut LlamaModel,
vocab: *const LlamaVocab,
model_path: PathBuf,
}
unsafe impl Send for LlamaLibrary {}
unsafe impl Sync for LlamaLibrary {}
unsafe impl Send for LlamaCppVocabOnlyTokenizer {}
unsafe impl Sync for LlamaCppVocabOnlyTokenizer {}
impl LlamaCppVocabOnlyTokenizer {
pub(super) fn load(model_path: &Path) -> Result<Self> {
let library = LlamaLibrary::shared()?;
let path = cstring_path(model_path)?;
let mut params = unsafe { (library.model_default_params)() };
params.vocab_only = true;
params.n_gpu_layers = 0;
let model = unsafe { (library.model_load_from_file)(path.as_ptr(), params) };
if model.is_null() {
return Err(KboltError::Inference(format!(
"failed to load llama.cpp tokenizer vocab from {} using {}",
model_path.display(),
library.source
))
.into());
}
let vocab = unsafe { (library.model_get_vocab)(model) };
if vocab.is_null() {
unsafe {
(library.model_free)(model);
}
return Err(KboltError::Inference(format!(
"failed to load llama.cpp tokenizer vocab from {}: model has no vocab",
model_path.display()
))
.into());
}
Ok(Self {
library,
model,
vocab,
model_path: model_path.to_path_buf(),
})
}
pub(super) fn count_tokens(&self, text: &str) -> Result<usize> {
if text.len() > c_int::MAX as usize {
return Err(KboltError::Inference(format!(
"llama.cpp tokenizer input is too large: {} bytes",
text.len()
))
.into());
}
let count = unsafe {
(self.library.tokenize)(
self.vocab,
text.as_ptr().cast::<c_char>(),
text.len() as c_int,
ptr::null_mut(),
0,
true,
true,
)
};
if count == c_int::MIN {
return Err(KboltError::Inference(format!(
"llama.cpp tokenizer output overflow for {}",
self.model_path.display()
))
.into());
}
if count < 0 {
return Ok((-count) as usize);
}
Ok(count as usize)
}
}
impl Drop for LlamaCppVocabOnlyTokenizer {
fn drop(&mut self) {
unsafe {
(self.library.model_free)(self.model);
}
}
}
impl LlamaLibrary {
fn shared() -> Result<Arc<Self>> {
static LIBRARY: OnceLock<Mutex<Option<Arc<LlamaLibrary>>>> = OnceLock::new();
let mut library = LIBRARY
.get_or_init(|| Mutex::new(None))
.lock()
.map_err(|_| {
KboltError::Internal("llama.cpp library cache lock poisoned".to_string())
})?;
if let Some(library) = library.as_ref() {
return Ok(Arc::clone(library));
}
let loaded = Arc::new(Self::load()?);
*library = Some(Arc::clone(&loaded));
Ok(loaded)
}
fn load() -> Result<Self> {
let mut errors = Vec::new();
for candidate in library_candidates() {
match unsafe { Self::load_from(candidate.as_str()) } {
Ok(library) => return Ok(library),
Err(err) => errors.push(err.to_string()),
}
}
Err(KboltError::Inference(format!(
"failed to load llama.cpp tokenizer library. Set {LLAMA_DYLIB_ENV} to libllama; attempts: {}",
errors.join("; ")
))
.into())
}
unsafe fn load_from(candidate: &str) -> Result<Self> {
let name = CString::new(candidate).map_err(|_| {
KboltError::Inference(format!(
"invalid llama.cpp library path contains NUL: {candidate}"
))
})?;
let handle = dlopen(name.as_ptr(), RTLD_NOW);
if handle.is_null() {
return Err(
KboltError::Inference(format!("{candidate}: {}", dlerror_message())).into(),
);
}
let loaded = (|| -> Result<Self> {
let library = Self {
handle,
model_default_params: load_symbol(handle, "llama_model_default_params")?,
model_load_from_file: load_symbol(handle, "llama_model_load_from_file")?,
model_free: load_symbol(handle, "llama_model_free")?,
model_get_vocab: load_symbol(handle, "llama_model_get_vocab")?,
log_set: load_symbol(handle, "llama_log_set")?,
tokenize: load_symbol(handle, "llama_tokenize")?,
source: candidate.to_string(),
};
(library.log_set)(Some(silent_llama_log), ptr::null_mut());
Ok(library)
})();
if loaded.is_err() {
dlclose(handle);
}
loaded
}
}
unsafe extern "C" fn silent_llama_log(
_level: c_int,
_text: *const c_char,
_user_data: *mut c_void,
) {
}
impl Drop for LlamaLibrary {
fn drop(&mut self) {
unsafe {
dlclose(self.handle);
}
}
}
fn library_candidates() -> Vec<String> {
let mut candidates = Vec::new();
if let Some(path) = std::env::var_os(LLAMA_DYLIB_ENV) {
candidates.push(path.to_string_lossy().into_owned());
}
#[cfg(target_os = "macos")]
{
candidates.push("/opt/homebrew/lib/libllama.dylib".to_string());
candidates.push("/usr/local/lib/libllama.dylib".to_string());
candidates.push("libllama.dylib".to_string());
}
#[cfg(all(unix, not(target_os = "macos")))]
{
candidates.push("libllama.so".to_string());
}
candidates
}
fn cstring_path(path: &Path) -> Result<CString> {
#[cfg(unix)]
{
return CString::new(path.as_os_str().as_bytes()).map_err(|_| {
KboltError::Inference(format!("model path contains NUL: {}", path.display())).into()
});
}
#[cfg(not(unix))]
{
let _ = path;
Err(KboltError::Inference(
"llama.cpp tokenizer FFI is only implemented on Unix platforms".to_string(),
)
.into())
}
}
unsafe fn load_symbol<T>(handle: *mut c_void, name: &str) -> Result<T>
where
T: Copy,
{
let symbol_name = CString::new(name)
.map_err(|_| KboltError::Internal(format!("invalid symbol name: {name}")))?;
let symbol = dlsym(handle, symbol_name.as_ptr());
if symbol.is_null() {
return Err(KboltError::Inference(format!(
"missing llama.cpp symbol {name}: {}",
dlerror_message()
))
.into());
}
Ok(std::mem::transmute_copy(&symbol))
}
unsafe fn dlerror_message() -> String {
let err = dlerror();
if err.is_null() {
return "unknown dynamic loader error".to_string();
}
std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
}
#[cfg(unix)]
const RTLD_NOW: c_int = 2;
#[cfg_attr(any(target_os = "linux", target_os = "android"), link(name = "dl"))]
extern "C" {
fn dlopen(filename: *const c_char, flags: c_int) -> *mut c_void;
fn dlsym(handle: *mut c_void, symbol: *const c_char) -> *mut c_void;
fn dlclose(handle: *mut c_void) -> c_int;
fn dlerror() -> *const c_char;
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use std::process::Command;
use super::LlamaCppVocabOnlyTokenizer;
#[test]
#[ignore = "requires local embedding GGUF, libllama, and llama-tokenize"]
fn vocab_only_tokenizer_matches_llama_tokenize_counts() {
let model_path = std::env::var_os("KBOLT_LLAMA_FFI_TEST_MODEL")
.map(PathBuf::from)
.unwrap_or_else(default_embedding_model_path);
let tokenizer =
LlamaCppVocabOnlyTokenizer::load(&model_path).expect("load vocab-only tokenizer");
for text in [
"hello world",
"fn main() { println!(\"hi\"); }",
"https://example.com/path?q=hello%20world",
"δΈζ text with emoji π",
"first line\nsecond line\n\nthird line",
] {
let ffi_count = tokenizer.count_tokens(text).expect("ffi token count");
let cli_count = llama_tokenize_count(&model_path, text);
assert_eq!(ffi_count, cli_count, "token count mismatch for {text:?}");
}
}
fn default_embedding_model_path() -> PathBuf {
let home = std::env::var_os("HOME").expect("HOME");
PathBuf::from(home)
.join("Library")
.join("Caches")
.join("kbolt")
.join("models")
.join("embedder")
.join("embeddinggemma-300M-Q8_0.gguf")
}
fn llama_tokenize_count(model_path: &std::path::Path, text: &str) -> usize {
let output = Command::new("llama-tokenize")
.arg("--log-disable")
.arg("-m")
.arg(model_path)
.arg("--ids")
.arg("--show-count")
.arg("-p")
.arg(text)
.output()
.expect("run llama-tokenize");
assert!(
output.status.success(),
"llama-tokenize failed: {}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8 stdout");
stdout
.lines()
.find_map(|line| {
line.strip_prefix("Total number of tokens: ")
.and_then(|count| count.trim().parse().ok())
})
.expect("token count line")
}
}