use std::ffi::NulError;
use std::fmt::Debug;
use std::num::NonZeroI32;
use crate::llama_batch::BatchAddError;
use std::os::raw::c_int;
use std::path::PathBuf;
use std::string::FromUtf8Error;
pub mod common;
pub mod context;
#[cfg(feature = "ggml")]
pub mod ggml;
pub mod llama_backend;
pub mod llama_batch;
pub mod model;
pub mod quantize;
pub mod sampling;
pub mod token;
pub mod token_type;
#[cfg(feature = "rpc")]
pub mod rpc;
#[cfg(feature = "mtmd")]
pub mod mtmd;
pub type Result<T> = std::result::Result<T, LLamaCppError>;
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LLamaCppError {
#[error("BackendAlreadyInitialized")]
BackendAlreadyInitialized,
#[error("{0}")]
ChatTemplateError(#[from] ChatTemplateError),
#[error("{0}")]
DecodeError(#[from] DecodeError),
#[error("{0}")]
EncodeError(#[from] EncodeError),
#[error("{0}")]
LlamaModelLoadError(#[from] LlamaModelLoadError),
#[error("{0}")]
LlamaContextLoadError(#[from] LlamaContextLoadError),
#[error["{0}"]]
BatchAddError(#[from] BatchAddError),
#[error(transparent)]
EmbeddingError(#[from] EmbeddingsError),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum ChatTemplateError {
#[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
BuffSizeError(usize),
#[error("the model has no meta val - returned code {0}")]
MissingTemplate(i32),
#[error(transparent)]
Utf8Error(#[from] std::str::Utf8Error),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum StringFromModelError {
#[error("llama.cpp returned error code {0}")]
ReturnedError(i32),
#[error(transparent)]
Utf8Error(#[from] std::str::Utf8Error),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaContextLoadError {
#[error("null reference from llama.cpp")]
NullReturn,
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum DecodeError {
#[error("Decode Error 1: NoKvCacheSlot")]
NoKvCacheSlot,
#[error("Decode Error -1: n_tokens == 0")]
NTokensZero,
#[error("Decode Error {0}: unknown")]
Unknown(c_int),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum EncodeError {
#[error("Encode Error 1: NoKvCacheSlot")]
NoKvCacheSlot,
#[error("Encode Error -1: n_tokens == 0")]
NTokensZero,
#[error("Encode Error {0}: unknown")]
Unknown(c_int),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum EmbeddingsError {
#[error("Embeddings weren't enabled in the context options")]
NotEnabled,
#[error("Logits were not enabled for the given token")]
LogitsNotEnabled,
#[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")]
NonePoolType,
}
impl From<NonZeroI32> for DecodeError {
fn from(value: NonZeroI32) -> Self {
match value.get() {
1 => DecodeError::NoKvCacheSlot,
-1 => DecodeError::NTokensZero,
i => DecodeError::Unknown(i),
}
}
}
impl From<NonZeroI32> for EncodeError {
fn from(value: NonZeroI32) -> Self {
match value.get() {
1 => EncodeError::NoKvCacheSlot,
-1 => EncodeError::NTokensZero,
i => EncodeError::Unknown(i),
}
}
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaModelLoadError {
#[error("null byte in string {0}")]
NullError(#[from] NulError),
#[error("null result from llama cpp")]
NullResult,
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaLoraAdapterInitError {
#[error("null byte in string {0}")]
NullError(#[from] NulError),
#[error("null result from llama cpp")]
NullResult,
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaLoraAdapterSetError {
#[error("error code from llama cpp")]
ErrorResult(i32),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaLoraAdapterRemoveError {
#[error("error code from llama cpp")]
ErrorResult(i32),
}
#[must_use]
pub fn llama_time_us() -> i64 {
unsafe { llama_cpp_sys_4::llama_time_us() }
}
#[must_use]
pub fn max_devices() -> usize {
unsafe { llama_cpp_sys_4::llama_max_devices() }
}
#[must_use]
pub fn mmap_supported() -> bool {
unsafe { llama_cpp_sys_4::llama_supports_mmap() }
}
#[must_use]
pub fn mlock_supported() -> bool {
unsafe { llama_cpp_sys_4::llama_supports_mlock() }
}
#[derive(Debug, thiserror::Error, Clone)]
#[non_exhaustive]
pub enum TokenToStringError {
#[error("Unknown Token Type")]
UnknownTokenType,
#[error("Insufficient Buffer Space {0}")]
InsufficientBufferSpace(c_int),
#[error("FromUtf8Error {0}")]
FromUtf8Error(#[from] FromUtf8Error),
}
#[derive(Debug, thiserror::Error)]
pub enum StringToTokenError {
#[error("{0}")]
NulError(#[from] NulError),
#[error("{0}")]
CIntConversionError(#[from] std::num::TryFromIntError),
}
#[derive(Debug, thiserror::Error)]
pub enum NewLlamaChatMessageError {
#[error("{0}")]
NulError(#[from] NulError),
}
#[derive(Debug, thiserror::Error)]
pub enum ApplyChatTemplateError {
#[error("The buffer was too small. Please contact a maintainer and we will update it.")]
BuffSizeError,
#[error("{0}")]
NulError(#[from] NulError),
#[error("{0}")]
FromUtf8Error(#[from] FromUtf8Error),
}
#[must_use]
pub fn ggml_time_us() -> i64 {
unsafe { llama_cpp_sys_4::ggml_time_us() }
}
#[must_use]
pub fn llama_supports_mlock() -> bool {
unsafe { llama_cpp_sys_4::llama_supports_mlock() }
}
#[must_use]
pub fn supports_gpu_offload() -> bool {
unsafe { llama_cpp_sys_4::llama_supports_gpu_offload() }
}
#[must_use]
pub fn supports_rpc() -> bool {
unsafe { llama_cpp_sys_4::llama_supports_rpc() }
}
#[must_use]
pub fn print_system_info() -> String {
let c_str = unsafe { llama_cpp_sys_4::llama_print_system_info() };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
c_str.to_str().expect("system info is not valid UTF-8").to_owned()
}
#[must_use]
pub fn max_parallel_sequences() -> usize {
unsafe { llama_cpp_sys_4::llama_max_parallel_sequences() }
}
#[must_use]
pub fn max_tensor_buft_overrides() -> usize {
unsafe { llama_cpp_sys_4::llama_max_tensor_buft_overrides() }
}
#[must_use]
pub fn flash_attn_type_name(flash_attn_type: i32) -> String {
let c_str = unsafe { llama_cpp_sys_4::llama_flash_attn_type_name(flash_attn_type) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
c_str.to_str().expect("flash_attn_type_name is not valid UTF-8").to_owned()
}
#[must_use]
pub fn model_meta_key_str(key: u32) -> String {
let c_str = unsafe { llama_cpp_sys_4::llama_model_meta_key_str(key.try_into().unwrap()) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
c_str.to_str().expect("meta_key_str is not valid UTF-8").to_owned()
}
pub fn model_quantize(
fname_inp: &str,
fname_out: &str,
params: &quantize::QuantizeParams,
) -> std::result::Result<(), u32> {
let c_inp = std::ffi::CString::new(fname_inp).expect("input path contains null bytes");
let c_out = std::ffi::CString::new(fname_out).expect("output path contains null bytes");
let guard = params.to_raw();
let rc =
unsafe { llama_cpp_sys_4::llama_model_quantize(c_inp.as_ptr(), c_out.as_ptr(), &guard.raw) };
if rc == 0 { Ok(()) } else { Err(rc) }
}
#[must_use]
#[deprecated(since = "0.2.19", note = "use `QuantizeParams::new` instead")]
pub fn model_quantize_default_params() -> llama_cpp_sys_4::llama_model_quantize_params {
unsafe { llama_cpp_sys_4::llama_model_quantize_default_params() }
}
pub unsafe fn log_set(
callback: llama_cpp_sys_4::ggml_log_callback,
user_data: *mut std::ffi::c_void,
) {
llama_cpp_sys_4::llama_log_set(callback, user_data);
}
pub unsafe fn log_get(
log_callback: *mut llama_cpp_sys_4::ggml_log_callback,
user_data: *mut *mut std::ffi::c_void,
) {
llama_cpp_sys_4::llama_log_get(log_callback, user_data);
}
pub unsafe fn opt_init(
ctx: *mut llama_cpp_sys_4::llama_context,
model: *mut llama_cpp_sys_4::llama_model,
params: llama_cpp_sys_4::llama_opt_params,
) {
llama_cpp_sys_4::llama_opt_init(ctx, model, params);
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn opt_epoch(
ctx: *mut llama_cpp_sys_4::llama_context,
dataset: llama_cpp_sys_4::ggml_opt_dataset_t,
result_train: llama_cpp_sys_4::ggml_opt_result_t,
result_eval: llama_cpp_sys_4::ggml_opt_result_t,
idata_split: i64,
callback_train: llama_cpp_sys_4::ggml_opt_epoch_callback,
callback_eval: llama_cpp_sys_4::ggml_opt_epoch_callback,
) {
llama_cpp_sys_4::llama_opt_epoch(
ctx,
dataset,
result_train,
result_eval,
idata_split,
callback_train,
callback_eval,
);
}
pub unsafe fn opt_param_filter_all(
tensor: *const llama_cpp_sys_4::ggml_tensor,
userdata: *mut std::ffi::c_void,
) -> bool {
llama_cpp_sys_4::llama_opt_param_filter_all(tensor, userdata)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn params_fit(
path_model: *const std::ffi::c_char,
mparams: *mut llama_cpp_sys_4::llama_model_params,
cparams: *mut llama_cpp_sys_4::llama_context_params,
tensor_split: *mut f32,
tensor_buft_overrides: *mut llama_cpp_sys_4::llama_model_tensor_buft_override,
margins: *mut usize,
n_ctx_min: u32,
log_level: llama_cpp_sys_4::ggml_log_level,
) -> llama_cpp_sys_4::llama_params_fit_status {
llama_cpp_sys_4::llama_params_fit(
path_model,
mparams,
cparams,
tensor_split,
tensor_buft_overrides,
margins,
n_ctx_min,
log_level,
)
}