Skip to main content

llama_cpp_2/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::token::LlamaToken;
15use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
16use crate::{
17    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
18    LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
19    TokenToStringError,
20};
21
22pub mod params;
23
24/// A safe wrapper around `llama_model`.
25#[derive(Debug)]
26#[repr(transparent)]
27#[allow(clippy::module_name_repetitions)]
28pub struct LlamaModel {
29    pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
30}
31
32/// A safe wrapper around `llama_lora_adapter`.
33#[derive(Debug)]
34#[repr(transparent)]
35#[allow(clippy::module_name_repetitions)]
36pub struct LlamaLoraAdapter {
37    pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
38}
39
40/// A performance-friendly wrapper around [`LlamaModel::chat_template`] which is then
41/// fed into [`LlamaModel::apply_chat_template`] to convert a list of messages into an LLM
42/// prompt. Internally the template is stored as a `CString` to avoid round-trip conversions
43/// within the FFI.
44#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
45pub struct LlamaChatTemplate(CString);
46
47impl LlamaChatTemplate {
48    /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
49    /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
50    pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
51        Ok(Self(CString::new(template)?))
52    }
53
54    /// Accesses the template as a c string reference.
55    pub fn as_c_str(&self) -> &CStr {
56        &self.0
57    }
58
59    /// Attempts to convert the `CString` into a Rust str reference.
60    pub fn to_str(&self) -> Result<&str, Utf8Error> {
61        self.0.to_str()
62    }
63
64    /// Convenience method to create an owned String.
65    pub fn to_string(&self) -> Result<String, Utf8Error> {
66        self.to_str().map(str::to_string)
67    }
68}
69
70impl std::fmt::Debug for LlamaChatTemplate {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        self.0.fmt(f)
73    }
74}
75
76/// A Safe wrapper around `llama_chat_message`
77#[derive(Debug, Eq, PartialEq, Clone)]
78pub struct LlamaChatMessage {
79    role: CString,
80    content: CString,
81}
82
83impl LlamaChatMessage {
84    /// Create a new `LlamaChatMessage`
85    ///
86    /// # Errors
87    /// If either of ``role`` or ``content`` contain null bytes.
88    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
89        Ok(Self {
90            role: CString::new(role)?,
91            content: CString::new(content)?,
92        })
93    }
94}
95
96/// The Rope type that's used within the model.
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RopeType {
99    Norm,
100    NeoX,
101    MRope,
102    Vision,
103}
104
105/// How to determine if we should prepend a bos token to tokens
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum AddBos {
108    /// Add the beginning of stream token to the start of the string.
109    Always,
110    /// Do not add the beginning of stream token to the start of the string.
111    Never,
112}
113
114/// How to determine if we should tokenize special tokens
115#[deprecated(
116    since = "0.1.0",
117    note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
118)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum Special {
121    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
122    Tokenize,
123    /// Treat special and/or control tokens as plaintext.
124    Plaintext,
125}
126
127unsafe impl Send for LlamaModel {}
128
129unsafe impl Sync for LlamaModel {}
130
131impl LlamaModel {
132    pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
133        unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
134    }
135
136    /// get the number of tokens the model was trained on
137    ///
138    /// # Panics
139    ///
140    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
141    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
142    #[must_use]
143    pub fn n_ctx_train(&self) -> u32 {
144        let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
145        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
146    }
147
148    /// Get all tokens in the model.
149    pub fn tokens(
150        &self,
151        decode_special: bool,
152    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
153        (0..self.n_vocab())
154            .map(LlamaToken::new)
155            .map(move |llama_token| {
156                let mut decoder = encoding_rs::UTF_8.new_decoder();
157                (
158                    llama_token,
159                    self.token_to_piece(llama_token, &mut decoder, decode_special, None),
160                )
161            })
162    }
163
164    /// Get the beginning of stream token.
165    #[must_use]
166    pub fn token_bos(&self) -> LlamaToken {
167        let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
168        LlamaToken(token)
169    }
170
171    /// Get the end of stream token.
172    #[must_use]
173    pub fn token_eos(&self) -> LlamaToken {
174        let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
175        LlamaToken(token)
176    }
177
178    /// Get the newline token.
179    #[must_use]
180    pub fn token_nl(&self) -> LlamaToken {
181        let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
182        LlamaToken(token)
183    }
184
185    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
186    #[must_use]
187    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
188        unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
189    }
190
191    /// Get the decoder start token.
192    #[must_use]
193    pub fn decode_start_token(&self) -> LlamaToken {
194        let token =
195            unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
196        LlamaToken(token)
197    }
198
199    /// Get the separator token (SEP).
200    #[must_use]
201    pub fn token_sep(&self) -> LlamaToken {
202        let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
203        LlamaToken(token)
204    }
205
206    /// Convert single token to a string.
207    ///
208    /// # Errors
209    ///
210    /// See [`TokenToStringError`] for more information.
211    #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
212    pub fn token_to_str(
213        &self,
214        token: LlamaToken,
215        special: Special,
216    ) -> Result<String, TokenToStringError> {
217        // TODO lsptrip None is acutally not quite the origignal behavior of this function,
218        let mut decoder = encoding_rs::UTF_8.new_decoder();
219        self.token_to_piece(
220            token,
221            &mut decoder,
222            matches!(special, Special::Tokenize),
223            None,
224        )
225    }
226
227    /// Convert single token to bytes.
228    ///
229    /// # Errors
230    /// See [`TokenToStringError`] for more information.
231    ///
232    /// # Panics
233    /// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by
234    /// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never
235    /// happen.
236    #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
237    pub fn token_to_bytes(
238        &self,
239        token: LlamaToken,
240        special: Special,
241    ) -> Result<Vec<u8>, TokenToStringError> {
242        // TODO lsptrip None is acutally not quite the origignal behavior of this function,
243        match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
244            Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
245                token,
246                (-i).try_into().expect("Error buffer size is positive"),
247                matches!(special, Special::Tokenize),
248                None,
249            ),
250            x => x,
251        }
252    }
253
254    /// Convert a vector of tokens to a single string.
255    ///
256    /// # Errors
257    ///
258    /// See [`TokenToStringError`] for more information.
259    #[deprecated(
260        since = "0.1.0",
261        note = "Use `token_to_piece` for each token individually instead"
262    )]
263    pub fn tokens_to_str(
264        &self,
265        tokens: &[LlamaToken],
266        special: Special,
267    ) -> Result<String, TokenToStringError> {
268        let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
269        for piece in tokens
270            .iter()
271            .copied()
272            .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
273        {
274            builder.extend_from_slice(&piece?);
275        }
276        Ok(String::from_utf8(builder)?)
277    }
278
279    /// Convert a string to a Vector of tokens.
280    ///
281    /// # Errors
282    ///
283    /// - if [`str`] contains a null byte.
284    ///
285    /// # Panics
286    ///
287    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
288    ///
289    ///
290    /// ```no_run
291    /// use llama_cpp_2::model::LlamaModel;
292    ///
293    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
294    /// use std::path::Path;
295    /// use llama_cpp_2::model::AddBos;
296    /// let backend = llama_cpp_2::llama_backend::LlamaBackend::init()?;
297    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
298    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
299    /// # Ok(())
300    /// # }
301    pub fn str_to_token(
302        &self,
303        str: &str,
304        add_bos: AddBos,
305    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
306        let add_bos = match add_bos {
307            AddBos::Always => true,
308            AddBos::Never => false,
309        };
310
311        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
312        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
313
314        let c_string = CString::new(str)?;
315        let buffer_capacity =
316            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
317
318        let size = unsafe {
319            llama_cpp_sys_2::llama_tokenize(
320                self.vocab_ptr(),
321                c_string.as_ptr(),
322                c_int::try_from(c_string.as_bytes().len())?,
323                buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
324                buffer_capacity,
325                add_bos,
326                true,
327            )
328        };
329
330        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
331        // as a result - size is guaranteed to be positive here.
332        let size = if size.is_negative() {
333            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
334            unsafe {
335                llama_cpp_sys_2::llama_tokenize(
336                    self.vocab_ptr(),
337                    c_string.as_ptr(),
338                    c_int::try_from(c_string.as_bytes().len())?,
339                    buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
340                    -size,
341                    add_bos,
342                    true,
343                )
344            }
345        } else {
346            size
347        };
348
349        let size = usize::try_from(size).expect("size is positive and usize ");
350
351        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
352        unsafe { buffer.set_len(size) }
353        Ok(buffer)
354    }
355
356    /// Get the type of a token.
357    ///
358    /// # Panics
359    ///
360    /// If the token type is not known to this library.
361    #[must_use]
362    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
363        let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
364        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
365    }
366
367    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
368    ///
369    /// This is the new default function for token decoding and provides direct access to
370    /// the llama.cpp token decoding functionality without any special logic or filtering.
371    ///
372    /// Decoding raw string requires using an decoder, tokens from language models may not always map
373    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
374    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
375    /// garbage.
376    ///
377    /// # Errors
378    ///
379    /// - if the token type is unknown
380    ///
381    /// # Panics
382    ///
383    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
384    pub fn token_to_piece(
385        &self,
386        token: LlamaToken,
387        decoder: &mut encoding_rs::Decoder,
388        special: bool,
389        lstrip: Option<NonZeroU16>,
390    ) -> Result<String, TokenToStringError> {
391        let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
392            // when there is insufficient space `token_to_piece` will return a negative number with the size that would have been returned
393            // https://github.com/abetlen/llama-cpp-python/blob/c37132bac860fcc333255c36313f89c4f49d4c8d/llama_cpp/llama_cpp.py#L3461
394            Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
395                token,
396                (-i).try_into().expect("Error buffer size is positive"),
397                special,
398                lstrip,
399            ),
400            x => x,
401        }?;
402        // here the assumption is that each byte from the output may map to at most one output charakter
403        let mut output_piece = String::with_capacity(bytes.len());
404        // _result only tells if there is nothing more in the input, or if the output was full
405        // but further decoding will happen on the next interation anyway
406        let (_result, _somesize, _truthy) =
407            decoder.decode_to_string(&bytes, &mut output_piece, false);
408        Ok(output_piece)
409    }
410
411    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
412    ///
413    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
414    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
415    /// letting the caller handle errors. For a safer inteface returing rust strings directly use `token_to_piece` instead!
416    ///
417    /// # Errors
418    ///
419    /// - if the token type is unknown
420    /// - the resultant token is larger than `buffer_size`.
421    ///
422    /// # Panics
423    ///
424    /// - if `buffer_size` does not fit into a [`c_int`].
425    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
426    pub fn token_to_piece_bytes(
427        &self,
428        token: LlamaToken,
429        buffer_size: usize,
430        special: bool,
431        lstrip: Option<NonZeroU16>,
432    ) -> Result<Vec<u8>, TokenToStringError> {
433        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
434        let len = string.as_bytes().len();
435        let len = c_int::try_from(len).expect("length fits into c_int");
436        let buf = string.into_raw();
437        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
438        let size = unsafe {
439            llama_cpp_sys_2::llama_token_to_piece(
440                self.vocab_ptr(),
441                token.0,
442                buf,
443                len,
444                lstrip,
445                special,
446            )
447        };
448
449        match size {
450            0 => Err(TokenToStringError::UnknownTokenType),
451            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
452            size => {
453                let string = unsafe { CString::from_raw(buf) };
454                let mut bytes = string.into_bytes();
455                let len = usize::try_from(size).expect("size is positive and fits into usize");
456                bytes.truncate(len);
457                Ok(bytes)
458            }
459        }
460    }
461
462    /// Convert a token to a string with a specified buffer size.
463    ///
464    /// Generally you should use [`LlamaModel::token_to_str`] as it is able to decode tokens with
465    /// any length.
466    ///
467    /// # Errors
468    ///
469    /// - if the token type is unknown
470    /// - the resultant token is larger than `buffer_size`.
471    /// - the string returend by llama-cpp is not valid utf8.
472    ///
473    /// # Panics
474    ///
475    /// - if `buffer_size` does not fit into a [`c_int`].
476    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
477    #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
478    pub fn token_to_str_with_size(
479        &self,
480        token: LlamaToken,
481        buffer_size: usize,
482        special: Special,
483    ) -> Result<String, TokenToStringError> {
484        let bytes = self.token_to_piece_bytes(
485            token,
486            buffer_size,
487            matches!(special, Special::Tokenize),
488            None,
489        )?;
490        Ok(String::from_utf8(bytes)?)
491    }
492
493    /// Convert a token to bytes with a specified buffer size.
494    ///
495    /// Generally you should use [`LlamaModel::token_to_bytes`] as it is able to handle tokens of
496    /// any length.
497    ///
498    /// # Errors
499    ///
500    /// - if the token type is unknown
501    /// - the resultant token is larger than `buffer_size`.
502    ///
503    /// # Panics
504    ///
505    /// - if `buffer_size` does not fit into a [`c_int`].
506    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
507    #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
508    pub fn token_to_bytes_with_size(
509        &self,
510        token: LlamaToken,
511        buffer_size: usize,
512        special: Special,
513        lstrip: Option<NonZeroU16>,
514    ) -> Result<Vec<u8>, TokenToStringError> {
515        if token == self.token_nl() {
516            return Ok(b"\n".to_vec());
517        }
518
519        // unsure what to do with this in the face of the 'special' arg + attr changes
520        let attrs = self.token_attr(token);
521        if attrs.is_empty()
522            || attrs
523                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
524            || attrs.contains(LlamaTokenAttr::Control)
525                && (token == self.token_bos() || token == self.token_eos())
526        {
527            return Ok(Vec::new());
528        }
529
530        let special = match special {
531            Special::Tokenize => true,
532            Special::Plaintext => false,
533        };
534
535        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
536        let len = string.as_bytes().len();
537        let len = c_int::try_from(len).expect("length fits into c_int");
538        let buf = string.into_raw();
539        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
540        let size = unsafe {
541            llama_cpp_sys_2::llama_token_to_piece(
542                self.vocab_ptr(),
543                token.0,
544                buf,
545                len,
546                lstrip,
547                special,
548            )
549        };
550
551        match size {
552            0 => Err(TokenToStringError::UnknownTokenType),
553            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
554            size => {
555                let string = unsafe { CString::from_raw(buf) };
556                let mut bytes = string.into_bytes();
557                let len = usize::try_from(size).expect("size is positive and fits into usize");
558                bytes.truncate(len);
559                Ok(bytes)
560            }
561        }
562    }
563    /// The number of tokens the model was trained on.
564    ///
565    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
566    /// without issue.
567    #[must_use]
568    pub fn n_vocab(&self) -> i32 {
569        unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
570    }
571
572    /// The type of vocab the model was trained on.
573    ///
574    /// # Panics
575    ///
576    /// If llama-cpp emits a vocab type that is not known to this library.
577    #[must_use]
578    pub fn vocab_type(&self) -> VocabType {
579        // llama_cpp_sys_2::llama_model_get_vocab
580        let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
581        VocabType::try_from(vocab_type).expect("invalid vocab type")
582    }
583
584    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
585    /// without issue.
586    #[must_use]
587    pub fn n_embd(&self) -> c_int {
588        unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
589    }
590
591    /// Returns the total size of all the tensors in the model in bytes.
592    pub fn size(&self) -> u64 {
593        unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
594    }
595
596    /// Returns the number of parameters in the model.
597    pub fn n_params(&self) -> u64 {
598        unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
599    }
600
601    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
602    pub fn is_recurrent(&self) -> bool {
603        unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
604    }
605
606    /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc)
607    ///
608    /// Hybrid models have both attention layers and recurrent/SSM layers.
609    /// They require special handling for state checkpointing.
610    pub fn is_hybrid(&self) -> bool {
611        unsafe { llama_cpp_sys_2::llama_model_is_hybrid(self.model.as_ptr()) }
612    }
613
614    /// Returns the number of layers within the model.
615    pub fn n_layer(&self) -> u32 {
616        // It's never possible for this to panic because while the API interface is defined as an int32_t,
617        // the field it's accessing is a uint32_t.
618        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
619    }
620
621    /// Returns the number of attention heads within the model.
622    pub fn n_head(&self) -> u32 {
623        // It's never possible for this to panic because while the API interface is defined as an int32_t,
624        // the field it's accessing is a uint32_t.
625        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
626    }
627
628    /// Returns the number of KV attention heads.
629    pub fn n_head_kv(&self) -> u32 {
630        // It's never possible for this to panic because while the API interface is defined as an int32_t,
631        // the field it's accessing is a uint32_t.
632        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
633            .unwrap()
634    }
635
636    /// Get metadata value as a string by key name
637    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
638        let key_cstring = CString::new(key)?;
639        let key_ptr = key_cstring.as_ptr();
640
641        extract_meta_string(
642            |buf_ptr, buf_len| unsafe {
643                llama_cpp_sys_2::llama_model_meta_val_str(
644                    self.model.as_ptr(),
645                    key_ptr,
646                    buf_ptr,
647                    buf_len,
648                )
649            },
650            256,
651        )
652    }
653
654    /// Get the number of metadata key/value pairs
655    pub fn meta_count(&self) -> i32 {
656        unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
657    }
658
659    /// Get metadata key name by index
660    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
661        extract_meta_string(
662            |buf_ptr, buf_len| unsafe {
663                llama_cpp_sys_2::llama_model_meta_key_by_index(
664                    self.model.as_ptr(),
665                    index,
666                    buf_ptr,
667                    buf_len,
668                )
669            },
670            256,
671        )
672    }
673
674    /// Get metadata value as a string by index
675    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
676        extract_meta_string(
677            |buf_ptr, buf_len| unsafe {
678                llama_cpp_sys_2::llama_model_meta_val_str_by_index(
679                    self.model.as_ptr(),
680                    index,
681                    buf_ptr,
682                    buf_len,
683                )
684            },
685            256,
686        )
687    }
688
689    /// Returns the rope type of the model.
690    pub fn rope_type(&self) -> Option<RopeType> {
691        match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
692            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
693            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
694            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
695            llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
696            llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
697            rope_type => {
698                tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
699                None
700            }
701        }
702    }
703
704    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
705    ///
706    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
707    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
708    /// the chat.
709    ///
710    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
711    /// to parse jinja templates not supported by the llama.cpp template engine.
712    ///
713    /// # Errors
714    ///
715    /// * If the model has no chat template by that name
716    /// * If the chat template is not a valid [`CString`].
717    pub fn chat_template(
718        &self,
719        name: Option<&str>,
720    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
721        let name_cstr = name.map(CString::new);
722        let name_ptr = match name_cstr {
723            Some(Ok(name)) => name.as_ptr(),
724            _ => std::ptr::null(),
725        };
726        let result =
727            unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
728
729        // Convert result to Rust String if not null
730        if result.is_null() {
731            Err(ChatTemplateError::MissingTemplate)
732        } else {
733            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
734            let chat_template = CString::new(chat_template_cstr.to_bytes())?;
735            Ok(LlamaChatTemplate(chat_template))
736        }
737    }
738
739    /// Loads a model from a file.
740    ///
741    /// # Errors
742    ///
743    /// See [`LlamaModelLoadError`] for more information.
744    #[tracing::instrument(skip_all, fields(params))]
745    pub fn load_from_file(
746        _: &LlamaBackend,
747        path: impl AsRef<Path>,
748        params: &LlamaModelParams,
749    ) -> Result<Self, LlamaModelLoadError> {
750        let path = path.as_ref();
751        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
752        let path = path
753            .to_str()
754            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
755
756        let cstr = CString::new(path)?;
757        let llama_model =
758            unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
759
760        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
761
762        tracing::debug!(?path, "Loaded model");
763        Ok(LlamaModel { model })
764    }
765
766    /// Initializes a lora adapter from a file.
767    ///
768    /// # Errors
769    ///
770    /// See [`LlamaLoraAdapterInitError`] for more information.
771    pub fn lora_adapter_init(
772        &self,
773        path: impl AsRef<Path>,
774    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
775        let path = path.as_ref();
776        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
777
778        let path = path
779            .to_str()
780            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
781                path.to_path_buf(),
782            ))?;
783
784        let cstr = CString::new(path)?;
785        let adapter =
786            unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
787
788        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
789
790        tracing::debug!(?path, "Initialized lora adapter");
791        Ok(LlamaLoraAdapter {
792            lora_adapter: adapter,
793        })
794    }
795
796    /// Create a new context from this model.
797    ///
798    /// # Errors
799    ///
800    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
801    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
802    #[allow(clippy::needless_pass_by_value)]
803    pub fn new_context<'a>(
804        &'a self,
805        _: &LlamaBackend,
806        params: LlamaContextParams,
807    ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
808        let context_params = params.context_params;
809        let context = unsafe {
810            llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
811        };
812        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
813
814        Ok(LlamaContext::new(self, context, params.embeddings()))
815    }
816
817    /// Apply the models chat template to some messages.
818    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
819    ///
820    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the ChatML template when given
821    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
822    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
823    /// string.
824    ///
825    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
826    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
827    ///
828    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
829    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
830    /// one into the output and the output may also have unexpected output aside from that.
831    ///
832    /// # Errors
833    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
834    #[tracing::instrument(skip_all)]
835    pub fn apply_chat_template(
836        &self,
837        tmpl: &LlamaChatTemplate,
838        chat: &[LlamaChatMessage],
839        add_ass: bool,
840    ) -> Result<String, ApplyChatTemplateError> {
841        // Buffer is twice the length of messages per their recommendation
842        let message_length = chat.iter().fold(0, |acc, c| {
843            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
844        });
845        let mut buff: Vec<u8> = vec![0; message_length * 2];
846
847        // Build our llama_cpp_sys_2 chat messages
848        let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
849            .iter()
850            .map(|c| llama_cpp_sys_2::llama_chat_message {
851                role: c.role.as_ptr(),
852                content: c.content.as_ptr(),
853            })
854            .collect();
855
856        let tmpl_ptr = tmpl.0.as_ptr();
857
858        let res = unsafe {
859            llama_cpp_sys_2::llama_chat_apply_template(
860                tmpl_ptr,
861                chat.as_ptr(),
862                chat.len(),
863                add_ass,
864                buff.as_mut_ptr().cast::<c_char>(),
865                buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
866            )
867        };
868
869        if res < 0 {
870            return Err(ApplyChatTemplateError::FfiError(res));
871        }
872
873        if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
874            buff.resize(res.try_into().expect("res is negative"), 0);
875
876            let res = unsafe {
877                llama_cpp_sys_2::llama_chat_apply_template(
878                    tmpl_ptr,
879                    chat.as_ptr(),
880                    chat.len(),
881                    add_ass,
882                    buff.as_mut_ptr().cast::<c_char>(),
883                    buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
884                )
885            };
886            if res < 0 {
887                return Err(ApplyChatTemplateError::FfiError(res));
888            }
889            assert_eq!(Ok(res), buff.len().try_into());
890        }
891        buff.truncate(res.try_into().expect("res is negative"));
892        Ok(String::from_utf8(buff)?)
893    }
894}
895
896/// Generic helper function for extracting string values from the C API
897/// This are specifically useful for the the metadata functions, where we pass in a buffer
898/// to be populated by a string, not yet knowing if the buffer is large enough.
899/// If the buffer was not large enough, we get the correct length back, which can be used to
900/// construct a buffer of appropriate size.
901fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
902where
903    F: Fn(*mut c_char, usize) -> i32,
904{
905    let mut buffer = vec![0u8; capacity];
906
907    // call the foreign function
908    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
909    if result < 0 {
910        return Err(MetaValError::NegativeReturn(result));
911    }
912
913    // check if the response fit in our buffer
914    let returned_len = result as usize;
915    if returned_len >= capacity {
916        // buffer wasn't large enough, try again with the correct capacity.
917        return extract_meta_string(c_function, returned_len + 1);
918    }
919
920    // verify null termination
921    debug_assert_eq!(
922        buffer.get(returned_len),
923        Some(&0),
924        "should end with null byte"
925    );
926
927    // resize, convert, and return
928    buffer.truncate(returned_len);
929    Ok(String::from_utf8(buffer)?)
930}
931
932impl Drop for LlamaModel {
933    fn drop(&mut self) {
934        unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
935    }
936}
937
938/// a rusty equivalent of `llama_vocab_type`
939#[repr(u32)]
940#[derive(Debug, Eq, Copy, Clone, PartialEq)]
941pub enum VocabType {
942    /// Byte Pair Encoding
943    BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
944    /// Sentence Piece Tokenizer
945    SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
946}
947
948/// There was an error converting a `llama_vocab_type` to a `VocabType`.
949#[derive(thiserror::Error, Debug, Eq, PartialEq)]
950pub enum LlamaTokenTypeFromIntError {
951    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
952    #[error("Unknown Value {0}")]
953    UnknownValue(llama_cpp_sys_2::llama_vocab_type),
954}
955
956impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
957    type Error = LlamaTokenTypeFromIntError;
958
959    fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
960        match value {
961            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
962            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
963            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
964        }
965    }
966}