Skip to main content

llama_cpp_4/
lib.rs

1//! Bindings to the llama.cpp library.
2//!
3//! As llama.cpp is a very fast moving target, this crate does not attempt to create a stable API
4//! with all the rust idioms. Instead it provides safe wrappers around nearly direct bindings to
5//! llama.cpp. This makes it easier to keep up with the changes in llama.cpp, but does mean that
6//! the API is not as nice as it could be.
7//!
8//! # Examples
9//!
10//! - [simple](https://github.com/eugenehp/llama-cpp-rs/tree/main/examples/simple)
11//! - [chat](https://github.com/eugenehp/llama-cpp-rs/tree/main/examples/chat)
12//! - [embeddings](https://github.com/eugenehp/llama-cpp-rs/tree/main/examples/embeddings)
13//! - [server](https://github.com/eugenehp/llama-cpp-rs/tree/main/examples/server)
14//!
15//! # Feature Flags
16//!
17//! - `cuda` enables CUDA GPU support.
18//! - `metal` enables Apple Metal GPU support.
19//! - `vulkan` enables Vulkan GPU support (AMD / Intel / cross-platform).
20//! - `native` enables host-CPU optimisations (`-march=native`).
21//! - `openmp` enables OpenMP multi-core CPU parallelism (on by default).
22//! - `rpc` enables RPC backend support for distributed inference across multiple machines.
23//! - `mtmd` enables multimodal (image + audio) support via `libmtmd`.
24use std::ffi::NulError;
25use std::fmt::Debug;
26use std::num::NonZeroI32;
27
28use crate::llama_batch::BatchAddError;
29use std::os::raw::c_int;
30use std::path::PathBuf;
31use std::string::FromUtf8Error;
32
33pub mod common;
34pub mod context;
35#[cfg(feature = "ggml")]
36pub mod ggml;
37pub mod llama_backend;
38pub mod llama_batch;
39pub mod model;
40pub mod sampling;
41pub mod token;
42pub mod token_type;
43
44#[cfg(feature = "rpc")]
45pub mod rpc;
46
47#[cfg(feature = "mtmd")]
48pub mod mtmd;
49
50/// A failable result from a llama.cpp function.
51pub type Result<T> = std::result::Result<T, LLamaCppError>;
52
53/// All errors that can occur in the llama-cpp crate.
54#[derive(Debug, Eq, PartialEq, thiserror::Error)]
55pub enum LLamaCppError {
56    /// The backend was already initialized. This can generally be ignored as initializing the backend
57    /// is idempotent.
58    #[error("BackendAlreadyInitialized")]
59    BackendAlreadyInitialized,
60    /// There was an error while get the chat template from model.
61    #[error("{0}")]
62    ChatTemplateError(#[from] ChatTemplateError),
63    /// There was an error while decoding a batch.
64    #[error("{0}")]
65    DecodeError(#[from] DecodeError),
66    /// There was an error while encoding a batch.
67    #[error("{0}")]
68    EncodeError(#[from] EncodeError),
69    /// There was an error loading a model.
70    #[error("{0}")]
71    LlamaModelLoadError(#[from] LlamaModelLoadError),
72    /// There was an error creating a new model context.
73    #[error("{0}")]
74    LlamaContextLoadError(#[from] LlamaContextLoadError),
75    /// There was an error adding a token to a batch.
76    #[error["{0}"]]
77    BatchAddError(#[from] BatchAddError),
78    /// see [`EmbeddingsError`]
79    #[error(transparent)]
80    EmbeddingError(#[from] EmbeddingsError),
81}
82
83/// There was an error while getting the chat template from a model.
84#[derive(Debug, Eq, PartialEq, thiserror::Error)]
85pub enum ChatTemplateError {
86    /// the buffer was too small.
87    #[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
88    BuffSizeError(usize),
89    /// gguf has no chat template
90    #[error("the model has no meta val - returned code {0}")]
91    MissingTemplate(i32),
92    /// The chat template was not valid utf8.
93    #[error(transparent)]
94    Utf8Error(#[from] std::str::Utf8Error),
95}
96
97/// Error retrieving a string from the model (e.g. description, metadata key/value).
98#[derive(Debug, Eq, PartialEq, thiserror::Error)]
99pub enum StringFromModelError {
100    /// The C function returned a negative error code.
101    #[error("llama.cpp returned error code {0}")]
102    ReturnedError(i32),
103    /// The returned bytes were not valid UTF-8.
104    #[error(transparent)]
105    Utf8Error(#[from] std::str::Utf8Error),
106}
107
108/// Failed to Load context
109#[derive(Debug, Eq, PartialEq, thiserror::Error)]
110pub enum LlamaContextLoadError {
111    /// llama.cpp returned null
112    #[error("null reference from llama.cpp")]
113    NullReturn,
114}
115
116/// Failed to decode a batch.
117#[derive(Debug, Eq, PartialEq, thiserror::Error)]
118pub enum DecodeError {
119    /// No kv cache slot was available.
120    #[error("Decode Error 1: NoKvCacheSlot")]
121    NoKvCacheSlot,
122    /// The number of tokens in the batch was 0.
123    #[error("Decode Error -1: n_tokens == 0")]
124    NTokensZero,
125    /// An unknown error occurred.
126    #[error("Decode Error {0}: unknown")]
127    Unknown(c_int),
128}
129
130/// Failed to decode a batch.
131#[derive(Debug, Eq, PartialEq, thiserror::Error)]
132pub enum EncodeError {
133    /// No kv cache slot was available.
134    #[error("Encode Error 1: NoKvCacheSlot")]
135    NoKvCacheSlot,
136    /// The number of tokens in the batch was 0.
137    #[error("Encode Error -1: n_tokens == 0")]
138    NTokensZero,
139    /// An unknown error occurred.
140    #[error("Encode Error {0}: unknown")]
141    Unknown(c_int),
142}
143
144/// When embedding related functions fail
145#[derive(Debug, Eq, PartialEq, thiserror::Error)]
146pub enum EmbeddingsError {
147    /// Embeddings weren't enabled in the context options
148    #[error("Embeddings weren't enabled in the context options")]
149    NotEnabled,
150    /// Logits weren't enabled for the given token
151    #[error("Logits were not enabled for the given token")]
152    LogitsNotEnabled,
153    /// The given sequence index exceeds the max sequence id
154    #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")]
155    NonePoolType,
156}
157
158/// Decode a error from llama.cpp into a [`DecodeError`].
159impl From<NonZeroI32> for DecodeError {
160    fn from(value: NonZeroI32) -> Self {
161        match value.get() {
162            1 => DecodeError::NoKvCacheSlot,
163            -1 => DecodeError::NTokensZero,
164            i => DecodeError::Unknown(i),
165        }
166    }
167}
168
169/// Encode a error from llama.cpp into a [`EncodeError`].
170impl From<NonZeroI32> for EncodeError {
171    fn from(value: NonZeroI32) -> Self {
172        match value.get() {
173            1 => EncodeError::NoKvCacheSlot,
174            -1 => EncodeError::NTokensZero,
175            i => EncodeError::Unknown(i),
176        }
177    }
178}
179
180/// An error that can occur when loading a model.
181#[derive(Debug, Eq, PartialEq, thiserror::Error)]
182pub enum LlamaModelLoadError {
183    /// There was a null byte in a provided string and thus it could not be converted to a C string.
184    #[error("null byte in string {0}")]
185    NullError(#[from] NulError),
186    /// llama.cpp returned a nullptr - this could be many different causes.
187    #[error("null result from llama cpp")]
188    NullResult,
189    /// Failed to convert the path to a rust str. This means the path was not valid unicode
190    #[error("failed to convert path {0} to str")]
191    PathToStrError(PathBuf),
192}
193
194/// An error that can occur when loading a model.
195#[derive(Debug, Eq, PartialEq, thiserror::Error)]
196pub enum LlamaLoraAdapterInitError {
197    /// There was a null byte in a provided string and thus it could not be converted to a C string.
198    #[error("null byte in string {0}")]
199    NullError(#[from] NulError),
200    /// llama.cpp returned a nullptr - this could be many different causes.
201    #[error("null result from llama cpp")]
202    NullResult,
203    /// Failed to convert the path to a rust str. This means the path was not valid unicode
204    #[error("failed to convert path {0} to str")]
205    PathToStrError(PathBuf),
206}
207
208/// An error that can occur when loading a model.
209#[derive(Debug, Eq, PartialEq, thiserror::Error)]
210pub enum LlamaLoraAdapterSetError {
211    /// llama.cpp returned a non-zero error code.
212    #[error("error code from llama cpp")]
213    ErrorResult(i32),
214}
215
216/// An error that can occur when loading a model.
217#[derive(Debug, Eq, PartialEq, thiserror::Error)]
218pub enum LlamaLoraAdapterRemoveError {
219    /// llama.cpp returned a non-zero error code.
220    #[error("error code from llama cpp")]
221    ErrorResult(i32),
222}
223
224/// get the time (in microseconds) according to llama.cpp
225/// ```
226/// # use llama_cpp_4::llama_time_us;
227/// let time = llama_time_us();
228/// assert!(time > 0);
229/// ```
230#[must_use]
231pub fn llama_time_us() -> i64 {
232    unsafe { llama_cpp_sys_4::llama_time_us() }
233}
234
235/// get the max number of devices according to llama.cpp (this is generally cuda devices)
236/// ```
237/// # use llama_cpp_4::max_devices;
238/// let max_devices = max_devices();
239/// assert!(max_devices >= 0);
240/// ```
241#[must_use]
242pub fn max_devices() -> usize {
243    unsafe { llama_cpp_sys_4::llama_max_devices() }
244}
245
246/// is memory mapping supported according to llama.cpp
247/// ```
248/// # use llama_cpp_4::mmap_supported;
249/// let mmap_supported = mmap_supported();
250/// if mmap_supported {
251///   println!("mmap_supported!");
252/// }
253/// ```
254#[must_use]
255pub fn mmap_supported() -> bool {
256    unsafe { llama_cpp_sys_4::llama_supports_mmap() }
257}
258
259/// is memory locking supported according to llama.cpp
260/// ```
261/// # use llama_cpp_4::mlock_supported;
262/// let mlock_supported = mlock_supported();
263/// if mlock_supported {
264///    println!("mlock_supported!");
265/// }
266/// ```
267#[must_use]
268pub fn mlock_supported() -> bool {
269    unsafe { llama_cpp_sys_4::llama_supports_mlock() }
270}
271
272/// An error that can occur when converting a token to a string.
273#[derive(Debug, thiserror::Error, Clone)]
274#[non_exhaustive]
275pub enum TokenToStringError {
276    /// the token type was unknown
277    #[error("Unknown Token Type")]
278    UnknownTokenType,
279    /// There was insufficient buffer space to convert the token to a string.
280    #[error("Insufficient Buffer Space {0}")]
281    InsufficientBufferSpace(c_int),
282    /// The token was not valid utf8.
283    #[error("FromUtf8Error {0}")]
284    FromUtf8Error(#[from] FromUtf8Error),
285}
286
287/// Failed to convert a string to a token sequence.
288#[derive(Debug, thiserror::Error)]
289pub enum StringToTokenError {
290    /// the string contained a null byte and thus could not be converted to a c string.
291    #[error("{0}")]
292    NulError(#[from] NulError),
293    #[error("{0}")]
294    /// Failed to convert a provided integer to a [`c_int`].
295    CIntConversionError(#[from] std::num::TryFromIntError),
296}
297
298/// Failed to apply model chat template.
299#[derive(Debug, thiserror::Error)]
300pub enum NewLlamaChatMessageError {
301    /// the string contained a null byte and thus could not be converted to a c string.
302    #[error("{0}")]
303    NulError(#[from] NulError),
304}
305
306/// Failed to apply model chat template.
307#[derive(Debug, thiserror::Error)]
308pub enum ApplyChatTemplateError {
309    /// the buffer was too small.
310    #[error("The buffer was too small. Please contact a maintainer and we will update it.")]
311    BuffSizeError,
312    /// the string contained a null byte and thus could not be converted to a c string.
313    #[error("{0}")]
314    NulError(#[from] NulError),
315    /// the string could not be converted to utf8.
316    #[error("{0}")]
317    FromUtf8Error(#[from] FromUtf8Error),
318}
319
320/// Get the time in microseconds according to ggml
321///
322/// ```
323/// # use std::time::Duration;
324/// use llama_cpp_4::ggml_time_us;
325///
326/// let start = ggml_time_us();
327///
328/// std::thread::sleep(Duration::from_micros(10));
329///
330/// let end = ggml_time_us();
331///
332/// let elapsed = end - start;
333///
334/// assert!(elapsed >= 10)
335#[must_use]
336pub fn ggml_time_us() -> i64 {
337    unsafe { llama_cpp_sys_4::ggml_time_us() }
338}
339
340/// Checks if mlock is supported.
341///
342/// ```
343/// # use llama_cpp_4::llama_supports_mlock;
344///
345/// if llama_supports_mlock() {
346///   println!("mlock is supported!");
347/// } else {
348///   println!("mlock is not supported!");
349/// }
350/// ```
351#[must_use]
352pub fn llama_supports_mlock() -> bool {
353    unsafe { llama_cpp_sys_4::llama_supports_mlock() }
354}
355
356/// Checks if GPU offload is supported.
357///
358/// Returns `true` if the library was compiled with GPU support (CUDA, Metal, Vulkan, etc.).
359#[must_use]
360pub fn supports_gpu_offload() -> bool {
361    unsafe { llama_cpp_sys_4::llama_supports_gpu_offload() }
362}
363
364/// Checks if RPC backend is supported.
365///
366/// Returns `true` if the library was compiled with RPC support.
367#[must_use]
368pub fn supports_rpc() -> bool {
369    unsafe { llama_cpp_sys_4::llama_supports_rpc() }
370}
371
372/// Get system information string.
373///
374/// Returns a string containing CPU features, build info, and other system details.
375///
376/// # Panics
377///
378/// Panics if the returned string is not valid UTF-8.
379#[must_use]
380pub fn print_system_info() -> String {
381    let c_str = unsafe { llama_cpp_sys_4::llama_print_system_info() };
382    let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
383    c_str.to_str().expect("system info is not valid UTF-8").to_owned()
384}
385
386/// Get the maximum number of parallel sequences supported.
387#[must_use]
388pub fn max_parallel_sequences() -> usize {
389    unsafe { llama_cpp_sys_4::llama_max_parallel_sequences() }
390}
391
392/// Get the maximum number of tensor buffer type overrides.
393#[must_use]
394pub fn max_tensor_buft_overrides() -> usize {
395    unsafe { llama_cpp_sys_4::llama_max_tensor_buft_overrides() }
396}
397
398/// Get the name of a flash attention type.
399///
400/// # Panics
401///
402/// Panics if the returned string is not valid UTF-8.
403#[must_use]
404pub fn flash_attn_type_name(flash_attn_type: i32) -> String {
405    let c_str = unsafe { llama_cpp_sys_4::llama_flash_attn_type_name(flash_attn_type) };
406    let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
407    c_str.to_str().expect("flash_attn_type_name is not valid UTF-8").to_owned()
408}
409
410/// Get the string representation of a model metadata key.
411///
412/// # Panics
413///
414/// Panics if the returned string is not valid UTF-8.
415#[must_use]
416pub fn model_meta_key_str(key: u32) -> String {
417    let c_str = unsafe { llama_cpp_sys_4::llama_model_meta_key_str(key.try_into().unwrap()) };
418    let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
419    c_str.to_str().expect("meta_key_str is not valid UTF-8").to_owned()
420}
421
422/// Quantize a model file.
423///
424/// # Parameters
425///
426/// - `fname_inp`: Path to the input model file.
427/// - `fname_out`: Path to the output quantized model file.
428/// - `params`: Quantization parameters. Use `None` for defaults.
429///
430/// # Returns
431///
432/// Returns 0 on success, non-zero on failure.
433///
434/// # Panics
435///
436/// Panics if the paths contain null bytes.
437#[must_use]
438pub fn model_quantize(
439    fname_inp: &str,
440    fname_out: &str,
441    params: Option<&llama_cpp_sys_4::llama_model_quantize_params>,
442) -> u32 {
443    let c_inp = std::ffi::CString::new(fname_inp).expect("input path contains null bytes");
444    let c_out = std::ffi::CString::new(fname_out).expect("output path contains null bytes");
445    let default_params = unsafe { llama_cpp_sys_4::llama_model_quantize_default_params() };
446    let params = params.unwrap_or(&default_params);
447    unsafe { llama_cpp_sys_4::llama_model_quantize(c_inp.as_ptr(), c_out.as_ptr(), params) }
448}
449
450/// Get default quantization parameters.
451#[must_use]
452pub fn model_quantize_default_params() -> llama_cpp_sys_4::llama_model_quantize_params {
453    unsafe { llama_cpp_sys_4::llama_model_quantize_default_params() }
454}
455
456/// Set the log callback.
457///
458/// # Safety
459///
460/// The callback and user data must remain valid for the lifetime of the application
461/// or until the callback is replaced.
462pub unsafe fn log_set(
463    callback: llama_cpp_sys_4::ggml_log_callback,
464    user_data: *mut std::ffi::c_void,
465) {
466    llama_cpp_sys_4::llama_log_set(callback, user_data);
467}
468
469/// Get the current log callback and user data.
470///
471/// # Safety
472///
473/// The caller must ensure the pointers are valid.
474pub unsafe fn log_get(
475    log_callback: *mut llama_cpp_sys_4::ggml_log_callback,
476    user_data: *mut *mut std::ffi::c_void,
477) {
478    llama_cpp_sys_4::llama_log_get(log_callback, user_data);
479}
480
481/// Initialize optimizer state for fine-tuning.
482///
483/// # Safety
484///
485/// The context and model must be valid and compatible.
486pub unsafe fn opt_init(
487    ctx: *mut llama_cpp_sys_4::llama_context,
488    model: *mut llama_cpp_sys_4::llama_model,
489    params: llama_cpp_sys_4::llama_opt_params,
490) {
491    llama_cpp_sys_4::llama_opt_init(ctx, model, params);
492}
493
494/// Run one training epoch.
495///
496/// # Safety
497///
498/// All pointers and handles must be valid.
499#[allow(clippy::too_many_arguments)]
500pub unsafe fn opt_epoch(
501    ctx: *mut llama_cpp_sys_4::llama_context,
502    dataset: llama_cpp_sys_4::ggml_opt_dataset_t,
503    result_train: llama_cpp_sys_4::ggml_opt_result_t,
504    result_eval: llama_cpp_sys_4::ggml_opt_result_t,
505    idata_split: i64,
506    callback_train: llama_cpp_sys_4::ggml_opt_epoch_callback,
507    callback_eval: llama_cpp_sys_4::ggml_opt_epoch_callback,
508) {
509    llama_cpp_sys_4::llama_opt_epoch(
510        ctx,
511        dataset,
512        result_train,
513        result_eval,
514        idata_split,
515        callback_train,
516        callback_eval,
517    );
518}
519
520/// Parameter filter that accepts all tensors (for use with [`opt_init`]).
521///
522/// # Safety
523///
524/// The tensor pointer must be valid.
525pub unsafe fn opt_param_filter_all(
526    tensor: *const llama_cpp_sys_4::ggml_tensor,
527    userdata: *mut std::ffi::c_void,
528) -> bool {
529    llama_cpp_sys_4::llama_opt_param_filter_all(tensor, userdata)
530}
531
532/// Auto-fit model and context parameters for available memory.
533///
534/// # Safety
535///
536/// All pointers must be valid.
537#[allow(clippy::too_many_arguments)]
538pub unsafe fn params_fit(
539    path_model: *const std::ffi::c_char,
540    mparams: *mut llama_cpp_sys_4::llama_model_params,
541    cparams: *mut llama_cpp_sys_4::llama_context_params,
542    tensor_split: *mut f32,
543    tensor_buft_overrides: *mut llama_cpp_sys_4::llama_model_tensor_buft_override,
544    margins: *mut usize,
545    n_ctx_min: u32,
546    log_level: llama_cpp_sys_4::ggml_log_level,
547) -> llama_cpp_sys_4::llama_params_fit_status {
548    llama_cpp_sys_4::llama_params_fit(
549        path_model,
550        mparams,
551        cparams,
552        tensor_split,
553        tensor_buft_overrides,
554        margins,
555        n_ctx_min,
556        log_level,
557    )
558}