Skip to main content

llama_cpp_4/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11    llama_adapter_lora, llama_adapter_lora_init, llama_add_bos_token, llama_add_eos_token,
12    llama_chat_apply_template, llama_chat_builtin_templates, llama_chat_message,
13    llama_detokenize, llama_free_model, llama_load_model_from_file, llama_model,
14    llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15    llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
16    llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
17    llama_model_load_from_splits, llama_model_meta_count, llama_model_meta_key_by_index,
18    llama_model_meta_val_str, llama_model_meta_val_str_by_index, llama_model_n_cls_out,
19    llama_model_n_embd_inp, llama_model_n_embd_out, llama_model_n_head_kv, llama_model_n_params,
20    llama_model_n_swa, llama_model_rope_freq_scale_train, llama_model_rope_type,
21    llama_model_save_to_file, llama_model_size, llama_n_ctx_train, llama_n_embd, llama_n_head,
22    llama_n_layer, llama_n_vocab, llama_new_context_with_model, llama_split_path,
23    llama_split_prefix, llama_token_bos, llama_token_cls, llama_token_eos, llama_token_eot,
24    llama_token_fim_mid, llama_token_fim_pad, llama_token_fim_pre, llama_token_fim_rep,
25    llama_token_fim_sep, llama_token_fim_suf, llama_token_get_attr, llama_token_get_score,
26    llama_token_get_text, llama_token_is_control, llama_token_is_eog, llama_token_nl,
27    llama_token_pad, llama_token_sep, llama_token_to_piece, llama_tokenize, llama_vocab,
28    llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
29};
30
31use crate::context::params::LlamaContextParams;
32use crate::context::LlamaContext;
33use crate::llama_backend::LlamaBackend;
34use crate::model::params::LlamaModelParams;
35use crate::token::LlamaToken;
36use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
37use crate::{
38    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
39    LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
40    TokenToStringError,
41};
42
43pub mod params;
44
45/// A safe wrapper around `llama_model`.
46#[derive(Debug)]
47#[repr(transparent)]
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaModel {
50    pub(crate) model: NonNull<llama_model>,
51}
52
53/// A safe wrapper around `llama_vocab`.
54#[derive(Debug)]
55#[repr(transparent)]
56#[allow(clippy::module_name_repetitions)]
57pub struct LlamaVocab {
58    pub(crate) vocab: NonNull<llama_vocab>,
59}
60
61impl LlamaVocab {
62    /// Get the number of tokens in the vocabulary.
63    #[must_use]
64    pub fn n_tokens(&self) -> i32 {
65        unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
66    }
67
68    /// Get the vocabulary type.
69    #[must_use]
70    pub fn vocab_type(&self) -> u32 {
71        unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()).try_into().unwrap() }
72    }
73
74    /// Get the BOS token.
75    #[must_use]
76    pub fn bos(&self) -> LlamaToken {
77        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
78    }
79
80    /// Get the EOS token.
81    #[must_use]
82    pub fn eos(&self) -> LlamaToken {
83        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
84    }
85
86    /// Get the EOT (end of turn) token.
87    #[must_use]
88    pub fn eot(&self) -> LlamaToken {
89        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
90    }
91
92    /// Get the CLS (classification) token.
93    #[must_use]
94    pub fn cls(&self) -> LlamaToken {
95        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
96    }
97
98    /// Get the SEP (separator) token.
99    #[must_use]
100    pub fn sep(&self) -> LlamaToken {
101        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
102    }
103
104    /// Get the NL (newline) token.
105    #[must_use]
106    pub fn nl(&self) -> LlamaToken {
107        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
108    }
109
110    /// Get the PAD (padding) token.
111    #[must_use]
112    pub fn pad(&self) -> LlamaToken {
113        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
114    }
115
116    /// Get the FIM prefix token.
117    #[must_use]
118    pub fn fim_pre(&self) -> LlamaToken {
119        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
120    }
121
122    /// Get the FIM suffix token.
123    #[must_use]
124    pub fn fim_suf(&self) -> LlamaToken {
125        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
126    }
127
128    /// Get the FIM middle token.
129    #[must_use]
130    pub fn fim_mid(&self) -> LlamaToken {
131        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
132    }
133
134    /// Get the FIM padding token.
135    #[must_use]
136    pub fn fim_pad(&self) -> LlamaToken {
137        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
138    }
139
140    /// Get the FIM repository token.
141    #[must_use]
142    pub fn fim_rep(&self) -> LlamaToken {
143        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
144    }
145
146    /// Get the FIM separator token.
147    #[must_use]
148    pub fn fim_sep(&self) -> LlamaToken {
149        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
150    }
151
152    /// Check whether BOS should be added.
153    #[must_use]
154    pub fn get_add_bos(&self) -> bool {
155        unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
156    }
157
158    /// Check whether EOS should be added.
159    #[must_use]
160    pub fn get_add_eos(&self) -> bool {
161        unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
162    }
163
164    /// Check whether SEP should be added.
165    #[must_use]
166    pub fn get_add_sep(&self) -> bool {
167        unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
168    }
169
170    /// Get the text representation of a token.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the text pointer is null or not valid UTF-8.
175    pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
176        let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
177        if ptr.is_null() {
178            return Err(StringFromModelError::ReturnedError(-1));
179        }
180        let cstr = unsafe { CStr::from_ptr(ptr) };
181        cstr.to_str().map_err(StringFromModelError::Utf8Error)
182    }
183
184    /// Get the score of a token.
185    #[must_use]
186    pub fn get_score(&self, token: LlamaToken) -> f32 {
187        unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
188    }
189
190    /// Get the attributes of a token.
191    #[must_use]
192    pub fn get_attr(&self, token: LlamaToken) -> u32 {
193        unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0).try_into().unwrap() }
194    }
195
196    /// Check if a token is a control token.
197    #[must_use]
198    pub fn is_control(&self, token: LlamaToken) -> bool {
199        unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
200    }
201
202    /// Check if a token is an end-of-generation token.
203    #[must_use]
204    pub fn is_eog(&self, token: LlamaToken) -> bool {
205        unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
206    }
207
208    /// Get the token mask value for the vocabulary.
209    #[must_use]
210    pub fn mask(&self) -> LlamaToken {
211        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
212    }
213}
214
215/// A safe wrapper around `llama_adapter_lora`.
216#[derive(Debug)]
217#[repr(transparent)]
218#[allow(clippy::module_name_repetitions)]
219pub struct LlamaLoraAdapter {
220    pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
221}
222
223impl LlamaLoraAdapter {
224    /// Get the number of metadata key-value pairs in the adapter.
225    #[must_use]
226    pub fn meta_count(&self) -> i32 {
227        unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
228    }
229
230    /// Get a metadata key by index.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the index is out of range or the key is not valid UTF-8.
235    #[allow(clippy::cast_sign_loss)]
236    pub fn meta_key_by_index(
237        &self,
238        index: i32,
239        buf_size: usize,
240    ) -> Result<String, StringFromModelError> {
241        let mut buf = vec![0u8; buf_size];
242        let ret = unsafe {
243            llama_cpp_sys_4::llama_adapter_meta_key_by_index(
244                self.lora_adapter.as_ptr(),
245                index,
246                buf.as_mut_ptr().cast::<c_char>(),
247                buf_size,
248            )
249        };
250        if ret < 0 {
251            return Err(StringFromModelError::ReturnedError(ret));
252        }
253        let len = ret as usize;
254        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
255        Ok(s.to_owned())
256    }
257
258    /// Get a metadata value by key name.
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if the key is not found or the value is not valid UTF-8.
263    #[allow(clippy::cast_sign_loss)]
264    pub fn meta_val_str(
265        &self,
266        key: &str,
267        buf_size: usize,
268    ) -> Result<String, StringFromModelError> {
269        let c_key =
270            CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
271        let mut buf = vec![0u8; buf_size];
272        let ret = unsafe {
273            llama_cpp_sys_4::llama_adapter_meta_val_str(
274                self.lora_adapter.as_ptr(),
275                c_key.as_ptr(),
276                buf.as_mut_ptr().cast::<c_char>(),
277                buf_size,
278            )
279        };
280        if ret < 0 {
281            return Err(StringFromModelError::ReturnedError(ret));
282        }
283        let len = ret as usize;
284        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
285        Ok(s.to_owned())
286    }
287
288    /// Get a metadata value by index.
289    ///
290    /// # Errors
291    ///
292    /// Returns an error if the index is out of range or the value is not valid UTF-8.
293    #[allow(clippy::cast_sign_loss)]
294    pub fn meta_val_str_by_index(
295        &self,
296        index: i32,
297        buf_size: usize,
298    ) -> Result<String, StringFromModelError> {
299        let mut buf = vec![0u8; buf_size];
300        let ret = unsafe {
301            llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
302                self.lora_adapter.as_ptr(),
303                index,
304                buf.as_mut_ptr().cast::<c_char>(),
305                buf_size,
306            )
307        };
308        if ret < 0 {
309            return Err(StringFromModelError::ReturnedError(ret));
310        }
311        let len = ret as usize;
312        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
313        Ok(s.to_owned())
314    }
315
316    /// Get all metadata as a list of `(key, value)` pairs.
317    ///
318    /// # Errors
319    ///
320    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
321    #[allow(clippy::cast_sign_loss)]
322    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
323        let count = self.meta_count();
324        let mut result = Vec::with_capacity(count as usize);
325        for i in 0..count {
326            let key = self.meta_key_by_index(i, 256)?;
327            let val = self.meta_val_str_by_index(i, 4096)?;
328            result.push((key, val));
329        }
330        Ok(result)
331    }
332
333    /// Get the number of invocation tokens for this adapter.
334    #[must_use]
335    pub fn n_invocation_tokens(&self) -> u64 {
336        unsafe {
337            llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(
338                self.lora_adapter.as_ptr(),
339            )
340        }
341    }
342
343    /// Get the invocation tokens for this adapter.
344    ///
345    /// Returns an empty slice if there are no invocation tokens.
346    #[must_use]
347    #[allow(clippy::cast_possible_truncation)]
348    pub fn invocation_tokens(&self) -> &[LlamaToken] {
349        let n = self.n_invocation_tokens() as usize;
350        if n == 0 {
351            return &[];
352        }
353        let ptr = unsafe {
354            llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(
355                self.lora_adapter.as_ptr(),
356            )
357        };
358        if ptr.is_null() {
359            return &[];
360        }
361        // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
362        unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
363    }
364}
365
366impl Drop for LlamaLoraAdapter {
367    fn drop(&mut self) {
368        unsafe {
369            llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
370        }
371    }
372}
373
374/// A Safe wrapper around `llama_chat_message`
375#[derive(Debug, Eq, PartialEq, Clone)]
376pub struct LlamaChatMessage {
377    role: CString,
378    content: CString,
379}
380
381impl LlamaChatMessage {
382    /// Create a new `LlamaChatMessage`.
383    ///
384    /// # Errors
385    ///
386    /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
387    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
388        Ok(Self {
389            role: CString::new(role)?,
390            content: CString::new(content)?,
391        })
392    }
393}
394
395/// How to determine if we should prepend a bos token to tokens
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
397pub enum AddBos {
398    /// Add the beginning of stream token to the start of the string.
399    Always,
400    /// Do not add the beginning of stream token to the start of the string.
401    Never,
402}
403
404/// How to determine if we should tokenize special tokens
405#[derive(Debug, Clone, Copy, PartialEq, Eq)]
406pub enum Special {
407    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
408    Tokenize,
409    /// Treat special and/or control tokens as plaintext.
410    Plaintext,
411}
412
413unsafe impl Send for LlamaModel {}
414
415unsafe impl Sync for LlamaModel {}
416
417impl LlamaModel {
418    /// Retrieves the vocabulary associated with the current Llama model.
419    ///
420    /// This method fetches the vocabulary from the underlying model using an unsafe
421    /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
422    /// the vocabulary data, which is wrapped in a `NonNull` for safety.
423    ///
424    /// # Safety
425    /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
426    /// which is assumed to return a valid pointer to the vocabulary. The caller should
427    /// ensure that the model object is properly initialized and valid before calling
428    /// this method, as dereferencing invalid pointers can lead to undefined behavior.
429    ///
430    /// # Returns
431    /// A `LlamaVocab` struct containing the vocabulary of the model.
432    ///
433    /// # Panics
434    ///
435    /// Panics if the underlying C function returns a null pointer.
436    ///
437    /// # Example
438    /// ```rust,ignore
439    /// let vocab = model.get_vocab();
440    /// ```
441    #[must_use]
442    pub fn get_vocab(&self) -> LlamaVocab {
443        let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
444
445        LlamaVocab {
446            vocab: NonNull::new(llama_vocab).unwrap(),
447        }
448    }
449    /// Get the number of tokens the model was trained on.
450    ///
451    /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
452    ///
453    /// # Panics
454    ///
455    /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
456    /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
457    /// which is almost certainly positive.
458    #[must_use]
459    pub fn n_ctx_train(&self) -> u32 {
460        let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
461        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
462    }
463
464    /// Get all tokens in the model.
465    ///
466    /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
467    /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
468    ///
469    /// # Parameters
470    ///
471    /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
472    pub fn tokens(
473        &self,
474        special: Special,
475    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
476        (0..self.n_vocab())
477            .map(LlamaToken::new)
478            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
479    }
480
481    /// Get the beginning of stream token.
482    ///
483    /// This function returns the token that represents the beginning of a stream (BOS token).
484    #[must_use]
485    pub fn token_bos(&self) -> LlamaToken {
486        let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
487        LlamaToken(token)
488    }
489
490    /// Get the end of stream token.
491    ///
492    /// This function returns the token that represents the end of a stream (EOS token).
493    #[must_use]
494    pub fn token_eos(&self) -> LlamaToken {
495        let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
496        LlamaToken(token)
497    }
498
499    /// Get the newline token.
500    ///
501    /// This function returns the token that represents a newline character.
502    #[must_use]
503    pub fn token_nl(&self) -> LlamaToken {
504        let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
505        LlamaToken(token)
506    }
507
508    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
509    ///
510    /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
511    /// such as EOS or other special tokens.
512    ///
513    /// # Parameters
514    ///
515    /// - `token`: The `LlamaToken` to check.
516    ///
517    /// # Returns
518    ///
519    /// - `true` if the token is an end-of-generation token, otherwise `false`.
520    #[must_use]
521    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
522        unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
523    }
524
525    /// Get the classification token.
526    #[must_use]
527    pub fn token_cls(&self) -> LlamaToken {
528        let token = unsafe { llama_token_cls(self.get_vocab().vocab.as_ref()) };
529        LlamaToken(token)
530    }
531
532    /// Get the end-of-turn token.
533    #[must_use]
534    pub fn token_eot(&self) -> LlamaToken {
535        let token = unsafe { llama_token_eot(self.get_vocab().vocab.as_ref()) };
536        LlamaToken(token)
537    }
538
539    /// Get the padding token.
540    #[must_use]
541    pub fn token_pad(&self) -> LlamaToken {
542        let token = unsafe { llama_token_pad(self.get_vocab().vocab.as_ref()) };
543        LlamaToken(token)
544    }
545
546    /// Get the separator token.
547    #[must_use]
548    pub fn token_sep(&self) -> LlamaToken {
549        let token = unsafe { llama_token_sep(self.get_vocab().vocab.as_ref()) };
550        LlamaToken(token)
551    }
552
553    /// Get the fill-in-the-middle prefix token.
554    #[must_use]
555    pub fn token_fim_pre(&self) -> LlamaToken {
556        let token = unsafe { llama_token_fim_pre(self.get_vocab().vocab.as_ref()) };
557        LlamaToken(token)
558    }
559
560    /// Get the fill-in-the-middle suffix token.
561    #[must_use]
562    pub fn token_fim_suf(&self) -> LlamaToken {
563        let token = unsafe { llama_token_fim_suf(self.get_vocab().vocab.as_ref()) };
564        LlamaToken(token)
565    }
566
567    /// Get the fill-in-the-middle middle token.
568    #[must_use]
569    pub fn token_fim_mid(&self) -> LlamaToken {
570        let token = unsafe { llama_token_fim_mid(self.get_vocab().vocab.as_ref()) };
571        LlamaToken(token)
572    }
573
574    /// Get the fill-in-the-middle padding token.
575    #[must_use]
576    pub fn token_fim_pad(&self) -> LlamaToken {
577        let token = unsafe { llama_token_fim_pad(self.get_vocab().vocab.as_ref()) };
578        LlamaToken(token)
579    }
580
581    /// Get the fill-in-the-middle repository token.
582    #[must_use]
583    pub fn token_fim_rep(&self) -> LlamaToken {
584        let token = unsafe { llama_token_fim_rep(self.get_vocab().vocab.as_ref()) };
585        LlamaToken(token)
586    }
587
588    /// Get the fill-in-the-middle separator token.
589    #[must_use]
590    pub fn token_fim_sep(&self) -> LlamaToken {
591        let token = unsafe { llama_token_fim_sep(self.get_vocab().vocab.as_ref()) };
592        LlamaToken(token)
593    }
594
595    /// Check if a token is a control token.
596    #[must_use]
597    pub fn token_is_control(&self, token: LlamaToken) -> bool {
598        unsafe { llama_token_is_control(self.get_vocab().vocab.as_ref(), token.0) }
599    }
600
601    /// Get the score of a token.
602    #[must_use]
603    pub fn token_get_score(&self, token: LlamaToken) -> f32 {
604        unsafe { llama_token_get_score(self.get_vocab().vocab.as_ref(), token.0) }
605    }
606
607    /// Get the raw text of a token.
608    ///
609    /// # Errors
610    ///
611    /// Returns an error if the token text is null or not valid UTF-8.
612    pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
613        let ptr =
614            unsafe { llama_token_get_text(self.get_vocab().vocab.as_ref(), token.0) };
615        if ptr.is_null() {
616            return Err(StringFromModelError::ReturnedError(-1));
617        }
618        let cstr = unsafe { CStr::from_ptr(ptr) };
619        cstr.to_str()
620            .map_err(StringFromModelError::Utf8Error)
621    }
622
623    /// Check if a BOS token should be added when tokenizing.
624    #[must_use]
625    pub fn add_bos_token(&self) -> bool {
626        unsafe { llama_add_bos_token(self.get_vocab().vocab.as_ref()) }
627    }
628
629    /// Check if an EOS token should be added when tokenizing.
630    #[must_use]
631    pub fn add_eos_token(&self) -> bool {
632        unsafe { llama_add_eos_token(self.get_vocab().vocab.as_ref()) }
633    }
634
635    /// Get the decoder start token.
636    ///
637    /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
638    /// of a sequence generation).
639    #[must_use]
640    pub fn decode_start_token(&self) -> LlamaToken {
641        let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
642        LlamaToken(token)
643    }
644
645    /// Convert a single token to a string.
646    ///
647    /// This function converts a `LlamaToken` into its string representation.
648    ///
649    /// # Errors
650    ///
651    /// This function returns an error if the token cannot be converted to a string. For more details, refer to
652    /// [`TokenToStringError`].
653    ///
654    /// # Parameters
655    ///
656    /// - `token`: The `LlamaToken` to convert.
657    /// - `special`: The `Special` value used to handle special tokens.
658    pub fn token_to_str(
659        &self,
660        token: LlamaToken,
661        special: Special,
662    ) -> Result<String, TokenToStringError> {
663        self.token_to_str_with_size(token, 32, special)
664    }
665
666    /// Convert a single token to bytes.
667    ///
668    /// This function converts a `LlamaToken` into a byte representation.
669    ///
670    /// # Errors
671    ///
672    /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
673    /// [`TokenToStringError`].
674    ///
675    /// # Parameters
676    ///
677    /// - `token`: The `LlamaToken` to convert.
678    /// - `special`: The `Special` value used to handle special tokens.
679    pub fn token_to_bytes(
680        &self,
681        token: LlamaToken,
682        special: Special,
683    ) -> Result<Vec<u8>, TokenToStringError> {
684        self.token_to_bytes_with_size(token, 32, special, None)
685    }
686
687    /// Convert a vector of tokens to a single string.
688    ///
689    /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
690    /// string representations.
691    ///
692    /// # Errors
693    ///
694    /// This function returns an error if any token cannot be converted to a string. For more details, refer to
695    /// [`TokenToStringError`].
696    ///
697    /// # Parameters
698    ///
699    /// - `tokens`: A slice of `LlamaToken`s to convert.
700    /// - `special`: The `Special` value used to handle special tokens.
701    pub fn tokens_to_str(
702        &self,
703        tokens: &[LlamaToken],
704        special: Special,
705    ) -> Result<String, TokenToStringError> {
706        let mut builder = String::with_capacity(tokens.len() * 4);
707        for str in tokens
708            .iter()
709            .copied()
710            .map(|t| self.token_to_str(t, special))
711        {
712            builder += &str?;
713        }
714        Ok(builder)
715    }
716
717    /// Convert a string to a vector of tokens.
718    ///
719    /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
720    /// and return the corresponding tokens.
721    ///
722    /// # Errors
723    ///
724    /// - This function will return an error if the input string contains a null byte.
725    ///
726    /// # Panics
727    ///
728    /// - This function will panic if the number of tokens exceeds `usize::MAX`.
729    ///
730    /// # Example
731    ///
732    /// ```no_run
733    /// use llama_cpp_4::model::LlamaModel;
734    ///
735    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
736    /// use std::path::Path;
737    /// use llama_cpp_4::model::AddBos;
738    /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
739    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
740    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
741    /// # Ok(())
742    /// # }
743    /// ```
744    pub fn str_to_token(
745        &self,
746        str: &str,
747        add_bos: AddBos,
748    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
749        let add_bos = match add_bos {
750            AddBos::Always => true,
751            AddBos::Never => false,
752        };
753
754        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
755        let mut buffer = Vec::with_capacity(tokens_estimation);
756
757        let c_string = CString::new(str)?;
758        let buffer_capacity =
759            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
760
761        let size = unsafe {
762            llama_tokenize(
763                self.get_vocab().vocab.as_ref(),
764                c_string.as_ptr(),
765                c_int::try_from(c_string.as_bytes().len())?,
766                buffer.as_mut_ptr(),
767                buffer_capacity,
768                add_bos,
769                true,
770            )
771        };
772
773        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
774        // as a result - size is guaranteed to be positive here.
775        let size = if size.is_negative() {
776            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
777            unsafe {
778                llama_tokenize(
779                    self.get_vocab().vocab.as_ref(),
780                    c_string.as_ptr(),
781                    c_int::try_from(c_string.as_bytes().len())?,
782                    buffer.as_mut_ptr(),
783                    -size,
784                    add_bos,
785                    true,
786                )
787            }
788        } else {
789            size
790        };
791
792        let size = usize::try_from(size).expect("size is positive and usize ");
793
794        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
795        unsafe { buffer.set_len(size) }
796        Ok(buffer.into_iter().map(LlamaToken).collect())
797    }
798
799    /// Get the type of a token.
800    ///
801    /// This function retrieves the attributes associated with a given token. The attributes are typically used to
802    /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
803    /// control tokens, etc.).
804    ///
805    /// # Panics
806    ///
807    /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
808    ///
809    /// # Example
810    ///
811    /// ```no_run
812    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
813    ///
814    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
815    /// let model = LlamaModel::load_from_file("path/to/model")?;
816    /// let token = LlamaToken(42);
817    /// let token_attrs = model.token_attr(token);
818    /// # Ok(())
819    /// # }
820    /// ```
821    #[must_use]
822    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
823        let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
824        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
825    }
826
827    /// Detokenize a slice of tokens into a string.
828    ///
829    /// This is the inverse of [`str_to_token`](Self::str_to_token).
830    ///
831    /// # Parameters
832    ///
833    /// - `tokens`: The tokens to detokenize.
834    /// - `remove_special`: If `true`, special tokens are removed from the output.
835    /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
836    ///
837    /// # Errors
838    ///
839    /// Returns an error if the detokenized text is not valid UTF-8.
840    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
841    pub fn detokenize(
842        &self,
843        tokens: &[LlamaToken],
844        remove_special: bool,
845        unparse_special: bool,
846    ) -> Result<String, StringFromModelError> {
847        // First call with empty buffer to get required size
848        let n_tokens = tokens.len() as i32;
849        let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
850        let needed = unsafe {
851            llama_detokenize(
852                self.get_vocab().vocab.as_ref(),
853                token_ptr,
854                n_tokens,
855                std::ptr::null_mut(),
856                0,
857                remove_special,
858                unparse_special,
859            )
860        };
861        // llama_detokenize returns negative required size when buffer is too small
862        let buf_size = if needed < 0 { (-needed) as usize } else { needed as usize };
863        let mut buf = vec![0u8; buf_size];
864        let ret = unsafe {
865            llama_detokenize(
866                self.get_vocab().vocab.as_ref(),
867                token_ptr,
868                n_tokens,
869                buf.as_mut_ptr().cast::<c_char>(),
870                buf_size as i32,
871                remove_special,
872                unparse_special,
873            )
874        };
875        if ret < 0 {
876            return Err(StringFromModelError::ReturnedError(ret));
877        }
878        let len = ret as usize;
879        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
880        Ok(s.to_owned())
881    }
882
883    /// Convert a token to a string with a specified buffer size.
884    ///
885    /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
886    /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
887    /// and the extra buffer size doesn't usually matter.
888    ///
889    /// # Errors
890    ///
891    /// - If the token type is unknown, an error will be returned.
892    /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
893    /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
894    ///
895    /// # Panics
896    ///
897    /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
898    /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
899    ///
900    /// # Example
901    ///
902    /// ```no_run
903    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
904    ///
905    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
906    /// let model = LlamaModel::load_from_file("path/to/model")?;
907    /// let token = LlamaToken(42);
908    /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
909    /// # Ok(())
910    /// # }
911    /// ```
912    pub fn token_to_str_with_size(
913        &self,
914        token: LlamaToken,
915        buffer_size: usize,
916        special: Special,
917    ) -> Result<String, TokenToStringError> {
918        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
919        Ok(String::from_utf8(bytes)?)
920    }
921
922    /// Convert a token to bytes with a specified buffer size.
923    ///
924    /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
925    /// the extra bytes do not really matter.
926    ///
927    /// # Errors
928    ///
929    /// - if the token type is unknown
930    /// - the resultant token is larger than `buffer_size`.
931    ///
932    /// # Panics
933    ///
934    /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
935    /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
936    ///
937    /// # Example
938    ///
939    /// ```no_run
940    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
941    ///
942    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
943    /// let model = LlamaModel::load_from_file("path/to/model")?;
944    /// let token = LlamaToken(42);
945    /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
946    /// # Ok(())
947    /// # }
948    /// ```
949    pub fn token_to_bytes_with_size(
950        &self,
951        token: LlamaToken,
952        buffer_size: usize,
953        special: Special,
954        lstrip: Option<NonZeroU16>,
955    ) -> Result<Vec<u8>, TokenToStringError> {
956        if token == self.token_nl() {
957            return Ok(String::from("\n").into_bytes());
958        }
959
960        // unsure what to do with this in the face of the 'special' arg + attr changes
961        let attrs = self.token_attr(token);
962        if (attrs.contains(LlamaTokenAttr::Control)
963            && (token == self.token_bos() || token == self.token_eos()))
964            || attrs.is_empty()
965            || attrs
966                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
967        {
968            return Ok(Vec::new());
969        }
970
971        let special = match special {
972            Special::Tokenize => true,
973            Special::Plaintext => false,
974        };
975
976        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
977        let len = string.as_bytes().len();
978        let len = c_int::try_from(len).expect("length fits into c_int");
979        let buf = string.into_raw();
980        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
981        let size = unsafe {
982            llama_token_to_piece(
983                self.get_vocab().vocab.as_ref(),
984                token.0,
985                buf,
986                len,
987                lstrip,
988                special,
989            )
990        };
991
992        match size {
993            0 => Err(TokenToStringError::UnknownTokenType),
994            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
995            size => {
996                let string = unsafe { CString::from_raw(buf) };
997                let mut bytes = string.into_bytes();
998                let len = usize::try_from(size).expect("size is positive and fits into usize");
999                bytes.truncate(len);
1000                Ok(bytes)
1001            }
1002        }
1003    }
1004    /// The number of tokens the model was trained on.
1005    ///
1006    /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1007    /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1008    ///
1009    /// # Example
1010    ///
1011    /// ```no_run
1012    /// use llama_cpp_4::model::LlamaModel;
1013    ///
1014    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1015    /// let model = LlamaModel::load_from_file("path/to/model")?;
1016    /// let n_vocab = model.n_vocab();
1017    /// # Ok(())
1018    /// # }
1019    /// ```
1020    #[must_use]
1021    pub fn n_vocab(&self) -> i32 {
1022        unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
1023    }
1024
1025    /// The type of vocab the model was trained on.
1026    ///
1027    /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1028    /// word-level tokens, or another tokenization scheme.
1029    ///
1030    /// # Panics
1031    ///
1032    /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1033    ///
1034    /// # Example
1035    ///
1036    /// ```no_run
1037    /// use llama_cpp_4::model::LlamaModel;
1038    ///
1039    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1040    /// let model = LlamaModel::load_from_file("path/to/model")?;
1041    /// let vocab_type = model.vocab_type();
1042    /// # Ok(())
1043    /// # }
1044    /// ```
1045    #[must_use]
1046    pub fn vocab_type(&self) -> VocabType {
1047        let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1048        VocabType::try_from(vocab_type).expect("invalid vocab type")
1049    }
1050
1051    /// Returns the number of embedding dimensions for the model.
1052    ///
1053    /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1054    /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1055    ///
1056    /// # Panics
1057    ///
1058    /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1059    ///
1060    /// # Example
1061    ///
1062    /// ```no_run
1063    /// use llama_cpp_4::model::LlamaModel;
1064    ///
1065    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1066    /// let model = LlamaModel::load_from_file("path/to/model")?;
1067    /// let n_embd = model.n_embd();
1068    /// # Ok(())
1069    /// # }
1070    /// ```
1071    #[must_use]
1072    pub fn n_embd(&self) -> c_int {
1073        unsafe { llama_n_embd(self.model.as_ptr()) }
1074    }
1075
1076    /// Get the number of transformer layers in the model.
1077    #[must_use]
1078    pub fn n_layer(&self) -> c_int {
1079        unsafe { llama_n_layer(self.model.as_ptr()) }
1080    }
1081
1082    /// Get the number of attention heads in the model.
1083    #[must_use]
1084    pub fn n_head(&self) -> c_int {
1085        unsafe { llama_n_head(self.model.as_ptr()) }
1086    }
1087
1088    /// Get the number of key-value attention heads in the model.
1089    #[must_use]
1090    pub fn n_head_kv(&self) -> c_int {
1091        unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1092    }
1093
1094    /// Get the input embedding size of the model.
1095    #[must_use]
1096    pub fn n_embd_inp(&self) -> c_int {
1097        unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1098    }
1099
1100    /// Get the output embedding size of the model.
1101    #[must_use]
1102    pub fn n_embd_out(&self) -> c_int {
1103        unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1104    }
1105
1106    /// Get the sliding window attention size of the model.
1107    /// Returns 0 if the model does not use sliding window attention.
1108    #[must_use]
1109    pub fn n_swa(&self) -> c_int {
1110        unsafe { llama_model_n_swa(self.model.as_ptr()) }
1111    }
1112
1113    /// Get the `RoPE` type used by the model.
1114    #[must_use]
1115    pub fn rope_type(&self) -> i32 {
1116        unsafe { llama_model_rope_type(self.model.as_ptr()) }
1117    }
1118
1119    /// Get the `RoPE` frequency scale used during training.
1120    #[must_use]
1121    pub fn rope_freq_scale_train(&self) -> f32 {
1122        unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1123    }
1124
1125    /// Get the model size in bytes.
1126    #[must_use]
1127    pub fn model_size(&self) -> u64 {
1128        unsafe { llama_model_size(self.model.as_ptr()) }
1129    }
1130
1131    /// Get the number of parameters in the model.
1132    #[must_use]
1133    pub fn n_params(&self) -> u64 {
1134        unsafe { llama_model_n_params(self.model.as_ptr()) }
1135    }
1136
1137    /// Get the number of classification outputs.
1138    #[must_use]
1139    pub fn n_cls_out(&self) -> u32 {
1140        unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1141    }
1142
1143    /// Get the classification label for the given index.
1144    ///
1145    /// # Errors
1146    ///
1147    /// Returns an error if the label is null or not valid UTF-8.
1148    pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1149        let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1150        if ptr.is_null() {
1151            return Err(StringFromModelError::ReturnedError(-1));
1152        }
1153        let cstr = unsafe { CStr::from_ptr(ptr) };
1154        cstr.to_str().map_err(StringFromModelError::Utf8Error)
1155    }
1156
1157    /// Get the number of metadata key-value pairs.
1158    #[must_use]
1159    pub fn meta_count(&self) -> c_int {
1160        unsafe { llama_model_meta_count(self.model.as_ptr()) }
1161    }
1162
1163    /// Get a model description string.
1164    ///
1165    /// The `buf_size` parameter specifies the maximum buffer size for the description.
1166    /// A default of 256 bytes is usually sufficient.
1167    ///
1168    /// # Errors
1169    ///
1170    /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1171    #[allow(clippy::cast_sign_loss)]
1172    pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1173        let mut buf = vec![0u8; buf_size];
1174        let ret = unsafe {
1175            llama_model_desc(
1176                self.model.as_ptr(),
1177                buf.as_mut_ptr().cast::<c_char>(),
1178                buf_size,
1179            )
1180        };
1181        if ret < 0 {
1182            return Err(StringFromModelError::ReturnedError(ret));
1183        }
1184        let len = ret as usize;
1185        let s = std::str::from_utf8(&buf[..len])
1186            .map_err(StringFromModelError::Utf8Error)?;
1187        Ok(s.to_owned())
1188    }
1189
1190    /// Get a metadata key by index.
1191    ///
1192    /// The `buf_size` parameter specifies the maximum buffer size for the key.
1193    /// A default of 256 bytes is usually sufficient.
1194    ///
1195    /// # Errors
1196    ///
1197    /// Returns an error if the index is out of range or the key is not valid UTF-8.
1198    #[allow(clippy::cast_sign_loss)]
1199    pub fn meta_key_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1200        let mut buf = vec![0u8; buf_size];
1201        let ret = unsafe {
1202            llama_model_meta_key_by_index(
1203                self.model.as_ptr(),
1204                index,
1205                buf.as_mut_ptr().cast::<c_char>(),
1206                buf_size,
1207            )
1208        };
1209        if ret < 0 {
1210            return Err(StringFromModelError::ReturnedError(ret));
1211        }
1212        let len = ret as usize;
1213        let s = std::str::from_utf8(&buf[..len])
1214            .map_err(StringFromModelError::Utf8Error)?;
1215        Ok(s.to_owned())
1216    }
1217
1218    /// Get a metadata value string by index.
1219    ///
1220    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1221    /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1222    ///
1223    /// # Errors
1224    ///
1225    /// Returns an error if the index is out of range or the value is not valid UTF-8.
1226    #[allow(clippy::cast_sign_loss)]
1227    pub fn meta_val_str_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1228        let mut buf = vec![0u8; buf_size];
1229        let ret = unsafe {
1230            llama_model_meta_val_str_by_index(
1231                self.model.as_ptr(),
1232                index,
1233                buf.as_mut_ptr().cast::<c_char>(),
1234                buf_size,
1235            )
1236        };
1237        if ret < 0 {
1238            return Err(StringFromModelError::ReturnedError(ret));
1239        }
1240        let len = ret as usize;
1241        let s = std::str::from_utf8(&buf[..len])
1242            .map_err(StringFromModelError::Utf8Error)?;
1243        Ok(s.to_owned())
1244    }
1245
1246    /// Get a metadata value by key name.
1247    ///
1248    /// This is more convenient than iterating metadata by index when you know the key.
1249    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1250    ///
1251    /// # Errors
1252    ///
1253    /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1254    #[allow(clippy::cast_sign_loss)]
1255    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1256        let c_key = CString::new(key)
1257            .map_err(|_| StringFromModelError::ReturnedError(-1))?;
1258        let mut buf = vec![0u8; buf_size];
1259        let ret = unsafe {
1260            llama_model_meta_val_str(
1261                self.model.as_ptr(),
1262                c_key.as_ptr(),
1263                buf.as_mut_ptr().cast::<c_char>(),
1264                buf_size,
1265            )
1266        };
1267        if ret < 0 {
1268            return Err(StringFromModelError::ReturnedError(ret));
1269        }
1270        let len = ret as usize;
1271        let s = std::str::from_utf8(&buf[..len])
1272            .map_err(StringFromModelError::Utf8Error)?;
1273        Ok(s.to_owned())
1274    }
1275
1276    /// Get all metadata as a list of `(key, value)` pairs.
1277    ///
1278    /// This is a convenience method that iterates over all metadata entries.
1279    /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1280    /// For values that may be larger (e.g. token lists), use
1281    /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1282    ///
1283    /// # Errors
1284    ///
1285    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1286    #[allow(clippy::cast_sign_loss)]
1287    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1288        let count = self.meta_count();
1289        let mut result = Vec::with_capacity(count as usize);
1290        for i in 0..count {
1291            let key = self.meta_key_by_index(i, 256)?;
1292            let val = self.meta_val_str_by_index(i, 4096)?;
1293            result.push((key, val));
1294        }
1295        Ok(result)
1296    }
1297
1298    /// Check if the model has an encoder.
1299    #[must_use]
1300    pub fn has_encoder(&self) -> bool {
1301        unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1302    }
1303
1304    /// Check if the model has a decoder.
1305    #[must_use]
1306    pub fn has_decoder(&self) -> bool {
1307        unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1308    }
1309
1310    /// Check if the model is recurrent (e.g. Mamba, RWKV).
1311    #[must_use]
1312    pub fn is_recurrent(&self) -> bool {
1313        unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1314    }
1315
1316    /// Check if the model is a hybrid model.
1317    #[must_use]
1318    pub fn is_hybrid(&self) -> bool {
1319        unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1320    }
1321
1322    /// Check if the model is a diffusion model.
1323    #[must_use]
1324    pub fn is_diffusion(&self) -> bool {
1325        unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1326    }
1327
1328    /// Get chat template from model.
1329    ///
1330    /// # Errors
1331    ///
1332    /// - If the model does not have a chat template, it will return an error.
1333    /// - If the chat template is not a valid `CString`, it will return an error.
1334    ///
1335    /// # Example
1336    ///
1337    /// ```no_run
1338    /// use llama_cpp_4::model::LlamaModel;
1339    ///
1340    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1341    /// let model = LlamaModel::load_from_file("path/to/model")?;
1342    /// let chat_template = model.get_chat_template(1024)?;
1343    /// # Ok(())
1344    /// # }
1345    /// ```
1346    #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1347    pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1348        // longest known template is about 1200 bytes from llama.cpp
1349        let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1350        let chat_ptr = chat_temp.into_raw();
1351        let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1352
1353        let ret = unsafe {
1354            llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1355        };
1356
1357        if ret < 0 {
1358            return Err(ChatTemplateError::MissingTemplate(ret));
1359        }
1360
1361        let template_c = unsafe { CString::from_raw(chat_ptr) };
1362        let template = template_c.to_str()?;
1363
1364        let ret: usize = ret.try_into().unwrap();
1365        if template.len() < ret {
1366            return Err(ChatTemplateError::BuffSizeError(ret + 1));
1367        }
1368
1369        Ok(template.to_owned())
1370    }
1371
1372    /// Loads a model from a file.
1373    ///
1374    /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1375    ///
1376    /// # Errors
1377    ///
1378    /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1379    /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1380    ///
1381    /// # Example
1382    ///
1383    /// ```no_run
1384    /// use llama_cpp_4::model::LlamaModel;
1385    /// use std::path::Path;
1386    ///
1387    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1388    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1389    /// # Ok(())
1390    /// # }
1391    /// ```
1392    #[tracing::instrument(skip_all, fields(params))]
1393    pub fn load_from_file(
1394        _: &LlamaBackend,
1395        path: impl AsRef<Path>,
1396        params: &LlamaModelParams,
1397    ) -> Result<Self, LlamaModelLoadError> {
1398        let path = path.as_ref();
1399        debug_assert!(
1400            Path::new(path).exists(),
1401            "{} does not exist",
1402            path.display()
1403        );
1404        let path = path
1405            .to_str()
1406            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1407
1408        let cstr = CString::new(path)?;
1409        let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
1410
1411        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1412
1413        tracing::debug!(?path, "Loaded model");
1414        Ok(LlamaModel { model })
1415    }
1416
1417    /// Load a model from multiple split files.
1418    ///
1419    /// This function loads a model that has been split across multiple files. This is useful for
1420    /// very large models that exceed filesystem limitations or need to be distributed across
1421    /// multiple storage devices.
1422    ///
1423    /// # Arguments
1424    ///
1425    /// * `paths` - A slice of paths to the split model files
1426    /// * `params` - The model parameters
1427    ///
1428    /// # Errors
1429    ///
1430    /// Returns an error if:
1431    /// - Any of the paths cannot be converted to a C string
1432    /// - The model fails to load from the splits
1433    /// - Any path doesn't exist or isn't accessible
1434    ///
1435    /// # Example
1436    ///
1437    /// ```no_run
1438    /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1439    /// use llama_cpp_4::llama_backend::LlamaBackend;
1440    ///
1441    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1442    /// let backend = LlamaBackend::init()?;
1443    /// let params = LlamaModelParams::default();
1444    ///
1445    /// let paths = vec![
1446    ///     "model-00001-of-00003.gguf",
1447    ///     "model-00002-of-00003.gguf",
1448    ///     "model-00003-of-00003.gguf",
1449    /// ];
1450    ///
1451    /// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
1452    /// # Ok(())
1453    /// # }
1454    /// ```
1455    #[tracing::instrument(skip_all)]
1456    pub fn load_from_splits(
1457        _: &LlamaBackend,
1458        paths: &[impl AsRef<Path>],
1459        params: &LlamaModelParams,
1460    ) -> Result<Self, LlamaModelLoadError> {
1461        // Convert paths to C strings
1462        let c_strings: Vec<CString> = paths
1463            .iter()
1464            .map(|p| {
1465                let path = p.as_ref();
1466                debug_assert!(path.exists(), "{} does not exist", path.display());
1467                let path_str = path
1468                    .to_str()
1469                    .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1470                CString::new(path_str).map_err(LlamaModelLoadError::from)
1471            })
1472            .collect::<Result<Vec<_>, _>>()?;
1473
1474        // Create array of pointers to C strings
1475        let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1476
1477        // Load the model from splits
1478        let llama_model = unsafe {
1479            llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1480        };
1481
1482        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1483
1484        tracing::debug!("Loaded model from {} splits", paths.len());
1485        Ok(LlamaModel { model })
1486    }
1487
1488    /// Load a model from a `FILE` pointer.
1489    ///
1490    /// # Safety
1491    ///
1492    /// The `file` pointer must be a valid, open `FILE*`.
1493    ///
1494    /// # Errors
1495    ///
1496    /// Returns an error if the model cannot be loaded.
1497    pub unsafe fn load_from_file_ptr(
1498        file: *mut llama_cpp_sys_4::FILE,
1499        params: &LlamaModelParams,
1500    ) -> Result<Self, LlamaModelLoadError> {
1501        let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1502        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1503        Ok(LlamaModel { model })
1504    }
1505
1506    /// Initialize a model from user-provided data.
1507    ///
1508    /// # Safety
1509    ///
1510    /// The metadata, callback, and user data must be valid.
1511    ///
1512    /// # Errors
1513    ///
1514    /// Returns an error if the model cannot be initialized.
1515    pub unsafe fn init_from_user(
1516        metadata: *mut llama_cpp_sys_4::gguf_context,
1517        set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1518        set_tensor_data_ud: *mut std::ffi::c_void,
1519        params: &LlamaModelParams,
1520    ) -> Result<Self, LlamaModelLoadError> {
1521        let model = llama_cpp_sys_4::llama_model_init_from_user(
1522            metadata,
1523            set_tensor_data,
1524            set_tensor_data_ud,
1525            params.params,
1526        );
1527        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1528        Ok(LlamaModel { model })
1529    }
1530
1531    /// Save the model to a file.
1532    ///
1533    /// # Panics
1534    ///
1535    /// Panics if the path contains null bytes.
1536    pub fn save_to_file(&self, path: impl AsRef<Path>) {
1537        let path = path.as_ref();
1538        let path_str = path.to_str().expect("path is not valid UTF-8");
1539        let c_path = CString::new(path_str).expect("path contains null bytes");
1540        unsafe {
1541            llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1542        }
1543    }
1544
1545    /// Get the list of built-in chat templates.
1546    ///
1547    /// Returns the names of all chat templates that are built into llama.cpp.
1548    ///
1549    /// # Panics
1550    ///
1551    /// Panics if any template name is not valid UTF-8.
1552    #[allow(clippy::cast_sign_loss)]
1553    #[must_use]
1554    pub fn chat_builtin_templates() -> Vec<String> {
1555        // First call to get count
1556        let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1557        if count <= 0 {
1558            return Vec::new();
1559        }
1560        let count = count as usize;
1561        let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1562        unsafe {
1563            llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1564        }
1565        ptrs.iter()
1566            .map(|&p| {
1567                let cstr = unsafe { CStr::from_ptr(p) };
1568                cstr.to_str()
1569                    .expect("template name is not valid UTF-8")
1570                    .to_owned()
1571            })
1572            .collect()
1573    }
1574
1575    /// Initializes a lora adapter from a file.
1576    ///
1577    /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1578    /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1579    /// to the model for improved performance on specialized tasks.
1580    ///
1581    /// # Errors
1582    ///
1583    /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1584    ///
1585    /// # Example
1586    ///
1587    /// ```no_run
1588    /// use llama_cpp_4::model::{LlamaModel, LlamaLoraAdapter};
1589    /// use std::path::Path;
1590    ///
1591    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1592    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1593    /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1594    /// # Ok(())
1595    /// # }
1596    /// ```
1597    pub fn lora_adapter_init(
1598        &self,
1599        path: impl AsRef<Path>,
1600    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1601        let path = path.as_ref();
1602        debug_assert!(
1603            Path::new(path).exists(),
1604            "{} does not exist",
1605            path.display()
1606        );
1607
1608        let path = path
1609            .to_str()
1610            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1611                path.to_path_buf(),
1612            ))?;
1613
1614        let cstr = CString::new(path)?;
1615        let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1616
1617        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1618
1619        tracing::debug!(?path, "Initialized lora adapter");
1620        Ok(LlamaLoraAdapter {
1621            lora_adapter: adapter,
1622        })
1623    }
1624
1625    /// Create a new context from this model.
1626    ///
1627    /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1628    /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1629    /// control over model parameters for a specific task.
1630    ///
1631    /// # Errors
1632    ///
1633    /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1634    ///   for more detailed error descriptions.
1635    ///
1636    /// # Example
1637    ///
1638    /// ```no_run
1639    /// use llama_cpp_4::model::{LlamaModel, LlamaContext};
1640    /// use llama_cpp_4::LlamaContextParams;
1641    ///
1642    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1643    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1644    /// let context = model.new_context(&LlamaBackend::init()?, LlamaContextParams::default())?;
1645    /// # Ok(())
1646    /// # }
1647    /// ```
1648    #[allow(clippy::needless_pass_by_value)]
1649    pub fn new_context(
1650        &self,
1651        _: &LlamaBackend,
1652        params: LlamaContextParams,
1653    ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1654        // Apply TurboQuant attn-rotation preference before the KV cache is
1655        // initialised inside llama_new_context_with_model.
1656        let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1657        if params.attn_rot_disabled {
1658            // SAFETY: we restore the value right after the call.
1659            #[allow(unused_unsafe)]
1660            unsafe {
1661                std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1662            }
1663        } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1664            // params say "enabled" – only clear if it was previously unset
1665            // (respect explicit user env var).
1666        }
1667
1668        let context_params = params.context_params;
1669        let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
1670
1671        // Restore the env-var to its previous state.
1672        #[allow(unused_unsafe)]
1673        match prev_rot_var {
1674            Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1675            None if params.attn_rot_disabled => unsafe {
1676                std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1677            },
1678            None => {}
1679        }
1680
1681        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1682        Ok(LlamaContext::new(self, context, params.embeddings()))
1683    }
1684
1685    /// Apply the model's chat template to a sequence of messages.
1686    ///
1687    /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1688    /// template determines the structure or style of conversation between the system and user, such as token formatting,
1689    /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1690    /// is provided, the default template used by `llama.cpp` will be applied.
1691    ///
1692    /// For more information on supported templates, visit:
1693    /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1694    ///
1695    /// # Arguments
1696    ///
1697    /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1698    /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1699    /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1700    ///
1701    /// # Errors
1702    ///
1703    /// There are several possible points of failure when applying the chat template:
1704    /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1705    /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1706    ///
1707    /// # Example
1708    ///
1709    /// ```no_run
1710    /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1711    ///
1712    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1713    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1714    /// let chat = vec![
1715    ///     LlamaChatMessage::new("user", "Hello!"),
1716    ///     LlamaChatMessage::new("assistant", "Hi! How can I assist you today?"),
1717    /// ];
1718    /// let formatted_chat = model.apply_chat_template(None, chat, true)?;
1719    /// # Ok(())
1720    /// # }
1721    /// ```
1722    ///
1723    /// # Notes
1724    ///
1725    /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1726    /// # Panics
1727    ///
1728    /// Panics if the buffer length exceeds `i32::MAX`.
1729    #[tracing::instrument(skip_all)]
1730    pub fn apply_chat_template(
1731        &self,
1732        tmpl: Option<&str>,
1733        chat: &[LlamaChatMessage],
1734        add_ass: bool,
1735    ) -> Result<String, ApplyChatTemplateError> {
1736        // Compute raw message byte total from the original LlamaChatMessage vec
1737        // *before* we shadow `chat` with the sys-type vec below.
1738        let message_length = chat.iter().fold(0usize, |acc, c| {
1739            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1740        });
1741
1742        // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1743        let chat_sys: Vec<llama_chat_message> = chat
1744            .iter()
1745            .map(|c| llama_chat_message {
1746                role: c.role.as_ptr(),
1747                content: c.content.as_ptr(),
1748            })
1749            .collect();
1750
1751        // Set the tmpl pointer.
1752        let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1753        let tmpl_ptr = tmpl_cstring
1754            .as_ref()
1755            .map_or(std::ptr::null(), |s| s.as_ptr());
1756
1757        // `message_length * 4` is far too small for models whose built-in chat
1758        // template adds a long default system prompt (e.g. Qwen3.5 prepends
1759        // ~80+ chars of markup even for a one-word user message).  Start with
1760        // at least 4 KiB so short inputs like "hi" always have room.
1761        //
1762        // `llama_chat_apply_template` returns the number of bytes it *actually*
1763        // needed when the buffer was too small, so we retry exactly once with
1764        // that precise size rather than giving up immediately.
1765        let mut buf_size = message_length.saturating_mul(4).max(4096);
1766
1767        for _ in 0..2 {
1768            // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1769            let mut buff = vec![0u8; buf_size];
1770            let res = unsafe {
1771                llama_chat_apply_template(
1772                    tmpl_ptr,
1773                    chat_sys.as_ptr(),
1774                    chat_sys.len(),
1775                    add_ass,
1776                    buff.as_mut_ptr().cast(),
1777                    i32::try_from(buff.len()).expect("buffer length fits in i32"),
1778                )
1779            };
1780
1781            if res < 0 {
1782                return Err(ApplyChatTemplateError::BuffSizeError);
1783            }
1784
1785            #[allow(clippy::cast_sign_loss)]
1786            let needed = res as usize;
1787            if needed > buf_size {
1788                // Buffer was too small — retry with the exact size llama.cpp reported.
1789                buf_size = needed + 1; // +1 for null terminator
1790                continue;
1791            }
1792
1793            // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1794            // into `buff`; `needed` bytes were used.
1795            let formatted = unsafe {
1796                CStr::from_ptr(buff.as_ptr().cast())
1797                    .to_string_lossy()
1798                    .into_owned()
1799            };
1800            return Ok(formatted);
1801        }
1802
1803        Err(ApplyChatTemplateError::BuffSizeError)
1804    }
1805
1806    /// Build a split GGUF file path for a specific chunk.
1807    ///
1808    /// This utility function creates the standardized filename for a split model chunk
1809    /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1810    ///
1811    /// # Arguments
1812    ///
1813    /// * `path_prefix` - The base path and filename prefix
1814    /// * `split_no` - The split number (1-indexed)
1815    /// * `split_count` - The total number of splits
1816    ///
1817    /// # Returns
1818    ///
1819    /// Returns the formatted split path as a String
1820    ///
1821    /// # Example
1822    ///
1823    /// ```
1824    /// use llama_cpp_4::model::LlamaModel;
1825    ///
1826    /// let path = LlamaModel::split_path("/models/llama", 2, 4);
1827    /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1828    /// ```
1829    ///
1830    /// # Panics
1831    ///
1832    /// Panics if the path prefix contains a null byte.
1833    #[must_use]
1834    pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1835        let mut buffer = vec![0u8; 1024];
1836        let len = unsafe {
1837            llama_split_path(
1838                buffer.as_mut_ptr().cast::<c_char>(),
1839                buffer.len(),
1840                CString::new(path_prefix).unwrap().as_ptr(),
1841                split_no,
1842                split_count,
1843            )
1844        };
1845
1846        let len = usize::try_from(len).expect("split_path length fits in usize");
1847        buffer.truncate(len);
1848        String::from_utf8(buffer).unwrap_or_default()
1849    }
1850
1851    /// Extract the path prefix from a split filename.
1852    ///
1853    /// This function extracts the base path prefix from a split model filename,
1854    /// but only if the `split_no` and `split_count` match the pattern in the filename.
1855    ///
1856    /// # Arguments
1857    ///
1858    /// * `split_path` - The full path to the split file
1859    /// * `split_no` - The expected split number
1860    /// * `split_count` - The expected total number of splits
1861    ///
1862    /// # Returns
1863    ///
1864    /// Returns the path prefix if the pattern matches, or None if it doesn't
1865    ///
1866    /// # Example
1867    ///
1868    /// ```
1869    /// use llama_cpp_4::model::LlamaModel;
1870    ///
1871    /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 2, 4);
1872    /// assert_eq!(prefix, Some("/models/llama".to_string()));
1873    /// ```
1874    ///
1875    /// # Panics
1876    ///
1877    /// Panics if the split path contains a null byte.
1878    #[must_use]
1879    pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1880        let mut buffer = vec![0u8; 1024];
1881        let len = unsafe {
1882            llama_split_prefix(
1883                buffer.as_mut_ptr().cast::<c_char>(),
1884                buffer.len(),
1885                CString::new(split_path).unwrap().as_ptr(),
1886                split_no,
1887                split_count,
1888            )
1889        };
1890
1891        if len > 0 {
1892            let len = usize::try_from(len).expect("split_prefix length fits in usize");
1893            buffer.truncate(len);
1894            String::from_utf8(buffer).ok()
1895        } else {
1896            None
1897        }
1898    }
1899}
1900
1901#[allow(clippy::cast_precision_loss)]
1902impl fmt::Display for LlamaModel {
1903    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1904        let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1905        write!(
1906            f,
1907            "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1908            layers = self.n_layer(),
1909            heads = self.n_head(),
1910            embd = self.n_embd(),
1911            params = self.n_params(),
1912            size = self.model_size() as f64 / (1024.0 * 1024.0),
1913        )
1914    }
1915}
1916
1917impl Drop for LlamaModel {
1918    fn drop(&mut self) {
1919        unsafe { llama_free_model(self.model.as_ptr()) }
1920    }
1921}
1922
1923/// Defines the possible types of vocabulary used by the model.
1924///
1925/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1926/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1927///
1928/// # Variants
1929///
1930/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1931/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1932///
1933/// # Example
1934///
1935/// ```no_run
1936/// use llama_cpp_4::model::VocabType;
1937///
1938/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1939/// let vocab_type = VocabType::BPE;
1940/// match vocab_type {
1941///     VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1942///     VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1943/// }
1944/// # Ok(())
1945/// # }
1946/// ```
1947#[repr(u32)]
1948#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1949pub enum VocabType {
1950    /// Byte Pair Encoding
1951    BPE = LLAMA_VOCAB_TYPE_BPE as _,
1952    /// Sentence Piece Tokenizer
1953    SPM = LLAMA_VOCAB_TYPE_SPM as _,
1954}
1955
1956/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1957///
1958/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1959///
1960/// # Variants
1961///
1962/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1963///
1964/// # Example
1965///
1966/// ```no_run
1967/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1968///
1969/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1970/// let invalid_value = 999; // Not a valid vocabulary type
1971/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1972/// println!("Error: {}", error);
1973/// # Ok(())
1974/// # }
1975/// ```
1976#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1977pub enum LlamaTokenTypeFromIntError {
1978    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
1979    #[error("Unknown Value {0}")]
1980    UnknownValue(llama_vocab_type),
1981}
1982
1983impl TryFrom<llama_vocab_type> for VocabType {
1984    type Error = LlamaTokenTypeFromIntError;
1985
1986    fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
1987        match value {
1988            LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1989            LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1990            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1991        }
1992    }
1993}