Skip to main content

llama_cpp_4/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::num::NonZeroU16;
5use std::os::raw::{c_char, c_int};
6use std::path::Path;
7use std::ptr::NonNull;
8
9use llama_cpp_sys_4::{
10    llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template, llama_chat_message,
11    llama_free_model, llama_load_model_from_file, llama_model, llama_model_decoder_start_token,
12    llama_model_get_vocab, llama_model_load_from_splits, llama_model_meta_val_str,
13    llama_n_ctx_train, llama_n_embd, llama_n_vocab, llama_new_context_with_model, llama_split_path,
14    llama_split_prefix, llama_token_bos, llama_token_eos, llama_token_get_attr, llama_token_is_eog,
15    llama_token_nl, llama_token_to_piece, llama_tokenize, llama_vocab, llama_vocab_type,
16    LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
17};
18
19use crate::context::params::LlamaContextParams;
20use crate::context::LlamaContext;
21use crate::llama_backend::LlamaBackend;
22use crate::model::params::LlamaModelParams;
23use crate::token::LlamaToken;
24use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
25use crate::{
26    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
27    LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
28};
29
30pub mod params;
31
32/// A safe wrapper around `llama_model`.
33#[derive(Debug)]
34#[repr(transparent)]
35#[allow(clippy::module_name_repetitions)]
36pub struct LlamaModel {
37    pub(crate) model: NonNull<llama_model>,
38}
39
40/// A safe wrapper around `llama_vocab`.
41#[derive(Debug)]
42#[repr(transparent)]
43#[allow(clippy::module_name_repetitions)]
44pub struct LlamaVocab {
45    pub(crate) vocab: NonNull<llama_vocab>,
46}
47
48/// A safe wrapper around `llama_adapter_lora`.
49#[derive(Debug)]
50#[repr(transparent)]
51#[allow(clippy::module_name_repetitions)]
52pub struct LlamaLoraAdapter {
53    pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
54}
55
56/// A Safe wrapper around `llama_chat_message`
57#[derive(Debug, Eq, PartialEq, Clone)]
58pub struct LlamaChatMessage {
59    role: CString,
60    content: CString,
61}
62
63impl LlamaChatMessage {
64    /// Create a new `LlamaChatMessage`.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
69    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
70        Ok(Self {
71            role: CString::new(role)?,
72            content: CString::new(content)?,
73        })
74    }
75}
76
77/// How to determine if we should prepend a bos token to tokens
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum AddBos {
80    /// Add the beginning of stream token to the start of the string.
81    Always,
82    /// Do not add the beginning of stream token to the start of the string.
83    Never,
84}
85
86/// How to determine if we should tokenize special tokens
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum Special {
89    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
90    Tokenize,
91    /// Treat special and/or control tokens as plaintext.
92    Plaintext,
93}
94
95unsafe impl Send for LlamaModel {}
96
97unsafe impl Sync for LlamaModel {}
98
99impl LlamaModel {
100    /// Retrieves the vocabulary associated with the current Llama model.
101    ///
102    /// This method fetches the vocabulary from the underlying model using an unsafe
103    /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
104    /// the vocabulary data, which is wrapped in a `NonNull` for safety.
105    ///
106    /// # Safety
107    /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
108    /// which is assumed to return a valid pointer to the vocabulary. The caller should
109    /// ensure that the model object is properly initialized and valid before calling
110    /// this method, as dereferencing invalid pointers can lead to undefined behavior.
111    ///
112    /// # Returns
113    /// A `LlamaVocab` struct containing the vocabulary of the model.
114    ///
115    /// # Panics
116    ///
117    /// Panics if the underlying C function returns a null pointer.
118    ///
119    /// # Example
120    /// ```rust,ignore
121    /// let vocab = model.get_vocab();
122    /// ```
123    #[must_use]
124    pub fn get_vocab(&self) -> LlamaVocab {
125        let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
126
127        LlamaVocab {
128            vocab: NonNull::new(llama_vocab).unwrap(),
129        }
130    }
131    /// Get the number of tokens the model was trained on.
132    ///
133    /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
134    ///
135    /// # Panics
136    ///
137    /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
138    /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
139    /// which is almost certainly positive.
140    #[must_use]
141    pub fn n_ctx_train(&self) -> u32 {
142        let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
143        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
144    }
145
146    /// Get all tokens in the model.
147    ///
148    /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
149    /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
150    ///
151    /// # Parameters
152    ///
153    /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
154    pub fn tokens(
155        &self,
156        special: Special,
157    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
158        (0..self.n_vocab())
159            .map(LlamaToken::new)
160            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
161    }
162
163    /// Get the beginning of stream token.
164    ///
165    /// This function returns the token that represents the beginning of a stream (BOS token).
166    #[must_use]
167    pub fn token_bos(&self) -> LlamaToken {
168        let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
169        LlamaToken(token)
170    }
171
172    /// Get the end of stream token.
173    ///
174    /// This function returns the token that represents the end of a stream (EOS token).
175    #[must_use]
176    pub fn token_eos(&self) -> LlamaToken {
177        let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
178        LlamaToken(token)
179    }
180
181    /// Get the newline token.
182    ///
183    /// This function returns the token that represents a newline character.
184    #[must_use]
185    pub fn token_nl(&self) -> LlamaToken {
186        let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
187        LlamaToken(token)
188    }
189
190    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
191    ///
192    /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
193    /// such as EOS or other special tokens.
194    ///
195    /// # Parameters
196    ///
197    /// - `token`: The `LlamaToken` to check.
198    ///
199    /// # Returns
200    ///
201    /// - `true` if the token is an end-of-generation token, otherwise `false`.
202    #[must_use]
203    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
204        unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
205    }
206
207    /// Get the decoder start token.
208    ///
209    /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
210    /// of a sequence generation).
211    #[must_use]
212    pub fn decode_start_token(&self) -> LlamaToken {
213        let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
214        LlamaToken(token)
215    }
216
217    /// Convert a single token to a string.
218    ///
219    /// This function converts a `LlamaToken` into its string representation.
220    ///
221    /// # Errors
222    ///
223    /// This function returns an error if the token cannot be converted to a string. For more details, refer to
224    /// [`TokenToStringError`].
225    ///
226    /// # Parameters
227    ///
228    /// - `token`: The `LlamaToken` to convert.
229    /// - `special`: The `Special` value used to handle special tokens.
230    pub fn token_to_str(
231        &self,
232        token: LlamaToken,
233        special: Special,
234    ) -> Result<String, TokenToStringError> {
235        self.token_to_str_with_size(token, 32, special)
236    }
237
238    /// Convert a single token to bytes.
239    ///
240    /// This function converts a `LlamaToken` into a byte representation.
241    ///
242    /// # Errors
243    ///
244    /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
245    /// [`TokenToStringError`].
246    ///
247    /// # Parameters
248    ///
249    /// - `token`: The `LlamaToken` to convert.
250    /// - `special`: The `Special` value used to handle special tokens.
251    pub fn token_to_bytes(
252        &self,
253        token: LlamaToken,
254        special: Special,
255    ) -> Result<Vec<u8>, TokenToStringError> {
256        self.token_to_bytes_with_size(token, 32, special, None)
257    }
258
259    /// Convert a vector of tokens to a single string.
260    ///
261    /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
262    /// string representations.
263    ///
264    /// # Errors
265    ///
266    /// This function returns an error if any token cannot be converted to a string. For more details, refer to
267    /// [`TokenToStringError`].
268    ///
269    /// # Parameters
270    ///
271    /// - `tokens`: A slice of `LlamaToken`s to convert.
272    /// - `special`: The `Special` value used to handle special tokens.
273    pub fn tokens_to_str(
274        &self,
275        tokens: &[LlamaToken],
276        special: Special,
277    ) -> Result<String, TokenToStringError> {
278        let mut builder = String::with_capacity(tokens.len() * 4);
279        for str in tokens
280            .iter()
281            .copied()
282            .map(|t| self.token_to_str(t, special))
283        {
284            builder += &str?;
285        }
286        Ok(builder)
287    }
288
289    /// Convert a string to a vector of tokens.
290    ///
291    /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
292    /// and return the corresponding tokens.
293    ///
294    /// # Errors
295    ///
296    /// - This function will return an error if the input string contains a null byte.
297    ///
298    /// # Panics
299    ///
300    /// - This function will panic if the number of tokens exceeds `usize::MAX`.
301    ///
302    /// # Example
303    ///
304    /// ```no_run
305    /// use llama_cpp_4::model::LlamaModel;
306    ///
307    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
308    /// use std::path::Path;
309    /// use llama_cpp_4::model::AddBos;
310    /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
311    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
312    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
313    /// # Ok(())
314    /// # }
315    /// ```
316    pub fn str_to_token(
317        &self,
318        str: &str,
319        add_bos: AddBos,
320    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
321        let add_bos = match add_bos {
322            AddBos::Always => true,
323            AddBos::Never => false,
324        };
325
326        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
327        let mut buffer = Vec::with_capacity(tokens_estimation);
328
329        let c_string = CString::new(str)?;
330        let buffer_capacity =
331            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
332
333        let size = unsafe {
334            llama_tokenize(
335                self.get_vocab().vocab.as_ref(),
336                c_string.as_ptr(),
337                c_int::try_from(c_string.as_bytes().len())?,
338                buffer.as_mut_ptr(),
339                buffer_capacity,
340                add_bos,
341                true,
342            )
343        };
344
345        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
346        // as a result - size is guaranteed to be positive here.
347        let size = if size.is_negative() {
348            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
349            unsafe {
350                llama_tokenize(
351                    self.get_vocab().vocab.as_ref(),
352                    c_string.as_ptr(),
353                    c_int::try_from(c_string.as_bytes().len())?,
354                    buffer.as_mut_ptr(),
355                    -size,
356                    add_bos,
357                    true,
358                )
359            }
360        } else {
361            size
362        };
363
364        let size = usize::try_from(size).expect("size is positive and usize ");
365
366        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
367        unsafe { buffer.set_len(size) }
368        Ok(buffer.into_iter().map(LlamaToken).collect())
369    }
370
371    /// Get the type of a token.
372    ///
373    /// This function retrieves the attributes associated with a given token. The attributes are typically used to
374    /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
375    /// control tokens, etc.).
376    ///
377    /// # Panics
378    ///
379    /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
380    ///
381    /// # Example
382    ///
383    /// ```no_run
384    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
385    ///
386    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
387    /// let model = LlamaModel::load_from_file("path/to/model")?;
388    /// let token = LlamaToken(42);
389    /// let token_attrs = model.token_attr(token);
390    /// # Ok(())
391    /// # }
392    /// ```
393    #[must_use]
394    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
395        let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
396        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
397    }
398
399    /// Convert a token to a string with a specified buffer size.
400    ///
401    /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
402    /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
403    /// and the extra buffer size doesn't usually matter.
404    ///
405    /// # Errors
406    ///
407    /// - If the token type is unknown, an error will be returned.
408    /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
409    /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
410    ///
411    /// # Panics
412    ///
413    /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
414    /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
415    ///
416    /// # Example
417    ///
418    /// ```no_run
419    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
420    ///
421    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
422    /// let model = LlamaModel::load_from_file("path/to/model")?;
423    /// let token = LlamaToken(42);
424    /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
425    /// # Ok(())
426    /// # }
427    /// ```
428    pub fn token_to_str_with_size(
429        &self,
430        token: LlamaToken,
431        buffer_size: usize,
432        special: Special,
433    ) -> Result<String, TokenToStringError> {
434        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
435        Ok(String::from_utf8(bytes)?)
436    }
437
438    /// Convert a token to bytes with a specified buffer size.
439    ///
440    /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
441    /// the extra bytes do not really matter.
442    ///
443    /// # Errors
444    ///
445    /// - if the token type is unknown
446    /// - the resultant token is larger than `buffer_size`.
447    ///
448    /// # Panics
449    ///
450    /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
451    /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
452    ///
453    /// # Example
454    ///
455    /// ```no_run
456    /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
457    ///
458    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
459    /// let model = LlamaModel::load_from_file("path/to/model")?;
460    /// let token = LlamaToken(42);
461    /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
462    /// # Ok(())
463    /// # }
464    /// ```
465    pub fn token_to_bytes_with_size(
466        &self,
467        token: LlamaToken,
468        buffer_size: usize,
469        special: Special,
470        lstrip: Option<NonZeroU16>,
471    ) -> Result<Vec<u8>, TokenToStringError> {
472        if token == self.token_nl() {
473            return Ok(String::from("\n").into_bytes());
474        }
475
476        // unsure what to do with this in the face of the 'special' arg + attr changes
477        let attrs = self.token_attr(token);
478        if (attrs.contains(LlamaTokenAttr::Control)
479            && (token == self.token_bos() || token == self.token_eos()))
480            || attrs.is_empty()
481            || attrs
482                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
483        {
484            return Ok(Vec::new());
485        }
486
487        let special = match special {
488            Special::Tokenize => true,
489            Special::Plaintext => false,
490        };
491
492        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
493        let len = string.as_bytes().len();
494        let len = c_int::try_from(len).expect("length fits into c_int");
495        let buf = string.into_raw();
496        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
497        let size = unsafe {
498            llama_token_to_piece(
499                self.get_vocab().vocab.as_ref(),
500                token.0,
501                buf,
502                len,
503                lstrip,
504                special,
505            )
506        };
507
508        match size {
509            0 => Err(TokenToStringError::UnknownTokenType),
510            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
511            size => {
512                let string = unsafe { CString::from_raw(buf) };
513                let mut bytes = string.into_bytes();
514                let len = usize::try_from(size).expect("size is positive and fits into usize");
515                bytes.truncate(len);
516                Ok(bytes)
517            }
518        }
519    }
520    /// The number of tokens the model was trained on.
521    ///
522    /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
523    /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
524    ///
525    /// # Example
526    ///
527    /// ```no_run
528    /// use llama_cpp_4::model::LlamaModel;
529    ///
530    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
531    /// let model = LlamaModel::load_from_file("path/to/model")?;
532    /// let n_vocab = model.n_vocab();
533    /// # Ok(())
534    /// # }
535    /// ```
536    #[must_use]
537    pub fn n_vocab(&self) -> i32 {
538        unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
539    }
540
541    /// The type of vocab the model was trained on.
542    ///
543    /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
544    /// word-level tokens, or another tokenization scheme.
545    ///
546    /// # Panics
547    ///
548    /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
549    ///
550    /// # Example
551    ///
552    /// ```no_run
553    /// use llama_cpp_4::model::LlamaModel;
554    ///
555    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
556    /// let model = LlamaModel::load_from_file("path/to/model")?;
557    /// let vocab_type = model.vocab_type();
558    /// # Ok(())
559    /// # }
560    /// ```
561    #[must_use]
562    pub fn vocab_type(&self) -> VocabType {
563        let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
564        VocabType::try_from(vocab_type).expect("invalid vocab type")
565    }
566
567    /// Returns the number of embedding dimensions for the model.
568    ///
569    /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
570    /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
571    ///
572    /// # Panics
573    ///
574    /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
575    ///
576    /// # Example
577    ///
578    /// ```no_run
579    /// use llama_cpp_4::model::LlamaModel;
580    ///
581    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
582    /// let model = LlamaModel::load_from_file("path/to/model")?;
583    /// let n_embd = model.n_embd();
584    /// # Ok(())
585    /// # }
586    /// ```
587    #[must_use]
588    pub fn n_embd(&self) -> c_int {
589        unsafe { llama_n_embd(self.model.as_ptr()) }
590    }
591
592    /// Get chat template from model.
593    ///
594    /// # Errors
595    ///
596    /// - If the model does not have a chat template, it will return an error.
597    /// - If the chat template is not a valid `CString`, it will return an error.
598    ///
599    /// # Example
600    ///
601    /// ```no_run
602    /// use llama_cpp_4::model::LlamaModel;
603    ///
604    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
605    /// let model = LlamaModel::load_from_file("path/to/model")?;
606    /// let chat_template = model.get_chat_template(1024)?;
607    /// # Ok(())
608    /// # }
609    /// ```
610    #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
611    pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
612        // longest known template is about 1200 bytes from llama.cpp
613        let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
614        let chat_ptr = chat_temp.into_raw();
615        let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
616
617        let ret = unsafe {
618            llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
619        };
620
621        if ret < 0 {
622            return Err(ChatTemplateError::MissingTemplate(ret));
623        }
624
625        let template_c = unsafe { CString::from_raw(chat_ptr) };
626        let template = template_c.to_str()?;
627
628        let ret: usize = ret.try_into().unwrap();
629        if template.len() < ret {
630            return Err(ChatTemplateError::BuffSizeError(ret + 1));
631        }
632
633        Ok(template.to_owned())
634    }
635
636    /// Loads a model from a file.
637    ///
638    /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
639    ///
640    /// # Errors
641    ///
642    /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
643    /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
644    ///
645    /// # Example
646    ///
647    /// ```no_run
648    /// use llama_cpp_4::model::LlamaModel;
649    /// use std::path::Path;
650    ///
651    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
652    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
653    /// # Ok(())
654    /// # }
655    /// ```
656    #[tracing::instrument(skip_all, fields(params))]
657    pub fn load_from_file(
658        _: &LlamaBackend,
659        path: impl AsRef<Path>,
660        params: &LlamaModelParams,
661    ) -> Result<Self, LlamaModelLoadError> {
662        let path = path.as_ref();
663        debug_assert!(
664            Path::new(path).exists(),
665            "{} does not exist",
666            path.display()
667        );
668        let path = path
669            .to_str()
670            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
671
672        let cstr = CString::new(path)?;
673        let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
674
675        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
676
677        tracing::debug!(?path, "Loaded model");
678        Ok(LlamaModel { model })
679    }
680
681    /// Load a model from multiple split files.
682    ///
683    /// This function loads a model that has been split across multiple files. This is useful for
684    /// very large models that exceed filesystem limitations or need to be distributed across
685    /// multiple storage devices.
686    ///
687    /// # Arguments
688    ///
689    /// * `paths` - A slice of paths to the split model files
690    /// * `params` - The model parameters
691    ///
692    /// # Errors
693    ///
694    /// Returns an error if:
695    /// - Any of the paths cannot be converted to a C string
696    /// - The model fails to load from the splits
697    /// - Any path doesn't exist or isn't accessible
698    ///
699    /// # Example
700    ///
701    /// ```no_run
702    /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
703    /// use llama_cpp_4::llama_backend::LlamaBackend;
704    ///
705    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
706    /// let backend = LlamaBackend::init()?;
707    /// let params = LlamaModelParams::default();
708    ///
709    /// let paths = vec![
710    ///     "model-00001-of-00003.gguf",
711    ///     "model-00002-of-00003.gguf",
712    ///     "model-00003-of-00003.gguf",
713    /// ];
714    ///
715    /// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
716    /// # Ok(())
717    /// # }
718    /// ```
719    #[tracing::instrument(skip_all)]
720    pub fn load_from_splits(
721        _: &LlamaBackend,
722        paths: &[impl AsRef<Path>],
723        params: &LlamaModelParams,
724    ) -> Result<Self, LlamaModelLoadError> {
725        // Convert paths to C strings
726        let c_strings: Vec<CString> = paths
727            .iter()
728            .map(|p| {
729                let path = p.as_ref();
730                debug_assert!(path.exists(), "{} does not exist", path.display());
731                let path_str = path
732                    .to_str()
733                    .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
734                CString::new(path_str).map_err(LlamaModelLoadError::from)
735            })
736            .collect::<Result<Vec<_>, _>>()?;
737
738        // Create array of pointers to C strings
739        let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
740
741        // Load the model from splits
742        let llama_model = unsafe {
743            llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
744        };
745
746        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
747
748        tracing::debug!("Loaded model from {} splits", paths.len());
749        Ok(LlamaModel { model })
750    }
751
752    /// Initializes a lora adapter from a file.
753    ///
754    /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
755    /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
756    /// to the model for improved performance on specialized tasks.
757    ///
758    /// # Errors
759    ///
760    /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
761    ///
762    /// # Example
763    ///
764    /// ```no_run
765    /// use llama_cpp_4::model::{LlamaModel, LlamaLoraAdapter};
766    /// use std::path::Path;
767    ///
768    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
769    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
770    /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
771    /// # Ok(())
772    /// # }
773    /// ```
774    pub fn lora_adapter_init(
775        &self,
776        path: impl AsRef<Path>,
777    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
778        let path = path.as_ref();
779        debug_assert!(
780            Path::new(path).exists(),
781            "{} does not exist",
782            path.display()
783        );
784
785        let path = path
786            .to_str()
787            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
788                path.to_path_buf(),
789            ))?;
790
791        let cstr = CString::new(path)?;
792        let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
793
794        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
795
796        tracing::debug!(?path, "Initialized lora adapter");
797        Ok(LlamaLoraAdapter {
798            lora_adapter: adapter,
799        })
800    }
801
802    /// Create a new context from this model.
803    ///
804    /// This function creates a new context for the model, which is used to manage and perform computations for inference,
805    /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
806    /// control over model parameters for a specific task.
807    ///
808    /// # Errors
809    ///
810    /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
811    ///   for more detailed error descriptions.
812    ///
813    /// # Example
814    ///
815    /// ```no_run
816    /// use llama_cpp_4::model::{LlamaModel, LlamaContext};
817    /// use llama_cpp_4::LlamaContextParams;
818    ///
819    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
820    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
821    /// let context = model.new_context(&LlamaBackend::init()?, LlamaContextParams::default())?;
822    /// # Ok(())
823    /// # }
824    /// ```
825    #[allow(clippy::needless_pass_by_value)]
826    pub fn new_context(
827        &self,
828        _: &LlamaBackend,
829        params: LlamaContextParams,
830    ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
831        let context_params = params.context_params;
832        let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
833        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
834
835        Ok(LlamaContext::new(self, context, params.embeddings()))
836    }
837
838    /// Apply the model's chat template to a sequence of messages.
839    ///
840    /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
841    /// template determines the structure or style of conversation between the system and user, such as token formatting,
842    /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
843    /// is provided, the default template used by `llama.cpp` will be applied.
844    ///
845    /// For more information on supported templates, visit:
846    /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
847    ///
848    /// # Arguments
849    ///
850    /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
851    /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
852    /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
853    ///
854    /// # Errors
855    ///
856    /// There are several possible points of failure when applying the chat template:
857    /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
858    /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
859    ///
860    /// # Example
861    ///
862    /// ```no_run
863    /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
864    ///
865    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
866    /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
867    /// let chat = vec![
868    ///     LlamaChatMessage::new("user", "Hello!"),
869    ///     LlamaChatMessage::new("assistant", "Hi! How can I assist you today?"),
870    /// ];
871    /// let formatted_chat = model.apply_chat_template(None, chat, true)?;
872    /// # Ok(())
873    /// # }
874    /// ```
875    ///
876    /// # Notes
877    ///
878    /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
879    /// # Panics
880    ///
881    /// Panics if the buffer length exceeds `i32::MAX`.
882    #[tracing::instrument(skip_all)]
883    pub fn apply_chat_template(
884        &self,
885        tmpl: Option<&str>,
886        chat: &[LlamaChatMessage],
887        add_ass: bool,
888    ) -> Result<String, ApplyChatTemplateError> {
889        // Compute raw message byte total from the original LlamaChatMessage vec
890        // *before* we shadow `chat` with the sys-type vec below.
891        let message_length = chat.iter().fold(0usize, |acc, c| {
892            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
893        });
894
895        // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
896        let chat_sys: Vec<llama_chat_message> = chat
897            .iter()
898            .map(|c| llama_chat_message {
899                role: c.role.as_ptr(),
900                content: c.content.as_ptr(),
901            })
902            .collect();
903
904        // Set the tmpl pointer.
905        let tmpl_cstring = tmpl.map(CString::new).transpose()?;
906        let tmpl_ptr = tmpl_cstring
907            .as_ref()
908            .map_or(std::ptr::null(), |s| s.as_ptr());
909
910        // `message_length * 4` is far too small for models whose built-in chat
911        // template adds a long default system prompt (e.g. Qwen3.5 prepends
912        // ~80+ chars of markup even for a one-word user message).  Start with
913        // at least 4 KiB so short inputs like "hi" always have room.
914        //
915        // `llama_chat_apply_template` returns the number of bytes it *actually*
916        // needed when the buffer was too small, so we retry exactly once with
917        // that precise size rather than giving up immediately.
918        let mut buf_size = message_length.saturating_mul(4).max(4096);
919
920        for _ in 0..2 {
921            // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
922            let mut buff = vec![0u8; buf_size];
923            let res = unsafe {
924                llama_chat_apply_template(
925                    tmpl_ptr,
926                    chat_sys.as_ptr(),
927                    chat_sys.len(),
928                    add_ass,
929                    buff.as_mut_ptr().cast(),
930                    i32::try_from(buff.len()).expect("buffer length fits in i32"),
931                )
932            };
933
934            if res < 0 {
935                return Err(ApplyChatTemplateError::BuffSizeError);
936            }
937
938            #[allow(clippy::cast_sign_loss)]
939            let needed = res as usize;
940            if needed > buf_size {
941                // Buffer was too small — retry with the exact size llama.cpp reported.
942                buf_size = needed + 1; // +1 for null terminator
943                continue;
944            }
945
946            // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
947            // into `buff`; `needed` bytes were used.
948            let formatted = unsafe {
949                CStr::from_ptr(buff.as_ptr().cast())
950                    .to_string_lossy()
951                    .into_owned()
952            };
953            return Ok(formatted);
954        }
955
956        Err(ApplyChatTemplateError::BuffSizeError)
957    }
958
959    /// Build a split GGUF file path for a specific chunk.
960    ///
961    /// This utility function creates the standardized filename for a split model chunk
962    /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
963    ///
964    /// # Arguments
965    ///
966    /// * `path_prefix` - The base path and filename prefix
967    /// * `split_no` - The split number (1-indexed)
968    /// * `split_count` - The total number of splits
969    ///
970    /// # Returns
971    ///
972    /// Returns the formatted split path as a String
973    ///
974    /// # Example
975    ///
976    /// ```
977    /// use llama_cpp_4::model::LlamaModel;
978    ///
979    /// let path = LlamaModel::split_path("/models/llama", 2, 4);
980    /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
981    /// ```
982    ///
983    /// # Panics
984    ///
985    /// Panics if the path prefix contains a null byte.
986    #[must_use]
987    pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
988        let mut buffer = vec![0u8; 1024];
989        let len = unsafe {
990            llama_split_path(
991                buffer.as_mut_ptr().cast::<c_char>(),
992                buffer.len(),
993                CString::new(path_prefix).unwrap().as_ptr(),
994                split_no,
995                split_count,
996            )
997        };
998
999        let len = usize::try_from(len).expect("split_path length fits in usize");
1000        buffer.truncate(len);
1001        String::from_utf8(buffer).unwrap_or_default()
1002    }
1003
1004    /// Extract the path prefix from a split filename.
1005    ///
1006    /// This function extracts the base path prefix from a split model filename,
1007    /// but only if the `split_no` and `split_count` match the pattern in the filename.
1008    ///
1009    /// # Arguments
1010    ///
1011    /// * `split_path` - The full path to the split file
1012    /// * `split_no` - The expected split number
1013    /// * `split_count` - The expected total number of splits
1014    ///
1015    /// # Returns
1016    ///
1017    /// Returns the path prefix if the pattern matches, or None if it doesn't
1018    ///
1019    /// # Example
1020    ///
1021    /// ```
1022    /// use llama_cpp_4::model::LlamaModel;
1023    ///
1024    /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 2, 4);
1025    /// assert_eq!(prefix, Some("/models/llama".to_string()));
1026    /// ```
1027    ///
1028    /// # Panics
1029    ///
1030    /// Panics if the split path contains a null byte.
1031    #[must_use]
1032    pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1033        let mut buffer = vec![0u8; 1024];
1034        let len = unsafe {
1035            llama_split_prefix(
1036                buffer.as_mut_ptr().cast::<c_char>(),
1037                buffer.len(),
1038                CString::new(split_path).unwrap().as_ptr(),
1039                split_no,
1040                split_count,
1041            )
1042        };
1043
1044        if len > 0 {
1045            let len = usize::try_from(len).expect("split_prefix length fits in usize");
1046            buffer.truncate(len);
1047            String::from_utf8(buffer).ok()
1048        } else {
1049            None
1050        }
1051    }
1052}
1053
1054impl Drop for LlamaModel {
1055    fn drop(&mut self) {
1056        unsafe { llama_free_model(self.model.as_ptr()) }
1057    }
1058}
1059
1060/// Defines the possible types of vocabulary used by the model.
1061///
1062/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1063/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1064///
1065/// # Variants
1066///
1067/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1068/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1069///
1070/// # Example
1071///
1072/// ```no_run
1073/// use llama_cpp_4::model::VocabType;
1074///
1075/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1076/// let vocab_type = VocabType::BPE;
1077/// match vocab_type {
1078///     VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1079///     VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1080/// }
1081/// # Ok(())
1082/// # }
1083/// ```
1084#[repr(u32)]
1085#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1086pub enum VocabType {
1087    /// Byte Pair Encoding
1088    BPE = LLAMA_VOCAB_TYPE_BPE as _,
1089    /// Sentence Piece Tokenizer
1090    SPM = LLAMA_VOCAB_TYPE_SPM as _,
1091}
1092
1093/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1094///
1095/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1096///
1097/// # Variants
1098///
1099/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1100///
1101/// # Example
1102///
1103/// ```no_run
1104/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1105///
1106/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1107/// let invalid_value = 999; // Not a valid vocabulary type
1108/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1109/// println!("Error: {}", error);
1110/// # Ok(())
1111/// # }
1112/// ```
1113#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1114pub enum LlamaTokenTypeFromIntError {
1115    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
1116    #[error("Unknown Value {0}")]
1117    UnknownValue(llama_vocab_type),
1118}
1119
1120impl TryFrom<llama_vocab_type> for VocabType {
1121    type Error = LlamaTokenTypeFromIntError;
1122
1123    fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
1124        match value {
1125            LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1126            LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1127            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1128        }
1129    }
1130}