Skip to main content

llama_cpp_4/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11    llama_adapter_lora, llama_adapter_lora_init, llama_add_bos_token, llama_add_eos_token,
12    llama_chat_apply_template, llama_chat_builtin_templates, llama_chat_message,
13    llama_detokenize, llama_free_model, llama_load_model_from_file, llama_model,
14    llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15    llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
16    llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
17    llama_model_load_from_splits, llama_model_meta_count, llama_model_meta_key_by_index,
18    llama_model_meta_val_str, llama_model_meta_val_str_by_index, llama_model_n_cls_out,
19    llama_model_n_embd_inp, llama_model_n_embd_out, llama_model_n_head_kv, llama_model_n_params,
20    llama_model_n_swa, llama_model_rope_freq_scale_train, llama_model_rope_type,
21    llama_model_save_to_file, llama_model_size, llama_n_ctx_train, llama_n_embd, llama_n_head,
22    llama_n_layer, llama_n_vocab, llama_new_context_with_model, llama_split_path,
23    llama_split_prefix, llama_token_bos, llama_token_cls, llama_token_eos, llama_token_eot,
24    llama_token_fim_mid, llama_token_fim_pad, llama_token_fim_pre, llama_token_fim_rep,
25    llama_token_fim_sep, llama_token_fim_suf, llama_token_get_attr, llama_token_get_score,
26    llama_token_get_text, llama_token_is_control, llama_token_is_eog, llama_token_nl,
27    llama_token_pad, llama_token_sep, llama_token_to_piece, llama_tokenize, llama_vocab,
28    llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
29};
30
31use crate::context::params::LlamaContextParams;
32use crate::context::LlamaContext;
33use crate::llama_backend::LlamaBackend;
34use crate::model::params::LlamaModelParams;
35use crate::token::LlamaToken;
36use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
37use crate::{
38    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
39    LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
40    TokenToStringError,
41};
42
43pub mod params;
44
45/// A safe wrapper around `llama_model`.
46#[derive(Debug)]
47#[repr(transparent)]
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaModel {
50    pub(crate) model: NonNull<llama_model>,
51}
52
53/// A safe wrapper around `llama_vocab`.
54#[derive(Debug)]
55#[repr(transparent)]
56#[allow(clippy::module_name_repetitions)]
57pub struct LlamaVocab {
58    pub(crate) vocab: NonNull<llama_vocab>,
59}
60
61impl LlamaVocab {
62    /// Get the number of tokens in the vocabulary.
63    #[must_use]
64    pub fn n_tokens(&self) -> i32 {
65        unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
66    }
67
68    /// Get the vocabulary type.
69    #[must_use]
70    pub fn vocab_type(&self) -> u32 {
71        unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()).try_into().unwrap() }
72    }
73
74    /// Get the BOS token.
75    #[must_use]
76    pub fn bos(&self) -> LlamaToken {
77        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
78    }
79
80    /// Get the EOS token.
81    #[must_use]
82    pub fn eos(&self) -> LlamaToken {
83        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
84    }
85
86    /// Get the EOT (end of turn) token.
87    #[must_use]
88    pub fn eot(&self) -> LlamaToken {
89        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
90    }
91
92    /// Get the CLS (classification) token.
93    #[must_use]
94    pub fn cls(&self) -> LlamaToken {
95        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
96    }
97
98    /// Get the SEP (separator) token.
99    #[must_use]
100    pub fn sep(&self) -> LlamaToken {
101        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
102    }
103
104    /// Get the NL (newline) token.
105    #[must_use]
106    pub fn nl(&self) -> LlamaToken {
107        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
108    }
109
110    /// Get the PAD (padding) token.
111    #[must_use]
112    pub fn pad(&self) -> LlamaToken {
113        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
114    }
115
116    /// Get the FIM prefix token.
117    #[must_use]
118    pub fn fim_pre(&self) -> LlamaToken {
119        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
120    }
121
122    /// Get the FIM suffix token.
123    #[must_use]
124    pub fn fim_suf(&self) -> LlamaToken {
125        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
126    }
127
128    /// Get the FIM middle token.
129    #[must_use]
130    pub fn fim_mid(&self) -> LlamaToken {
131        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
132    }
133
134    /// Get the FIM padding token.
135    #[must_use]
136    pub fn fim_pad(&self) -> LlamaToken {
137        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
138    }
139
140    /// Get the FIM repository token.
141    #[must_use]
142    pub fn fim_rep(&self) -> LlamaToken {
143        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
144    }
145
146    /// Get the FIM separator token.
147    #[must_use]
148    pub fn fim_sep(&self) -> LlamaToken {
149        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
150    }
151
152    /// Check whether BOS should be added.
153    #[must_use]
154    pub fn get_add_bos(&self) -> bool {
155        unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
156    }
157
158    /// Check whether EOS should be added.
159    #[must_use]
160    pub fn get_add_eos(&self) -> bool {
161        unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
162    }
163
164    /// Check whether SEP should be added.
165    #[must_use]
166    pub fn get_add_sep(&self) -> bool {
167        unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
168    }
169
170    /// Get the text representation of a token.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the text pointer is null or not valid UTF-8.
175    pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
176        let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
177        if ptr.is_null() {
178            return Err(StringFromModelError::ReturnedError(-1));
179        }
180        let cstr = unsafe { CStr::from_ptr(ptr) };
181        cstr.to_str().map_err(StringFromModelError::Utf8Error)
182    }
183
184    /// Get the score of a token.
185    #[must_use]
186    pub fn get_score(&self, token: LlamaToken) -> f32 {
187        unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
188    }
189
190    /// Get the attributes of a token.
191    #[must_use]
192    pub fn get_attr(&self, token: LlamaToken) -> u32 {
193        unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0).try_into().unwrap() }
194    }
195
196    /// Check if a token is a control token.
197    #[must_use]
198    pub fn is_control(&self, token: LlamaToken) -> bool {
199        unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
200    }
201
202    /// Check if a token is an end-of-generation token.
203    #[must_use]
204    pub fn is_eog(&self, token: LlamaToken) -> bool {
205        unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
206    }
207
208    /// Get the token mask value for the vocabulary.
209    #[must_use]
210    pub fn mask(&self) -> LlamaToken {
211        LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
212    }
213}
214
215/// A safe wrapper around `llama_adapter_lora`.
216#[derive(Debug)]
217#[repr(transparent)]
218#[allow(clippy::module_name_repetitions)]
219pub struct LlamaLoraAdapter {
220    pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
221}
222
223impl LlamaLoraAdapter {
224    /// Get the number of metadata key-value pairs in the adapter.
225    #[must_use]
226    pub fn meta_count(&self) -> i32 {
227        unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
228    }
229
230    /// Get a metadata key by index.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the index is out of range or the key is not valid UTF-8.
235    #[allow(clippy::cast_sign_loss)]
236    pub fn meta_key_by_index(
237        &self,
238        index: i32,
239        buf_size: usize,
240    ) -> Result<String, StringFromModelError> {
241        let mut buf = vec![0u8; buf_size];
242        let ret = unsafe {
243            llama_cpp_sys_4::llama_adapter_meta_key_by_index(
244                self.lora_adapter.as_ptr(),
245                index,
246                buf.as_mut_ptr().cast::<c_char>(),
247                buf_size,
248            )
249        };
250        if ret < 0 {
251            return Err(StringFromModelError::ReturnedError(ret));
252        }
253        let len = ret as usize;
254        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
255        Ok(s.to_owned())
256    }
257
258    /// Get a metadata value by key name.
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if the key is not found or the value is not valid UTF-8.
263    #[allow(clippy::cast_sign_loss)]
264    pub fn meta_val_str(
265        &self,
266        key: &str,
267        buf_size: usize,
268    ) -> Result<String, StringFromModelError> {
269        let c_key =
270            CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
271        let mut buf = vec![0u8; buf_size];
272        let ret = unsafe {
273            llama_cpp_sys_4::llama_adapter_meta_val_str(
274                self.lora_adapter.as_ptr(),
275                c_key.as_ptr(),
276                buf.as_mut_ptr().cast::<c_char>(),
277                buf_size,
278            )
279        };
280        if ret < 0 {
281            return Err(StringFromModelError::ReturnedError(ret));
282        }
283        let len = ret as usize;
284        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
285        Ok(s.to_owned())
286    }
287
288    /// Get a metadata value by index.
289    ///
290    /// # Errors
291    ///
292    /// Returns an error if the index is out of range or the value is not valid UTF-8.
293    #[allow(clippy::cast_sign_loss)]
294    pub fn meta_val_str_by_index(
295        &self,
296        index: i32,
297        buf_size: usize,
298    ) -> Result<String, StringFromModelError> {
299        let mut buf = vec![0u8; buf_size];
300        let ret = unsafe {
301            llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
302                self.lora_adapter.as_ptr(),
303                index,
304                buf.as_mut_ptr().cast::<c_char>(),
305                buf_size,
306            )
307        };
308        if ret < 0 {
309            return Err(StringFromModelError::ReturnedError(ret));
310        }
311        let len = ret as usize;
312        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
313        Ok(s.to_owned())
314    }
315
316    /// Get all metadata as a list of `(key, value)` pairs.
317    ///
318    /// # Errors
319    ///
320    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
321    #[allow(clippy::cast_sign_loss)]
322    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
323        let count = self.meta_count();
324        let mut result = Vec::with_capacity(count as usize);
325        for i in 0..count {
326            let key = self.meta_key_by_index(i, 256)?;
327            let val = self.meta_val_str_by_index(i, 4096)?;
328            result.push((key, val));
329        }
330        Ok(result)
331    }
332
333    /// Get the number of invocation tokens for this adapter.
334    #[must_use]
335    pub fn n_invocation_tokens(&self) -> u64 {
336        unsafe {
337            llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(
338                self.lora_adapter.as_ptr(),
339            )
340        }
341    }
342
343    /// Get the invocation tokens for this adapter.
344    ///
345    /// Returns an empty slice if there are no invocation tokens.
346    #[must_use]
347    #[allow(clippy::cast_possible_truncation)]
348    pub fn invocation_tokens(&self) -> &[LlamaToken] {
349        let n = self.n_invocation_tokens() as usize;
350        if n == 0 {
351            return &[];
352        }
353        let ptr = unsafe {
354            llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(
355                self.lora_adapter.as_ptr(),
356            )
357        };
358        if ptr.is_null() {
359            return &[];
360        }
361        // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
362        unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
363    }
364}
365
366impl Drop for LlamaLoraAdapter {
367    fn drop(&mut self) {
368        unsafe {
369            llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
370        }
371    }
372}
373
374/// A Safe wrapper around `llama_chat_message`
375#[derive(Debug, Eq, PartialEq, Clone)]
376pub struct LlamaChatMessage {
377    role: CString,
378    content: CString,
379}
380
381impl LlamaChatMessage {
382    /// Create a new `LlamaChatMessage`.
383    ///
384    /// # Errors
385    ///
386    /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
387    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
388        Ok(Self {
389            role: CString::new(role)?,
390            content: CString::new(content)?,
391        })
392    }
393}
394
395/// How to determine if we should prepend a bos token to tokens
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
397pub enum AddBos {
398    /// Add the beginning of stream token to the start of the string.
399    Always,
400    /// Do not add the beginning of stream token to the start of the string.
401    Never,
402}
403
404/// How to determine if we should tokenize special tokens
405#[derive(Debug, Clone, Copy, PartialEq, Eq)]
406pub enum Special {
407    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
408    Tokenize,
409    /// Treat special and/or control tokens as plaintext.
410    Plaintext,
411}
412
413unsafe impl Send for LlamaModel {}
414
415unsafe impl Sync for LlamaModel {}
416
417impl LlamaModel {
418    /// Retrieves the vocabulary associated with the current Llama model.
419    ///
420    /// This method fetches the vocabulary from the underlying model using an unsafe
421    /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
422    /// the vocabulary data, which is wrapped in a `NonNull` for safety.
423    ///
424    /// # Safety
425    /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
426    /// which is assumed to return a valid pointer to the vocabulary. The caller should
427    /// ensure that the model object is properly initialized and valid before calling
428    /// this method, as dereferencing invalid pointers can lead to undefined behavior.
429    ///
430    /// # Returns
431    /// A `LlamaVocab` struct containing the vocabulary of the model.
432    ///
433    /// # Panics
434    ///
435    /// Panics if the underlying C function returns a null pointer.
436    ///
437    /// # Example
438    /// ```rust,ignore
439    /// let vocab = model.get_vocab();
440    /// ```
441    #[must_use]
442    pub fn get_vocab(&self) -> LlamaVocab {
443        let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
444
445        LlamaVocab {
446            vocab: NonNull::new(llama_vocab).unwrap(),
447        }
448    }
449    /// Get the number of tokens the model was trained on.
450    ///
451    /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
452    ///
453    /// # Panics
454    ///
455    /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
456    /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
457    /// which is almost certainly positive.
458    #[must_use]
459    pub fn n_ctx_train(&self) -> u32 {
460        let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
461        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
462    }
463
464    /// Get all tokens in the model.
465    ///
466    /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
467    /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
468    ///
469    /// # Parameters
470    ///
471    /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
472    pub fn tokens(
473        &self,
474        special: Special,
475    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
476        (0..self.n_vocab())
477            .map(LlamaToken::new)
478            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
479    }
480
481    /// Get the beginning of stream token.
482    ///
483    /// This function returns the token that represents the beginning of a stream (BOS token).
484    #[must_use]
485    pub fn token_bos(&self) -> LlamaToken {
486        let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
487        LlamaToken(token)
488    }
489
490    /// Get the end of stream token.
491    ///
492    /// This function returns the token that represents the end of a stream (EOS token).
493    #[must_use]
494    pub fn token_eos(&self) -> LlamaToken {
495        let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
496        LlamaToken(token)
497    }
498
499    /// Get the newline token.
500    ///
501    /// This function returns the token that represents a newline character.
502    #[must_use]
503    pub fn token_nl(&self) -> LlamaToken {
504        let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
505        LlamaToken(token)
506    }
507
508    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
509    ///
510    /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
511    /// such as EOS or other special tokens.
512    ///
513    /// # Parameters
514    ///
515    /// - `token`: The `LlamaToken` to check.
516    ///
517    /// # Returns
518    ///
519    /// - `true` if the token is an end-of-generation token, otherwise `false`.
520    #[must_use]
521    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
522        unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
523    }
524
525    /// Get the classification token.
526    #[must_use]
527    pub fn token_cls(&self) -> LlamaToken {
528        let token = unsafe { llama_token_cls(self.get_vocab().vocab.as_ref()) };
529        LlamaToken(token)
530    }
531
532    /// Get the end-of-turn token.
533    #[must_use]
534    pub fn token_eot(&self) -> LlamaToken {
535        let token = unsafe { llama_token_eot(self.get_vocab().vocab.as_ref()) };
536        LlamaToken(token)
537    }
538
539    /// Get the padding token.
540    #[must_use]
541    pub fn token_pad(&self) -> LlamaToken {
542        let token = unsafe { llama_token_pad(self.get_vocab().vocab.as_ref()) };
543        LlamaToken(token)
544    }
545
546    /// Get the separator token.
547    #[must_use]
548    pub fn token_sep(&self) -> LlamaToken {
549        let token = unsafe { llama_token_sep(self.get_vocab().vocab.as_ref()) };
550        LlamaToken(token)
551    }
552
553    /// Get the fill-in-the-middle prefix token.
554    #[must_use]
555    pub fn token_fim_pre(&self) -> LlamaToken {
556        let token = unsafe { llama_token_fim_pre(self.get_vocab().vocab.as_ref()) };
557        LlamaToken(token)
558    }
559
560    /// Get the fill-in-the-middle suffix token.
561    #[must_use]
562    pub fn token_fim_suf(&self) -> LlamaToken {
563        let token = unsafe { llama_token_fim_suf(self.get_vocab().vocab.as_ref()) };
564        LlamaToken(token)
565    }
566
567    /// Get the fill-in-the-middle middle token.
568    #[must_use]
569    pub fn token_fim_mid(&self) -> LlamaToken {
570        let token = unsafe { llama_token_fim_mid(self.get_vocab().vocab.as_ref()) };
571        LlamaToken(token)
572    }
573
574    /// Get the fill-in-the-middle padding token.
575    #[must_use]
576    pub fn token_fim_pad(&self) -> LlamaToken {
577        let token = unsafe { llama_token_fim_pad(self.get_vocab().vocab.as_ref()) };
578        LlamaToken(token)
579    }
580
581    /// Get the fill-in-the-middle repository token.
582    #[must_use]
583    pub fn token_fim_rep(&self) -> LlamaToken {
584        let token = unsafe { llama_token_fim_rep(self.get_vocab().vocab.as_ref()) };
585        LlamaToken(token)
586    }
587
588    /// Get the fill-in-the-middle separator token.
589    #[must_use]
590    pub fn token_fim_sep(&self) -> LlamaToken {
591        let token = unsafe { llama_token_fim_sep(self.get_vocab().vocab.as_ref()) };
592        LlamaToken(token)
593    }
594
595    /// Check if a token is a control token.
596    #[must_use]
597    pub fn token_is_control(&self, token: LlamaToken) -> bool {
598        unsafe { llama_token_is_control(self.get_vocab().vocab.as_ref(), token.0) }
599    }
600
601    /// Get the score of a token.
602    #[must_use]
603    pub fn token_get_score(&self, token: LlamaToken) -> f32 {
604        unsafe { llama_token_get_score(self.get_vocab().vocab.as_ref(), token.0) }
605    }
606
607    /// Get the raw text of a token.
608    ///
609    /// # Errors
610    ///
611    /// Returns an error if the token text is null or not valid UTF-8.
612    pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
613        let ptr =
614            unsafe { llama_token_get_text(self.get_vocab().vocab.as_ref(), token.0) };
615        if ptr.is_null() {
616            return Err(StringFromModelError::ReturnedError(-1));
617        }
618        let cstr = unsafe { CStr::from_ptr(ptr) };
619        cstr.to_str()
620            .map_err(StringFromModelError::Utf8Error)
621    }
622
623    /// Check if a BOS token should be added when tokenizing.
624    #[must_use]
625    pub fn add_bos_token(&self) -> bool {
626        unsafe { llama_add_bos_token(self.get_vocab().vocab.as_ref()) }
627    }
628
629    /// Check if an EOS token should be added when tokenizing.
630    #[must_use]
631    pub fn add_eos_token(&self) -> bool {
632        unsafe { llama_add_eos_token(self.get_vocab().vocab.as_ref()) }
633    }
634
635    /// Get the decoder start token.
636    ///
637    /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
638    /// of a sequence generation).
639    #[must_use]
640    pub fn decode_start_token(&self) -> LlamaToken {
641        let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
642        LlamaToken(token)
643    }
644
645    /// Convert a single token to a string.
646    ///
647    /// This function converts a `LlamaToken` into its string representation.
648    ///
649    /// # Errors
650    ///
651    /// This function returns an error if the token cannot be converted to a string. For more details, refer to
652    /// [`TokenToStringError`].
653    ///
654    /// # Parameters
655    ///
656    /// - `token`: The `LlamaToken` to convert.
657    /// - `special`: The `Special` value used to handle special tokens.
658    pub fn token_to_str(
659        &self,
660        token: LlamaToken,
661        special: Special,
662    ) -> Result<String, TokenToStringError> {
663        self.token_to_str_with_size(token, 32, special)
664    }
665
666    /// Convert a single token to bytes.
667    ///
668    /// This function converts a `LlamaToken` into a byte representation.
669    ///
670    /// # Errors
671    ///
672    /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
673    /// [`TokenToStringError`].
674    ///
675    /// # Parameters
676    ///
677    /// - `token`: The `LlamaToken` to convert.
678    /// - `special`: The `Special` value used to handle special tokens.
679    pub fn token_to_bytes(
680        &self,
681        token: LlamaToken,
682        special: Special,
683    ) -> Result<Vec<u8>, TokenToStringError> {
684        self.token_to_bytes_with_size(token, 32, special, None)
685    }
686
687    /// Convert a vector of tokens to a single string.
688    ///
689    /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
690    /// string representations.
691    ///
692    /// # Errors
693    ///
694    /// This function returns an error if any token cannot be converted to a string. For more details, refer to
695    /// [`TokenToStringError`].
696    ///
697    /// # Parameters
698    ///
699    /// - `tokens`: A slice of `LlamaToken`s to convert.
700    /// - `special`: The `Special` value used to handle special tokens.
701    pub fn tokens_to_str(
702        &self,
703        tokens: &[LlamaToken],
704        special: Special,
705    ) -> Result<String, TokenToStringError> {
706        let mut builder = String::with_capacity(tokens.len() * 4);
707        for str in tokens
708            .iter()
709            .copied()
710            .map(|t| self.token_to_str(t, special))
711        {
712            builder += &str?;
713        }
714        Ok(builder)
715    }
716
717    /// Convert a string to a vector of tokens.
718    ///
719    /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
720    /// and return the corresponding tokens.
721    ///
722    /// # Errors
723    ///
724    /// - This function will return an error if the input string contains a null byte.
725    ///
726    /// # Panics
727    ///
728    /// - This function will panic if the number of tokens exceeds `usize::MAX`.
729    ///
730    /// # Example
731    ///
732    /// ```no_run
733    /// use llama_cpp_4::model::LlamaModel;
734    ///
735    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
736    /// use std::path::Path;
737    /// use llama_cpp_4::model::AddBos;
738    /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
739    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
740    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
741    /// # Ok(())
742    /// # }
743    /// ```
744    pub fn str_to_token(
745        &self,
746        str: &str,
747        add_bos: AddBos,
748    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
749        let add_bos = match add_bos {
750            AddBos::Always => true,
751            AddBos::Never => false,
752        };
753
754        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
755        let mut buffer = Vec::with_capacity(tokens_estimation);
756
757        let c_string = CString::new(str)?;
758        let buffer_capacity =
759            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
760
761        let size = unsafe {
762            llama_tokenize(
763                self.get_vocab().vocab.as_ref(),
764                c_string.as_ptr(),
765                c_int::try_from(c_string.as_bytes().len())?,
766                buffer.as_mut_ptr(),
767                buffer_capacity,
768                add_bos,
769                true,
770            )
771        };
772
773        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
774        // as a result - size is guaranteed to be positive here.
775        let size = if size.is_negative() {
776            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
777            unsafe {
778                llama_tokenize(
779                    self.get_vocab().vocab.as_ref(),
780                    c_string.as_ptr(),
781                    c_int::try_from(c_string.as_bytes().len())?,
782                    buffer.as_mut_ptr(),
783                    -size,
784                    add_bos,
785                    true,
786                )
787            }
788        } else {
789            size
790        };
791
792        let size = usize::try_from(size).expect("size is positive and usize ");
793
794        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
795        unsafe { buffer.set_len(size) }
796        Ok(buffer.into_iter().map(LlamaToken).collect())
797    }
798
799    /// Get the type of a token.
800    ///
801    /// This function retrieves the attributes associated with a given token. The attributes are typically used to
802    /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
803    /// control tokens, etc.).
804    ///
805    /// # Panics
806    ///
807    /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
808    ///
809    /// # Example
810    ///
811    /// ```no_run
812    /// use llama_cpp_4::model::LlamaModel;
813    /// use llama_cpp_4::model::params::LlamaModelParams;
814    /// use llama_cpp_4::llama_backend::LlamaBackend;
815    /// use llama_cpp_4::token::LlamaToken;
816    ///
817    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
818    /// let backend = LlamaBackend::init()?;
819    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
820    /// let token = LlamaToken::new(42);
821    /// let token_attrs = model.token_attr(token);
822    /// # Ok(())
823    /// # }
824    /// ```
825    #[must_use]
826    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
827        let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
828        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
829    }
830
831    /// Detokenize a slice of tokens into a string.
832    ///
833    /// This is the inverse of [`str_to_token`](Self::str_to_token).
834    ///
835    /// # Parameters
836    ///
837    /// - `tokens`: The tokens to detokenize.
838    /// - `remove_special`: If `true`, special tokens are removed from the output.
839    /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
840    ///
841    /// # Errors
842    ///
843    /// Returns an error if the detokenized text is not valid UTF-8.
844    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
845    pub fn detokenize(
846        &self,
847        tokens: &[LlamaToken],
848        remove_special: bool,
849        unparse_special: bool,
850    ) -> Result<String, StringFromModelError> {
851        // First call with empty buffer to get required size
852        let n_tokens = tokens.len() as i32;
853        let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
854        let needed = unsafe {
855            llama_detokenize(
856                self.get_vocab().vocab.as_ref(),
857                token_ptr,
858                n_tokens,
859                std::ptr::null_mut(),
860                0,
861                remove_special,
862                unparse_special,
863            )
864        };
865        // llama_detokenize returns negative required size when buffer is too small
866        let buf_size = if needed < 0 { (-needed) as usize } else { needed as usize };
867        let mut buf = vec![0u8; buf_size];
868        let ret = unsafe {
869            llama_detokenize(
870                self.get_vocab().vocab.as_ref(),
871                token_ptr,
872                n_tokens,
873                buf.as_mut_ptr().cast::<c_char>(),
874                buf_size as i32,
875                remove_special,
876                unparse_special,
877            )
878        };
879        if ret < 0 {
880            return Err(StringFromModelError::ReturnedError(ret));
881        }
882        let len = ret as usize;
883        let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
884        Ok(s.to_owned())
885    }
886
887    /// Convert a token to a string with a specified buffer size.
888    ///
889    /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
890    /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
891    /// and the extra buffer size doesn't usually matter.
892    ///
893    /// # Errors
894    ///
895    /// - If the token type is unknown, an error will be returned.
896    /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
897    /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
898    ///
899    /// # Panics
900    ///
901    /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
902    /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
903    ///
904    /// # Example
905    ///
906    /// ```no_run
907    /// use llama_cpp_4::model::{LlamaModel, Special};
908    /// use llama_cpp_4::model::params::LlamaModelParams;
909    /// use llama_cpp_4::llama_backend::LlamaBackend;
910    /// use llama_cpp_4::token::LlamaToken;
911    ///
912    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
913    /// let backend = LlamaBackend::init()?;
914    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
915    /// let token = LlamaToken::new(42);
916    /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
917    /// # Ok(())
918    /// # }
919    /// ```
920    pub fn token_to_str_with_size(
921        &self,
922        token: LlamaToken,
923        buffer_size: usize,
924        special: Special,
925    ) -> Result<String, TokenToStringError> {
926        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
927        Ok(String::from_utf8(bytes)?)
928    }
929
930    /// Convert a token to bytes with a specified buffer size.
931    ///
932    /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
933    /// the extra bytes do not really matter.
934    ///
935    /// # Errors
936    ///
937    /// - if the token type is unknown
938    /// - the resultant token is larger than `buffer_size`.
939    ///
940    /// # Panics
941    ///
942    /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
943    /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
944    ///
945    /// # Example
946    ///
947    /// ```no_run
948    /// use llama_cpp_4::model::{LlamaModel, Special};
949    /// use llama_cpp_4::model::params::LlamaModelParams;
950    /// use llama_cpp_4::llama_backend::LlamaBackend;
951    /// use llama_cpp_4::token::LlamaToken;
952    ///
953    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
954    /// let backend = LlamaBackend::init()?;
955    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
956    /// let token = LlamaToken::new(42);
957    /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
958    /// # Ok(())
959    /// # }
960    /// ```
961    pub fn token_to_bytes_with_size(
962        &self,
963        token: LlamaToken,
964        buffer_size: usize,
965        special: Special,
966        lstrip: Option<NonZeroU16>,
967    ) -> Result<Vec<u8>, TokenToStringError> {
968        if token == self.token_nl() {
969            return Ok(String::from("\n").into_bytes());
970        }
971
972        // unsure what to do with this in the face of the 'special' arg + attr changes
973        let attrs = self.token_attr(token);
974        if (attrs.contains(LlamaTokenAttr::Control)
975            && (token == self.token_bos() || token == self.token_eos()))
976            || attrs.is_empty()
977            || attrs
978                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
979        {
980            return Ok(Vec::new());
981        }
982
983        let special = match special {
984            Special::Tokenize => true,
985            Special::Plaintext => false,
986        };
987
988        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
989        let len = string.as_bytes().len();
990        let len = c_int::try_from(len).expect("length fits into c_int");
991        let buf = string.into_raw();
992        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
993        let size = unsafe {
994            llama_token_to_piece(
995                self.get_vocab().vocab.as_ref(),
996                token.0,
997                buf,
998                len,
999                lstrip,
1000                special,
1001            )
1002        };
1003
1004        match size {
1005            0 => Err(TokenToStringError::UnknownTokenType),
1006            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
1007            size => {
1008                let string = unsafe { CString::from_raw(buf) };
1009                let mut bytes = string.into_bytes();
1010                let len = usize::try_from(size).expect("size is positive and fits into usize");
1011                bytes.truncate(len);
1012                Ok(bytes)
1013            }
1014        }
1015    }
1016    /// The number of tokens the model was trained on.
1017    ///
1018    /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1019    /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1020    ///
1021    /// # Example
1022    ///
1023    /// ```no_run
1024    /// use llama_cpp_4::model::LlamaModel;
1025    /// use llama_cpp_4::model::params::LlamaModelParams;
1026    /// use llama_cpp_4::llama_backend::LlamaBackend;
1027    ///
1028    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1029    /// let backend = LlamaBackend::init()?;
1030    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1031    /// let n_vocab = model.n_vocab();
1032    /// # Ok(())
1033    /// # }
1034    /// ```
1035    #[must_use]
1036    pub fn n_vocab(&self) -> i32 {
1037        unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
1038    }
1039
1040    /// The type of vocab the model was trained on.
1041    ///
1042    /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1043    /// word-level tokens, or another tokenization scheme.
1044    ///
1045    /// # Panics
1046    ///
1047    /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1048    ///
1049    /// # Example
1050    ///
1051    /// ```no_run
1052    /// use llama_cpp_4::model::LlamaModel;
1053    /// use llama_cpp_4::model::params::LlamaModelParams;
1054    /// use llama_cpp_4::llama_backend::LlamaBackend;
1055    ///
1056    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1057    /// let backend = LlamaBackend::init()?;
1058    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1059    /// let vocab_type = model.vocab_type();
1060    /// # Ok(())
1061    /// # }
1062    /// ```
1063    #[must_use]
1064    pub fn vocab_type(&self) -> VocabType {
1065        let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1066        VocabType::try_from(vocab_type).expect("invalid vocab type")
1067    }
1068
1069    /// Returns the number of embedding dimensions for the model.
1070    ///
1071    /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1072    /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1073    ///
1074    /// # Panics
1075    ///
1076    /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1077    ///
1078    /// # Example
1079    ///
1080    /// ```no_run
1081    /// use llama_cpp_4::model::LlamaModel;
1082    /// use llama_cpp_4::model::params::LlamaModelParams;
1083    /// use llama_cpp_4::llama_backend::LlamaBackend;
1084    ///
1085    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1086    /// let backend = LlamaBackend::init()?;
1087    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1088    /// let n_embd = model.n_embd();
1089    /// # Ok(())
1090    /// # }
1091    /// ```
1092    #[must_use]
1093    pub fn n_embd(&self) -> c_int {
1094        unsafe { llama_n_embd(self.model.as_ptr()) }
1095    }
1096
1097    /// Get the number of transformer layers in the model.
1098    #[must_use]
1099    pub fn n_layer(&self) -> c_int {
1100        unsafe { llama_n_layer(self.model.as_ptr()) }
1101    }
1102
1103    /// Get the number of attention heads in the model.
1104    #[must_use]
1105    pub fn n_head(&self) -> c_int {
1106        unsafe { llama_n_head(self.model.as_ptr()) }
1107    }
1108
1109    /// Get the number of key-value attention heads in the model.
1110    #[must_use]
1111    pub fn n_head_kv(&self) -> c_int {
1112        unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1113    }
1114
1115    /// Get the input embedding size of the model.
1116    #[must_use]
1117    pub fn n_embd_inp(&self) -> c_int {
1118        unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1119    }
1120
1121    /// Get the output embedding size of the model.
1122    #[must_use]
1123    pub fn n_embd_out(&self) -> c_int {
1124        unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1125    }
1126
1127    /// Get the sliding window attention size of the model.
1128    /// Returns 0 if the model does not use sliding window attention.
1129    #[must_use]
1130    pub fn n_swa(&self) -> c_int {
1131        unsafe { llama_model_n_swa(self.model.as_ptr()) }
1132    }
1133
1134    /// Get the `RoPE` type used by the model.
1135    #[must_use]
1136    pub fn rope_type(&self) -> i32 {
1137        unsafe { llama_model_rope_type(self.model.as_ptr()) }
1138    }
1139
1140    /// Get the `RoPE` frequency scale used during training.
1141    #[must_use]
1142    pub fn rope_freq_scale_train(&self) -> f32 {
1143        unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1144    }
1145
1146    /// Get the model size in bytes.
1147    #[must_use]
1148    pub fn model_size(&self) -> u64 {
1149        unsafe { llama_model_size(self.model.as_ptr()) }
1150    }
1151
1152    /// Get the number of parameters in the model.
1153    #[must_use]
1154    pub fn n_params(&self) -> u64 {
1155        unsafe { llama_model_n_params(self.model.as_ptr()) }
1156    }
1157
1158    /// Get the number of classification outputs.
1159    #[must_use]
1160    pub fn n_cls_out(&self) -> u32 {
1161        unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1162    }
1163
1164    /// Get the classification label for the given index.
1165    ///
1166    /// # Errors
1167    ///
1168    /// Returns an error if the label is null or not valid UTF-8.
1169    pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1170        let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1171        if ptr.is_null() {
1172            return Err(StringFromModelError::ReturnedError(-1));
1173        }
1174        let cstr = unsafe { CStr::from_ptr(ptr) };
1175        cstr.to_str().map_err(StringFromModelError::Utf8Error)
1176    }
1177
1178    /// Get the number of metadata key-value pairs.
1179    #[must_use]
1180    pub fn meta_count(&self) -> c_int {
1181        unsafe { llama_model_meta_count(self.model.as_ptr()) }
1182    }
1183
1184    /// Get a model description string.
1185    ///
1186    /// The `buf_size` parameter specifies the maximum buffer size for the description.
1187    /// A default of 256 bytes is usually sufficient.
1188    ///
1189    /// # Errors
1190    ///
1191    /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1192    #[allow(clippy::cast_sign_loss)]
1193    pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1194        let mut buf = vec![0u8; buf_size];
1195        let ret = unsafe {
1196            llama_model_desc(
1197                self.model.as_ptr(),
1198                buf.as_mut_ptr().cast::<c_char>(),
1199                buf_size,
1200            )
1201        };
1202        if ret < 0 {
1203            return Err(StringFromModelError::ReturnedError(ret));
1204        }
1205        let len = ret as usize;
1206        let s = std::str::from_utf8(&buf[..len])
1207            .map_err(StringFromModelError::Utf8Error)?;
1208        Ok(s.to_owned())
1209    }
1210
1211    /// Get a metadata key by index.
1212    ///
1213    /// The `buf_size` parameter specifies the maximum buffer size for the key.
1214    /// A default of 256 bytes is usually sufficient.
1215    ///
1216    /// # Errors
1217    ///
1218    /// Returns an error if the index is out of range or the key is not valid UTF-8.
1219    #[allow(clippy::cast_sign_loss)]
1220    pub fn meta_key_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1221        let mut buf = vec![0u8; buf_size];
1222        let ret = unsafe {
1223            llama_model_meta_key_by_index(
1224                self.model.as_ptr(),
1225                index,
1226                buf.as_mut_ptr().cast::<c_char>(),
1227                buf_size,
1228            )
1229        };
1230        if ret < 0 {
1231            return Err(StringFromModelError::ReturnedError(ret));
1232        }
1233        let len = ret as usize;
1234        let s = std::str::from_utf8(&buf[..len])
1235            .map_err(StringFromModelError::Utf8Error)?;
1236        Ok(s.to_owned())
1237    }
1238
1239    /// Get a metadata value string by index.
1240    ///
1241    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1242    /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1243    ///
1244    /// # Errors
1245    ///
1246    /// Returns an error if the index is out of range or the value is not valid UTF-8.
1247    #[allow(clippy::cast_sign_loss)]
1248    pub fn meta_val_str_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1249        let mut buf = vec![0u8; buf_size];
1250        let ret = unsafe {
1251            llama_model_meta_val_str_by_index(
1252                self.model.as_ptr(),
1253                index,
1254                buf.as_mut_ptr().cast::<c_char>(),
1255                buf_size,
1256            )
1257        };
1258        if ret < 0 {
1259            return Err(StringFromModelError::ReturnedError(ret));
1260        }
1261        let len = ret as usize;
1262        let s = std::str::from_utf8(&buf[..len])
1263            .map_err(StringFromModelError::Utf8Error)?;
1264        Ok(s.to_owned())
1265    }
1266
1267    /// Get a metadata value by key name.
1268    ///
1269    /// This is more convenient than iterating metadata by index when you know the key.
1270    /// The `buf_size` parameter specifies the maximum buffer size for the value.
1271    ///
1272    /// # Errors
1273    ///
1274    /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1275    #[allow(clippy::cast_sign_loss)]
1276    pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1277        let c_key = CString::new(key)
1278            .map_err(|_| StringFromModelError::ReturnedError(-1))?;
1279        let mut buf = vec![0u8; buf_size];
1280        let ret = unsafe {
1281            llama_model_meta_val_str(
1282                self.model.as_ptr(),
1283                c_key.as_ptr(),
1284                buf.as_mut_ptr().cast::<c_char>(),
1285                buf_size,
1286            )
1287        };
1288        if ret < 0 {
1289            return Err(StringFromModelError::ReturnedError(ret));
1290        }
1291        let len = ret as usize;
1292        let s = std::str::from_utf8(&buf[..len])
1293            .map_err(StringFromModelError::Utf8Error)?;
1294        Ok(s.to_owned())
1295    }
1296
1297    /// Get all metadata as a list of `(key, value)` pairs.
1298    ///
1299    /// This is a convenience method that iterates over all metadata entries.
1300    /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1301    /// For values that may be larger (e.g. token lists), use
1302    /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1303    ///
1304    /// # Errors
1305    ///
1306    /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1307    #[allow(clippy::cast_sign_loss)]
1308    pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1309        let count = self.meta_count();
1310        let mut result = Vec::with_capacity(count as usize);
1311        for i in 0..count {
1312            let key = self.meta_key_by_index(i, 256)?;
1313            let val = self.meta_val_str_by_index(i, 4096)?;
1314            result.push((key, val));
1315        }
1316        Ok(result)
1317    }
1318
1319    /// Check if the model has an encoder.
1320    #[must_use]
1321    pub fn has_encoder(&self) -> bool {
1322        unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1323    }
1324
1325    /// Check if the model has a decoder.
1326    #[must_use]
1327    pub fn has_decoder(&self) -> bool {
1328        unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1329    }
1330
1331    /// Check if the model is recurrent (e.g. Mamba, RWKV).
1332    #[must_use]
1333    pub fn is_recurrent(&self) -> bool {
1334        unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1335    }
1336
1337    /// Check if the model is a hybrid model.
1338    #[must_use]
1339    pub fn is_hybrid(&self) -> bool {
1340        unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1341    }
1342
1343    /// Check if the model is a diffusion model.
1344    #[must_use]
1345    pub fn is_diffusion(&self) -> bool {
1346        unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1347    }
1348
1349    /// Get chat template from model.
1350    ///
1351    /// # Errors
1352    ///
1353    /// - If the model does not have a chat template, it will return an error.
1354    /// - If the chat template is not a valid `CString`, it will return an error.
1355    ///
1356    /// # Example
1357    ///
1358    /// ```no_run
1359    /// use llama_cpp_4::model::LlamaModel;
1360    /// use llama_cpp_4::model::params::LlamaModelParams;
1361    /// use llama_cpp_4::llama_backend::LlamaBackend;
1362    ///
1363    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1364    /// let backend = LlamaBackend::init()?;
1365    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1366    /// let chat_template = model.get_chat_template(1024)?;
1367    /// # Ok(())
1368    /// # }
1369    /// ```
1370    #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1371    pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1372        // longest known template is about 1200 bytes from llama.cpp
1373        let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1374        let chat_ptr = chat_temp.into_raw();
1375        let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1376
1377        let ret = unsafe {
1378            llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1379        };
1380
1381        if ret < 0 {
1382            return Err(ChatTemplateError::MissingTemplate(ret));
1383        }
1384
1385        let template_c = unsafe { CString::from_raw(chat_ptr) };
1386        let template = template_c.to_str()?;
1387
1388        let ret: usize = ret.try_into().unwrap();
1389        if template.len() < ret {
1390            return Err(ChatTemplateError::BuffSizeError(ret + 1));
1391        }
1392
1393        Ok(template.to_owned())
1394    }
1395
1396    /// Loads a model from a file.
1397    ///
1398    /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1399    ///
1400    /// # Errors
1401    ///
1402    /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1403    /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1404    ///
1405    /// # Example
1406    ///
1407    /// ```no_run
1408    /// use llama_cpp_4::model::LlamaModel;
1409    /// use llama_cpp_4::model::params::LlamaModelParams;
1410    /// use llama_cpp_4::llama_backend::LlamaBackend;
1411    ///
1412    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1413    /// let backend = LlamaBackend::init()?;
1414    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1415    /// # Ok(())
1416    /// # }
1417    /// ```
1418    #[tracing::instrument(skip_all, fields(params))]
1419    pub fn load_from_file(
1420        _: &LlamaBackend,
1421        path: impl AsRef<Path>,
1422        params: &LlamaModelParams,
1423    ) -> Result<Self, LlamaModelLoadError> {
1424        let path = path.as_ref();
1425        debug_assert!(
1426            Path::new(path).exists(),
1427            "{} does not exist",
1428            path.display()
1429        );
1430        let path = path
1431            .to_str()
1432            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1433
1434        let cstr = CString::new(path)?;
1435        let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
1436
1437        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1438
1439        tracing::debug!(?path, "Loaded model");
1440        Ok(LlamaModel { model })
1441    }
1442
1443    /// Load a model from multiple split files.
1444    ///
1445    /// This function loads a model that has been split across multiple files. This is useful for
1446    /// very large models that exceed filesystem limitations or need to be distributed across
1447    /// multiple storage devices.
1448    ///
1449    /// # Arguments
1450    ///
1451    /// * `paths` - A slice of paths to the split model files
1452    /// * `params` - The model parameters
1453    ///
1454    /// # Errors
1455    ///
1456    /// Returns an error if:
1457    /// - Any of the paths cannot be converted to a C string
1458    /// - The model fails to load from the splits
1459    /// - Any path doesn't exist or isn't accessible
1460    ///
1461    /// # Example
1462    ///
1463    /// ```no_run
1464    /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1465    /// use llama_cpp_4::llama_backend::LlamaBackend;
1466    ///
1467    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1468    /// let backend = LlamaBackend::init()?;
1469    /// let params = LlamaModelParams::default();
1470    ///
1471    /// let paths = vec![
1472    ///     "model-00001-of-00003.gguf",
1473    ///     "model-00002-of-00003.gguf",
1474    ///     "model-00003-of-00003.gguf",
1475    /// ];
1476    ///
1477    /// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
1478    /// # Ok(())
1479    /// # }
1480    /// ```
1481    #[tracing::instrument(skip_all)]
1482    pub fn load_from_splits(
1483        _: &LlamaBackend,
1484        paths: &[impl AsRef<Path>],
1485        params: &LlamaModelParams,
1486    ) -> Result<Self, LlamaModelLoadError> {
1487        // Convert paths to C strings
1488        let c_strings: Vec<CString> = paths
1489            .iter()
1490            .map(|p| {
1491                let path = p.as_ref();
1492                debug_assert!(path.exists(), "{} does not exist", path.display());
1493                let path_str = path
1494                    .to_str()
1495                    .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1496                CString::new(path_str).map_err(LlamaModelLoadError::from)
1497            })
1498            .collect::<Result<Vec<_>, _>>()?;
1499
1500        // Create array of pointers to C strings
1501        let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1502
1503        // Load the model from splits
1504        let llama_model = unsafe {
1505            llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1506        };
1507
1508        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1509
1510        tracing::debug!("Loaded model from {} splits", paths.len());
1511        Ok(LlamaModel { model })
1512    }
1513
1514    /// Load a model from a `FILE` pointer.
1515    ///
1516    /// # Safety
1517    ///
1518    /// The `file` pointer must be a valid, open `FILE*`.
1519    ///
1520    /// # Errors
1521    ///
1522    /// Returns an error if the model cannot be loaded.
1523    pub unsafe fn load_from_file_ptr(
1524        file: *mut llama_cpp_sys_4::FILE,
1525        params: &LlamaModelParams,
1526    ) -> Result<Self, LlamaModelLoadError> {
1527        let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1528        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1529        Ok(LlamaModel { model })
1530    }
1531
1532    /// Initialize a model from user-provided data.
1533    ///
1534    /// # Safety
1535    ///
1536    /// The metadata, callback, and user data must be valid.
1537    ///
1538    /// # Errors
1539    ///
1540    /// Returns an error if the model cannot be initialized.
1541    pub unsafe fn init_from_user(
1542        metadata: *mut llama_cpp_sys_4::gguf_context,
1543        set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1544        set_tensor_data_ud: *mut std::ffi::c_void,
1545        params: &LlamaModelParams,
1546    ) -> Result<Self, LlamaModelLoadError> {
1547        let model = llama_cpp_sys_4::llama_model_init_from_user(
1548            metadata,
1549            set_tensor_data,
1550            set_tensor_data_ud,
1551            params.params,
1552        );
1553        let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1554        Ok(LlamaModel { model })
1555    }
1556
1557    /// Save the model to a file.
1558    ///
1559    /// # Panics
1560    ///
1561    /// Panics if the path contains null bytes.
1562    pub fn save_to_file(&self, path: impl AsRef<Path>) {
1563        let path = path.as_ref();
1564        let path_str = path.to_str().expect("path is not valid UTF-8");
1565        let c_path = CString::new(path_str).expect("path contains null bytes");
1566        unsafe {
1567            llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1568        }
1569    }
1570
1571    /// Get the list of built-in chat templates.
1572    ///
1573    /// Returns the names of all chat templates that are built into llama.cpp.
1574    ///
1575    /// # Panics
1576    ///
1577    /// Panics if any template name is not valid UTF-8.
1578    #[allow(clippy::cast_sign_loss)]
1579    #[must_use]
1580    pub fn chat_builtin_templates() -> Vec<String> {
1581        // First call to get count
1582        let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1583        if count <= 0 {
1584            return Vec::new();
1585        }
1586        let count = count as usize;
1587        let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1588        unsafe {
1589            llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1590        }
1591        ptrs.iter()
1592            .map(|&p| {
1593                let cstr = unsafe { CStr::from_ptr(p) };
1594                cstr.to_str()
1595                    .expect("template name is not valid UTF-8")
1596                    .to_owned()
1597            })
1598            .collect()
1599    }
1600
1601    /// Initializes a lora adapter from a file.
1602    ///
1603    /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1604    /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1605    /// to the model for improved performance on specialized tasks.
1606    ///
1607    /// # Errors
1608    ///
1609    /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1610    ///
1611    /// # Example
1612    ///
1613    /// ```no_run
1614    /// use llama_cpp_4::model::LlamaModel;
1615    /// use llama_cpp_4::model::params::LlamaModelParams;
1616    /// use llama_cpp_4::llama_backend::LlamaBackend;
1617    ///
1618    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1619    /// let backend = LlamaBackend::init()?;
1620    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1621    /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1622    /// # Ok(())
1623    /// # }
1624    /// ```
1625    pub fn lora_adapter_init(
1626        &self,
1627        path: impl AsRef<Path>,
1628    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1629        let path = path.as_ref();
1630        debug_assert!(
1631            Path::new(path).exists(),
1632            "{} does not exist",
1633            path.display()
1634        );
1635
1636        let path = path
1637            .to_str()
1638            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1639                path.to_path_buf(),
1640            ))?;
1641
1642        let cstr = CString::new(path)?;
1643        let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1644
1645        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1646
1647        tracing::debug!(?path, "Initialized lora adapter");
1648        Ok(LlamaLoraAdapter {
1649            lora_adapter: adapter,
1650        })
1651    }
1652
1653    /// Create a new context from this model.
1654    ///
1655    /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1656    /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1657    /// control over model parameters for a specific task.
1658    ///
1659    /// # Errors
1660    ///
1661    /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1662    ///   for more detailed error descriptions.
1663    ///
1664    /// # Example
1665    ///
1666    /// ```no_run
1667    /// use llama_cpp_4::model::LlamaModel;
1668    /// use llama_cpp_4::model::params::LlamaModelParams;
1669    /// use llama_cpp_4::context::params::LlamaContextParams;
1670    /// use llama_cpp_4::llama_backend::LlamaBackend;
1671    ///
1672    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1673    /// let backend = LlamaBackend::init()?;
1674    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1675    /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1676    /// # Ok(())
1677    /// # }
1678    /// ```
1679    #[allow(clippy::needless_pass_by_value)]
1680    pub fn new_context(
1681        &self,
1682        _: &LlamaBackend,
1683        params: LlamaContextParams,
1684    ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1685        // Apply TurboQuant attn-rotation preference before the KV cache is
1686        // initialised inside llama_new_context_with_model.
1687        let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1688        if params.attn_rot_disabled {
1689            // SAFETY: we restore the value right after the call.
1690            #[allow(unused_unsafe)]
1691            unsafe {
1692                std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1693            }
1694        } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1695            // params say "enabled" – only clear if it was previously unset
1696            // (respect explicit user env var).
1697        }
1698
1699        let context_params = params.context_params;
1700        let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
1701
1702        // Restore the env-var to its previous state.
1703        #[allow(unused_unsafe)]
1704        match prev_rot_var {
1705            Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1706            None if params.attn_rot_disabled => unsafe {
1707                std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1708            },
1709            None => {}
1710        }
1711
1712        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1713        Ok(LlamaContext::new(self, context, params.embeddings()))
1714    }
1715
1716    /// Apply the model's chat template to a sequence of messages.
1717    ///
1718    /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1719    /// template determines the structure or style of conversation between the system and user, such as token formatting,
1720    /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1721    /// is provided, the default template used by `llama.cpp` will be applied.
1722    ///
1723    /// For more information on supported templates, visit:
1724    /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1725    ///
1726    /// # Arguments
1727    ///
1728    /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1729    /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1730    /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1731    ///
1732    /// # Errors
1733    ///
1734    /// There are several possible points of failure when applying the chat template:
1735    /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1736    /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1737    ///
1738    /// # Example
1739    ///
1740    /// ```no_run
1741    /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1742    /// use llama_cpp_4::model::params::LlamaModelParams;
1743    /// use llama_cpp_4::llama_backend::LlamaBackend;
1744    ///
1745    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1746    /// let backend = LlamaBackend::init()?;
1747    /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1748    /// let chat = vec![
1749    ///     LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1750    ///     LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1751    /// ];
1752    /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1753    /// # Ok(())
1754    /// # }
1755    /// ```
1756    ///
1757    /// # Notes
1758    ///
1759    /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1760    /// # Panics
1761    ///
1762    /// Panics if the buffer length exceeds `i32::MAX`.
1763    #[tracing::instrument(skip_all)]
1764    pub fn apply_chat_template(
1765        &self,
1766        tmpl: Option<&str>,
1767        chat: &[LlamaChatMessage],
1768        add_ass: bool,
1769    ) -> Result<String, ApplyChatTemplateError> {
1770        // Compute raw message byte total from the original LlamaChatMessage vec
1771        // *before* we shadow `chat` with the sys-type vec below.
1772        let message_length = chat.iter().fold(0usize, |acc, c| {
1773            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1774        });
1775
1776        // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1777        let chat_sys: Vec<llama_chat_message> = chat
1778            .iter()
1779            .map(|c| llama_chat_message {
1780                role: c.role.as_ptr(),
1781                content: c.content.as_ptr(),
1782            })
1783            .collect();
1784
1785        // Set the tmpl pointer.
1786        let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1787        let tmpl_ptr = tmpl_cstring
1788            .as_ref()
1789            .map_or(std::ptr::null(), |s| s.as_ptr());
1790
1791        // `message_length * 4` is far too small for models whose built-in chat
1792        // template adds a long default system prompt (e.g. Qwen3.5 prepends
1793        // ~80+ chars of markup even for a one-word user message).  Start with
1794        // at least 4 KiB so short inputs like "hi" always have room.
1795        //
1796        // `llama_chat_apply_template` returns the number of bytes it *actually*
1797        // needed when the buffer was too small, so we retry exactly once with
1798        // that precise size rather than giving up immediately.
1799        let mut buf_size = message_length.saturating_mul(4).max(4096);
1800
1801        for _ in 0..2 {
1802            // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1803            let mut buff = vec![0u8; buf_size];
1804            let res = unsafe {
1805                llama_chat_apply_template(
1806                    tmpl_ptr,
1807                    chat_sys.as_ptr(),
1808                    chat_sys.len(),
1809                    add_ass,
1810                    buff.as_mut_ptr().cast(),
1811                    i32::try_from(buff.len()).expect("buffer length fits in i32"),
1812                )
1813            };
1814
1815            if res < 0 {
1816                return Err(ApplyChatTemplateError::BuffSizeError);
1817            }
1818
1819            #[allow(clippy::cast_sign_loss)]
1820            let needed = res as usize;
1821            if needed > buf_size {
1822                // Buffer was too small — retry with the exact size llama.cpp reported.
1823                buf_size = needed + 1; // +1 for null terminator
1824                continue;
1825            }
1826
1827            // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1828            // into `buff`; `needed` bytes were used.
1829            let formatted = unsafe {
1830                CStr::from_ptr(buff.as_ptr().cast())
1831                    .to_string_lossy()
1832                    .into_owned()
1833            };
1834            return Ok(formatted);
1835        }
1836
1837        Err(ApplyChatTemplateError::BuffSizeError)
1838    }
1839
1840    /// Build a split GGUF file path for a specific chunk.
1841    ///
1842    /// This utility function creates the standardized filename for a split model chunk
1843    /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1844    ///
1845    /// # Arguments
1846    ///
1847    /// * `path_prefix` - The base path and filename prefix
1848    /// * `split_no` - The split number (1-indexed)
1849    /// * `split_count` - The total number of splits
1850    ///
1851    /// # Returns
1852    ///
1853    /// Returns the formatted split path as a String
1854    ///
1855    /// # Example
1856    ///
1857    /// ```
1858    /// use llama_cpp_4::model::LlamaModel;
1859    ///
1860    /// let path = LlamaModel::split_path("/models/llama", 1, 4);
1861    /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1862    /// ```
1863    ///
1864    /// # Panics
1865    ///
1866    /// Panics if the path prefix contains a null byte.
1867    #[must_use]
1868    pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1869        let mut buffer = vec![0u8; 1024];
1870        let len = unsafe {
1871            llama_split_path(
1872                buffer.as_mut_ptr().cast::<c_char>(),
1873                buffer.len(),
1874                CString::new(path_prefix).unwrap().as_ptr(),
1875                split_no,
1876                split_count,
1877            )
1878        };
1879
1880        let len = usize::try_from(len).expect("split_path length fits in usize");
1881        buffer.truncate(len);
1882        String::from_utf8(buffer).unwrap_or_default()
1883    }
1884
1885    /// Extract the path prefix from a split filename.
1886    ///
1887    /// This function extracts the base path prefix from a split model filename,
1888    /// but only if the `split_no` and `split_count` match the pattern in the filename.
1889    ///
1890    /// # Arguments
1891    ///
1892    /// * `split_path` - The full path to the split file
1893    /// * `split_no` - The expected split number
1894    /// * `split_count` - The expected total number of splits
1895    ///
1896    /// # Returns
1897    ///
1898    /// Returns the path prefix if the pattern matches, or None if it doesn't
1899    ///
1900    /// # Example
1901    ///
1902    /// ```
1903    /// use llama_cpp_4::model::LlamaModel;
1904    ///
1905    /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
1906    /// assert_eq!(prefix, Some("/models/llama".to_string()));
1907    /// ```
1908    ///
1909    /// # Panics
1910    ///
1911    /// Panics if the split path contains a null byte.
1912    #[must_use]
1913    pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1914        let mut buffer = vec![0u8; 1024];
1915        let len = unsafe {
1916            llama_split_prefix(
1917                buffer.as_mut_ptr().cast::<c_char>(),
1918                buffer.len(),
1919                CString::new(split_path).unwrap().as_ptr(),
1920                split_no,
1921                split_count,
1922            )
1923        };
1924
1925        if len > 0 {
1926            let len = usize::try_from(len).expect("split_prefix length fits in usize");
1927            buffer.truncate(len);
1928            String::from_utf8(buffer).ok()
1929        } else {
1930            None
1931        }
1932    }
1933}
1934
1935#[allow(clippy::cast_precision_loss)]
1936impl fmt::Display for LlamaModel {
1937    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1938        let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1939        write!(
1940            f,
1941            "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1942            layers = self.n_layer(),
1943            heads = self.n_head(),
1944            embd = self.n_embd(),
1945            params = self.n_params(),
1946            size = self.model_size() as f64 / (1024.0 * 1024.0),
1947        )
1948    }
1949}
1950
1951impl Drop for LlamaModel {
1952    fn drop(&mut self) {
1953        unsafe { llama_free_model(self.model.as_ptr()) }
1954    }
1955}
1956
1957/// Defines the possible types of vocabulary used by the model.
1958///
1959/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1960/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1961///
1962/// # Variants
1963///
1964/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1965/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1966///
1967/// # Example
1968///
1969/// ```no_run
1970/// use llama_cpp_4::model::VocabType;
1971///
1972/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1973/// let vocab_type = VocabType::BPE;
1974/// match vocab_type {
1975///     VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1976///     VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1977/// }
1978/// # Ok(())
1979/// # }
1980/// ```
1981#[repr(u32)]
1982#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1983pub enum VocabType {
1984    /// Byte Pair Encoding
1985    BPE = LLAMA_VOCAB_TYPE_BPE as _,
1986    /// Sentence Piece Tokenizer
1987    SPM = LLAMA_VOCAB_TYPE_SPM as _,
1988}
1989
1990/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1991///
1992/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1993///
1994/// # Variants
1995///
1996/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1997///
1998/// # Example
1999///
2000/// ```no_run
2001/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
2002///
2003/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2004/// let invalid_value = 999; // Not a valid vocabulary type
2005/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
2006/// println!("Error: {}", error);
2007/// # Ok(())
2008/// # }
2009/// ```
2010#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2011pub enum LlamaTokenTypeFromIntError {
2012    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2013    #[error("Unknown Value {0}")]
2014    UnknownValue(llama_vocab_type),
2015}
2016
2017impl TryFrom<llama_vocab_type> for VocabType {
2018    type Error = LlamaTokenTypeFromIntError;
2019
2020    fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2021        match value {
2022            LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2023            LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2024            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2025        }
2026    }
2027}