Skip to main content

llama_cpp_4/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9use std::slice;
10
11use llama_cpp_sys_4::{
12    llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
13    llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
14    llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15    llama_model_free, llama_model_get_device, llama_model_get_vocab, llama_model_has_decoder,
16    llama_model_has_encoder, llama_model_is_diffusion, llama_model_is_hybrid,
17    llama_model_is_recurrent, llama_model_load_from_file, llama_model_load_from_splits,
18    llama_model_meta_count, llama_model_meta_key_by_index, llama_model_meta_val_str,
19    llama_model_meta_val_str_by_index, llama_model_n_cls_out, llama_model_n_ctx_train,
20    llama_model_n_devices, llama_model_n_embd, llama_model_n_embd_inp, llama_model_n_embd_out,
21    llama_model_n_expert, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
22    llama_model_n_layer_nextn, llama_model_n_params, llama_model_n_swa,
23    llama_model_rope_freq_scale_train, llama_model_rope_type, llama_model_save_to_file,
24    llama_model_size, llama_model_target_layer_ids, llama_model_target_layer_ids_n,
25    llama_split_path, llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab,
26    llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
27};
28
29use crate::context::params::LlamaContextParams;
30use crate::context::LlamaContext;
31use crate::llama_backend::LlamaBackend;
32use crate::model::params::LlamaModelParams;
33use crate::token::LlamaToken;
34use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
35use crate::{
36    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
37    LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
38    TokenToStringError,
39};
40
41pub mod params;
42
43/// Opaque ggml backend device handle returned by [`LlamaModel::get_device`].
44///
45/// Use [`Self::name`], [`Self::description`], [`Self::device_type`], and
46/// [`Self::memory`] to inspect the device. The handle is valid for the lifetime
47/// of the parent [`LlamaModel`].
48#[derive(Debug, Copy, Clone, PartialEq, Eq)]
49pub struct LlamaBackendDevice {
50    pub(crate) dev: llama_cpp_sys_4::ggml_backend_dev_t,
51}
52
53/// Backend device class (CPU, discrete GPU, integrated GPU, …).
54#[repr(i32)]
55#[derive(Copy, Clone, Debug, PartialEq, Eq)]
56pub enum LlamaBackendDeviceType {
57    /// Host CPU backend.
58    Cpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU.cast_signed(),
59    /// Discrete GPU.
60    Gpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU.cast_signed(),
61    /// Integrated GPU.
62    IntegratedGpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU.cast_signed(),
63    /// Accelerator device (e.g. BLAS / Hexagon).
64    Accel = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL.cast_signed(),
65    /// Meta / placeholder device entry.
66    Meta = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_META.cast_signed(),
67}
68
69impl From<llama_cpp_sys_4::ggml_backend_dev_type> for LlamaBackendDeviceType {
70    fn from(value: llama_cpp_sys_4::ggml_backend_dev_type) -> Self {
71        match value {
72            llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU => Self::Cpu,
73            llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU => Self::Gpu,
74            llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU => Self::IntegratedGpu,
75            llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL => Self::Accel,
76            _ => Self::Meta,
77        }
78    }
79}
80
81impl LlamaBackendDevice {
82    /// Human-readable device name (e.g. `CUDA0`, `Metal`).
83    ///
84    /// # Errors
85    ///
86    /// Returns an error when the name pointer is null or not valid UTF-8.
87    pub fn name(&self) -> Result<&str, StringFromModelError> {
88        let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_name(self.dev) };
89        if ptr.is_null() {
90            return Err(StringFromModelError::ReturnedError(-1));
91        }
92        let cstr = unsafe { CStr::from_ptr(ptr) };
93        cstr.to_str().map_err(StringFromModelError::Utf8Error)
94    }
95
96    /// Longer device description (often includes hardware name).
97    ///
98    /// # Errors
99    ///
100    /// Returns an error when the description pointer is null or not valid UTF-8.
101    pub fn description(&self) -> Result<&str, StringFromModelError> {
102        let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_description(self.dev) };
103        if ptr.is_null() {
104            return Err(StringFromModelError::ReturnedError(-1));
105        }
106        let cstr = unsafe { CStr::from_ptr(ptr) };
107        cstr.to_str().map_err(StringFromModelError::Utf8Error)
108    }
109
110    /// Device class (CPU, GPU, integrated GPU, …).
111    #[must_use]
112    pub fn device_type(&self) -> LlamaBackendDeviceType {
113        unsafe { llama_cpp_sys_4::ggml_backend_dev_type(self.dev).into() }
114    }
115
116    /// Device memory `(free_bytes, total_bytes)`.
117    #[must_use]
118    pub fn memory(&self) -> (usize, usize) {
119        let mut free = 0usize;
120        let mut total = 0usize;
121        unsafe {
122            llama_cpp_sys_4::ggml_backend_dev_memory(self.dev, &raw mut free, &raw mut total);
123        }
124        (free, total)
125    }
126}
127
128/// Iterator over [`LlamaBackendDevice`] handles for a loaded model.
129///
130/// # Examples
131///
132/// ```no_run
133/// use llama_cpp_4::prelude::*;
134///
135/// fn main() {
136///     let backend = LlamaBackend::init().unwrap();
137///     let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default()).unwrap();
138///     for dev in model.devices() {
139///         let (free, total) = dev.memory();
140///         println!("{}: {} / {} bytes free", dev.name().unwrap(), free, total);
141///     }
142/// }
143/// ```
144#[derive(Debug, Clone, Copy)]
145pub struct LlamaBackendDevices<'a> {
146    model: &'a LlamaModel,
147    next: i32,
148}
149
150#[allow(clippy::copy_iterator)]
151impl Iterator for LlamaBackendDevices<'_> {
152    type Item = LlamaBackendDevice;
153
154    fn next(&mut self) -> Option<Self::Item> {
155        let dev = self.model.get_device(self.next)?;
156        self.next += 1;
157        Some(dev)
158    }
159
160    fn size_hint(&self) -> (usize, Option<usize>) {
161        let remaining = usize::try_from((self.model.n_devices() - self.next).max(0)).unwrap_or(0);
162        (remaining, Some(remaining))
163    }
164}
165
166impl ExactSizeIterator for LlamaBackendDevices<'_> {}
167
168/// A safe wrapper around `llama_model`.
169#[derive(Debug)]
170#[repr(transparent)]
171#[allow(clippy::module_name_repetitions)]
172pub struct LlamaModel {
173    pub(crate) model: NonNull<llama_model>,
174}
175
176/// A safe wrapper around `llama_vocab`.
177#[derive(Debug)]
178#[repr(transparent)]
179#[allow(clippy::module_name_repetitions)]
180pub struct LlamaVocab {
181    pub(crate) vocab: NonNull<llama_vocab>,
182}
183
184impl LlamaVocab {
185    /// Get the number of tokens in the vocabulary.
186    #[must_use]
187    pub fn n_tokens(&self) -> i32 {
188        unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
189    }
190
191    /// Get the vocabulary type.
192    ///
193    /// # Panics
194    ///
195    /// Panics if the C API returns a vocabulary type that does not fit in `u32`.
196    #[must_use]
197    pub fn vocab_type(&self) -> u32 {
198        unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()) as u32 }
199    }
200
201    /// Get the BOS token.
202    #[must_use]
203    pub fn bos(&self) -> LlamaToken {
204        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
205    }
206
207    /// Get the EOS token.
208    #[must_use]
209    pub fn eos(&self) -> LlamaToken {
210        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
211    }
212
213    /// Get the EOT (end of turn) token.
214    #[must_use]
215    pub fn eot(&self) -> LlamaToken {
216        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
217    }
218
219    /// Get the CLS (classification) token.
220    #[must_use]
221    pub fn cls(&self) -> LlamaToken {
222        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
223    }
224
225    /// Get the SEP (separator) token.
226    #[must_use]
227    pub fn sep(&self) -> LlamaToken {
228        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
229    }
230
231    /// Get the NL (newline) token.
232    #[must_use]
233    pub fn nl(&self) -> LlamaToken {
234        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
235    }
236
237    /// Get the PAD (padding) token.
238    #[must_use]
239    pub fn pad(&self) -> LlamaToken {
240        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
241    }
242
243    /// Get the FIM prefix token.
244    #[must_use]
245    pub fn fim_pre(&self) -> LlamaToken {
246        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
247    }
248
249    /// Get the FIM suffix token.
250    #[must_use]
251    pub fn fim_suf(&self) -> LlamaToken {
252        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
253    }
254
255    /// Get the FIM middle token.
256    #[must_use]
257    pub fn fim_mid(&self) -> LlamaToken {
258        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
259    }
260
261    /// Get the FIM padding token.
262    #[must_use]
263    pub fn fim_pad(&self) -> LlamaToken {
264        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
265    }
266
267    /// Get the FIM repository token.
268    #[must_use]
269    pub fn fim_rep(&self) -> LlamaToken {
270        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
271    }
272
273    /// Get the FIM separator token.
274    #[must_use]
275    pub fn fim_sep(&self) -> LlamaToken {
276        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
277    }
278
279    /// Check whether BOS should be added.
280    #[must_use]
281    pub fn get_add_bos(&self) -> bool {
282        unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
283    }
284
285    /// Check whether EOS should be added.
286    #[must_use]
287    pub fn get_add_eos(&self) -> bool {
288        unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
289    }
290
291    /// Check whether SEP should be added.
292    #[must_use]
293    pub fn get_add_sep(&self) -> bool {
294        unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
295    }
296
297    /// Get the text representation of a token.
298    ///
299    /// # Errors
300    ///
301    /// Returns an error if the text pointer is null or not valid UTF-8.
302    pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
303        let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
304        if ptr.is_null() {
305            return Err(StringFromModelError::ReturnedError(-1));
306        }
307        let cstr = unsafe { CStr::from_ptr(ptr) };
308        cstr.to_str().map_err(StringFromModelError::Utf8Error)
309    }
310
311    /// Get the score of a token.
312    #[must_use]
313    pub fn get_score(&self, token: LlamaToken) -> f32 {
314        unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
315    }
316
317    /// Get the attributes of a token.
318    ///
319    /// # Panics
320    ///
321    /// Panics if the C API returns attributes that do not fit in `u32`.
322    #[must_use]
323    pub fn get_attr(&self, token: LlamaToken) -> u32 {
324        unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0) as u32 }
325    }
326
327    /// Check if a token is a control token.
328    #[must_use]
329    pub fn is_control(&self, token: LlamaToken) -> bool {
330        unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
331    }
332
333    /// Check if a token is an end-of-generation token.
334    #[must_use]
335    pub fn is_eog(&self, token: LlamaToken) -> bool {
336        unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
337    }
338
339    /// Get the token mask value for the vocabulary.
340    #[must_use]
341    pub fn mask(&self) -> LlamaToken {
342        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
343    }
344}
345
346/// A safe wrapper around `llama_adapter_lora`.
347#[derive(Debug)]
348#[repr(transparent)]
349#[allow(clippy::module_name_repetitions)]
350pub struct LlamaLoraAdapter {
351    pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
352}
353
354impl LlamaLoraAdapter {
355    /// Get the number of metadata key-value pairs in the adapter.
356    #[must_use]
357    pub fn meta_count(&self) -> i32 {
358        unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
359    }
360
361    /// Get a metadata key by index.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if the index is out of range or the key is not valid UTF-8.
366    #[allow(clippy::cast_sign_loss)]
367    pub fn meta_key_by_index(
368        &self,
369        index: i32,
370        buf_size: usize,
371    ) -> Result<String, StringFromModelError> {
372        let mut buf = vec![0u8; buf_size];
373        let ret = unsafe {
374            llama_cpp_sys_4::llama_adapter_meta_key_by_index(
375                self.lora_adapter.as_ptr(),
376                index,
377                buf.as_mut_ptr().cast::<c_char>(),
378                buf_size,
379            )
380        };
381        if ret < 0 {
382            return Err(StringFromModelError::ReturnedError(ret));
383        }
384        let len = ret as usize;
385        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
386        Ok(s.to_owned())
387    }
388
389    /// Get a metadata value by key name.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the key is not found or the value is not valid UTF-8.
394    #[allow(clippy::cast_sign_loss)]
395    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
396        let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
397        let mut buf = vec![0u8; buf_size];
398        let ret = unsafe {
399            llama_cpp_sys_4::llama_adapter_meta_val_str(
400                self.lora_adapter.as_ptr(),
401                c_key.as_ptr(),
402                buf.as_mut_ptr().cast::<c_char>(),
403                buf_size,
404            )
405        };
406        if ret < 0 {
407            return Err(StringFromModelError::ReturnedError(ret));
408        }
409        let len = ret as usize;
410        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
411        Ok(s.to_owned())
412    }
413
414    /// Get a metadata value by index.
415    ///
416    /// # Errors
417    ///
418    /// Returns an error if the index is out of range or the value is not valid UTF-8.
419    #[allow(clippy::cast_sign_loss)]
420    pub fn meta_val_str_by_index(
421        &self,
422        index: i32,
423        buf_size: usize,
424    ) -> Result<String, StringFromModelError> {
425        let mut buf = vec![0u8; buf_size];
426        let ret = unsafe {
427            llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
428                self.lora_adapter.as_ptr(),
429                index,
430                buf.as_mut_ptr().cast::<c_char>(),
431                buf_size,
432            )
433        };
434        if ret < 0 {
435            return Err(StringFromModelError::ReturnedError(ret));
436        }
437        let len = ret as usize;
438        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
439        Ok(s.to_owned())
440    }
441
442    /// Get all metadata as a list of `(key, value)` pairs.
443    ///
444    /// # Errors
445    ///
446    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
447    #[allow(clippy::cast_sign_loss)]
448    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
449        let count = self.meta_count();
450        let mut result = Vec::with_capacity(count as usize);
451        for i in 0..count {
452            let key = self.meta_key_by_index(i, 256)?;
453            let val = self.meta_val_str_by_index(i, 4096)?;
454            result.push((key, val));
455        }
456        Ok(result)
457    }
458
459    /// Get the number of invocation tokens for this adapter.
460    #[must_use]
461    pub fn n_invocation_tokens(&self) -> u64 {
462        unsafe {
463            llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(self.lora_adapter.as_ptr())
464        }
465    }
466
467    /// Get the invocation tokens for this adapter.
468    ///
469    /// Returns an empty slice if there are no invocation tokens.
470    #[must_use]
471    #[allow(clippy::cast_possible_truncation)]
472    pub fn invocation_tokens(&self) -> &[LlamaToken] {
473        let n = self.n_invocation_tokens() as usize;
474        if n == 0 {
475            return &[];
476        }
477        let ptr = unsafe {
478            llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(self.lora_adapter.as_ptr())
479        };
480        if ptr.is_null() {
481            return &[];
482        }
483        // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
484        unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
485    }
486}
487
488impl Drop for LlamaLoraAdapter {
489    fn drop(&mut self) {
490        unsafe {
491            llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
492        }
493    }
494}
495
496/// A Safe wrapper around `llama_chat_message`
497#[derive(Debug, Eq, PartialEq, Clone)]
498pub struct LlamaChatMessage {
499    role: CString,
500    content: CString,
501}
502
503impl LlamaChatMessage {
504    /// Create a new `LlamaChatMessage`.
505    ///
506    /// # Errors
507    ///
508    /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
509    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
510        Ok(Self {
511            role: CString::new(role)?,
512            content: CString::new(content)?,
513        })
514    }
515}
516
517/// How to determine if we should prepend a bos token to tokens
518#[derive(Debug, Clone, Copy, PartialEq, Eq)]
519pub enum AddBos {
520    /// Add the beginning of stream token to the start of the string.
521    Always,
522    /// Do not add the beginning of stream token to the start of the string.
523    Never,
524}
525
526/// How to determine if we should tokenize special tokens
527#[derive(Debug, Clone, Copy, PartialEq, Eq)]
528pub enum Special {
529    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
530    Tokenize,
531    /// Treat special and/or control tokens as plaintext.
532    Plaintext,
533}
534
535unsafe impl Send for LlamaModel {}
536
537unsafe impl Sync for LlamaModel {}
538
539impl LlamaModel {
540    /// Retrieves the vocabulary associated with the current Llama model.
541    ///
542    /// This method fetches the vocabulary from the underlying model using an unsafe
543    /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
544    /// the vocabulary data, which is wrapped in a `NonNull` for safety.
545    ///
546    /// # Safety
547    /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
548    /// which is assumed to return a valid pointer to the vocabulary. The caller should
549    /// ensure that the model object is properly initialized and valid before calling
550    /// this method, as dereferencing invalid pointers can lead to undefined behavior.
551    ///
552    /// # Returns
553    /// A `LlamaVocab` struct containing the vocabulary of the model.
554    ///
555    /// # Panics
556    ///
557    /// Panics if the underlying C function returns a null pointer.
558    ///
559    /// # Example
560    /// ```rust,ignore
561    /// let vocab = model.get_vocab();
562    /// ```
563    #[must_use]
564    pub fn get_vocab(&self) -> LlamaVocab {
565        let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
566
567        LlamaVocab {
568            vocab: NonNull::new(llama_vocab).unwrap(),
569        }
570    }
571    /// Get the number of tokens the model was trained on.
572    ///
573    /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
574    ///
575    /// # Panics
576    ///
577    /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
578    /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
579    /// which is almost certainly positive.
580    #[must_use]
581    pub fn n_ctx_train(&self) -> u32 {
582        let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
583        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
584    }
585
586    /// Get all tokens in the model.
587    ///
588    /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
589    /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
590    ///
591    /// # Parameters
592    ///
593    /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
594    pub fn tokens(
595        &self,
596        special: Special,
597    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
598        (0..self.n_vocab())
599            .map(LlamaToken::new)
600            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
601    }
602
603    /// Get the beginning of stream token.
604    ///
605    /// This function returns the token that represents the beginning of a stream (BOS token).
606    #[must_use]
607    pub fn token_bos(&self) -> LlamaToken {
608        self.get_vocab().bos()
609    }
610
611    /// Get the end of stream token.
612    ///
613    /// This function returns the token that represents the end of a stream (EOS token).
614    #[must_use]
615    pub fn token_eos(&self) -> LlamaToken {
616        self.get_vocab().eos()
617    }
618
619    /// Get the newline token.
620    ///
621    /// This function returns the token that represents a newline character.
622    #[must_use]
623    pub fn token_nl(&self) -> LlamaToken {
624        self.get_vocab().nl()
625    }
626
627    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
628    ///
629    /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
630    /// such as EOS or other special tokens.
631    ///
632    /// # Parameters
633    ///
634    /// - `token`: The `LlamaToken` to check.
635    ///
636    /// # Returns
637    ///
638    /// - `true` if the token is an end-of-generation token, otherwise `false`.
639    #[must_use]
640    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
641        self.get_vocab().is_eog(token)
642    }
643
644    /// Get the classification token.
645    #[must_use]
646    pub fn token_cls(&self) -> LlamaToken {
647        self.get_vocab().cls()
648    }
649
650    /// Get the end-of-turn token.
651    #[must_use]
652    pub fn token_eot(&self) -> LlamaToken {
653        self.get_vocab().eot()
654    }
655
656    /// Get the padding token.
657    #[must_use]
658    pub fn token_pad(&self) -> LlamaToken {
659        self.get_vocab().pad()
660    }
661
662    /// Get the separator token.
663    #[must_use]
664    pub fn token_sep(&self) -> LlamaToken {
665        self.get_vocab().sep()
666    }
667
668    /// Get the fill-in-the-middle prefix token.
669    #[must_use]
670    pub fn token_fim_pre(&self) -> LlamaToken {
671        self.get_vocab().fim_pre()
672    }
673
674    /// Get the fill-in-the-middle suffix token.
675    #[must_use]
676    pub fn token_fim_suf(&self) -> LlamaToken {
677        self.get_vocab().fim_suf()
678    }
679
680    /// Get the fill-in-the-middle middle token.
681    #[must_use]
682    pub fn token_fim_mid(&self) -> LlamaToken {
683        self.get_vocab().fim_mid()
684    }
685
686    /// Get the fill-in-the-middle padding token.
687    #[must_use]
688    pub fn token_fim_pad(&self) -> LlamaToken {
689        self.get_vocab().fim_pad()
690    }
691
692    /// Get the fill-in-the-middle repository token.
693    #[must_use]
694    pub fn token_fim_rep(&self) -> LlamaToken {
695        self.get_vocab().fim_rep()
696    }
697
698    /// Get the fill-in-the-middle separator token.
699    #[must_use]
700    pub fn token_fim_sep(&self) -> LlamaToken {
701        self.get_vocab().fim_sep()
702    }
703
704    /// Check if a token is a control token.
705    #[must_use]
706    pub fn token_is_control(&self, token: LlamaToken) -> bool {
707        self.get_vocab().is_control(token)
708    }
709
710    /// Get the score of a token.
711    #[must_use]
712    pub fn token_get_score(&self, token: LlamaToken) -> f32 {
713        self.get_vocab().get_score(token)
714    }
715
716    /// Get the raw text of a token.
717    ///
718    /// # Errors
719    ///
720    /// Returns an error if the token text is null or not valid UTF-8.
721    pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
722        let ptr = unsafe {
723            llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
724        };
725        if ptr.is_null() {
726            return Err(StringFromModelError::ReturnedError(-1));
727        }
728        let cstr = unsafe { CStr::from_ptr(ptr) };
729        cstr.to_str().map_err(StringFromModelError::Utf8Error)
730    }
731
732    /// Check if a BOS token should be added when tokenizing.
733    #[must_use]
734    pub fn add_bos_token(&self) -> bool {
735        self.get_vocab().get_add_bos()
736    }
737
738    /// Check if an EOS token should be added when tokenizing.
739    #[must_use]
740    pub fn add_eos_token(&self) -> bool {
741        self.get_vocab().get_add_eos()
742    }
743
744    /// Get the decoder start token.
745    ///
746    /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
747    /// of a sequence generation).
748    #[must_use]
749    pub fn decode_start_token(&self) -> LlamaToken {
750        let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
751        LlamaToken(token)
752    }
753
754    /// Convert a single token to a string.
755    ///
756    /// This function converts a `LlamaToken` into its string representation.
757    ///
758    /// # Errors
759    ///
760    /// This function returns an error if the token cannot be converted to a string. For more details, refer to
761    /// [`TokenToStringError`].
762    ///
763    /// # Parameters
764    ///
765    /// - `token`: The `LlamaToken` to convert.
766    /// - `special`: The `Special` value used to handle special tokens.
767    pub fn token_to_str(
768        &self,
769        token: LlamaToken,
770        special: Special,
771    ) -> Result<String, TokenToStringError> {
772        self.token_to_str_with_size(token, 32, special)
773    }
774
775    /// Convert a single token to bytes.
776    ///
777    /// This function converts a `LlamaToken` into a byte representation.
778    ///
779    /// # Errors
780    ///
781    /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
782    /// [`TokenToStringError`].
783    ///
784    /// # Parameters
785    ///
786    /// - `token`: The `LlamaToken` to convert.
787    /// - `special`: The `Special` value used to handle special tokens.
788    pub fn token_to_bytes(
789        &self,
790        token: LlamaToken,
791        special: Special,
792    ) -> Result<Vec<u8>, TokenToStringError> {
793        self.token_to_bytes_with_size(token, 32, special, None)
794    }
795
796    /// Convert a vector of tokens to a single string.
797    ///
798    /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
799    /// string representations.
800    ///
801    /// # Errors
802    ///
803    /// This function returns an error if any token cannot be converted to a string. For more details, refer to
804    /// [`TokenToStringError`].
805    ///
806    /// # Parameters
807    ///
808    /// - `tokens`: A slice of `LlamaToken`s to convert.
809    /// - `special`: The `Special` value used to handle special tokens.
810    pub fn tokens_to_str(
811        &self,
812        tokens: &[LlamaToken],
813        special: Special,
814    ) -> Result<String, TokenToStringError> {
815        let mut builder = String::with_capacity(tokens.len() * 4);
816        for str in tokens
817            .iter()
818            .copied()
819            .map(|t| self.token_to_str(t, special))
820        {
821            builder += &str?;
822        }
823        Ok(builder)
824    }
825
826    /// Convert a string to a vector of tokens.
827    ///
828    /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
829    /// and return the corresponding tokens.
830    ///
831    /// # Errors
832    ///
833    /// - This function will return an error if the input string contains a null byte.
834    ///
835    /// # Panics
836    ///
837    /// - This function will panic if the number of tokens exceeds `usize::MAX`.
838    ///
839    /// # Example
840    ///
841    /// ```no_run
842    /// use llama_cpp_4::model::LlamaModel;
843    ///
844    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845    /// use std::path::Path;
846    /// use llama_cpp_4::model::AddBos;
847    /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
848    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
849    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
850    /// # Ok(())
851    /// # }
852    /// ```
853    pub fn str_to_token(
854        &self,
855        str: &str,
856        add_bos: AddBos,
857    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
858        let add_bos = match add_bos {
859            AddBos::Always => true,
860            AddBos::Never => false,
861        };
862
863        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
864        let mut buffer = Vec::with_capacity(tokens_estimation);
865
866        let c_string = CString::new(str)?;
867        let buffer_capacity =
868            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
869
870        let size = unsafe {
871            llama_tokenize(
872                self.get_vocab().vocab.as_ref(),
873                c_string.as_ptr(),
874                c_int::try_from(c_string.as_bytes().len())?,
875                buffer.as_mut_ptr(),
876                buffer_capacity,
877                add_bos,
878                true,
879            )
880        };
881
882        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
883        // as a result - size is guaranteed to be positive here.
884        let size = if size.is_negative() {
885            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
886            unsafe {
887                llama_tokenize(
888                    self.get_vocab().vocab.as_ref(),
889                    c_string.as_ptr(),
890                    c_int::try_from(c_string.as_bytes().len())?,
891                    buffer.as_mut_ptr(),
892                    -size,
893                    add_bos,
894                    true,
895                )
896            }
897        } else {
898            size
899        };
900
901        let size = usize::try_from(size).expect("size is positive and usize ");
902
903        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
904        unsafe { buffer.set_len(size) }
905        Ok(buffer.into_iter().map(LlamaToken).collect())
906    }
907
908    /// Get the type of a token.
909    ///
910    /// This function retrieves the attributes associated with a given token. The attributes are typically used to
911    /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
912    /// control tokens, etc.).
913    ///
914    /// # Panics
915    ///
916    /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
917    ///
918    /// # Example
919    ///
920    /// ```no_run
921    /// use llama_cpp_4::model::LlamaModel;
922    /// use llama_cpp_4::model::params::LlamaModelParams;
923    /// use llama_cpp_4::llama_backend::LlamaBackend;
924    /// use llama_cpp_4::token::LlamaToken;
925    ///
926    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
927    /// let backend = LlamaBackend::init()?;
928    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
929    /// let token = LlamaToken::new(42);
930    /// let token_attrs = model.token_attr(token);
931    /// # Ok(())
932    /// # }
933    /// ```
934    #[must_use]
935    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
936        let token_type =
937            unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
938        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
939    }
940
941    /// Detokenize a slice of tokens into a string.
942    ///
943    /// This is the inverse of [`str_to_token`](Self::str_to_token).
944    ///
945    /// # Parameters
946    ///
947    /// - `tokens`: The tokens to detokenize.
948    /// - `remove_special`: If `true`, special tokens are removed from the output.
949    /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
950    ///
951    /// # Errors
952    ///
953    /// Returns an error if the detokenized text is not valid UTF-8.
954    #[allow(
955        clippy::cast_possible_truncation,
956        clippy::cast_possible_wrap,
957        clippy::cast_sign_loss
958    )]
959    pub fn detokenize(
960        &self,
961        tokens: &[LlamaToken],
962        remove_special: bool,
963        unparse_special: bool,
964    ) -> Result<String, StringFromModelError> {
965        // First call with empty buffer to get required size
966        let n_tokens = tokens.len() as i32;
967        let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
968        let needed = unsafe {
969            llama_detokenize(
970                self.get_vocab().vocab.as_ref(),
971                token_ptr,
972                n_tokens,
973                std::ptr::null_mut(),
974                0,
975                remove_special,
976                unparse_special,
977            )
978        };
979        // llama_detokenize returns negative required size when buffer is too small
980        let buf_size = if needed < 0 {
981            (-needed) as usize
982        } else {
983            needed as usize
984        };
985        let mut buf = vec![0u8; buf_size];
986        let ret = unsafe {
987            llama_detokenize(
988                self.get_vocab().vocab.as_ref(),
989                token_ptr,
990                n_tokens,
991                buf.as_mut_ptr().cast::<c_char>(),
992                buf_size as i32,
993                remove_special,
994                unparse_special,
995            )
996        };
997        if ret < 0 {
998            return Err(StringFromModelError::ReturnedError(ret));
999        }
1000        let len = ret as usize;
1001        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1002        Ok(s.to_owned())
1003    }
1004
1005    /// Convert a token to a string with a specified buffer size.
1006    ///
1007    /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
1008    /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
1009    /// and the extra buffer size doesn't usually matter.
1010    ///
1011    /// # Errors
1012    ///
1013    /// - If the token type is unknown, an error will be returned.
1014    /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
1015    /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
1016    ///
1017    /// # Panics
1018    ///
1019    /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
1020    /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
1021    ///
1022    /// # Example
1023    ///
1024    /// ```no_run
1025    /// use llama_cpp_4::model::{LlamaModel, Special};
1026    /// use llama_cpp_4::model::params::LlamaModelParams;
1027    /// use llama_cpp_4::llama_backend::LlamaBackend;
1028    /// use llama_cpp_4::token::LlamaToken;
1029    ///
1030    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1031    /// let backend = LlamaBackend::init()?;
1032    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1033    /// let token = LlamaToken::new(42);
1034    /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
1035    /// # Ok(())
1036    /// # }
1037    /// ```
1038    pub fn token_to_str_with_size(
1039        &self,
1040        token: LlamaToken,
1041        buffer_size: usize,
1042        special: Special,
1043    ) -> Result<String, TokenToStringError> {
1044        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
1045        Ok(String::from_utf8(bytes)?)
1046    }
1047
1048    /// Convert a token to bytes with a specified buffer size.
1049    ///
1050    /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
1051    /// the extra bytes do not really matter.
1052    ///
1053    /// # Errors
1054    ///
1055    /// - if the token type is unknown
1056    /// - the resultant token is larger than `buffer_size`.
1057    ///
1058    /// # Panics
1059    ///
1060    /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
1061    /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
1062    ///
1063    /// # Example
1064    ///
1065    /// ```no_run
1066    /// use llama_cpp_4::model::{LlamaModel, Special};
1067    /// use llama_cpp_4::model::params::LlamaModelParams;
1068    /// use llama_cpp_4::llama_backend::LlamaBackend;
1069    /// use llama_cpp_4::token::LlamaToken;
1070    ///
1071    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1072    /// let backend = LlamaBackend::init()?;
1073    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1074    /// let token = LlamaToken::new(42);
1075    /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
1076    /// # Ok(())
1077    /// # }
1078    /// ```
1079    pub fn token_to_bytes_with_size(
1080        &self,
1081        token: LlamaToken,
1082        buffer_size: usize,
1083        special: Special,
1084        lstrip: Option<NonZeroU16>,
1085    ) -> Result<Vec<u8>, TokenToStringError> {
1086        if token == self.token_nl() {
1087            return Ok(String::from("\n").into_bytes());
1088        }
1089
1090        // unsure what to do with this in the face of the 'special' arg + attr changes
1091        let attrs = self.token_attr(token);
1092        if (attrs.contains(LlamaTokenAttr::Control)
1093            && (token == self.token_bos() || token == self.token_eos()))
1094            || attrs.is_empty()
1095            || attrs
1096                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
1097        {
1098            return Ok(Vec::new());
1099        }
1100
1101        let special = match special {
1102            Special::Tokenize => true,
1103            Special::Plaintext => false,
1104        };
1105
1106        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
1107        let len = string.as_bytes().len();
1108        let len = c_int::try_from(len).expect("length fits into c_int");
1109        let buf = string.into_raw();
1110        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
1111        let size = unsafe {
1112            llama_token_to_piece(
1113                self.get_vocab().vocab.as_ref(),
1114                token.0,
1115                buf,
1116                len,
1117                lstrip,
1118                special,
1119            )
1120        };
1121
1122        match size {
1123            0 => Err(TokenToStringError::UnknownTokenType),
1124            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
1125            size => {
1126                let string = unsafe { CString::from_raw(buf) };
1127                let mut bytes = string.into_bytes();
1128                let len = usize::try_from(size).expect("size is positive and fits into usize");
1129                bytes.truncate(len);
1130                Ok(bytes)
1131            }
1132        }
1133    }
1134    /// The number of tokens the model was trained on.
1135    ///
1136    /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1137    /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1138    ///
1139    /// # Example
1140    ///
1141    /// ```no_run
1142    /// use llama_cpp_4::model::LlamaModel;
1143    /// use llama_cpp_4::model::params::LlamaModelParams;
1144    /// use llama_cpp_4::llama_backend::LlamaBackend;
1145    ///
1146    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1147    /// let backend = LlamaBackend::init()?;
1148    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1149    /// let n_vocab = model.n_vocab();
1150    /// # Ok(())
1151    /// # }
1152    /// ```
1153    #[must_use]
1154    pub fn n_vocab(&self) -> i32 {
1155        self.get_vocab().n_tokens()
1156    }
1157
1158    /// The type of vocab the model was trained on.
1159    ///
1160    /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1161    /// word-level tokens, or another tokenization scheme.
1162    ///
1163    /// # Panics
1164    ///
1165    /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1166    ///
1167    /// # Example
1168    ///
1169    /// ```no_run
1170    /// use llama_cpp_4::model::LlamaModel;
1171    /// use llama_cpp_4::model::params::LlamaModelParams;
1172    /// use llama_cpp_4::llama_backend::LlamaBackend;
1173    ///
1174    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1175    /// let backend = LlamaBackend::init()?;
1176    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1177    /// let vocab_type = model.vocab_type();
1178    /// # Ok(())
1179    /// # }
1180    /// ```
1181    #[must_use]
1182    pub fn vocab_type(&self) -> VocabType {
1183        let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1184        VocabType::try_from(vocab_type).expect("invalid vocab type")
1185    }
1186
1187    /// Returns the number of embedding dimensions for the model.
1188    ///
1189    /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1190    /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1191    ///
1192    /// # Panics
1193    ///
1194    /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1195    ///
1196    /// # Example
1197    ///
1198    /// ```no_run
1199    /// use llama_cpp_4::model::LlamaModel;
1200    /// use llama_cpp_4::model::params::LlamaModelParams;
1201    /// use llama_cpp_4::llama_backend::LlamaBackend;
1202    ///
1203    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1204    /// let backend = LlamaBackend::init()?;
1205    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1206    /// let n_embd = model.n_embd();
1207    /// # Ok(())
1208    /// # }
1209    /// ```
1210    #[must_use]
1211    pub fn n_embd(&self) -> c_int {
1212        unsafe { llama_model_n_embd(self.model.as_ptr()) }
1213    }
1214
1215    /// Get the number of transformer layers in the model.
1216    #[must_use]
1217    pub fn n_layer(&self) -> c_int {
1218        unsafe { llama_model_n_layer(self.model.as_ptr()) }
1219    }
1220
1221    /// Get the number of `NextN` / MTP prediction heads bundled with the model.
1222    ///
1223    /// Returns `0` when the checkpoint has no `NextN` blocks. Multi-head models
1224    /// (e.g. Step3.5) return values greater than `1`; pair with
1225    /// [`crate::context::LlamaContext::set_nextn_layer_offset`] on the draft
1226    /// context. See [`crate::mtp`] for the speculative-decoding workflow.
1227    ///
1228    /// # Examples
1229    ///
1230    /// ```no_run
1231    /// # use llama_cpp_4::llama_backend::LlamaBackend;
1232    /// # use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1233    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1234    /// # let backend = LlamaBackend::init()?;
1235    /// # let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default())?;
1236    /// if model.n_layer_nextn() > 0 {
1237    ///     println!("MTP model with {} NextN heads", model.n_layer_nextn());
1238    /// }
1239    /// # Ok(())
1240    /// # }
1241    /// ```
1242    #[must_use]
1243    pub fn n_layer_nextn(&self) -> c_int {
1244        unsafe { llama_model_n_layer_nextn(self.model.as_ptr()) }
1245    }
1246
1247    /// Get the number of mixture-of-experts (`MoE`) layers in the model.
1248    ///
1249    /// Returns `0` for dense (non-MoE) checkpoints.
1250    #[must_use]
1251    pub fn n_expert(&self) -> c_int {
1252        unsafe { llama_model_n_expert(self.model.as_ptr()) }
1253    }
1254
1255    /// Number of backend devices the model tensors are spread across.
1256    ///
1257    /// Use with [`Self::get_device`] to inspect each device. Returns `0` when
1258    /// the model is not yet loaded onto any device.
1259    #[must_use]
1260    pub fn n_devices(&self) -> c_int {
1261        unsafe { llama_model_n_devices(self.model.as_ptr()) }
1262    }
1263
1264    /// Get the backend device at `index`.
1265    ///
1266    /// Valid indices satisfy `0 <= index < n_devices()`. Returns `None` for
1267    /// out-of-range indices or when the device pointer is null.
1268    #[must_use]
1269    pub fn get_device(&self, index: i32) -> Option<LlamaBackendDevice> {
1270        if index < 0 || index >= self.n_devices() {
1271            return None;
1272        }
1273        let dev = unsafe { llama_model_get_device(self.model.as_ptr(), index) };
1274        if dev.is_null() {
1275            None
1276        } else {
1277            Some(LlamaBackendDevice { dev })
1278        }
1279    }
1280
1281    /// Iterate backend devices the model tensors are spread across.
1282    ///
1283    /// Equivalent to calling [`Self::get_device`] for `0..self.n_devices()`.
1284    /// Use [`LlamaBackendDevice::memory`] to inspect free/total bytes per device.
1285    #[must_use]
1286    pub fn devices(&self) -> LlamaBackendDevices<'_> {
1287        LlamaBackendDevices {
1288            model: self,
1289            next: 0,
1290        }
1291    }
1292
1293    /// Target-model layer indices stored in this checkpoint.
1294    ///
1295    /// Populated for EAGLE / distillation draft models that record which target
1296    /// layers they were trained against. Returns an empty slice when the
1297    /// metadata is absent.
1298    #[must_use]
1299    pub fn target_layer_ids(&self) -> &[i32] {
1300        let n = unsafe { llama_model_target_layer_ids_n(self.model.as_ptr()) };
1301        if n == 0 {
1302            return &[];
1303        }
1304        let ptr = unsafe { llama_model_target_layer_ids(self.model.as_ptr()) };
1305        if ptr.is_null() {
1306            &[]
1307        } else {
1308            unsafe { slice::from_raw_parts(ptr, n as usize) }
1309        }
1310    }
1311
1312    /// Get the number of attention heads in the model.
1313    #[must_use]
1314    pub fn n_head(&self) -> c_int {
1315        unsafe { llama_model_n_head(self.model.as_ptr()) }
1316    }
1317
1318    /// Get the number of key-value attention heads in the model.
1319    #[must_use]
1320    pub fn n_head_kv(&self) -> c_int {
1321        unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1322    }
1323
1324    /// Get the input embedding size of the model.
1325    #[must_use]
1326    pub fn n_embd_inp(&self) -> c_int {
1327        unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1328    }
1329
1330    /// Get the output embedding size of the model.
1331    #[must_use]
1332    pub fn n_embd_out(&self) -> c_int {
1333        unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1334    }
1335
1336    /// Get the sliding window attention size of the model.
1337    /// Returns 0 if the model does not use sliding window attention.
1338    #[must_use]
1339    pub fn n_swa(&self) -> c_int {
1340        unsafe { llama_model_n_swa(self.model.as_ptr()) }
1341    }
1342
1343    /// Get the `RoPE` type used by the model.
1344    #[must_use]
1345    pub fn rope_type(&self) -> i32 {
1346        unsafe { llama_model_rope_type(self.model.as_ptr()) }
1347    }
1348
1349    /// Get the `RoPE` frequency scale used during training.
1350    #[must_use]
1351    pub fn rope_freq_scale_train(&self) -> f32 {
1352        unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1353    }
1354
1355    /// Get the model size in bytes.
1356    #[must_use]
1357    pub fn model_size(&self) -> u64 {
1358        unsafe { llama_model_size(self.model.as_ptr()) }
1359    }
1360
1361    /// Get the number of parameters in the model.
1362    #[must_use]
1363    pub fn n_params(&self) -> u64 {
1364        unsafe { llama_model_n_params(self.model.as_ptr()) }
1365    }
1366
1367    /// Get the number of classification outputs.
1368    #[must_use]
1369    pub fn n_cls_out(&self) -> u32 {
1370        unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1371    }
1372
1373    /// Get the classification label for the given index.
1374    ///
1375    /// # Errors
1376    ///
1377    /// Returns an error if the label is null or not valid UTF-8.
1378    pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1379        let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1380        if ptr.is_null() {
1381            return Err(StringFromModelError::ReturnedError(-1));
1382        }
1383        let cstr = unsafe { CStr::from_ptr(ptr) };
1384        cstr.to_str().map_err(StringFromModelError::Utf8Error)
1385    }
1386
1387    /// Get the number of metadata key-value pairs.
1388    #[must_use]
1389    pub fn meta_count(&self) -> c_int {
1390        unsafe { llama_model_meta_count(self.model.as_ptr()) }
1391    }
1392
1393    /// Get a model description string.
1394    ///
1395    /// The `buf_size` parameter specifies the maximum buffer size for the description.
1396    /// A default of 256 bytes is usually sufficient.
1397    ///
1398    /// # Errors
1399    ///
1400    /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1401    #[allow(clippy::cast_sign_loss)]
1402    pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1403        let mut buf = vec![0u8; buf_size];
1404        let ret = unsafe {
1405            llama_model_desc(
1406                self.model.as_ptr(),
1407                buf.as_mut_ptr().cast::<c_char>(),
1408                buf_size,
1409            )
1410        };
1411        if ret < 0 {
1412            return Err(StringFromModelError::ReturnedError(ret));
1413        }
1414        let len = ret as usize;
1415        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1416        Ok(s.to_owned())
1417    }
1418
1419    /// Get a metadata key by index.
1420    ///
1421    /// The `buf_size` parameter specifies the maximum buffer size for the key.
1422    /// A default of 256 bytes is usually sufficient.
1423    ///
1424    /// # Errors
1425    ///
1426    /// Returns an error if the index is out of range or the key is not valid UTF-8.
1427    #[allow(clippy::cast_sign_loss)]
1428    pub fn meta_key_by_index(
1429        &self,
1430        index: i32,
1431        buf_size: usize,
1432    ) -> Result<String, StringFromModelError> {
1433        let mut buf = vec![0u8; buf_size];
1434        let ret = unsafe {
1435            llama_model_meta_key_by_index(
1436                self.model.as_ptr(),
1437                index,
1438                buf.as_mut_ptr().cast::<c_char>(),
1439                buf_size,
1440            )
1441        };
1442        if ret < 0 {
1443            return Err(StringFromModelError::ReturnedError(ret));
1444        }
1445        let len = ret as usize;
1446        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1447        Ok(s.to_owned())
1448    }
1449
1450    /// Get a metadata value string by index.
1451    ///
1452    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1453    /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1454    ///
1455    /// # Errors
1456    ///
1457    /// Returns an error if the index is out of range or the value is not valid UTF-8.
1458    #[allow(clippy::cast_sign_loss)]
1459    pub fn meta_val_str_by_index(
1460        &self,
1461        index: i32,
1462        buf_size: usize,
1463    ) -> Result<String, StringFromModelError> {
1464        let mut buf = vec![0u8; buf_size];
1465        let ret = unsafe {
1466            llama_model_meta_val_str_by_index(
1467                self.model.as_ptr(),
1468                index,
1469                buf.as_mut_ptr().cast::<c_char>(),
1470                buf_size,
1471            )
1472        };
1473        if ret < 0 {
1474            return Err(StringFromModelError::ReturnedError(ret));
1475        }
1476        let len = ret as usize;
1477        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1478        Ok(s.to_owned())
1479    }
1480
1481    /// Get a metadata value by key name.
1482    ///
1483    /// This is more convenient than iterating metadata by index when you know the key.
1484    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1485    ///
1486    /// # Errors
1487    ///
1488    /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1489    #[allow(clippy::cast_sign_loss)]
1490    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1491        let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
1492        let mut buf = vec![0u8; buf_size];
1493        let ret = unsafe {
1494            llama_model_meta_val_str(
1495                self.model.as_ptr(),
1496                c_key.as_ptr(),
1497                buf.as_mut_ptr().cast::<c_char>(),
1498                buf_size,
1499            )
1500        };
1501        if ret < 0 {
1502            return Err(StringFromModelError::ReturnedError(ret));
1503        }
1504        let len = ret as usize;
1505        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1506        Ok(s.to_owned())
1507    }
1508
1509    /// Get all metadata as a list of `(key, value)` pairs.
1510    ///
1511    /// This is a convenience method that iterates over all metadata entries.
1512    /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1513    /// For values that may be larger (e.g. token lists), use
1514    /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1515    ///
1516    /// # Errors
1517    ///
1518    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1519    #[allow(clippy::cast_sign_loss)]
1520    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1521        let count = self.meta_count();
1522        let mut result = Vec::with_capacity(count as usize);
1523        for i in 0..count {
1524            let key = self.meta_key_by_index(i, 256)?;
1525            let val = self.meta_val_str_by_index(i, 4096)?;
1526            result.push((key, val));
1527        }
1528        Ok(result)
1529    }
1530
1531    /// Check if the model has an encoder.
1532    #[must_use]
1533    pub fn has_encoder(&self) -> bool {
1534        unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1535    }
1536
1537    /// Check if the model has a decoder.
1538    #[must_use]
1539    pub fn has_decoder(&self) -> bool {
1540        unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1541    }
1542
1543    /// Check if the model is recurrent (e.g. Mamba, RWKV).
1544    #[must_use]
1545    pub fn is_recurrent(&self) -> bool {
1546        unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1547    }
1548
1549    /// Check if the model is a hybrid model.
1550    #[must_use]
1551    pub fn is_hybrid(&self) -> bool {
1552        unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1553    }
1554
1555    /// Check if the model is a diffusion model.
1556    #[must_use]
1557    pub fn is_diffusion(&self) -> bool {
1558        unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1559    }
1560
1561    /// Get chat template from model.
1562    ///
1563    /// # Errors
1564    ///
1565    /// - If the model does not have a chat template, it will return an error.
1566    /// - If the chat template is not a valid `CString`, it will return an error.
1567    ///
1568    /// # Example
1569    ///
1570    /// ```no_run
1571    /// use llama_cpp_4::model::LlamaModel;
1572    /// use llama_cpp_4::model::params::LlamaModelParams;
1573    /// use llama_cpp_4::llama_backend::LlamaBackend;
1574    ///
1575    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1576    /// let backend = LlamaBackend::init()?;
1577    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1578    /// let chat_template = model.get_chat_template(1024)?;
1579    /// # Ok(())
1580    /// # }
1581    /// ```
1582    #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1583    pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1584        // longest known template is about 1200 bytes from llama.cpp
1585        let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1586        let chat_ptr = chat_temp.into_raw();
1587        let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1588
1589        let ret = unsafe {
1590            llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1591        };
1592
1593        if ret < 0 {
1594            return Err(ChatTemplateError::MissingTemplate(ret));
1595        }
1596
1597        let template_c = unsafe { CString::from_raw(chat_ptr) };
1598        let template = template_c.to_str()?;
1599
1600        let ret: usize = ret.try_into().unwrap();
1601        if template.len() < ret {
1602            return Err(ChatTemplateError::BuffSizeError(ret + 1));
1603        }
1604
1605        Ok(template.to_owned())
1606    }
1607
1608    /// Loads a model from a file.
1609    ///
1610    /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1611    ///
1612    /// # Errors
1613    ///
1614    /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1615    /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1616    ///
1617    /// # Example
1618    ///
1619    /// ```no_run
1620    /// use llama_cpp_4::model::LlamaModel;
1621    /// use llama_cpp_4::model::params::LlamaModelParams;
1622    /// use llama_cpp_4::llama_backend::LlamaBackend;
1623    ///
1624    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1625    /// let backend = LlamaBackend::init()?;
1626    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1627    /// # Ok(())
1628    /// # }
1629    /// ```
1630    #[tracing::instrument(skip_all, fields(params))]
1631    pub fn load_from_file(
1632        _: &LlamaBackend,
1633        path: impl AsRef<Path>,
1634        params: &LlamaModelParams,
1635    ) -> Result<Self, LlamaModelLoadError> {
1636        let path = path.as_ref();
1637        debug_assert!(
1638            Path::new(path).exists(),
1639            "{} does not exist",
1640            path.display()
1641        );
1642        let path = path
1643            .to_str()
1644            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1645
1646        let cstr = CString::new(path)?;
1647        let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
1648
1649        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1650
1651        tracing::debug!(?path, "Loaded model");
1652        Ok(LlamaModel { model })
1653    }
1654
1655    /// Load a model from multiple split files.
1656    ///
1657    /// This function loads a model that has been split across multiple files. This is useful for
1658    /// very large models that exceed filesystem limitations or need to be distributed across
1659    /// multiple storage devices.
1660    ///
1661    /// # Arguments
1662    ///
1663    /// * `paths` - A slice of paths to the split model files
1664    /// * `params` - The model parameters
1665    ///
1666    /// # Errors
1667    ///
1668    /// Returns an error if:
1669    /// - Any of the paths cannot be converted to a C string
1670    /// - The model fails to load from the splits
1671    /// - Any path doesn't exist or isn't accessible
1672    ///
1673    /// # Example
1674    ///
1675    /// ```no_run
1676    /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1677    /// use llama_cpp_4::llama_backend::LlamaBackend;
1678    ///
1679    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1680    /// let backend = LlamaBackend::init()?;
1681    /// let params = LlamaModelParams::default();
1682    ///
1683    /// let paths = vec![
1684    ///     "model-00001-of-00003.gguf",
1685    ///     "model-00002-of-00003.gguf",
1686    ///     "model-00003-of-00003.gguf",
1687    /// ];
1688    ///
1689    /// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
1690    /// # Ok(())
1691    /// # }
1692    /// ```
1693    #[tracing::instrument(skip_all)]
1694    pub fn load_from_splits(
1695        _: &LlamaBackend,
1696        paths: &[impl AsRef<Path>],
1697        params: &LlamaModelParams,
1698    ) -> Result<Self, LlamaModelLoadError> {
1699        // Convert paths to C strings
1700        let c_strings: Vec<CString> = paths
1701            .iter()
1702            .map(|p| {
1703                let path = p.as_ref();
1704                debug_assert!(path.exists(), "{} does not exist", path.display());
1705                let path_str = path
1706                    .to_str()
1707                    .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1708                CString::new(path_str).map_err(LlamaModelLoadError::from)
1709            })
1710            .collect::<Result<Vec<_>, _>>()?;
1711
1712        // Create array of pointers to C strings
1713        let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1714
1715        // Load the model from splits
1716        let llama_model = unsafe {
1717            llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1718        };
1719
1720        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1721
1722        tracing::debug!("Loaded model from {} splits", paths.len());
1723        Ok(LlamaModel { model })
1724    }
1725
1726    /// Load a model from a `FILE` pointer.
1727    ///
1728    /// # Safety
1729    ///
1730    /// The `file` pointer must be a valid, open `FILE*`.
1731    ///
1732    /// # Errors
1733    ///
1734    /// Returns an error if the model cannot be loaded.
1735    pub unsafe fn load_from_file_ptr(
1736        file: *mut llama_cpp_sys_4::FILE,
1737        params: &LlamaModelParams,
1738    ) -> Result<Self, LlamaModelLoadError> {
1739        let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1740        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1741        Ok(LlamaModel { model })
1742    }
1743
1744    /// Initialize a model from user-provided data.
1745    ///
1746    /// # Safety
1747    ///
1748    /// The metadata, callback, and user data must be valid.
1749    ///
1750    /// # Errors
1751    ///
1752    /// Returns an error if the model cannot be initialized.
1753    pub unsafe fn init_from_user(
1754        metadata: *mut llama_cpp_sys_4::gguf_context,
1755        set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1756        set_tensor_data_ud: *mut std::ffi::c_void,
1757        params: &LlamaModelParams,
1758    ) -> Result<Self, LlamaModelLoadError> {
1759        let model = llama_cpp_sys_4::llama_model_init_from_user(
1760            metadata,
1761            set_tensor_data,
1762            set_tensor_data_ud,
1763            params.params,
1764        );
1765        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1766        Ok(LlamaModel { model })
1767    }
1768
1769    /// Save the model to a file.
1770    ///
1771    /// # Panics
1772    ///
1773    /// Panics if the path contains null bytes.
1774    pub fn save_to_file(&self, path: impl AsRef<Path>) {
1775        let path = path.as_ref();
1776        let path_str = path.to_str().expect("path is not valid UTF-8");
1777        let c_path = CString::new(path_str).expect("path contains null bytes");
1778        unsafe {
1779            llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1780        }
1781    }
1782
1783    /// Get the list of built-in chat templates.
1784    ///
1785    /// Returns the names of all chat templates that are built into llama.cpp.
1786    ///
1787    /// # Panics
1788    ///
1789    /// Panics if any template name is not valid UTF-8.
1790    #[allow(clippy::cast_sign_loss)]
1791    #[must_use]
1792    pub fn chat_builtin_templates() -> Vec<String> {
1793        // First call to get count
1794        let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1795        if count <= 0 {
1796            return Vec::new();
1797        }
1798        let count = count as usize;
1799        let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1800        unsafe {
1801            llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1802        }
1803        ptrs.iter()
1804            .map(|&p| {
1805                let cstr = unsafe { CStr::from_ptr(p) };
1806                cstr.to_str()
1807                    .expect("template name is not valid UTF-8")
1808                    .to_owned()
1809            })
1810            .collect()
1811    }
1812
1813    /// Initializes a lora adapter from a file.
1814    ///
1815    /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1816    /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1817    /// to the model for improved performance on specialized tasks.
1818    ///
1819    /// # Errors
1820    ///
1821    /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1822    ///
1823    /// # Example
1824    ///
1825    /// ```no_run
1826    /// use llama_cpp_4::model::LlamaModel;
1827    /// use llama_cpp_4::model::params::LlamaModelParams;
1828    /// use llama_cpp_4::llama_backend::LlamaBackend;
1829    ///
1830    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1831    /// let backend = LlamaBackend::init()?;
1832    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1833    /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1834    /// # Ok(())
1835    /// # }
1836    /// ```
1837    pub fn lora_adapter_init(
1838        &self,
1839        path: impl AsRef<Path>,
1840    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1841        let path = path.as_ref();
1842        debug_assert!(
1843            Path::new(path).exists(),
1844            "{} does not exist",
1845            path.display()
1846        );
1847
1848        let path = path
1849            .to_str()
1850            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1851                path.to_path_buf(),
1852            ))?;
1853
1854        let cstr = CString::new(path)?;
1855        let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1856
1857        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1858
1859        tracing::debug!(?path, "Initialized lora adapter");
1860        Ok(LlamaLoraAdapter {
1861            lora_adapter: adapter,
1862        })
1863    }
1864
1865    /// Create a new context from this model.
1866    ///
1867    /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1868    /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1869    /// control over model parameters for a specific task.
1870    ///
1871    /// # Errors
1872    ///
1873    /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1874    ///   for more detailed error descriptions.
1875    ///
1876    /// # Example
1877    ///
1878    /// ```no_run
1879    /// use llama_cpp_4::model::LlamaModel;
1880    /// use llama_cpp_4::model::params::LlamaModelParams;
1881    /// use llama_cpp_4::context::params::LlamaContextParams;
1882    /// use llama_cpp_4::llama_backend::LlamaBackend;
1883    ///
1884    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1885    /// let backend = LlamaBackend::init()?;
1886    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1887    /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1888    /// # Ok(())
1889    /// # }
1890    /// ```
1891    #[allow(clippy::needless_pass_by_value)]
1892    pub fn new_context(
1893        &self,
1894        _: &LlamaBackend,
1895        params: LlamaContextParams,
1896    ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1897        // Apply TurboQuant attn-rotation preference before the KV cache is
1898        // initialised inside llama_init_from_model.
1899        let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1900        if params.attn_rot_disabled {
1901            // SAFETY: we restore the value right after the call.
1902            #[allow(unused_unsafe)]
1903            unsafe {
1904                std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1905            }
1906        } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1907            // params say "enabled" – only clear if it was previously unset
1908            // (respect explicit user env var).
1909        }
1910
1911        let context_params = params.context_params;
1912        let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
1913
1914        // Restore the env-var to its previous state.
1915        #[allow(unused_unsafe)]
1916        match prev_rot_var {
1917            Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1918            None if params.attn_rot_disabled => unsafe {
1919                std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1920            },
1921            None => {}
1922        }
1923
1924        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1925        Ok(LlamaContext::new(self, context, params.embeddings()))
1926    }
1927
1928    /// Apply the model's chat template to a sequence of messages.
1929    ///
1930    /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1931    /// template determines the structure or style of conversation between the system and user, such as token formatting,
1932    /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1933    /// is provided, the default template used by `llama.cpp` will be applied.
1934    ///
1935    /// For more information on supported templates, visit:
1936    /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1937    ///
1938    /// # Arguments
1939    ///
1940    /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1941    /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1942    /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1943    ///
1944    /// # Errors
1945    ///
1946    /// There are several possible points of failure when applying the chat template:
1947    /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1948    /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1949    ///
1950    /// # Example
1951    ///
1952    /// ```no_run
1953    /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1954    /// use llama_cpp_4::model::params::LlamaModelParams;
1955    /// use llama_cpp_4::llama_backend::LlamaBackend;
1956    ///
1957    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1958    /// let backend = LlamaBackend::init()?;
1959    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1960    /// let chat = vec![
1961    ///     LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1962    ///     LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1963    /// ];
1964    /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1965    /// # Ok(())
1966    /// # }
1967    /// ```
1968    ///
1969    /// # Notes
1970    ///
1971    /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1972    /// # Panics
1973    ///
1974    /// Panics if the buffer length exceeds `i32::MAX`.
1975    #[tracing::instrument(skip_all)]
1976    pub fn apply_chat_template(
1977        &self,
1978        tmpl: Option<&str>,
1979        chat: &[LlamaChatMessage],
1980        add_ass: bool,
1981    ) -> Result<String, ApplyChatTemplateError> {
1982        // Compute raw message byte total from the original LlamaChatMessage vec
1983        // *before* we shadow `chat` with the sys-type vec below.
1984        let message_length = chat.iter().fold(0usize, |acc, c| {
1985            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1986        });
1987
1988        // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1989        let chat_sys: Vec<llama_chat_message> = chat
1990            .iter()
1991            .map(|c| llama_chat_message {
1992                role: c.role.as_ptr(),
1993                content: c.content.as_ptr(),
1994            })
1995            .collect();
1996
1997        // Set the tmpl pointer.
1998        let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1999        let tmpl_ptr = tmpl_cstring
2000            .as_ref()
2001            .map_or(std::ptr::null(), |s| s.as_ptr());
2002
2003        // `message_length * 4` is far too small for models whose built-in chat
2004        // template adds a long default system prompt (e.g. Qwen3.5 prepends
2005        // ~80+ chars of markup even for a one-word user message).  Start with
2006        // at least 4 KiB so short inputs like "hi" always have room.
2007        //
2008        // `llama_chat_apply_template` returns the number of bytes it *actually*
2009        // needed when the buffer was too small, so we retry exactly once with
2010        // that precise size rather than giving up immediately.
2011        let mut buf_size = message_length.saturating_mul(4).max(4096);
2012
2013        for _ in 0..2 {
2014            // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
2015            let mut buff = vec![0u8; buf_size];
2016            let res = unsafe {
2017                llama_chat_apply_template(
2018                    tmpl_ptr,
2019                    chat_sys.as_ptr(),
2020                    chat_sys.len(),
2021                    add_ass,
2022                    buff.as_mut_ptr().cast(),
2023                    i32::try_from(buff.len()).expect("buffer length fits in i32"),
2024                )
2025            };
2026
2027            if res < 0 {
2028                return Err(ApplyChatTemplateError::BuffSizeError);
2029            }
2030
2031            #[allow(clippy::cast_sign_loss)]
2032            let needed = res as usize;
2033            if needed > buf_size {
2034                // Buffer was too small — retry with the exact size llama.cpp reported.
2035                buf_size = needed + 1; // +1 for null terminator
2036                continue;
2037            }
2038
2039            // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
2040            // into `buff`; `needed` bytes were used.
2041            let formatted = unsafe {
2042                CStr::from_ptr(buff.as_ptr().cast())
2043                    .to_string_lossy()
2044                    .into_owned()
2045            };
2046            return Ok(formatted);
2047        }
2048
2049        Err(ApplyChatTemplateError::BuffSizeError)
2050    }
2051
2052    /// Build a split GGUF file path for a specific chunk.
2053    ///
2054    /// This utility function creates the standardized filename for a split model chunk
2055    /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
2056    ///
2057    /// # Arguments
2058    ///
2059    /// * `path_prefix` - The base path and filename prefix
2060    /// * `split_no` - The split number (1-indexed)
2061    /// * `split_count` - The total number of splits
2062    ///
2063    /// # Returns
2064    ///
2065    /// Returns the formatted split path as a String
2066    ///
2067    /// # Example
2068    ///
2069    /// ```
2070    /// use llama_cpp_4::model::LlamaModel;
2071    ///
2072    /// let path = LlamaModel::split_path("/models/llama", 1, 4);
2073    /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
2074    /// ```
2075    ///
2076    /// # Panics
2077    ///
2078    /// Panics if the path prefix contains a null byte.
2079    #[must_use]
2080    pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
2081        let mut buffer = vec![0u8; 1024];
2082        let len = unsafe {
2083            llama_split_path(
2084                buffer.as_mut_ptr().cast::<c_char>(),
2085                buffer.len(),
2086                CString::new(path_prefix).unwrap().as_ptr(),
2087                split_no,
2088                split_count,
2089            )
2090        };
2091
2092        let len = usize::try_from(len).expect("split_path length fits in usize");
2093        buffer.truncate(len);
2094        String::from_utf8(buffer).unwrap_or_default()
2095    }
2096
2097    /// Extract the path prefix from a split filename.
2098    ///
2099    /// This function extracts the base path prefix from a split model filename,
2100    /// but only if the `split_no` and `split_count` match the pattern in the filename.
2101    ///
2102    /// # Arguments
2103    ///
2104    /// * `split_path` - The full path to the split file
2105    /// * `split_no` - The expected split number
2106    /// * `split_count` - The expected total number of splits
2107    ///
2108    /// # Returns
2109    ///
2110    /// Returns the path prefix if the pattern matches, or None if it doesn't
2111    ///
2112    /// # Example
2113    ///
2114    /// ```
2115    /// use llama_cpp_4::model::LlamaModel;
2116    ///
2117    /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
2118    /// assert_eq!(prefix, Some("/models/llama".to_string()));
2119    /// ```
2120    ///
2121    /// # Panics
2122    ///
2123    /// Panics if the split path contains a null byte.
2124    #[must_use]
2125    pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
2126        let mut buffer = vec![0u8; 1024];
2127        let len = unsafe {
2128            llama_split_prefix(
2129                buffer.as_mut_ptr().cast::<c_char>(),
2130                buffer.len(),
2131                CString::new(split_path).unwrap().as_ptr(),
2132                split_no,
2133                split_count,
2134            )
2135        };
2136
2137        if len > 0 {
2138            let len = usize::try_from(len).expect("split_prefix length fits in usize");
2139            buffer.truncate(len);
2140            String::from_utf8(buffer).ok()
2141        } else {
2142            None
2143        }
2144    }
2145}
2146
2147#[allow(clippy::cast_precision_loss)]
2148impl fmt::Display for LlamaModel {
2149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2150        let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
2151        write!(
2152            f,
2153            "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
2154            layers = self.n_layer(),
2155            heads = self.n_head(),
2156            embd = self.n_embd(),
2157            params = self.n_params(),
2158            size = self.model_size() as f64 / (1024.0 * 1024.0),
2159        )
2160    }
2161}
2162
2163impl Drop for LlamaModel {
2164    fn drop(&mut self) {
2165        unsafe { llama_model_free(self.model.as_ptr()) }
2166    }
2167}
2168
2169/// Defines the possible types of vocabulary used by the model.
2170///
2171/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
2172/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
2173///
2174/// # Variants
2175///
2176/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
2177/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
2178///
2179/// # Example
2180///
2181/// ```no_run
2182/// use llama_cpp_4::model::VocabType;
2183///
2184/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2185/// let vocab_type = VocabType::BPE;
2186/// match vocab_type {
2187///     VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
2188///     VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
2189/// }
2190/// # Ok(())
2191/// # }
2192/// ```
2193#[repr(u32)]
2194#[derive(Debug, Eq, Copy, Clone, PartialEq)]
2195pub enum VocabType {
2196    /// Byte Pair Encoding
2197    BPE = LLAMA_VOCAB_TYPE_BPE as _,
2198    /// Sentence Piece Tokenizer
2199    SPM = LLAMA_VOCAB_TYPE_SPM as _,
2200}
2201
2202/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
2203///
2204/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
2205///
2206/// # Variants
2207///
2208/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
2209///
2210/// # Example
2211///
2212/// ```no_run
2213/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
2214///
2215/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2216/// let invalid_value = 999; // Not a valid vocabulary type
2217/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
2218/// println!("Error: {}", error);
2219/// # Ok(())
2220/// # }
2221/// ```
2222#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2223pub enum LlamaTokenTypeFromIntError {
2224    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2225    #[error("Unknown Value {0}")]
2226    UnknownValue(llama_vocab_type),
2227}
2228
2229impl TryFrom<llama_vocab_type> for VocabType {
2230    type Error = LlamaTokenTypeFromIntError;
2231
2232    fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2233        match value {
2234            LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2235            LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2236            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2237        }
2238    }
2239}