Skip to main content

llama_cpp_2/
lib.rs

1//! Bindings to the llama.cpp library.
2//!
3//! As llama.cpp is a very fast moving target, this crate does not attempt to create a stable API
4//! with all the rust idioms. Instead it provided safe wrappers around nearly direct bindings to
5//! llama.cpp. This makes it easier to keep up with the changes in llama.cpp, but does mean that
6//! the API is not as nice as it could be.
7//!
8//! # Examples
9//!
10//! - [simple](https://github.com/utilityai/llama-cpp-rs/tree/main/examples/simple)
11//! - [tools](https://github.com/utilityai/llama-cpp-rs/tree/main/examples/tools)
12//!
13//! # Feature Flags
14//!
15//! - `cuda` enables CUDA gpu support.
16//! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling.
17use std::ffi::{c_char, CStr, CString, NulError};
18use std::fmt::Debug;
19use std::num::NonZeroI32;
20
21use crate::llama_batch::BatchAddError;
22use std::os::raw::c_int;
23use std::path::PathBuf;
24use std::string::FromUtf8Error;
25
26pub mod context;
27pub mod gguf;
28pub mod llama_backend;
29pub mod llama_batch;
30#[cfg(feature = "llguidance")]
31pub(crate) mod llguidance_sampler;
32mod log;
33pub mod model;
34#[cfg(feature = "mtmd")]
35pub mod mtmd;
36pub mod sampling;
37pub mod timing;
38pub mod token;
39pub mod token_type;
40
41pub use crate::context::session::LlamaStateSeqFlags;
42
43#[cfg(feature = "common")]
44pub(crate) fn status_is_ok(status: llama_cpp_sys_2::llama_rs_status) -> bool {
45    status == llama_cpp_sys_2::LLAMA_RS_STATUS_OK
46}
47
48/// A failable result from a llama.cpp function.
49pub type Result<T> = std::result::Result<T, LlamaCppError>;
50
51/// All errors that can occur in the llama-cpp crate.
52#[derive(Debug, Eq, PartialEq, thiserror::Error)]
53pub enum LlamaCppError {
54    /// The backend was already initialized. This can generally be ignored as initializing the backend
55    /// is idempotent.
56    #[error("BackendAlreadyInitialized")]
57    BackendAlreadyInitialized,
58    /// There was an error while get the chat template from model.
59    #[error("{0}")]
60    ChatTemplateError(#[from] ChatTemplateError),
61    /// There was an error while decoding a batch.
62    #[error("{0}")]
63    DecodeError(#[from] DecodeError),
64    /// There was an error while encoding a batch.
65    #[error("{0}")]
66    EncodeError(#[from] EncodeError),
67    /// There was an error loading a model.
68    #[error("{0}")]
69    LlamaModelLoadError(#[from] LlamaModelLoadError),
70    /// There was an error creating a new model context.
71    #[error("{0}")]
72    LlamaContextLoadError(#[from] LlamaContextLoadError),
73    /// There was an error adding a token to a batch.
74    #[error["{0}"]]
75    BatchAddError(#[from] BatchAddError),
76    /// see [`EmbeddingsError`]
77    #[error(transparent)]
78    EmbeddingError(#[from] EmbeddingsError),
79    // See [`LlamaSamplerError`]
80    /// Backend device not found
81    #[error("Backend device {0} not found")]
82    BackendDeviceNotFound(usize),
83    /// Max devices exceeded
84    #[error("Max devices exceeded. Max devices is {0}")]
85    MaxDevicesExceeded(usize),
86    /// Failed to convert JSON schema to grammar.
87    #[cfg(feature = "common")]
88    #[error("JsonSchemaToGrammarError: {0}")]
89    JsonSchemaToGrammarError(String),
90    /// There was an error fitting model parameters to available memory.
91    #[cfg(feature = "common")]
92    #[error("{0}")]
93    FitError(#[from] crate::model::params::FitError),
94}
95
96/// There was an error while getting the chat template from a model.
97#[derive(Debug, Eq, PartialEq, thiserror::Error)]
98pub enum ChatTemplateError {
99    /// gguf has no chat template (by that name)
100    #[error("chat template not found - returned null pointer")]
101    MissingTemplate,
102
103    /// chat template contained a null byte
104    #[error("null byte in string {0}")]
105    NullError(#[from] NulError),
106
107    /// The chat template was not valid utf8.
108    #[error(transparent)]
109    Utf8Error(#[from] std::str::Utf8Error),
110}
111
112/// Failed fetching metadata value
113#[derive(Debug, Eq, PartialEq, thiserror::Error)]
114pub enum MetaValError {
115    /// The provided string contains an unexpected null-byte
116    #[error("null byte in string {0}")]
117    NullError(#[from] NulError),
118
119    /// The returned data contains invalid UTF8 data
120    #[error("FromUtf8Error {0}")]
121    FromUtf8Error(#[from] FromUtf8Error),
122
123    /// Got negative return value. This happens if the key or index queried does not exist.
124    #[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")]
125    NegativeReturn(i32),
126}
127
128/// Failed to Load context
129#[derive(Debug, Eq, PartialEq, thiserror::Error)]
130pub enum LlamaContextLoadError {
131    /// llama.cpp returned null
132    #[error("null reference from llama.cpp")]
133    NullReturn,
134}
135
136/// Failed to decode a batch.
137#[derive(Debug, Eq, PartialEq, thiserror::Error)]
138pub enum DecodeError {
139    /// No kv cache slot was available.
140    #[error("Decode Error 1: NoKvCacheSlot")]
141    NoKvCacheSlot,
142    /// The number of tokens in the batch was 0.
143    #[error("Decode Error -1: n_tokens == 0")]
144    NTokensZero,
145    /// An unknown error occurred.
146    #[error("Decode Error {0}: unknown")]
147    Unknown(c_int),
148}
149
150/// Failed to decode a batch.
151#[derive(Debug, Eq, PartialEq, thiserror::Error)]
152pub enum EncodeError {
153    /// No kv cache slot was available.
154    #[error("Encode Error 1: NoKvCacheSlot")]
155    NoKvCacheSlot,
156    /// The number of tokens in the batch was 0.
157    #[error("Encode Error -1: n_tokens == 0")]
158    NTokensZero,
159    /// An unknown error occurred.
160    #[error("Encode Error {0}: unknown")]
161    Unknown(c_int),
162}
163
164/// When embedding related functions fail
165#[derive(Debug, Eq, PartialEq, thiserror::Error)]
166pub enum EmbeddingsError {
167    /// Embeddings weren't enabled in the context options
168    #[error("Embeddings weren't enabled in the context options")]
169    NotEnabled,
170    /// Logits weren't enabled for the given token
171    #[error("Logits were not enabled for the given token")]
172    LogitsNotEnabled,
173    /// The given sequence index exceeds the max sequence id
174    #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")]
175    NonePoolType,
176}
177
178/// Errors that can occur when initializing a grammar sampler
179#[derive(Debug, Eq, PartialEq, thiserror::Error)]
180pub enum GrammarError {
181    /// The grammar root was not found in the grammar string
182    #[error("Grammar root not found in grammar string")]
183    RootNotFound,
184    /// The trigger word contains null bytes
185    #[error("Trigger word contains null bytes")]
186    TriggerWordNullBytes,
187    /// The grammar string or root contains null bytes
188    #[error("Grammar string or root contains null bytes")]
189    GrammarNullBytes,
190    /// The grammar call returned null
191    #[error("Grammar call returned null")]
192    NullGrammar,
193}
194
195/// Decode a error from llama.cpp into a [`DecodeError`].
196impl From<NonZeroI32> for DecodeError {
197    fn from(value: NonZeroI32) -> Self {
198        match value.get() {
199            1 => DecodeError::NoKvCacheSlot,
200            -1 => DecodeError::NTokensZero,
201            i => DecodeError::Unknown(i),
202        }
203    }
204}
205
206/// Encode a error from llama.cpp into a [`EncodeError`].
207impl From<NonZeroI32> for EncodeError {
208    fn from(value: NonZeroI32) -> Self {
209        match value.get() {
210            1 => EncodeError::NoKvCacheSlot,
211            -1 => EncodeError::NTokensZero,
212            i => EncodeError::Unknown(i),
213        }
214    }
215}
216
217/// An error that can occur when loading a model.
218#[derive(Debug, Eq, PartialEq, thiserror::Error)]
219pub enum LlamaModelLoadError {
220    /// There was a null byte in a provided string and thus it could not be converted to a C string.
221    #[error("null byte in string {0}")]
222    NullError(#[from] NulError),
223    /// llama.cpp returned a nullptr - this could be many different causes.
224    #[error("null result from llama cpp")]
225    NullResult,
226    /// Failed to convert the path to a rust str. This means the path was not valid unicode
227    #[error("failed to convert path {0} to str")]
228    PathToStrError(PathBuf),
229}
230
231/// An error that can occur when loading a model.
232#[derive(Debug, Eq, PartialEq, thiserror::Error)]
233pub enum LlamaLoraAdapterInitError {
234    /// There was a null byte in a provided string and thus it could not be converted to a C string.
235    #[error("null byte in string {0}")]
236    NullError(#[from] NulError),
237    /// llama.cpp returned a nullptr - this could be many different causes.
238    #[error("null result from llama cpp")]
239    NullResult,
240    /// Failed to convert the path to a rust str. This means the path was not valid unicode
241    #[error("failed to convert path {0} to str")]
242    PathToStrError(PathBuf),
243}
244
245/// An error that can occur when loading a model.
246#[derive(Debug, Eq, PartialEq, thiserror::Error)]
247pub enum LlamaLoraAdapterSetError {
248    /// llama.cpp returned a non-zero error code.
249    #[error("error code from llama cpp")]
250    ErrorResult(i32),
251}
252
253/// An error that can occur when loading a model.
254#[derive(Debug, Eq, PartialEq, thiserror::Error)]
255pub enum LlamaLoraAdapterRemoveError {
256    /// llama.cpp returned a non-zero error code.
257    #[error("error code from llama cpp")]
258    ErrorResult(i32),
259}
260
261/// get the time (in microseconds) according to llama.cpp
262/// ```
263/// # use llama_cpp_2::llama_time_us;
264/// # use llama_cpp_2::llama_backend::LlamaBackend;
265/// let backend = LlamaBackend::init().unwrap();
266/// let time = llama_time_us();
267/// assert!(time > 0);
268/// ```
269#[must_use]
270pub fn llama_time_us() -> i64 {
271    unsafe { llama_cpp_sys_2::llama_time_us() }
272}
273
274/// get the max number of devices according to llama.cpp (this is generally cuda devices)
275/// ```
276/// # use llama_cpp_2::max_devices;
277/// let max_devices = max_devices();
278/// assert!(max_devices >= 0);
279/// ```
280#[must_use]
281pub fn max_devices() -> usize {
282    unsafe { llama_cpp_sys_2::llama_max_devices() }
283}
284
285/// is memory mapping supported according to llama.cpp
286/// ```
287/// # use llama_cpp_2::mmap_supported;
288/// let mmap_supported = mmap_supported();
289/// if mmap_supported {
290///   println!("mmap_supported!");
291/// }
292/// ```
293#[must_use]
294pub fn mmap_supported() -> bool {
295    unsafe { llama_cpp_sys_2::llama_supports_mmap() }
296}
297
298/// is memory locking supported according to llama.cpp
299/// ```
300/// # use llama_cpp_2::mlock_supported;
301/// let mlock_supported = mlock_supported();
302/// if mlock_supported {
303///    println!("mlock_supported!");
304/// }
305/// ```
306#[must_use]
307pub fn mlock_supported() -> bool {
308    unsafe { llama_cpp_sys_2::llama_supports_mlock() }
309}
310
311/// Convert a JSON schema string into a llama.cpp grammar string.
312#[cfg(feature = "common")]
313pub fn json_schema_to_grammar(schema_json: &str) -> Result<String> {
314    let schema_cstr = CString::new(schema_json)
315        .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?;
316    let mut out = std::ptr::null_mut();
317    let rc = unsafe {
318        llama_cpp_sys_2::llama_rs_json_schema_to_grammar(schema_cstr.as_ptr(), false, &mut out)
319    };
320
321    let result = {
322        if !status_is_ok(rc) || out.is_null() {
323            return Err(LlamaCppError::JsonSchemaToGrammarError(format!(
324                "ffi error {}",
325                rc
326            )));
327        }
328        let grammar_bytes = unsafe { CStr::from_ptr(out) }.to_bytes().to_vec();
329        let grammar = String::from_utf8(grammar_bytes)
330            .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?;
331        Ok(grammar)
332    };
333
334    unsafe { llama_cpp_sys_2::llama_rs_string_free(out) };
335    result
336}
337
338#[cfg(all(test, feature = "common"))]
339mod tests {
340    use super::json_schema_to_grammar;
341
342    #[test]
343    fn json_schema_string_api_returns_grammar() {
344        let schema = r#"{
345            "type": "object",
346            "properties": {
347                "city": { "type": "string" },
348                "unit": { "enum": ["c", "f"] }
349            },
350            "required": ["city"]
351        }"#;
352
353        let grammar =
354            json_schema_to_grammar(schema).expect("string-based schema conversion should succeed");
355
356        assert!(grammar.contains("root ::="));
357    }
358}
359
360/// An error that can occur when converting a token to a string.
361#[derive(Debug, thiserror::Error, Clone)]
362#[non_exhaustive]
363pub enum TokenToStringError {
364    /// the token type was unknown
365    #[error("Unknown Token Type")]
366    UnknownTokenType,
367    /// There was insufficient buffer space to convert the token to a string.
368    #[error("Insufficient Buffer Space {0}")]
369    InsufficientBufferSpace(c_int),
370    /// The token was not valid utf8.
371    #[error("FromUtf8Error {0}")]
372    FromUtf8Error(#[from] FromUtf8Error),
373}
374
375/// Failed to convert a string to a token sequence.
376#[derive(Debug, thiserror::Error)]
377pub enum StringToTokenError {
378    /// the string contained a null byte and thus could not be converted to a c string.
379    #[error("{0}")]
380    NulError(#[from] NulError),
381    #[error("{0}")]
382    /// Failed to convert a provided integer to a [`c_int`].
383    CIntConversionError(#[from] std::num::TryFromIntError),
384}
385
386/// Failed to apply model chat template.
387#[derive(Debug, thiserror::Error)]
388pub enum NewLlamaChatMessageError {
389    /// the string contained a null byte and thus could not be converted to a c string.
390    #[error("{0}")]
391    NulError(#[from] NulError),
392}
393
394/// Failed to apply model chat template.
395#[derive(Debug, thiserror::Error)]
396pub enum ApplyChatTemplateError {
397    /// the string contained a null byte and thus could not be converted to a c string.
398    #[error("{0}")]
399    NulError(#[from] NulError),
400    /// the string could not be converted to utf8.
401    #[error("{0}")]
402    FromUtf8Error(#[from] FromUtf8Error),
403    /// llama.cpp returned a null pointer for the template result.
404    #[error("null result from llama.cpp")]
405    NullResult,
406    /// llama.cpp returned an error code.
407    #[error("ffi error {0}")]
408    FfiError(i32),
409}
410
411/// Failed to accept a token in a sampler.
412#[derive(Debug, thiserror::Error)]
413pub enum SamplerAcceptError {
414    /// llama.cpp returned an error code.
415    #[error("ffi error {0}")]
416    FfiError(i32),
417}
418
419/// Get the time in microseconds according to ggml
420///
421/// ```
422/// # use std::time::Duration;
423/// # use llama_cpp_2::llama_backend::LlamaBackend;
424/// let backend = LlamaBackend::init().unwrap();
425/// use llama_cpp_2::ggml_time_us;
426///
427/// let start = ggml_time_us();
428///
429/// std::thread::sleep(Duration::from_micros(10));
430///
431/// let end = ggml_time_us();
432///
433/// let elapsed = end - start;
434///
435/// assert!(elapsed >= 10)
436#[must_use]
437pub fn ggml_time_us() -> i64 {
438    unsafe { llama_cpp_sys_2::ggml_time_us() }
439}
440
441/// checks if mlock is supported
442///
443/// ```
444/// # use llama_cpp_2::llama_supports_mlock;
445///
446/// if llama_supports_mlock() {
447///   println!("mlock is supported!");
448/// } else {
449///   println!("mlock is not supported!");
450/// }
451/// ```
452#[must_use]
453pub fn llama_supports_mlock() -> bool {
454    unsafe { llama_cpp_sys_2::llama_supports_mlock() }
455}
456
457/// Backend device type
458#[derive(Debug, Clone, Copy, PartialEq, Eq)]
459pub enum LlamaBackendDeviceType {
460    /// CPU device
461    Cpu,
462    /// ACCEL device
463    Accelerator,
464    /// GPU device
465    Gpu,
466    /// iGPU device
467    IntegratedGpu,
468    /// Unknown device type
469    Unknown,
470}
471
472/// A ggml backend device
473///
474/// The index is can be used from `LlamaModelParams::with_devices` to select specific devices.
475#[derive(Debug, Clone)]
476pub struct LlamaBackendDevice {
477    /// The index of the device
478    ///
479    /// The index is can be used from `LlamaModelParams::with_devices` to select specific devices.
480    pub index: usize,
481    /// The name of the device (e.g. "Vulkan0")
482    pub name: String,
483    /// A description of the device (e.g. "NVIDIA GeForce RTX 3080")
484    pub description: String,
485    /// The backend of the device (e.g. "Vulkan", "CUDA", "CPU")
486    pub backend: String,
487    /// Total memory of the device in bytes
488    pub memory_total: usize,
489    /// Free memory of the device in bytes
490    pub memory_free: usize,
491    /// Device type
492    pub device_type: LlamaBackendDeviceType,
493}
494
495/// List ggml backend devices
496#[must_use]
497pub fn list_llama_ggml_backend_devices() -> Vec<LlamaBackendDevice> {
498    let mut devices = Vec::new();
499    for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
500        fn cstr_to_string(ptr: *const c_char) -> String {
501            if ptr.is_null() {
502                String::new()
503            } else {
504                unsafe { std::ffi::CStr::from_ptr(ptr) }
505                    .to_string_lossy()
506                    .to_string()
507            }
508        }
509        let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
510        let props = unsafe {
511            let mut props = std::mem::zeroed();
512            llama_cpp_sys_2::ggml_backend_dev_get_props(dev, &raw mut props);
513            props
514        };
515        let name = cstr_to_string(props.name);
516        let description = cstr_to_string(props.description);
517        let backend = unsafe { llama_cpp_sys_2::ggml_backend_dev_backend_reg(dev) };
518        let backend_name = unsafe { llama_cpp_sys_2::ggml_backend_reg_name(backend) };
519        let backend = cstr_to_string(backend_name);
520        let memory_total = props.memory_total;
521        let memory_free = props.memory_free;
522        let device_type = match props.type_ {
523            llama_cpp_sys_2::GGML_BACKEND_DEVICE_TYPE_CPU => LlamaBackendDeviceType::Cpu,
524            llama_cpp_sys_2::GGML_BACKEND_DEVICE_TYPE_ACCEL => LlamaBackendDeviceType::Accelerator,
525            llama_cpp_sys_2::GGML_BACKEND_DEVICE_TYPE_GPU => LlamaBackendDeviceType::Gpu,
526            llama_cpp_sys_2::GGML_BACKEND_DEVICE_TYPE_IGPU => LlamaBackendDeviceType::IntegratedGpu,
527            _ => LlamaBackendDeviceType::Unknown,
528        };
529        devices.push(LlamaBackendDevice {
530            index: i,
531            name,
532            description,
533            backend,
534            memory_total,
535            memory_free,
536            device_type,
537        });
538    }
539    devices
540}
541
542/// Options to configure how llama.cpp logs are intercepted.
543#[derive(Default, Debug, Clone)]
544pub struct LogOptions {
545    disabled: bool,
546}
547
548impl LogOptions {
549    /// If enabled, logs are sent to tracing. If disabled, all logs are suppressed. Default is for
550    /// logs to be sent to tracing.
551    #[must_use]
552    pub fn with_logs_enabled(mut self, enabled: bool) -> Self {
553        self.disabled = !enabled;
554        self
555    }
556}
557
558extern "C" fn logs_to_trace(
559    level: llama_cpp_sys_2::ggml_log_level,
560    text: *const ::std::os::raw::c_char,
561    data: *mut ::std::os::raw::c_void,
562) {
563    // In the "fast-path" (i.e. the vast majority of logs) we want to avoid needing to take the log state
564    // lock at all. Similarly, we try to avoid any heap allocations within this function. This is accomplished
565    // by being a dummy pass-through to tracing in the normal case of DEBUG/INFO/WARN/ERROR logs that are
566    // newline terminated and limiting the slow-path of locks and/or heap allocations for other cases.
567    use std::borrow::Borrow;
568
569    let log_state = unsafe { &*(data as *const log::State) };
570
571    if log_state.options.disabled {
572        return;
573    }
574
575    // If the log level is disabled, we can just return early
576    if !log_state.is_enabled_for_level(level) {
577        log_state.update_previous_level_for_disabled_log(level);
578        return;
579    }
580
581    let text = unsafe { std::ffi::CStr::from_ptr(text) };
582    let text = text.to_string_lossy();
583    let text: &str = text.borrow();
584
585    // As best I can tell llama.cpp / ggml require all log format strings at call sites to have the '\n'.
586    // If it's missing, it means that you expect more logs via CONT (or there's a typo in the codebase). To
587    // distinguish typo from intentional support for CONT, we have to buffer until the next message comes in
588    // to know how to flush it.
589
590    if level == llama_cpp_sys_2::GGML_LOG_LEVEL_CONT {
591        log_state.cont_buffered_log(text);
592    } else if text.ends_with('\n') {
593        log_state.emit_non_cont_line(level, text);
594    } else {
595        log_state.buffer_non_cont(level, text);
596    }
597}
598
599/// Redirect llama.cpp logs into tracing.
600pub fn send_logs_to_tracing(options: LogOptions) {
601    // TODO: Reinitialize the state to support calling send_logs_to_tracing multiple times.
602
603    // We set up separate log states for llama.cpp and ggml to make sure that CONT logs between the two
604    // can't possibly interfere with each other. In other words, if llama.cpp emits a log without a trailing
605    // newline and calls a GGML function, the logs won't be weirdly intermixed and instead we'll llama.cpp logs
606    // will CONT previous llama.cpp logs and GGML logs will CONT previous ggml logs.
607    let llama_heap_state = Box::as_ref(
608        log::LLAMA_STATE
609            .get_or_init(|| Box::new(log::State::new(log::Module::LlamaCpp, options.clone()))),
610    ) as *const _;
611    let ggml_heap_state = Box::as_ref(
612        log::GGML_STATE.get_or_init(|| Box::new(log::State::new(log::Module::GGML, options))),
613    ) as *const _;
614
615    unsafe {
616        // GGML has to be set after llama since setting llama sets ggml as well.
617        llama_cpp_sys_2::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _);
618        llama_cpp_sys_2::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _);
619    }
620}