Skip to main content

llama_cpp_4/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11    llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
12    llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
13    llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
14    llama_model_free, llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
15    llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
16    llama_model_load_from_file, llama_model_load_from_splits, llama_model_meta_count,
17    llama_model_meta_key_by_index, llama_model_meta_val_str, llama_model_meta_val_str_by_index,
18    llama_model_n_cls_out, llama_model_n_ctx_train, llama_model_n_embd, llama_model_n_embd_inp,
19    llama_model_n_embd_out, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
20    llama_model_n_params, llama_model_n_swa, llama_model_rope_freq_scale_train,
21    llama_model_rope_type, llama_model_save_to_file, llama_model_size, llama_split_path,
22    llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab, llama_vocab_type,
23    LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
24};
25
26use crate::context::params::LlamaContextParams;
27use crate::context::LlamaContext;
28use crate::llama_backend::LlamaBackend;
29use crate::model::params::LlamaModelParams;
30use crate::token::LlamaToken;
31use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
32use crate::{
33    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
34    LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
35    TokenToStringError,
36};
37
38pub mod params;
39
40/// A safe wrapper around `llama_model`.
41#[derive(Debug)]
42#[repr(transparent)]
43#[allow(clippy::module_name_repetitions)]
44pub struct LlamaModel {
45    pub(crate) model: NonNull<llama_model>,
46}
47
48/// A safe wrapper around `llama_vocab`.
49#[derive(Debug)]
50#[repr(transparent)]
51#[allow(clippy::module_name_repetitions)]
52pub struct LlamaVocab {
53    pub(crate) vocab: NonNull<llama_vocab>,
54}
55
56impl LlamaVocab {
57    /// Get the number of tokens in the vocabulary.
58    #[must_use]
59    pub fn n_tokens(&self) -> i32 {
60        unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
61    }
62
63    /// Get the vocabulary type.
64    #[must_use]
65    pub fn vocab_type(&self) -> u32 {
66        unsafe {
67            llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref())
68                .try_into()
69                .unwrap()
70        }
71    }
72
73    /// Get the BOS token.
74    #[must_use]
75    pub fn bos(&self) -> LlamaToken {
76        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
77    }
78
79    /// Get the EOS token.
80    #[must_use]
81    pub fn eos(&self) -> LlamaToken {
82        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
83    }
84
85    /// Get the EOT (end of turn) token.
86    #[must_use]
87    pub fn eot(&self) -> LlamaToken {
88        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
89    }
90
91    /// Get the CLS (classification) token.
92    #[must_use]
93    pub fn cls(&self) -> LlamaToken {
94        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
95    }
96
97    /// Get the SEP (separator) token.
98    #[must_use]
99    pub fn sep(&self) -> LlamaToken {
100        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
101    }
102
103    /// Get the NL (newline) token.
104    #[must_use]
105    pub fn nl(&self) -> LlamaToken {
106        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
107    }
108
109    /// Get the PAD (padding) token.
110    #[must_use]
111    pub fn pad(&self) -> LlamaToken {
112        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
113    }
114
115    /// Get the FIM prefix token.
116    #[must_use]
117    pub fn fim_pre(&self) -> LlamaToken {
118        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
119    }
120
121    /// Get the FIM suffix token.
122    #[must_use]
123    pub fn fim_suf(&self) -> LlamaToken {
124        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
125    }
126
127    /// Get the FIM middle token.
128    #[must_use]
129    pub fn fim_mid(&self) -> LlamaToken {
130        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
131    }
132
133    /// Get the FIM padding token.
134    #[must_use]
135    pub fn fim_pad(&self) -> LlamaToken {
136        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
137    }
138
139    /// Get the FIM repository token.
140    #[must_use]
141    pub fn fim_rep(&self) -> LlamaToken {
142        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
143    }
144
145    /// Get the FIM separator token.
146    #[must_use]
147    pub fn fim_sep(&self) -> LlamaToken {
148        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
149    }
150
151    /// Check whether BOS should be added.
152    #[must_use]
153    pub fn get_add_bos(&self) -> bool {
154        unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
155    }
156
157    /// Check whether EOS should be added.
158    #[must_use]
159    pub fn get_add_eos(&self) -> bool {
160        unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
161    }
162
163    /// Check whether SEP should be added.
164    #[must_use]
165    pub fn get_add_sep(&self) -> bool {
166        unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
167    }
168
169    /// Get the text representation of a token.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if the text pointer is null or not valid UTF-8.
174    pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
175        let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
176        if ptr.is_null() {
177            return Err(StringFromModelError::ReturnedError(-1));
178        }
179        let cstr = unsafe { CStr::from_ptr(ptr) };
180        cstr.to_str().map_err(StringFromModelError::Utf8Error)
181    }
182
183    /// Get the score of a token.
184    #[must_use]
185    pub fn get_score(&self, token: LlamaToken) -> f32 {
186        unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
187    }
188
189    /// Get the attributes of a token.
190    #[must_use]
191    pub fn get_attr(&self, token: LlamaToken) -> u32 {
192        unsafe {
193            llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0)
194                .try_into()
195                .unwrap()
196        }
197    }
198
199    /// Check if a token is a control token.
200    #[must_use]
201    pub fn is_control(&self, token: LlamaToken) -> bool {
202        unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
203    }
204
205    /// Check if a token is an end-of-generation token.
206    #[must_use]
207    pub fn is_eog(&self, token: LlamaToken) -> bool {
208        unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
209    }
210
211    /// Get the token mask value for the vocabulary.
212    #[must_use]
213    pub fn mask(&self) -> LlamaToken {
214        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
215    }
216}
217
218/// A safe wrapper around `llama_adapter_lora`.
219#[derive(Debug)]
220#[repr(transparent)]
221#[allow(clippy::module_name_repetitions)]
222pub struct LlamaLoraAdapter {
223    pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
224}
225
226impl LlamaLoraAdapter {
227    /// Get the number of metadata key-value pairs in the adapter.
228    #[must_use]
229    pub fn meta_count(&self) -> i32 {
230        unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
231    }
232
233    /// Get a metadata key by index.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the index is out of range or the key is not valid UTF-8.
238    #[allow(clippy::cast_sign_loss)]
239    pub fn meta_key_by_index(
240        &self,
241        index: i32,
242        buf_size: usize,
243    ) -> Result<String, StringFromModelError> {
244        let mut buf = vec![0u8; buf_size];
245        let ret = unsafe {
246            llama_cpp_sys_4::llama_adapter_meta_key_by_index(
247                self.lora_adapter.as_ptr(),
248                index,
249                buf.as_mut_ptr().cast::<c_char>(),
250                buf_size,
251            )
252        };
253        if ret < 0 {
254            return Err(StringFromModelError::ReturnedError(ret));
255        }
256        let len = ret as usize;
257        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
258        Ok(s.to_owned())
259    }
260
261    /// Get a metadata value by key name.
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the key is not found or the value is not valid UTF-8.
266    #[allow(clippy::cast_sign_loss)]
267    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
268        let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
269        let mut buf = vec![0u8; buf_size];
270        let ret = unsafe {
271            llama_cpp_sys_4::llama_adapter_meta_val_str(
272                self.lora_adapter.as_ptr(),
273                c_key.as_ptr(),
274                buf.as_mut_ptr().cast::<c_char>(),
275                buf_size,
276            )
277        };
278        if ret < 0 {
279            return Err(StringFromModelError::ReturnedError(ret));
280        }
281        let len = ret as usize;
282        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
283        Ok(s.to_owned())
284    }
285
286    /// Get a metadata value by index.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if the index is out of range or the value is not valid UTF-8.
291    #[allow(clippy::cast_sign_loss)]
292    pub fn meta_val_str_by_index(
293        &self,
294        index: i32,
295        buf_size: usize,
296    ) -> Result<String, StringFromModelError> {
297        let mut buf = vec![0u8; buf_size];
298        let ret = unsafe {
299            llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
300                self.lora_adapter.as_ptr(),
301                index,
302                buf.as_mut_ptr().cast::<c_char>(),
303                buf_size,
304            )
305        };
306        if ret < 0 {
307            return Err(StringFromModelError::ReturnedError(ret));
308        }
309        let len = ret as usize;
310        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
311        Ok(s.to_owned())
312    }
313
314    /// Get all metadata as a list of `(key, value)` pairs.
315    ///
316    /// # Errors
317    ///
318    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
319    #[allow(clippy::cast_sign_loss)]
320    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
321        let count = self.meta_count();
322        let mut result = Vec::with_capacity(count as usize);
323        for i in 0..count {
324            let key = self.meta_key_by_index(i, 256)?;
325            let val = self.meta_val_str_by_index(i, 4096)?;
326            result.push((key, val));
327        }
328        Ok(result)
329    }
330
331    /// Get the number of invocation tokens for this adapter.
332    #[must_use]
333    pub fn n_invocation_tokens(&self) -> u64 {
334        unsafe {
335            llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(self.lora_adapter.as_ptr())
336        }
337    }
338
339    /// Get the invocation tokens for this adapter.
340    ///
341    /// Returns an empty slice if there are no invocation tokens.
342    #[must_use]
343    #[allow(clippy::cast_possible_truncation)]
344    pub fn invocation_tokens(&self) -> &[LlamaToken] {
345        let n = self.n_invocation_tokens() as usize;
346        if n == 0 {
347            return &[];
348        }
349        let ptr = unsafe {
350            llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(self.lora_adapter.as_ptr())
351        };
352        if ptr.is_null() {
353            return &[];
354        }
355        // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
356        unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
357    }
358}
359
360impl Drop for LlamaLoraAdapter {
361    fn drop(&mut self) {
362        unsafe {
363            llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
364        }
365    }
366}
367
368/// A Safe wrapper around `llama_chat_message`
369#[derive(Debug, Eq, PartialEq, Clone)]
370pub struct LlamaChatMessage {
371    role: CString,
372    content: CString,
373}
374
375impl LlamaChatMessage {
376    /// Create a new `LlamaChatMessage`.
377    ///
378    /// # Errors
379    ///
380    /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
381    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
382        Ok(Self {
383            role: CString::new(role)?,
384            content: CString::new(content)?,
385        })
386    }
387}
388
389/// How to determine if we should prepend a bos token to tokens
390#[derive(Debug, Clone, Copy, PartialEq, Eq)]
391pub enum AddBos {
392    /// Add the beginning of stream token to the start of the string.
393    Always,
394    /// Do not add the beginning of stream token to the start of the string.
395    Never,
396}
397
398/// How to determine if we should tokenize special tokens
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub enum Special {
401    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
402    Tokenize,
403    /// Treat special and/or control tokens as plaintext.
404    Plaintext,
405}
406
407unsafe impl Send for LlamaModel {}
408
409unsafe impl Sync for LlamaModel {}
410
411impl LlamaModel {
412    /// Retrieves the vocabulary associated with the current Llama model.
413    ///
414    /// This method fetches the vocabulary from the underlying model using an unsafe
415    /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
416    /// the vocabulary data, which is wrapped in a `NonNull` for safety.
417    ///
418    /// # Safety
419    /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
420    /// which is assumed to return a valid pointer to the vocabulary. The caller should
421    /// ensure that the model object is properly initialized and valid before calling
422    /// this method, as dereferencing invalid pointers can lead to undefined behavior.
423    ///
424    /// # Returns
425    /// A `LlamaVocab` struct containing the vocabulary of the model.
426    ///
427    /// # Panics
428    ///
429    /// Panics if the underlying C function returns a null pointer.
430    ///
431    /// # Example
432    /// ```rust,ignore
433    /// let vocab = model.get_vocab();
434    /// ```
435    #[must_use]
436    pub fn get_vocab(&self) -> LlamaVocab {
437        let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
438
439        LlamaVocab {
440            vocab: NonNull::new(llama_vocab).unwrap(),
441        }
442    }
443    /// Get the number of tokens the model was trained on.
444    ///
445    /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
446    ///
447    /// # Panics
448    ///
449    /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
450    /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
451    /// which is almost certainly positive.
452    #[must_use]
453    pub fn n_ctx_train(&self) -> u32 {
454        let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
455        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
456    }
457
458    /// Get all tokens in the model.
459    ///
460    /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
461    /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
462    ///
463    /// # Parameters
464    ///
465    /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
466    pub fn tokens(
467        &self,
468        special: Special,
469    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
470        (0..self.n_vocab())
471            .map(LlamaToken::new)
472            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
473    }
474
475    /// Get the beginning of stream token.
476    ///
477    /// This function returns the token that represents the beginning of a stream (BOS token).
478    #[must_use]
479    pub fn token_bos(&self) -> LlamaToken {
480        self.get_vocab().bos()
481    }
482
483    /// Get the end of stream token.
484    ///
485    /// This function returns the token that represents the end of a stream (EOS token).
486    #[must_use]
487    pub fn token_eos(&self) -> LlamaToken {
488        self.get_vocab().eos()
489    }
490
491    /// Get the newline token.
492    ///
493    /// This function returns the token that represents a newline character.
494    #[must_use]
495    pub fn token_nl(&self) -> LlamaToken {
496        self.get_vocab().nl()
497    }
498
499    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
500    ///
501    /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
502    /// such as EOS or other special tokens.
503    ///
504    /// # Parameters
505    ///
506    /// - `token`: The `LlamaToken` to check.
507    ///
508    /// # Returns
509    ///
510    /// - `true` if the token is an end-of-generation token, otherwise `false`.
511    #[must_use]
512    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
513        self.get_vocab().is_eog(token)
514    }
515
516    /// Get the classification token.
517    #[must_use]
518    pub fn token_cls(&self) -> LlamaToken {
519        self.get_vocab().cls()
520    }
521
522    /// Get the end-of-turn token.
523    #[must_use]
524    pub fn token_eot(&self) -> LlamaToken {
525        self.get_vocab().eot()
526    }
527
528    /// Get the padding token.
529    #[must_use]
530    pub fn token_pad(&self) -> LlamaToken {
531        self.get_vocab().pad()
532    }
533
534    /// Get the separator token.
535    #[must_use]
536    pub fn token_sep(&self) -> LlamaToken {
537        self.get_vocab().sep()
538    }
539
540    /// Get the fill-in-the-middle prefix token.
541    #[must_use]
542    pub fn token_fim_pre(&self) -> LlamaToken {
543        self.get_vocab().fim_pre()
544    }
545
546    /// Get the fill-in-the-middle suffix token.
547    #[must_use]
548    pub fn token_fim_suf(&self) -> LlamaToken {
549        self.get_vocab().fim_suf()
550    }
551
552    /// Get the fill-in-the-middle middle token.
553    #[must_use]
554    pub fn token_fim_mid(&self) -> LlamaToken {
555        self.get_vocab().fim_mid()
556    }
557
558    /// Get the fill-in-the-middle padding token.
559    #[must_use]
560    pub fn token_fim_pad(&self) -> LlamaToken {
561        self.get_vocab().fim_pad()
562    }
563
564    /// Get the fill-in-the-middle repository token.
565    #[must_use]
566    pub fn token_fim_rep(&self) -> LlamaToken {
567        self.get_vocab().fim_rep()
568    }
569
570    /// Get the fill-in-the-middle separator token.
571    #[must_use]
572    pub fn token_fim_sep(&self) -> LlamaToken {
573        self.get_vocab().fim_sep()
574    }
575
576    /// Check if a token is a control token.
577    #[must_use]
578    pub fn token_is_control(&self, token: LlamaToken) -> bool {
579        self.get_vocab().is_control(token)
580    }
581
582    /// Get the score of a token.
583    #[must_use]
584    pub fn token_get_score(&self, token: LlamaToken) -> f32 {
585        self.get_vocab().get_score(token)
586    }
587
588    /// Get the raw text of a token.
589    ///
590    /// # Errors
591    ///
592    /// Returns an error if the token text is null or not valid UTF-8.
593    pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
594        let ptr = unsafe {
595            llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
596        };
597        if ptr.is_null() {
598            return Err(StringFromModelError::ReturnedError(-1));
599        }
600        let cstr = unsafe { CStr::from_ptr(ptr) };
601        cstr.to_str().map_err(StringFromModelError::Utf8Error)
602    }
603
604    /// Check if a BOS token should be added when tokenizing.
605    #[must_use]
606    pub fn add_bos_token(&self) -> bool {
607        self.get_vocab().get_add_bos()
608    }
609
610    /// Check if an EOS token should be added when tokenizing.
611    #[must_use]
612    pub fn add_eos_token(&self) -> bool {
613        self.get_vocab().get_add_eos()
614    }
615
616    /// Get the decoder start token.
617    ///
618    /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
619    /// of a sequence generation).
620    #[must_use]
621    pub fn decode_start_token(&self) -> LlamaToken {
622        let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
623        LlamaToken(token)
624    }
625
626    /// Convert a single token to a string.
627    ///
628    /// This function converts a `LlamaToken` into its string representation.
629    ///
630    /// # Errors
631    ///
632    /// This function returns an error if the token cannot be converted to a string. For more details, refer to
633    /// [`TokenToStringError`].
634    ///
635    /// # Parameters
636    ///
637    /// - `token`: The `LlamaToken` to convert.
638    /// - `special`: The `Special` value used to handle special tokens.
639    pub fn token_to_str(
640        &self,
641        token: LlamaToken,
642        special: Special,
643    ) -> Result<String, TokenToStringError> {
644        self.token_to_str_with_size(token, 32, special)
645    }
646
647    /// Convert a single token to bytes.
648    ///
649    /// This function converts a `LlamaToken` into a byte representation.
650    ///
651    /// # Errors
652    ///
653    /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
654    /// [`TokenToStringError`].
655    ///
656    /// # Parameters
657    ///
658    /// - `token`: The `LlamaToken` to convert.
659    /// - `special`: The `Special` value used to handle special tokens.
660    pub fn token_to_bytes(
661        &self,
662        token: LlamaToken,
663        special: Special,
664    ) -> Result<Vec<u8>, TokenToStringError> {
665        self.token_to_bytes_with_size(token, 32, special, None)
666    }
667
668    /// Convert a vector of tokens to a single string.
669    ///
670    /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
671    /// string representations.
672    ///
673    /// # Errors
674    ///
675    /// This function returns an error if any token cannot be converted to a string. For more details, refer to
676    /// [`TokenToStringError`].
677    ///
678    /// # Parameters
679    ///
680    /// - `tokens`: A slice of `LlamaToken`s to convert.
681    /// - `special`: The `Special` value used to handle special tokens.
682    pub fn tokens_to_str(
683        &self,
684        tokens: &[LlamaToken],
685        special: Special,
686    ) -> Result<String, TokenToStringError> {
687        let mut builder = String::with_capacity(tokens.len() * 4);
688        for str in tokens
689            .iter()
690            .copied()
691            .map(|t| self.token_to_str(t, special))
692        {
693            builder += &str?;
694        }
695        Ok(builder)
696    }
697
698    /// Convert a string to a vector of tokens.
699    ///
700    /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
701    /// and return the corresponding tokens.
702    ///
703    /// # Errors
704    ///
705    /// - This function will return an error if the input string contains a null byte.
706    ///
707    /// # Panics
708    ///
709    /// - This function will panic if the number of tokens exceeds `usize::MAX`.
710    ///
711    /// # Example
712    ///
713    /// ```no_run
714    /// use llama_cpp_4::model::LlamaModel;
715    ///
716    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
717    /// use std::path::Path;
718    /// use llama_cpp_4::model::AddBos;
719    /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
720    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
721    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
722    /// # Ok(())
723    /// # }
724    /// ```
725    pub fn str_to_token(
726        &self,
727        str: &str,
728        add_bos: AddBos,
729    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
730        let add_bos = match add_bos {
731            AddBos::Always => true,
732            AddBos::Never => false,
733        };
734
735        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
736        let mut buffer = Vec::with_capacity(tokens_estimation);
737
738        let c_string = CString::new(str)?;
739        let buffer_capacity =
740            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
741
742        let size = unsafe {
743            llama_tokenize(
744                self.get_vocab().vocab.as_ref(),
745                c_string.as_ptr(),
746                c_int::try_from(c_string.as_bytes().len())?,
747                buffer.as_mut_ptr(),
748                buffer_capacity,
749                add_bos,
750                true,
751            )
752        };
753
754        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
755        // as a result - size is guaranteed to be positive here.
756        let size = if size.is_negative() {
757            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
758            unsafe {
759                llama_tokenize(
760                    self.get_vocab().vocab.as_ref(),
761                    c_string.as_ptr(),
762                    c_int::try_from(c_string.as_bytes().len())?,
763                    buffer.as_mut_ptr(),
764                    -size,
765                    add_bos,
766                    true,
767                )
768            }
769        } else {
770            size
771        };
772
773        let size = usize::try_from(size).expect("size is positive and usize ");
774
775        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
776        unsafe { buffer.set_len(size) }
777        Ok(buffer.into_iter().map(LlamaToken).collect())
778    }
779
780    /// Get the type of a token.
781    ///
782    /// This function retrieves the attributes associated with a given token. The attributes are typically used to
783    /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
784    /// control tokens, etc.).
785    ///
786    /// # Panics
787    ///
788    /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
789    ///
790    /// # Example
791    ///
792    /// ```no_run
793    /// use llama_cpp_4::model::LlamaModel;
794    /// use llama_cpp_4::model::params::LlamaModelParams;
795    /// use llama_cpp_4::llama_backend::LlamaBackend;
796    /// use llama_cpp_4::token::LlamaToken;
797    ///
798    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
799    /// let backend = LlamaBackend::init()?;
800    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
801    /// let token = LlamaToken::new(42);
802    /// let token_attrs = model.token_attr(token);
803    /// # Ok(())
804    /// # }
805    /// ```
806    #[must_use]
807    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
808        let token_type =
809            unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
810        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
811    }
812
813    /// Detokenize a slice of tokens into a string.
814    ///
815    /// This is the inverse of [`str_to_token`](Self::str_to_token).
816    ///
817    /// # Parameters
818    ///
819    /// - `tokens`: The tokens to detokenize.
820    /// - `remove_special`: If `true`, special tokens are removed from the output.
821    /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
822    ///
823    /// # Errors
824    ///
825    /// Returns an error if the detokenized text is not valid UTF-8.
826    #[allow(
827        clippy::cast_possible_truncation,
828        clippy::cast_possible_wrap,
829        clippy::cast_sign_loss
830    )]
831    pub fn detokenize(
832        &self,
833        tokens: &[LlamaToken],
834        remove_special: bool,
835        unparse_special: bool,
836    ) -> Result<String, StringFromModelError> {
837        // First call with empty buffer to get required size
838        let n_tokens = tokens.len() as i32;
839        let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
840        let needed = unsafe {
841            llama_detokenize(
842                self.get_vocab().vocab.as_ref(),
843                token_ptr,
844                n_tokens,
845                std::ptr::null_mut(),
846                0,
847                remove_special,
848                unparse_special,
849            )
850        };
851        // llama_detokenize returns negative required size when buffer is too small
852        let buf_size = if needed < 0 {
853            (-needed) as usize
854        } else {
855            needed as usize
856        };
857        let mut buf = vec![0u8; buf_size];
858        let ret = unsafe {
859            llama_detokenize(
860                self.get_vocab().vocab.as_ref(),
861                token_ptr,
862                n_tokens,
863                buf.as_mut_ptr().cast::<c_char>(),
864                buf_size as i32,
865                remove_special,
866                unparse_special,
867            )
868        };
869        if ret < 0 {
870            return Err(StringFromModelError::ReturnedError(ret));
871        }
872        let len = ret as usize;
873        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
874        Ok(s.to_owned())
875    }
876
877    /// Convert a token to a string with a specified buffer size.
878    ///
879    /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
880    /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
881    /// and the extra buffer size doesn't usually matter.
882    ///
883    /// # Errors
884    ///
885    /// - If the token type is unknown, an error will be returned.
886    /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
887    /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
888    ///
889    /// # Panics
890    ///
891    /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
892    /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
893    ///
894    /// # Example
895    ///
896    /// ```no_run
897    /// use llama_cpp_4::model::{LlamaModel, Special};
898    /// use llama_cpp_4::model::params::LlamaModelParams;
899    /// use llama_cpp_4::llama_backend::LlamaBackend;
900    /// use llama_cpp_4::token::LlamaToken;
901    ///
902    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
903    /// let backend = LlamaBackend::init()?;
904    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
905    /// let token = LlamaToken::new(42);
906    /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
907    /// # Ok(())
908    /// # }
909    /// ```
910    pub fn token_to_str_with_size(
911        &self,
912        token: LlamaToken,
913        buffer_size: usize,
914        special: Special,
915    ) -> Result<String, TokenToStringError> {
916        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
917        Ok(String::from_utf8(bytes)?)
918    }
919
920    /// Convert a token to bytes with a specified buffer size.
921    ///
922    /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
923    /// the extra bytes do not really matter.
924    ///
925    /// # Errors
926    ///
927    /// - if the token type is unknown
928    /// - the resultant token is larger than `buffer_size`.
929    ///
930    /// # Panics
931    ///
932    /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
933    /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
934    ///
935    /// # Example
936    ///
937    /// ```no_run
938    /// use llama_cpp_4::model::{LlamaModel, Special};
939    /// use llama_cpp_4::model::params::LlamaModelParams;
940    /// use llama_cpp_4::llama_backend::LlamaBackend;
941    /// use llama_cpp_4::token::LlamaToken;
942    ///
943    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
944    /// let backend = LlamaBackend::init()?;
945    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
946    /// let token = LlamaToken::new(42);
947    /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
948    /// # Ok(())
949    /// # }
950    /// ```
951    pub fn token_to_bytes_with_size(
952        &self,
953        token: LlamaToken,
954        buffer_size: usize,
955        special: Special,
956        lstrip: Option<NonZeroU16>,
957    ) -> Result<Vec<u8>, TokenToStringError> {
958        if token == self.token_nl() {
959            return Ok(String::from("\n").into_bytes());
960        }
961
962        // unsure what to do with this in the face of the 'special' arg + attr changes
963        let attrs = self.token_attr(token);
964        if (attrs.contains(LlamaTokenAttr::Control)
965            && (token == self.token_bos() || token == self.token_eos()))
966            || attrs.is_empty()
967            || attrs
968                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
969        {
970            return Ok(Vec::new());
971        }
972
973        let special = match special {
974            Special::Tokenize => true,
975            Special::Plaintext => false,
976        };
977
978        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
979        let len = string.as_bytes().len();
980        let len = c_int::try_from(len).expect("length fits into c_int");
981        let buf = string.into_raw();
982        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
983        let size = unsafe {
984            llama_token_to_piece(
985                self.get_vocab().vocab.as_ref(),
986                token.0,
987                buf,
988                len,
989                lstrip,
990                special,
991            )
992        };
993
994        match size {
995            0 => Err(TokenToStringError::UnknownTokenType),
996            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
997            size => {
998                let string = unsafe { CString::from_raw(buf) };
999                let mut bytes = string.into_bytes();
1000                let len = usize::try_from(size).expect("size is positive and fits into usize");
1001                bytes.truncate(len);
1002                Ok(bytes)
1003            }
1004        }
1005    }
1006    /// The number of tokens the model was trained on.
1007    ///
1008    /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1009    /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1010    ///
1011    /// # Example
1012    ///
1013    /// ```no_run
1014    /// use llama_cpp_4::model::LlamaModel;
1015    /// use llama_cpp_4::model::params::LlamaModelParams;
1016    /// use llama_cpp_4::llama_backend::LlamaBackend;
1017    ///
1018    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1019    /// let backend = LlamaBackend::init()?;
1020    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1021    /// let n_vocab = model.n_vocab();
1022    /// # Ok(())
1023    /// # }
1024    /// ```
1025    #[must_use]
1026    pub fn n_vocab(&self) -> i32 {
1027        self.get_vocab().n_tokens()
1028    }
1029
1030    /// The type of vocab the model was trained on.
1031    ///
1032    /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1033    /// word-level tokens, or another tokenization scheme.
1034    ///
1035    /// # Panics
1036    ///
1037    /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1038    ///
1039    /// # Example
1040    ///
1041    /// ```no_run
1042    /// use llama_cpp_4::model::LlamaModel;
1043    /// use llama_cpp_4::model::params::LlamaModelParams;
1044    /// use llama_cpp_4::llama_backend::LlamaBackend;
1045    ///
1046    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1047    /// let backend = LlamaBackend::init()?;
1048    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1049    /// let vocab_type = model.vocab_type();
1050    /// # Ok(())
1051    /// # }
1052    /// ```
1053    #[must_use]
1054    pub fn vocab_type(&self) -> VocabType {
1055        let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1056        VocabType::try_from(vocab_type).expect("invalid vocab type")
1057    }
1058
1059    /// Returns the number of embedding dimensions for the model.
1060    ///
1061    /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1062    /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1063    ///
1064    /// # Panics
1065    ///
1066    /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1067    ///
1068    /// # Example
1069    ///
1070    /// ```no_run
1071    /// use llama_cpp_4::model::LlamaModel;
1072    /// use llama_cpp_4::model::params::LlamaModelParams;
1073    /// use llama_cpp_4::llama_backend::LlamaBackend;
1074    ///
1075    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1076    /// let backend = LlamaBackend::init()?;
1077    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1078    /// let n_embd = model.n_embd();
1079    /// # Ok(())
1080    /// # }
1081    /// ```
1082    #[must_use]
1083    pub fn n_embd(&self) -> c_int {
1084        unsafe { llama_model_n_embd(self.model.as_ptr()) }
1085    }
1086
1087    /// Get the number of transformer layers in the model.
1088    #[must_use]
1089    pub fn n_layer(&self) -> c_int {
1090        unsafe { llama_model_n_layer(self.model.as_ptr()) }
1091    }
1092
1093    /// Get the number of attention heads in the model.
1094    #[must_use]
1095    pub fn n_head(&self) -> c_int {
1096        unsafe { llama_model_n_head(self.model.as_ptr()) }
1097    }
1098
1099    /// Get the number of key-value attention heads in the model.
1100    #[must_use]
1101    pub fn n_head_kv(&self) -> c_int {
1102        unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1103    }
1104
1105    /// Get the input embedding size of the model.
1106    #[must_use]
1107    pub fn n_embd_inp(&self) -> c_int {
1108        unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1109    }
1110
1111    /// Get the output embedding size of the model.
1112    #[must_use]
1113    pub fn n_embd_out(&self) -> c_int {
1114        unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1115    }
1116
1117    /// Get the sliding window attention size of the model.
1118    /// Returns 0 if the model does not use sliding window attention.
1119    #[must_use]
1120    pub fn n_swa(&self) -> c_int {
1121        unsafe { llama_model_n_swa(self.model.as_ptr()) }
1122    }
1123
1124    /// Get the `RoPE` type used by the model.
1125    #[must_use]
1126    pub fn rope_type(&self) -> i32 {
1127        unsafe { llama_model_rope_type(self.model.as_ptr()) }
1128    }
1129
1130    /// Get the `RoPE` frequency scale used during training.
1131    #[must_use]
1132    pub fn rope_freq_scale_train(&self) -> f32 {
1133        unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1134    }
1135
1136    /// Get the model size in bytes.
1137    #[must_use]
1138    pub fn model_size(&self) -> u64 {
1139        unsafe { llama_model_size(self.model.as_ptr()) }
1140    }
1141
1142    /// Get the number of parameters in the model.
1143    #[must_use]
1144    pub fn n_params(&self) -> u64 {
1145        unsafe { llama_model_n_params(self.model.as_ptr()) }
1146    }
1147
1148    /// Get the number of classification outputs.
1149    #[must_use]
1150    pub fn n_cls_out(&self) -> u32 {
1151        unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1152    }
1153
1154    /// Get the classification label for the given index.
1155    ///
1156    /// # Errors
1157    ///
1158    /// Returns an error if the label is null or not valid UTF-8.
1159    pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1160        let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1161        if ptr.is_null() {
1162            return Err(StringFromModelError::ReturnedError(-1));
1163        }
1164        let cstr = unsafe { CStr::from_ptr(ptr) };
1165        cstr.to_str().map_err(StringFromModelError::Utf8Error)
1166    }
1167
1168    /// Get the number of metadata key-value pairs.
1169    #[must_use]
1170    pub fn meta_count(&self) -> c_int {
1171        unsafe { llama_model_meta_count(self.model.as_ptr()) }
1172    }
1173
1174    /// Get a model description string.
1175    ///
1176    /// The `buf_size` parameter specifies the maximum buffer size for the description.
1177    /// A default of 256 bytes is usually sufficient.
1178    ///
1179    /// # Errors
1180    ///
1181    /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1182    #[allow(clippy::cast_sign_loss)]
1183    pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1184        let mut buf = vec![0u8; buf_size];
1185        let ret = unsafe {
1186            llama_model_desc(
1187                self.model.as_ptr(),
1188                buf.as_mut_ptr().cast::<c_char>(),
1189                buf_size,
1190            )
1191        };
1192        if ret < 0 {
1193            return Err(StringFromModelError::ReturnedError(ret));
1194        }
1195        let len = ret as usize;
1196        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1197        Ok(s.to_owned())
1198    }
1199
1200    /// Get a metadata key by index.
1201    ///
1202    /// The `buf_size` parameter specifies the maximum buffer size for the key.
1203    /// A default of 256 bytes is usually sufficient.
1204    ///
1205    /// # Errors
1206    ///
1207    /// Returns an error if the index is out of range or the key is not valid UTF-8.
1208    #[allow(clippy::cast_sign_loss)]
1209    pub fn meta_key_by_index(
1210        &self,
1211        index: i32,
1212        buf_size: usize,
1213    ) -> Result<String, StringFromModelError> {
1214        let mut buf = vec![0u8; buf_size];
1215        let ret = unsafe {
1216            llama_model_meta_key_by_index(
1217                self.model.as_ptr(),
1218                index,
1219                buf.as_mut_ptr().cast::<c_char>(),
1220                buf_size,
1221            )
1222        };
1223        if ret < 0 {
1224            return Err(StringFromModelError::ReturnedError(ret));
1225        }
1226        let len = ret as usize;
1227        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1228        Ok(s.to_owned())
1229    }
1230
1231    /// Get a metadata value string by index.
1232    ///
1233    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1234    /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1235    ///
1236    /// # Errors
1237    ///
1238    /// Returns an error if the index is out of range or the value is not valid UTF-8.
1239    #[allow(clippy::cast_sign_loss)]
1240    pub fn meta_val_str_by_index(
1241        &self,
1242        index: i32,
1243        buf_size: usize,
1244    ) -> Result<String, StringFromModelError> {
1245        let mut buf = vec![0u8; buf_size];
1246        let ret = unsafe {
1247            llama_model_meta_val_str_by_index(
1248                self.model.as_ptr(),
1249                index,
1250                buf.as_mut_ptr().cast::<c_char>(),
1251                buf_size,
1252            )
1253        };
1254        if ret < 0 {
1255            return Err(StringFromModelError::ReturnedError(ret));
1256        }
1257        let len = ret as usize;
1258        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1259        Ok(s.to_owned())
1260    }
1261
1262    /// Get a metadata value by key name.
1263    ///
1264    /// This is more convenient than iterating metadata by index when you know the key.
1265    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1266    ///
1267    /// # Errors
1268    ///
1269    /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1270    #[allow(clippy::cast_sign_loss)]
1271    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1272        let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
1273        let mut buf = vec![0u8; buf_size];
1274        let ret = unsafe {
1275            llama_model_meta_val_str(
1276                self.model.as_ptr(),
1277                c_key.as_ptr(),
1278                buf.as_mut_ptr().cast::<c_char>(),
1279                buf_size,
1280            )
1281        };
1282        if ret < 0 {
1283            return Err(StringFromModelError::ReturnedError(ret));
1284        }
1285        let len = ret as usize;
1286        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1287        Ok(s.to_owned())
1288    }
1289
1290    /// Get all metadata as a list of `(key, value)` pairs.
1291    ///
1292    /// This is a convenience method that iterates over all metadata entries.
1293    /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1294    /// For values that may be larger (e.g. token lists), use
1295    /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1296    ///
1297    /// # Errors
1298    ///
1299    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1300    #[allow(clippy::cast_sign_loss)]
1301    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1302        let count = self.meta_count();
1303        let mut result = Vec::with_capacity(count as usize);
1304        for i in 0..count {
1305            let key = self.meta_key_by_index(i, 256)?;
1306            let val = self.meta_val_str_by_index(i, 4096)?;
1307            result.push((key, val));
1308        }
1309        Ok(result)
1310    }
1311
1312    /// Check if the model has an encoder.
1313    #[must_use]
1314    pub fn has_encoder(&self) -> bool {
1315        unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1316    }
1317
1318    /// Check if the model has a decoder.
1319    #[must_use]
1320    pub fn has_decoder(&self) -> bool {
1321        unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1322    }
1323
1324    /// Check if the model is recurrent (e.g. Mamba, RWKV).
1325    #[must_use]
1326    pub fn is_recurrent(&self) -> bool {
1327        unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1328    }
1329
1330    /// Check if the model is a hybrid model.
1331    #[must_use]
1332    pub fn is_hybrid(&self) -> bool {
1333        unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1334    }
1335
1336    /// Check if the model is a diffusion model.
1337    #[must_use]
1338    pub fn is_diffusion(&self) -> bool {
1339        unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1340    }
1341
1342    /// Get chat template from model.
1343    ///
1344    /// # Errors
1345    ///
1346    /// - If the model does not have a chat template, it will return an error.
1347    /// - If the chat template is not a valid `CString`, it will return an error.
1348    ///
1349    /// # Example
1350    ///
1351    /// ```no_run
1352    /// use llama_cpp_4::model::LlamaModel;
1353    /// use llama_cpp_4::model::params::LlamaModelParams;
1354    /// use llama_cpp_4::llama_backend::LlamaBackend;
1355    ///
1356    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1357    /// let backend = LlamaBackend::init()?;
1358    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1359    /// let chat_template = model.get_chat_template(1024)?;
1360    /// # Ok(())
1361    /// # }
1362    /// ```
1363    #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1364    pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1365        // longest known template is about 1200 bytes from llama.cpp
1366        let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1367        let chat_ptr = chat_temp.into_raw();
1368        let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1369
1370        let ret = unsafe {
1371            llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1372        };
1373
1374        if ret < 0 {
1375            return Err(ChatTemplateError::MissingTemplate(ret));
1376        }
1377
1378        let template_c = unsafe { CString::from_raw(chat_ptr) };
1379        let template = template_c.to_str()?;
1380
1381        let ret: usize = ret.try_into().unwrap();
1382        if template.len() < ret {
1383            return Err(ChatTemplateError::BuffSizeError(ret + 1));
1384        }
1385
1386        Ok(template.to_owned())
1387    }
1388
1389    /// Loads a model from a file.
1390    ///
1391    /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1392    ///
1393    /// # Errors
1394    ///
1395    /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1396    /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1397    ///
1398    /// # Example
1399    ///
1400    /// ```no_run
1401    /// use llama_cpp_4::model::LlamaModel;
1402    /// use llama_cpp_4::model::params::LlamaModelParams;
1403    /// use llama_cpp_4::llama_backend::LlamaBackend;
1404    ///
1405    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1406    /// let backend = LlamaBackend::init()?;
1407    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1408    /// # Ok(())
1409    /// # }
1410    /// ```
1411    #[tracing::instrument(skip_all, fields(params))]
1412    pub fn load_from_file(
1413        _: &LlamaBackend,
1414        path: impl AsRef<Path>,
1415        params: &LlamaModelParams,
1416    ) -> Result<Self, LlamaModelLoadError> {
1417        let path = path.as_ref();
1418        debug_assert!(
1419            Path::new(path).exists(),
1420            "{} does not exist",
1421            path.display()
1422        );
1423        let path = path
1424            .to_str()
1425            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1426
1427        let cstr = CString::new(path)?;
1428        let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
1429
1430        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1431
1432        tracing::debug!(?path, "Loaded model");
1433        Ok(LlamaModel { model })
1434    }
1435
1436    /// Load a model from multiple split files.
1437    ///
1438    /// This function loads a model that has been split across multiple files. This is useful for
1439    /// very large models that exceed filesystem limitations or need to be distributed across
1440    /// multiple storage devices.
1441    ///
1442    /// # Arguments
1443    ///
1444    /// * `paths` - A slice of paths to the split model files
1445    /// * `params` - The model parameters
1446    ///
1447    /// # Errors
1448    ///
1449    /// Returns an error if:
1450    /// - Any of the paths cannot be converted to a C string
1451    /// - The model fails to load from the splits
1452    /// - Any path doesn't exist or isn't accessible
1453    ///
1454    /// # Example
1455    ///
1456    /// ```no_run
1457    /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1458    /// use llama_cpp_4::llama_backend::LlamaBackend;
1459    ///
1460    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1461    /// let backend = LlamaBackend::init()?;
1462    /// let params = LlamaModelParams::default();
1463    ///
1464    /// let paths = vec![
1465    ///     "model-00001-of-00003.gguf",
1466    ///     "model-00002-of-00003.gguf",
1467    ///     "model-00003-of-00003.gguf",
1468    /// ];
1469    ///
1470    /// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
1471    /// # Ok(())
1472    /// # }
1473    /// ```
1474    #[tracing::instrument(skip_all)]
1475    pub fn load_from_splits(
1476        _: &LlamaBackend,
1477        paths: &[impl AsRef<Path>],
1478        params: &LlamaModelParams,
1479    ) -> Result<Self, LlamaModelLoadError> {
1480        // Convert paths to C strings
1481        let c_strings: Vec<CString> = paths
1482            .iter()
1483            .map(|p| {
1484                let path = p.as_ref();
1485                debug_assert!(path.exists(), "{} does not exist", path.display());
1486                let path_str = path
1487                    .to_str()
1488                    .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1489                CString::new(path_str).map_err(LlamaModelLoadError::from)
1490            })
1491            .collect::<Result<Vec<_>, _>>()?;
1492
1493        // Create array of pointers to C strings
1494        let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1495
1496        // Load the model from splits
1497        let llama_model = unsafe {
1498            llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1499        };
1500
1501        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1502
1503        tracing::debug!("Loaded model from {} splits", paths.len());
1504        Ok(LlamaModel { model })
1505    }
1506
1507    /// Load a model from a `FILE` pointer.
1508    ///
1509    /// # Safety
1510    ///
1511    /// The `file` pointer must be a valid, open `FILE*`.
1512    ///
1513    /// # Errors
1514    ///
1515    /// Returns an error if the model cannot be loaded.
1516    pub unsafe fn load_from_file_ptr(
1517        file: *mut llama_cpp_sys_4::FILE,
1518        params: &LlamaModelParams,
1519    ) -> Result<Self, LlamaModelLoadError> {
1520        let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1521        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1522        Ok(LlamaModel { model })
1523    }
1524
1525    /// Initialize a model from user-provided data.
1526    ///
1527    /// # Safety
1528    ///
1529    /// The metadata, callback, and user data must be valid.
1530    ///
1531    /// # Errors
1532    ///
1533    /// Returns an error if the model cannot be initialized.
1534    pub unsafe fn init_from_user(
1535        metadata: *mut llama_cpp_sys_4::gguf_context,
1536        set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1537        set_tensor_data_ud: *mut std::ffi::c_void,
1538        params: &LlamaModelParams,
1539    ) -> Result<Self, LlamaModelLoadError> {
1540        let model = llama_cpp_sys_4::llama_model_init_from_user(
1541            metadata,
1542            set_tensor_data,
1543            set_tensor_data_ud,
1544            params.params,
1545        );
1546        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1547        Ok(LlamaModel { model })
1548    }
1549
1550    /// Save the model to a file.
1551    ///
1552    /// # Panics
1553    ///
1554    /// Panics if the path contains null bytes.
1555    pub fn save_to_file(&self, path: impl AsRef<Path>) {
1556        let path = path.as_ref();
1557        let path_str = path.to_str().expect("path is not valid UTF-8");
1558        let c_path = CString::new(path_str).expect("path contains null bytes");
1559        unsafe {
1560            llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1561        }
1562    }
1563
1564    /// Get the list of built-in chat templates.
1565    ///
1566    /// Returns the names of all chat templates that are built into llama.cpp.
1567    ///
1568    /// # Panics
1569    ///
1570    /// Panics if any template name is not valid UTF-8.
1571    #[allow(clippy::cast_sign_loss)]
1572    #[must_use]
1573    pub fn chat_builtin_templates() -> Vec<String> {
1574        // First call to get count
1575        let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1576        if count <= 0 {
1577            return Vec::new();
1578        }
1579        let count = count as usize;
1580        let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1581        unsafe {
1582            llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1583        }
1584        ptrs.iter()
1585            .map(|&p| {
1586                let cstr = unsafe { CStr::from_ptr(p) };
1587                cstr.to_str()
1588                    .expect("template name is not valid UTF-8")
1589                    .to_owned()
1590            })
1591            .collect()
1592    }
1593
1594    /// Initializes a lora adapter from a file.
1595    ///
1596    /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1597    /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1598    /// to the model for improved performance on specialized tasks.
1599    ///
1600    /// # Errors
1601    ///
1602    /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1603    ///
1604    /// # Example
1605    ///
1606    /// ```no_run
1607    /// use llama_cpp_4::model::LlamaModel;
1608    /// use llama_cpp_4::model::params::LlamaModelParams;
1609    /// use llama_cpp_4::llama_backend::LlamaBackend;
1610    ///
1611    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1612    /// let backend = LlamaBackend::init()?;
1613    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1614    /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1615    /// # Ok(())
1616    /// # }
1617    /// ```
1618    pub fn lora_adapter_init(
1619        &self,
1620        path: impl AsRef<Path>,
1621    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1622        let path = path.as_ref();
1623        debug_assert!(
1624            Path::new(path).exists(),
1625            "{} does not exist",
1626            path.display()
1627        );
1628
1629        let path = path
1630            .to_str()
1631            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1632                path.to_path_buf(),
1633            ))?;
1634
1635        let cstr = CString::new(path)?;
1636        let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1637
1638        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1639
1640        tracing::debug!(?path, "Initialized lora adapter");
1641        Ok(LlamaLoraAdapter {
1642            lora_adapter: adapter,
1643        })
1644    }
1645
1646    /// Create a new context from this model.
1647    ///
1648    /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1649    /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1650    /// control over model parameters for a specific task.
1651    ///
1652    /// # Errors
1653    ///
1654    /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1655    ///   for more detailed error descriptions.
1656    ///
1657    /// # Example
1658    ///
1659    /// ```no_run
1660    /// use llama_cpp_4::model::LlamaModel;
1661    /// use llama_cpp_4::model::params::LlamaModelParams;
1662    /// use llama_cpp_4::context::params::LlamaContextParams;
1663    /// use llama_cpp_4::llama_backend::LlamaBackend;
1664    ///
1665    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1666    /// let backend = LlamaBackend::init()?;
1667    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1668    /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1669    /// # Ok(())
1670    /// # }
1671    /// ```
1672    #[allow(clippy::needless_pass_by_value)]
1673    pub fn new_context(
1674        &self,
1675        _: &LlamaBackend,
1676        params: LlamaContextParams,
1677    ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1678        // Apply TurboQuant attn-rotation preference before the KV cache is
1679        // initialised inside llama_init_from_model.
1680        let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1681        if params.attn_rot_disabled {
1682            // SAFETY: we restore the value right after the call.
1683            #[allow(unused_unsafe)]
1684            unsafe {
1685                std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1686            }
1687        } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1688            // params say "enabled" – only clear if it was previously unset
1689            // (respect explicit user env var).
1690        }
1691
1692        let context_params = params.context_params;
1693        let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
1694
1695        // Restore the env-var to its previous state.
1696        #[allow(unused_unsafe)]
1697        match prev_rot_var {
1698            Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1699            None if params.attn_rot_disabled => unsafe {
1700                std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1701            },
1702            None => {}
1703        }
1704
1705        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1706        Ok(LlamaContext::new(self, context, params.embeddings()))
1707    }
1708
1709    /// Apply the model's chat template to a sequence of messages.
1710    ///
1711    /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1712    /// template determines the structure or style of conversation between the system and user, such as token formatting,
1713    /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1714    /// is provided, the default template used by `llama.cpp` will be applied.
1715    ///
1716    /// For more information on supported templates, visit:
1717    /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1718    ///
1719    /// # Arguments
1720    ///
1721    /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1722    /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1723    /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1724    ///
1725    /// # Errors
1726    ///
1727    /// There are several possible points of failure when applying the chat template:
1728    /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1729    /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1730    ///
1731    /// # Example
1732    ///
1733    /// ```no_run
1734    /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1735    /// use llama_cpp_4::model::params::LlamaModelParams;
1736    /// use llama_cpp_4::llama_backend::LlamaBackend;
1737    ///
1738    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1739    /// let backend = LlamaBackend::init()?;
1740    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1741    /// let chat = vec![
1742    ///     LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1743    ///     LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1744    /// ];
1745    /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1746    /// # Ok(())
1747    /// # }
1748    /// ```
1749    ///
1750    /// # Notes
1751    ///
1752    /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1753    /// # Panics
1754    ///
1755    /// Panics if the buffer length exceeds `i32::MAX`.
1756    #[tracing::instrument(skip_all)]
1757    pub fn apply_chat_template(
1758        &self,
1759        tmpl: Option<&str>,
1760        chat: &[LlamaChatMessage],
1761        add_ass: bool,
1762    ) -> Result<String, ApplyChatTemplateError> {
1763        // Compute raw message byte total from the original LlamaChatMessage vec
1764        // *before* we shadow `chat` with the sys-type vec below.
1765        let message_length = chat.iter().fold(0usize, |acc, c| {
1766            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1767        });
1768
1769        // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1770        let chat_sys: Vec<llama_chat_message> = chat
1771            .iter()
1772            .map(|c| llama_chat_message {
1773                role: c.role.as_ptr(),
1774                content: c.content.as_ptr(),
1775            })
1776            .collect();
1777
1778        // Set the tmpl pointer.
1779        let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1780        let tmpl_ptr = tmpl_cstring
1781            .as_ref()
1782            .map_or(std::ptr::null(), |s| s.as_ptr());
1783
1784        // `message_length * 4` is far too small for models whose built-in chat
1785        // template adds a long default system prompt (e.g. Qwen3.5 prepends
1786        // ~80+ chars of markup even for a one-word user message).  Start with
1787        // at least 4 KiB so short inputs like "hi" always have room.
1788        //
1789        // `llama_chat_apply_template` returns the number of bytes it *actually*
1790        // needed when the buffer was too small, so we retry exactly once with
1791        // that precise size rather than giving up immediately.
1792        let mut buf_size = message_length.saturating_mul(4).max(4096);
1793
1794        for _ in 0..2 {
1795            // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1796            let mut buff = vec![0u8; buf_size];
1797            let res = unsafe {
1798                llama_chat_apply_template(
1799                    tmpl_ptr,
1800                    chat_sys.as_ptr(),
1801                    chat_sys.len(),
1802                    add_ass,
1803                    buff.as_mut_ptr().cast(),
1804                    i32::try_from(buff.len()).expect("buffer length fits in i32"),
1805                )
1806            };
1807
1808            if res < 0 {
1809                return Err(ApplyChatTemplateError::BuffSizeError);
1810            }
1811
1812            #[allow(clippy::cast_sign_loss)]
1813            let needed = res as usize;
1814            if needed > buf_size {
1815                // Buffer was too small — retry with the exact size llama.cpp reported.
1816                buf_size = needed + 1; // +1 for null terminator
1817                continue;
1818            }
1819
1820            // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1821            // into `buff`; `needed` bytes were used.
1822            let formatted = unsafe {
1823                CStr::from_ptr(buff.as_ptr().cast())
1824                    .to_string_lossy()
1825                    .into_owned()
1826            };
1827            return Ok(formatted);
1828        }
1829
1830        Err(ApplyChatTemplateError::BuffSizeError)
1831    }
1832
1833    /// Build a split GGUF file path for a specific chunk.
1834    ///
1835    /// This utility function creates the standardized filename for a split model chunk
1836    /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1837    ///
1838    /// # Arguments
1839    ///
1840    /// * `path_prefix` - The base path and filename prefix
1841    /// * `split_no` - The split number (1-indexed)
1842    /// * `split_count` - The total number of splits
1843    ///
1844    /// # Returns
1845    ///
1846    /// Returns the formatted split path as a String
1847    ///
1848    /// # Example
1849    ///
1850    /// ```
1851    /// use llama_cpp_4::model::LlamaModel;
1852    ///
1853    /// let path = LlamaModel::split_path("/models/llama", 1, 4);
1854    /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1855    /// ```
1856    ///
1857    /// # Panics
1858    ///
1859    /// Panics if the path prefix contains a null byte.
1860    #[must_use]
1861    pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1862        let mut buffer = vec![0u8; 1024];
1863        let len = unsafe {
1864            llama_split_path(
1865                buffer.as_mut_ptr().cast::<c_char>(),
1866                buffer.len(),
1867                CString::new(path_prefix).unwrap().as_ptr(),
1868                split_no,
1869                split_count,
1870            )
1871        };
1872
1873        let len = usize::try_from(len).expect("split_path length fits in usize");
1874        buffer.truncate(len);
1875        String::from_utf8(buffer).unwrap_or_default()
1876    }
1877
1878    /// Extract the path prefix from a split filename.
1879    ///
1880    /// This function extracts the base path prefix from a split model filename,
1881    /// but only if the `split_no` and `split_count` match the pattern in the filename.
1882    ///
1883    /// # Arguments
1884    ///
1885    /// * `split_path` - The full path to the split file
1886    /// * `split_no` - The expected split number
1887    /// * `split_count` - The expected total number of splits
1888    ///
1889    /// # Returns
1890    ///
1891    /// Returns the path prefix if the pattern matches, or None if it doesn't
1892    ///
1893    /// # Example
1894    ///
1895    /// ```
1896    /// use llama_cpp_4::model::LlamaModel;
1897    ///
1898    /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
1899    /// assert_eq!(prefix, Some("/models/llama".to_string()));
1900    /// ```
1901    ///
1902    /// # Panics
1903    ///
1904    /// Panics if the split path contains a null byte.
1905    #[must_use]
1906    pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1907        let mut buffer = vec![0u8; 1024];
1908        let len = unsafe {
1909            llama_split_prefix(
1910                buffer.as_mut_ptr().cast::<c_char>(),
1911                buffer.len(),
1912                CString::new(split_path).unwrap().as_ptr(),
1913                split_no,
1914                split_count,
1915            )
1916        };
1917
1918        if len > 0 {
1919            let len = usize::try_from(len).expect("split_prefix length fits in usize");
1920            buffer.truncate(len);
1921            String::from_utf8(buffer).ok()
1922        } else {
1923            None
1924        }
1925    }
1926}
1927
1928#[allow(clippy::cast_precision_loss)]
1929impl fmt::Display for LlamaModel {
1930    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1931        let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1932        write!(
1933            f,
1934            "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1935            layers = self.n_layer(),
1936            heads = self.n_head(),
1937            embd = self.n_embd(),
1938            params = self.n_params(),
1939            size = self.model_size() as f64 / (1024.0 * 1024.0),
1940        )
1941    }
1942}
1943
1944impl Drop for LlamaModel {
1945    fn drop(&mut self) {
1946        unsafe { llama_model_free(self.model.as_ptr()) }
1947    }
1948}
1949
1950/// Defines the possible types of vocabulary used by the model.
1951///
1952/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1953/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1954///
1955/// # Variants
1956///
1957/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1958/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1959///
1960/// # Example
1961///
1962/// ```no_run
1963/// use llama_cpp_4::model::VocabType;
1964///
1965/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1966/// let vocab_type = VocabType::BPE;
1967/// match vocab_type {
1968///     VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1969///     VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1970/// }
1971/// # Ok(())
1972/// # }
1973/// ```
1974#[repr(u32)]
1975#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1976pub enum VocabType {
1977    /// Byte Pair Encoding
1978    BPE = LLAMA_VOCAB_TYPE_BPE as _,
1979    /// Sentence Piece Tokenizer
1980    SPM = LLAMA_VOCAB_TYPE_SPM as _,
1981}
1982
1983/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1984///
1985/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1986///
1987/// # Variants
1988///
1989/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1990///
1991/// # Example
1992///
1993/// ```no_run
1994/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1995///
1996/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1997/// let invalid_value = 999; // Not a valid vocabulary type
1998/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1999/// println!("Error: {}", error);
2000/// # Ok(())
2001/// # }
2002/// ```
2003#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2004pub enum LlamaTokenTypeFromIntError {
2005    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2006    #[error("Unknown Value {0}")]
2007    UnknownValue(llama_vocab_type),
2008}
2009
2010impl TryFrom<llama_vocab_type> for VocabType {
2011    type Error = LlamaTokenTypeFromIntError;
2012
2013    fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2014        match value {
2015            LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2016            LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2017            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2018        }
2019    }
2020}