kbolt-core 0.1.7

Core engine for kbolt local-first retrieval
Documentation
use std::ffi::{c_char, c_float, c_int, c_void, CString};
use std::path::{Path, PathBuf};
use std::ptr;
use std::sync::{Arc, Mutex, OnceLock};

use kbolt_types::KboltError;

use crate::Result;

#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;

const LLAMA_DYLIB_ENV: &str = "KBOLT_LLAMA_DYLIB";

#[repr(C)]
struct LlamaModel {
    _private: [u8; 0],
}

#[repr(C)]
struct LlamaVocab {
    _private: [u8; 0],
}

#[repr(C)]
#[derive(Clone, Copy)]
struct LlamaModelParams {
    devices: *mut *mut c_void,
    tensor_buft_overrides: *const c_void,
    n_gpu_layers: i32,
    split_mode: i32,
    main_gpu: i32,
    tensor_split: *const c_float,
    progress_callback: Option<unsafe extern "C" fn(c_float, *mut c_void) -> bool>,
    progress_callback_user_data: *mut c_void,
    kv_overrides: *const c_void,
    vocab_only: bool,
    use_mmap: bool,
    use_direct_io: bool,
    use_mlock: bool,
    check_tensors: bool,
    use_extra_bufts: bool,
    no_host: bool,
    no_alloc: bool,
}

type LlamaModelDefaultParams = unsafe extern "C" fn() -> LlamaModelParams;
type LlamaModelLoadFromFile =
    unsafe extern "C" fn(*const c_char, LlamaModelParams) -> *mut LlamaModel;
type LlamaModelFree = unsafe extern "C" fn(*mut LlamaModel);
type LlamaModelGetVocab = unsafe extern "C" fn(*const LlamaModel) -> *const LlamaVocab;
type LlamaLogSet = unsafe extern "C" fn(
    Option<unsafe extern "C" fn(c_int, *const c_char, *mut c_void)>,
    *mut c_void,
);
type LlamaTokenize = unsafe extern "C" fn(
    *const LlamaVocab,
    *const c_char,
    c_int,
    *mut i32,
    c_int,
    bool,
    bool,
) -> c_int;

struct LlamaLibrary {
    handle: *mut c_void,
    model_default_params: LlamaModelDefaultParams,
    model_load_from_file: LlamaModelLoadFromFile,
    model_free: LlamaModelFree,
    model_get_vocab: LlamaModelGetVocab,
    log_set: LlamaLogSet,
    tokenize: LlamaTokenize,
    source: String,
}

pub(super) struct LlamaCppVocabOnlyTokenizer {
    library: Arc<LlamaLibrary>,
    model: *mut LlamaModel,
    vocab: *const LlamaVocab,
    model_path: PathBuf,
}

unsafe impl Send for LlamaLibrary {}
unsafe impl Sync for LlamaLibrary {}
unsafe impl Send for LlamaCppVocabOnlyTokenizer {}
unsafe impl Sync for LlamaCppVocabOnlyTokenizer {}

impl LlamaCppVocabOnlyTokenizer {
    pub(super) fn load(model_path: &Path) -> Result<Self> {
        let library = LlamaLibrary::shared()?;
        let path = cstring_path(model_path)?;
        let mut params = unsafe { (library.model_default_params)() };
        params.vocab_only = true;
        params.n_gpu_layers = 0;

        let model = unsafe { (library.model_load_from_file)(path.as_ptr(), params) };
        if model.is_null() {
            return Err(KboltError::Inference(format!(
                "failed to load llama.cpp tokenizer vocab from {} using {}",
                model_path.display(),
                library.source
            ))
            .into());
        }

        let vocab = unsafe { (library.model_get_vocab)(model) };
        if vocab.is_null() {
            unsafe {
                (library.model_free)(model);
            }
            return Err(KboltError::Inference(format!(
                "failed to load llama.cpp tokenizer vocab from {}: model has no vocab",
                model_path.display()
            ))
            .into());
        }

        Ok(Self {
            library,
            model,
            vocab,
            model_path: model_path.to_path_buf(),
        })
    }

    pub(super) fn count_tokens(&self, text: &str) -> Result<usize> {
        if text.len() > c_int::MAX as usize {
            return Err(KboltError::Inference(format!(
                "llama.cpp tokenizer input is too large: {} bytes",
                text.len()
            ))
            .into());
        }

        let count = unsafe {
            // llama.cpp currently builds the token vector and returns
            // -required_count before writing when n_tokens_max is too small. The
            // public header still requires a large enough buffer, so this count-only
            // shortcut is intentionally tied to the llama.cpp source behavior we
            // validate with the live parity test.
            (self.library.tokenize)(
                self.vocab,
                text.as_ptr().cast::<c_char>(),
                text.len() as c_int,
                ptr::null_mut(),
                0,
                true,
                true,
            )
        };

        if count == c_int::MIN {
            return Err(KboltError::Inference(format!(
                "llama.cpp tokenizer output overflow for {}",
                self.model_path.display()
            ))
            .into());
        }
        if count < 0 {
            return Ok((-count) as usize);
        }

        Ok(count as usize)
    }
}

impl Drop for LlamaCppVocabOnlyTokenizer {
    fn drop(&mut self) {
        unsafe {
            (self.library.model_free)(self.model);
        }
    }
}

impl LlamaLibrary {
    fn shared() -> Result<Arc<Self>> {
        static LIBRARY: OnceLock<Mutex<Option<Arc<LlamaLibrary>>>> = OnceLock::new();
        let mut library = LIBRARY
            .get_or_init(|| Mutex::new(None))
            .lock()
            .map_err(|_| {
                KboltError::Internal("llama.cpp library cache lock poisoned".to_string())
            })?;
        if let Some(library) = library.as_ref() {
            return Ok(Arc::clone(library));
        }

        let loaded = Arc::new(Self::load()?);
        *library = Some(Arc::clone(&loaded));
        Ok(loaded)
    }

    fn load() -> Result<Self> {
        let mut errors = Vec::new();
        for candidate in library_candidates() {
            match unsafe { Self::load_from(candidate.as_str()) } {
                Ok(library) => return Ok(library),
                Err(err) => errors.push(err.to_string()),
            }
        }

        Err(KboltError::Inference(format!(
            "failed to load llama.cpp tokenizer library. Set {LLAMA_DYLIB_ENV} to libllama; attempts: {}",
            errors.join("; ")
        ))
        .into())
    }

    unsafe fn load_from(candidate: &str) -> Result<Self> {
        let name = CString::new(candidate).map_err(|_| {
            KboltError::Inference(format!(
                "invalid llama.cpp library path contains NUL: {candidate}"
            ))
        })?;
        let handle = dlopen(name.as_ptr(), RTLD_NOW);
        if handle.is_null() {
            return Err(
                KboltError::Inference(format!("{candidate}: {}", dlerror_message())).into(),
            );
        }

        let loaded = (|| -> Result<Self> {
            let library = Self {
                handle,
                model_default_params: load_symbol(handle, "llama_model_default_params")?,
                model_load_from_file: load_symbol(handle, "llama_model_load_from_file")?,
                model_free: load_symbol(handle, "llama_model_free")?,
                model_get_vocab: load_symbol(handle, "llama_model_get_vocab")?,
                log_set: load_symbol(handle, "llama_log_set")?,
                tokenize: load_symbol(handle, "llama_tokenize")?,
                source: candidate.to_string(),
            };
            (library.log_set)(Some(silent_llama_log), ptr::null_mut());
            Ok(library)
        })();

        if loaded.is_err() {
            dlclose(handle);
        }

        loaded
    }
}

unsafe extern "C" fn silent_llama_log(
    _level: c_int,
    _text: *const c_char,
    _user_data: *mut c_void,
) {
}

impl Drop for LlamaLibrary {
    fn drop(&mut self) {
        unsafe {
            dlclose(self.handle);
        }
    }
}

fn library_candidates() -> Vec<String> {
    let mut candidates = Vec::new();
    if let Some(path) = std::env::var_os(LLAMA_DYLIB_ENV) {
        candidates.push(path.to_string_lossy().into_owned());
    }

    #[cfg(target_os = "macos")]
    {
        candidates.push("/opt/homebrew/lib/libllama.dylib".to_string());
        candidates.push("/usr/local/lib/libllama.dylib".to_string());
        candidates.push("libllama.dylib".to_string());
    }

    #[cfg(all(unix, not(target_os = "macos")))]
    {
        candidates.push("libllama.so".to_string());
    }

    candidates
}

fn cstring_path(path: &Path) -> Result<CString> {
    #[cfg(unix)]
    {
        return CString::new(path.as_os_str().as_bytes()).map_err(|_| {
            KboltError::Inference(format!("model path contains NUL: {}", path.display())).into()
        });
    }

    #[cfg(not(unix))]
    {
        let _ = path;
        Err(KboltError::Inference(
            "llama.cpp tokenizer FFI is only implemented on Unix platforms".to_string(),
        )
        .into())
    }
}

unsafe fn load_symbol<T>(handle: *mut c_void, name: &str) -> Result<T>
where
    T: Copy,
{
    let symbol_name = CString::new(name)
        .map_err(|_| KboltError::Internal(format!("invalid symbol name: {name}")))?;
    let symbol = dlsym(handle, symbol_name.as_ptr());
    if symbol.is_null() {
        return Err(KboltError::Inference(format!(
            "missing llama.cpp symbol {name}: {}",
            dlerror_message()
        ))
        .into());
    }

    Ok(std::mem::transmute_copy(&symbol))
}

unsafe fn dlerror_message() -> String {
    let err = dlerror();
    if err.is_null() {
        return "unknown dynamic loader error".to_string();
    }
    std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
}

#[cfg(unix)]
const RTLD_NOW: c_int = 2;

#[cfg_attr(any(target_os = "linux", target_os = "android"), link(name = "dl"))]
extern "C" {
    fn dlopen(filename: *const c_char, flags: c_int) -> *mut c_void;
    fn dlsym(handle: *mut c_void, symbol: *const c_char) -> *mut c_void;
    fn dlclose(handle: *mut c_void) -> c_int;
    fn dlerror() -> *const c_char;
}

#[cfg(test)]
mod tests {
    use std::path::PathBuf;
    use std::process::Command;

    use super::LlamaCppVocabOnlyTokenizer;

    #[test]
    #[ignore = "requires local embedding GGUF, libllama, and llama-tokenize"]
    fn vocab_only_tokenizer_matches_llama_tokenize_counts() {
        let model_path = std::env::var_os("KBOLT_LLAMA_FFI_TEST_MODEL")
            .map(PathBuf::from)
            .unwrap_or_else(default_embedding_model_path);
        let tokenizer =
            LlamaCppVocabOnlyTokenizer::load(&model_path).expect("load vocab-only tokenizer");

        for text in [
            "hello world",
            "fn main() { println!(\"hi\"); }",
            "https://example.com/path?q=hello%20world",
            "δΈ­ζ–‡ text with emoji πŸ˜€",
            "first line\nsecond line\n\nthird line",
        ] {
            let ffi_count = tokenizer.count_tokens(text).expect("ffi token count");
            let cli_count = llama_tokenize_count(&model_path, text);
            assert_eq!(ffi_count, cli_count, "token count mismatch for {text:?}");
        }
    }

    fn default_embedding_model_path() -> PathBuf {
        let home = std::env::var_os("HOME").expect("HOME");
        PathBuf::from(home)
            .join("Library")
            .join("Caches")
            .join("kbolt")
            .join("models")
            .join("embedder")
            .join("embeddinggemma-300M-Q8_0.gguf")
    }

    fn llama_tokenize_count(model_path: &std::path::Path, text: &str) -> usize {
        let output = Command::new("llama-tokenize")
            .arg("--log-disable")
            .arg("-m")
            .arg(model_path)
            .arg("--ids")
            .arg("--show-count")
            .arg("-p")
            .arg(text)
            .output()
            .expect("run llama-tokenize");
        assert!(
            output.status.success(),
            "llama-tokenize failed: {}",
            String::from_utf8_lossy(&output.stderr)
        );
        let stdout = String::from_utf8(output.stdout).expect("utf8 stdout");
        stdout
            .lines()
            .find_map(|line| {
                line.strip_prefix("Total number of tokens: ")
                    .and_then(|count| count.trim().parse().ok())
            })
            .expect("token count line")
    }
}