llama_cpp_2/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::NonNull;
7use std::str::Utf8Error;
8
9use crate::context::params::LlamaContextParams;
10use crate::context::LlamaContext;
11use crate::llama_backend::LlamaBackend;
12use crate::model::params::LlamaModelParams;
13use crate::token::LlamaToken;
14use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
15use crate::{
16    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17    LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
18    TokenToStringError,
19};
20
21pub mod params;
22
23/// A safe wrapper around `llama_model`.
24#[derive(Debug)]
25#[repr(transparent)]
26#[allow(clippy::module_name_repetitions)]
27pub struct LlamaModel {
28    pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
29}
30
31/// A safe wrapper around `llama_lora_adapter`.
32#[derive(Debug)]
33#[repr(transparent)]
34#[allow(clippy::module_name_repetitions)]
35pub struct LlamaLoraAdapter {
36    pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
37}
38
39/// A performance-friendly wrapper around [LlamaModel::chat_template] which is then
40/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
41/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
42/// within the FFI.
43#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
44pub struct LlamaChatTemplate(CString);
45
46impl LlamaChatTemplate {
47    /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
48    /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
49    pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
50        Ok(Self(CString::new(template)?))
51    }
52
53    /// Accesses the template as a c string reference.
54    pub fn as_c_str(&self) -> &CStr {
55        &self.0
56    }
57
58    /// Attempts to convert the CString into a Rust str reference.
59    pub fn to_str(&self) -> Result<&str, Utf8Error> {
60        self.0.to_str()
61    }
62
63    /// Convenience method to create an owned String.
64    pub fn to_string(&self) -> Result<String, Utf8Error> {
65        self.to_str().map(str::to_string)
66    }
67}
68
69impl std::fmt::Debug for LlamaChatTemplate {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        self.0.fmt(f)
72    }
73}
74
75/// A Safe wrapper around `llama_chat_message`
76#[derive(Debug, Eq, PartialEq, Clone)]
77pub struct LlamaChatMessage {
78    role: CString,
79    content: CString,
80}
81
82impl LlamaChatMessage {
83    /// Create a new `LlamaChatMessage`
84    ///
85    /// # Errors
86    /// If either of ``role`` or ``content`` contain null bytes.
87    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
88        Ok(Self {
89            role: CString::new(role)?,
90            content: CString::new(content)?,
91        })
92    }
93}
94
95/// The Rope type that's used within the model.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum RopeType {
98    Norm,
99    NeoX,
100    MRope,
101    Vision,
102}
103
104/// How to determine if we should prepend a bos token to tokens
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum AddBos {
107    /// Add the beginning of stream token to the start of the string.
108    Always,
109    /// Do not add the beginning of stream token to the start of the string.
110    Never,
111}
112
113/// How to determine if we should tokenize special tokens
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum Special {
116    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
117    Tokenize,
118    /// Treat special and/or control tokens as plaintext.
119    Plaintext,
120}
121
122unsafe impl Send for LlamaModel {}
123
124unsafe impl Sync for LlamaModel {}
125
126impl LlamaModel {
127    pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
128        unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
129    }
130
131    /// get the number of tokens the model was trained on
132    ///
133    /// # Panics
134    ///
135    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
136    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
137    #[must_use]
138    pub fn n_ctx_train(&self) -> u32 {
139        let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
140        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
141    }
142
143    /// Get all tokens in the model.
144    pub fn tokens(
145        &self,
146        special: Special,
147    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
148        (0..self.n_vocab())
149            .map(LlamaToken::new)
150            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
151    }
152
153    /// Get the beginning of stream token.
154    #[must_use]
155    pub fn token_bos(&self) -> LlamaToken {
156        let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
157        LlamaToken(token)
158    }
159
160    /// Get the end of stream token.
161    #[must_use]
162    pub fn token_eos(&self) -> LlamaToken {
163        let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
164        LlamaToken(token)
165    }
166
167    /// Get the newline token.
168    #[must_use]
169    pub fn token_nl(&self) -> LlamaToken {
170        let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
171        LlamaToken(token)
172    }
173
174    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
175    #[must_use]
176    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
177        unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
178    }
179
180    /// Get the decoder start token.
181    #[must_use]
182    pub fn decode_start_token(&self) -> LlamaToken {
183        let token =
184            unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
185        LlamaToken(token)
186    }
187
188    /// Convert single token to a string.
189    ///
190    /// # Errors
191    ///
192    /// See [`TokenToStringError`] for more information.
193    pub fn token_to_str(
194        &self,
195        token: LlamaToken,
196        special: Special,
197    ) -> Result<String, TokenToStringError> {
198        let bytes = self.token_to_bytes(token, special)?;
199        Ok(String::from_utf8(bytes)?)
200    }
201
202    /// Convert single token to bytes.
203    ///
204    /// # Errors
205    /// See [`TokenToStringError`] for more information.
206    ///
207    /// # Panics
208    /// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by
209    /// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never
210    /// happen.
211    pub fn token_to_bytes(
212        &self,
213        token: LlamaToken,
214        special: Special,
215    ) -> Result<Vec<u8>, TokenToStringError> {
216        match self.token_to_bytes_with_size(token, 8, special, None) {
217            Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size(
218                token,
219                (-i).try_into().expect("Error buffer size is positive"),
220                special,
221                None,
222            ),
223            x => x,
224        }
225    }
226
227    /// Convert a vector of tokens to a single string.
228    ///
229    /// # Errors
230    ///
231    /// See [`TokenToStringError`] for more information.
232    pub fn tokens_to_str(
233        &self,
234        tokens: &[LlamaToken],
235        special: Special,
236    ) -> Result<String, TokenToStringError> {
237        let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
238        for piece in tokens
239            .iter()
240            .copied()
241            .map(|t| self.token_to_bytes(t, special))
242        {
243            builder.extend_from_slice(&piece?);
244        }
245        Ok(String::from_utf8(builder)?)
246    }
247
248    /// Convert a string to a Vector of tokens.
249    ///
250    /// # Errors
251    ///
252    /// - if [`str`] contains a null byte.
253    ///
254    /// # Panics
255    ///
256    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
257    ///
258    ///
259    /// ```no_run
260    /// use llama_cpp_2::model::LlamaModel;
261    ///
262    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
263    /// use std::path::Path;
264    /// use llama_cpp_2::model::AddBos;
265    /// let backend = llama_cpp_2::llama_backend::LlamaBackend::init()?;
266    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
267    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
268    /// # Ok(())
269    /// # }
270    pub fn str_to_token(
271        &self,
272        str: &str,
273        add_bos: AddBos,
274    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
275        let add_bos = match add_bos {
276            AddBos::Always => true,
277            AddBos::Never => false,
278        };
279
280        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
281        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
282
283        let c_string = CString::new(str)?;
284        let buffer_capacity =
285            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
286
287        let size = unsafe {
288            llama_cpp_sys_2::llama_tokenize(
289                self.vocab_ptr(),
290                c_string.as_ptr(),
291                c_int::try_from(c_string.as_bytes().len())?,
292                buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
293                buffer_capacity,
294                add_bos,
295                true,
296            )
297        };
298
299        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
300        // as a result - size is guaranteed to be positive here.
301        let size = if size.is_negative() {
302            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
303            unsafe {
304                llama_cpp_sys_2::llama_tokenize(
305                    self.vocab_ptr(),
306                    c_string.as_ptr(),
307                    c_int::try_from(c_string.as_bytes().len())?,
308                    buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
309                    -size,
310                    add_bos,
311                    true,
312                )
313            }
314        } else {
315            size
316        };
317
318        let size = usize::try_from(size).expect("size is positive and usize ");
319
320        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
321        unsafe { buffer.set_len(size) }
322        Ok(buffer)
323    }
324
325    /// Get the type of a token.
326    ///
327    /// # Panics
328    ///
329    /// If the token type is not known to this library.
330    #[must_use]
331    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
332        let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
333        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
334    }
335
336    /// Convert a token to a string with a specified buffer size.
337    ///
338    /// Generally you should use [`LlamaModel::token_to_str`] as it is able to decode tokens with
339    /// any length.
340    ///
341    /// # Errors
342    ///
343    /// - if the token type is unknown
344    /// - the resultant token is larger than `buffer_size`.
345    /// - the string returend by llama-cpp is not valid utf8.
346    ///
347    /// # Panics
348    ///
349    /// - if `buffer_size` does not fit into a [`c_int`].
350    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
351    pub fn token_to_str_with_size(
352        &self,
353        token: LlamaToken,
354        buffer_size: usize,
355        special: Special,
356    ) -> Result<String, TokenToStringError> {
357        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
358        Ok(String::from_utf8(bytes)?)
359    }
360
361    /// Convert a token to bytes with a specified buffer size.
362    ///
363    /// Generally you should use [`LlamaModel::token_to_bytes`] as it is able to handle tokens of
364    /// any length.
365    ///
366    /// # Errors
367    ///
368    /// - if the token type is unknown
369    /// - the resultant token is larger than `buffer_size`.
370    ///
371    /// # Panics
372    ///
373    /// - if `buffer_size` does not fit into a [`c_int`].
374    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
375    pub fn token_to_bytes_with_size(
376        &self,
377        token: LlamaToken,
378        buffer_size: usize,
379        special: Special,
380        lstrip: Option<NonZeroU16>,
381    ) -> Result<Vec<u8>, TokenToStringError> {
382        if token == self.token_nl() {
383            return Ok(b"\n".to_vec());
384        }
385
386        // unsure what to do with this in the face of the 'special' arg + attr changes
387        let attrs = self.token_attr(token);
388        if attrs.is_empty()
389            || attrs
390                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
391            || attrs.contains(LlamaTokenAttr::Control)
392                && (token == self.token_bos() || token == self.token_eos())
393        {
394            return Ok(Vec::new());
395        }
396
397        let special = match special {
398            Special::Tokenize => true,
399            Special::Plaintext => false,
400        };
401
402        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
403        let len = string.as_bytes().len();
404        let len = c_int::try_from(len).expect("length fits into c_int");
405        let buf = string.into_raw();
406        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
407        let size = unsafe {
408            llama_cpp_sys_2::llama_token_to_piece(
409                self.vocab_ptr(),
410                token.0,
411                buf,
412                len,
413                lstrip,
414                special,
415            )
416        };
417
418        match size {
419            0 => Err(TokenToStringError::UnknownTokenType),
420            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
421            size => {
422                let string = unsafe { CString::from_raw(buf) };
423                let mut bytes = string.into_bytes();
424                let len = usize::try_from(size).expect("size is positive and fits into usize");
425                bytes.truncate(len);
426                Ok(bytes)
427            }
428        }
429    }
430    /// The number of tokens the model was trained on.
431    ///
432    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
433    /// without issue.
434    #[must_use]
435    pub fn n_vocab(&self) -> i32 {
436        unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
437    }
438
439    /// The type of vocab the model was trained on.
440    ///
441    /// # Panics
442    ///
443    /// If llama-cpp emits a vocab type that is not known to this library.
444    #[must_use]
445    pub fn vocab_type(&self) -> VocabType {
446        // llama_cpp_sys_2::llama_model_get_vocab
447        let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
448        VocabType::try_from(vocab_type).expect("invalid vocab type")
449    }
450
451    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
452    /// without issue.
453    #[must_use]
454    pub fn n_embd(&self) -> c_int {
455        unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
456    }
457
458    /// Returns the total size of all the tensors in the model in bytes.
459    pub fn size(&self) -> u64 {
460        unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
461    }
462
463    /// Returns the number of parameters in the model.
464    pub fn n_params(&self) -> u64 {
465        unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
466    }
467
468    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
469    pub fn is_recurrent(&self) -> bool {
470        unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
471    }
472
473    /// Returns the number of layers within the model.
474    pub fn n_layer(&self) -> u32 {
475        // It's never possible for this to panic because while the API interface is defined as an int32_t,
476        // the field it's accessing is a uint32_t.
477        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
478    }
479
480    /// Returns the number of attention heads within the model.
481    pub fn n_head(&self) -> u32 {
482        // It's never possible for this to panic because while the API interface is defined as an int32_t,
483        // the field it's accessing is a uint32_t.
484        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
485    }
486
487    /// Returns the number of KV attention heads.
488    pub fn n_head_kv(&self) -> u32 {
489        // It's never possible for this to panic because while the API interface is defined as an int32_t,
490        // the field it's accessing is a uint32_t.
491        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
492            .unwrap()
493    }
494
495    /// Get metadata value as a string by key name
496    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
497        let key_cstring = CString::new(key)?;
498        let key_ptr = key_cstring.as_ptr();
499
500        extract_meta_string(
501            |buf_ptr, buf_len| unsafe {
502                llama_cpp_sys_2::llama_model_meta_val_str(
503                    self.model.as_ptr(),
504                    key_ptr,
505                    buf_ptr,
506                    buf_len,
507                )
508            },
509            256,
510        )
511    }
512
513    /// Get the number of metadata key/value pairs
514    pub fn meta_count(&self) -> i32 {
515        unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
516    }
517
518    /// Get metadata key name by index
519    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
520        extract_meta_string(
521            |buf_ptr, buf_len| unsafe {
522                llama_cpp_sys_2::llama_model_meta_key_by_index(
523                    self.model.as_ptr(),
524                    index,
525                    buf_ptr,
526                    buf_len,
527                )
528            },
529            256,
530        )
531    }
532
533    /// Get metadata value as a string by index
534    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
535        extract_meta_string(
536            |buf_ptr, buf_len| unsafe {
537                llama_cpp_sys_2::llama_model_meta_val_str_by_index(
538                    self.model.as_ptr(),
539                    index,
540                    buf_ptr,
541                    buf_len,
542                )
543            },
544            256,
545        )
546    }
547
548    /// Returns the rope type of the model.
549    pub fn rope_type(&self) -> Option<RopeType> {
550        match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
551            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
552            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
553            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
554            llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
555            llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
556            rope_type => {
557                tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
558                None
559            }
560        }
561    }
562
563    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
564    ///
565    /// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
566    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
567    /// the chat.
568    ///
569    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
570    /// to parse jinja templates not supported by the llama.cpp template engine.
571    ///
572    /// # Errors
573    ///
574    /// * If the model has no chat template by that name
575    /// * If the chat template is not a valid [`CString`].
576    pub fn chat_template(
577        &self,
578        name: Option<&str>,
579    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
580        let name_cstr = name.map(CString::new);
581        let name_ptr = match name_cstr {
582            Some(Ok(name)) => name.as_ptr(),
583            _ => std::ptr::null(),
584        };
585        let result =
586            unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
587
588        // Convert result to Rust String if not null
589        if result.is_null() {
590            Err(ChatTemplateError::MissingTemplate)
591        } else {
592            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
593            let chat_template = CString::new(chat_template_cstr.to_bytes())?;
594            Ok(LlamaChatTemplate(chat_template))
595        }
596    }
597
598    /// Loads a model from a file.
599    ///
600    /// # Errors
601    ///
602    /// See [`LlamaModelLoadError`] for more information.
603    #[tracing::instrument(skip_all, fields(params))]
604    pub fn load_from_file(
605        _: &LlamaBackend,
606        path: impl AsRef<Path>,
607        params: &LlamaModelParams,
608    ) -> Result<Self, LlamaModelLoadError> {
609        let path = path.as_ref();
610        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
611        let path = path
612            .to_str()
613            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
614
615        let cstr = CString::new(path)?;
616        let llama_model =
617            unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
618
619        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
620
621        tracing::debug!(?path, "Loaded model");
622        Ok(LlamaModel { model })
623    }
624
625    /// Initializes a lora adapter from a file.
626    ///
627    /// # Errors
628    ///
629    /// See [`LlamaLoraAdapterInitError`] for more information.
630    pub fn lora_adapter_init(
631        &self,
632        path: impl AsRef<Path>,
633    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
634        let path = path.as_ref();
635        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
636
637        let path = path
638            .to_str()
639            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
640                path.to_path_buf(),
641            ))?;
642
643        let cstr = CString::new(path)?;
644        let adapter =
645            unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
646
647        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
648
649        tracing::debug!(?path, "Initialized lora adapter");
650        Ok(LlamaLoraAdapter {
651            lora_adapter: adapter,
652        })
653    }
654
655    /// Create a new context from this model.
656    ///
657    /// # Errors
658    ///
659    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
660    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
661    #[allow(clippy::needless_pass_by_value)]
662    pub fn new_context(
663        &self,
664        _: &LlamaBackend,
665        params: LlamaContextParams,
666    ) -> Result<LlamaContext, LlamaContextLoadError> {
667        let context_params = params.context_params;
668        let context = unsafe {
669            llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
670        };
671        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
672
673        Ok(LlamaContext::new(self, context, params.embeddings()))
674    }
675
676    /// Apply the models chat template to some messages.
677    /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
678    ///
679    /// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
680    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
681    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
682    /// string.
683    ///
684    /// Use [Self::chat_template] to retrieve the template baked into the model (this is the preferred
685    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
686    ///
687    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
688    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
689    /// one into the output and the output may also have unexpected output aside from that.
690    ///
691    /// # Errors
692    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
693    #[tracing::instrument(skip_all)]
694    pub fn apply_chat_template(
695        &self,
696        tmpl: &LlamaChatTemplate,
697        chat: &[LlamaChatMessage],
698        add_ass: bool,
699    ) -> Result<String, ApplyChatTemplateError> {
700        // Buffer is twice the length of messages per their recommendation
701        let message_length = chat.iter().fold(0, |acc, c| {
702            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
703        });
704        let mut buff: Vec<u8> = vec![0; message_length * 2];
705
706        // Build our llama_cpp_sys_2 chat messages
707        let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
708            .iter()
709            .map(|c| llama_cpp_sys_2::llama_chat_message {
710                role: c.role.as_ptr(),
711                content: c.content.as_ptr(),
712            })
713            .collect();
714
715        let tmpl_ptr = tmpl.0.as_ptr();
716
717        let res = unsafe {
718            llama_cpp_sys_2::llama_chat_apply_template(
719                tmpl_ptr,
720                chat.as_ptr(),
721                chat.len(),
722                add_ass,
723                buff.as_mut_ptr().cast::<c_char>(),
724                buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
725            )
726        };
727
728        if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
729            buff.resize(res.try_into().expect("res is negative"), 0);
730
731            let res = unsafe {
732                llama_cpp_sys_2::llama_chat_apply_template(
733                    tmpl_ptr,
734                    chat.as_ptr(),
735                    chat.len(),
736                    add_ass,
737                    buff.as_mut_ptr().cast::<c_char>(),
738                    buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
739                )
740            };
741            assert_eq!(Ok(res), buff.len().try_into());
742        }
743        buff.truncate(res.try_into().expect("res is negative"));
744        Ok(String::from_utf8(buff)?)
745    }
746}
747
748/// Generic helper function for extracting string values from the C API
749/// This are specifically useful for the the metadata functions, where we pass in a buffer
750/// to be populated by a string, not yet knowing if the buffer is large enough.
751/// If the buffer was not large enough, we get the correct length back, which can be used to
752/// construct a buffer of appropriate size.
753fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
754where
755    F: Fn(*mut c_char, usize) -> i32,
756{
757    let mut buffer = vec![0u8; capacity];
758
759    // call the foreign function
760    let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
761    if result < 0 {
762        return Err(MetaValError::NegativeReturn(result));
763    }
764
765    // check if the response fit in our buffer
766    let returned_len = result as usize;
767    if returned_len >= capacity {
768        // buffer wasn't large enough, try again with the correct capacity.
769        return extract_meta_string(c_function, returned_len + 1);
770    }
771
772    // verify null termination
773    debug_assert_eq!(
774        buffer.get(returned_len),
775        Some(&0),
776        "should end with null byte"
777    );
778
779    // resize, convert, and return
780    buffer.truncate(returned_len);
781    Ok(String::from_utf8(buffer)?)
782}
783
784impl Drop for LlamaModel {
785    fn drop(&mut self) {
786        unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
787    }
788}
789
790/// a rusty equivalent of `llama_vocab_type`
791#[repr(u32)]
792#[derive(Debug, Eq, Copy, Clone, PartialEq)]
793pub enum VocabType {
794    /// Byte Pair Encoding
795    BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
796    /// Sentence Piece Tokenizer
797    SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
798}
799
800/// There was an error converting a `llama_vocab_type` to a `VocabType`.
801#[derive(thiserror::Error, Debug, Eq, PartialEq)]
802pub enum LlamaTokenTypeFromIntError {
803    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
804    #[error("Unknown Value {0}")]
805    UnknownValue(llama_cpp_sys_2::llama_vocab_type),
806}
807
808impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
809    type Error = LlamaTokenTypeFromIntError;
810
811    fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
812        match value {
813            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
814            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
815            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
816        }
817    }
818}