llama_cpp_2/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::NonNull;
7use std::str::Utf8Error;
8
9use crate::context::params::LlamaContextParams;
10use crate::context::LlamaContext;
11use crate::llama_backend::LlamaBackend;
12use crate::model::params::LlamaModelParams;
13use crate::token::LlamaToken;
14use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
15use crate::{
16    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17    LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
18    TokenToStringError,
19};
20
21pub mod params;
22
23/// A safe wrapper around `llama_model`.
24#[derive(Debug)]
25#[repr(transparent)]
26#[allow(clippy::module_name_repetitions)]
27pub struct LlamaModel {
28    pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
29}
30
31/// A safe wrapper around `llama_lora_adapter`.
32#[derive(Debug)]
33#[repr(transparent)]
34#[allow(clippy::module_name_repetitions)]
35pub struct LlamaLoraAdapter {
36    pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
37}
38
39/// A performance-friendly wrapper around [LlamaModel::chat_template] which is then
40/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
41/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
42/// within the FFI.
43#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
44pub struct LlamaChatTemplate(CString);
45
46impl LlamaChatTemplate {
47    /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
48    /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
49    pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
50        Ok(Self(CString::new(template)?))
51    }
52
53    /// Accesses the template as a c string reference.
54    pub fn as_c_str(&self) -> &CStr {
55        &self.0
56    }
57
58    /// Attempts to convert the CString into a Rust str reference.
59    pub fn to_str(&self) -> Result<&str, Utf8Error> {
60        self.0.to_str()
61    }
62
63    /// Convenience method to create an owned String.
64    pub fn to_string(&self) -> Result<String, Utf8Error> {
65        self.to_str().map(str::to_string)
66    }
67}
68
69impl std::fmt::Debug for LlamaChatTemplate {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        self.0.fmt(f)
72    }
73}
74
75/// A Safe wrapper around `llama_chat_message`
76#[derive(Debug, Eq, PartialEq, Clone)]
77pub struct LlamaChatMessage {
78    role: CString,
79    content: CString,
80}
81
82impl LlamaChatMessage {
83    /// Create a new `LlamaChatMessage`
84    ///
85    /// # Errors
86    /// If either of ``role`` or ``content`` contain null bytes.
87    pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
88        Ok(Self {
89            role: CString::new(role)?,
90            content: CString::new(content)?,
91        })
92    }
93}
94
95/// The Rope type that's used within the model.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum RopeType {
98    Norm,
99    NeoX,
100    MRope,
101    Vision,
102}
103
104/// How to determine if we should prepend a bos token to tokens
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum AddBos {
107    /// Add the beginning of stream token to the start of the string.
108    Always,
109    /// Do not add the beginning of stream token to the start of the string.
110    Never,
111}
112
113/// How to determine if we should tokenize special tokens
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum Special {
116    /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
117    Tokenize,
118    /// Treat special and/or control tokens as plaintext.
119    Plaintext,
120}
121
122unsafe impl Send for LlamaModel {}
123
124unsafe impl Sync for LlamaModel {}
125
126impl LlamaModel {
127    pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
128        unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
129    }
130
131    /// get the number of tokens the model was trained on
132    ///
133    /// # Panics
134    ///
135    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
136    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
137    #[must_use]
138    pub fn n_ctx_train(&self) -> u32 {
139        let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
140        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
141    }
142
143    /// Get all tokens in the model.
144    pub fn tokens(
145        &self,
146        special: Special,
147    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
148        (0..self.n_vocab())
149            .map(LlamaToken::new)
150            .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
151    }
152
153    /// Get the beginning of stream token.
154    #[must_use]
155    pub fn token_bos(&self) -> LlamaToken {
156        let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
157        LlamaToken(token)
158    }
159
160    /// Get the end of stream token.
161    #[must_use]
162    pub fn token_eos(&self) -> LlamaToken {
163        let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
164        LlamaToken(token)
165    }
166
167    /// Get the newline token.
168    #[must_use]
169    pub fn token_nl(&self) -> LlamaToken {
170        let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
171        LlamaToken(token)
172    }
173
174    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
175    #[must_use]
176    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
177        unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
178    }
179
180    /// Get the decoder start token.
181    #[must_use]
182    pub fn decode_start_token(&self) -> LlamaToken {
183        let token =
184            unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
185        LlamaToken(token)
186    }
187
188    /// Get the separator token (SEP).
189    #[must_use]
190    pub fn token_sep(&self) -> LlamaToken {
191        let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
192        LlamaToken(token)
193    }
194
195    /// Convert single token to a string.
196    ///
197    /// # Errors
198    ///
199    /// See [`TokenToStringError`] for more information.
200    pub fn token_to_str(
201        &self,
202        token: LlamaToken,
203        special: Special,
204    ) -> Result<String, TokenToStringError> {
205        let bytes = self.token_to_bytes(token, special)?;
206        Ok(String::from_utf8(bytes)?)
207    }
208
209    /// Convert single token to bytes.
210    ///
211    /// # Errors
212    /// See [`TokenToStringError`] for more information.
213    ///
214    /// # Panics
215    /// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by
216    /// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never
217    /// happen.
218    pub fn token_to_bytes(
219        &self,
220        token: LlamaToken,
221        special: Special,
222    ) -> Result<Vec<u8>, TokenToStringError> {
223        match self.token_to_bytes_with_size(token, 8, special, None) {
224            Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size(
225                token,
226                (-i).try_into().expect("Error buffer size is positive"),
227                special,
228                None,
229            ),
230            x => x,
231        }
232    }
233
234    /// Convert a vector of tokens to a single string.
235    ///
236    /// # Errors
237    ///
238    /// See [`TokenToStringError`] for more information.
239    pub fn tokens_to_str(
240        &self,
241        tokens: &[LlamaToken],
242        special: Special,
243    ) -> Result<String, TokenToStringError> {
244        let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
245        for piece in tokens
246            .iter()
247            .copied()
248            .map(|t| self.token_to_bytes(t, special))
249        {
250            builder.extend_from_slice(&piece?);
251        }
252        Ok(String::from_utf8(builder)?)
253    }
254
255    /// Convert a string to a Vector of tokens.
256    ///
257    /// # Errors
258    ///
259    /// - if [`str`] contains a null byte.
260    ///
261    /// # Panics
262    ///
263    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
264    ///
265    ///
266    /// ```no_run
267    /// use llama_cpp_2::model::LlamaModel;
268    ///
269    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
270    /// use std::path::Path;
271    /// use llama_cpp_2::model::AddBos;
272    /// let backend = llama_cpp_2::llama_backend::LlamaBackend::init()?;
273    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
274    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
275    /// # Ok(())
276    /// # }
277    pub fn str_to_token(
278        &self,
279        str: &str,
280        add_bos: AddBos,
281    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
282        let add_bos = match add_bos {
283            AddBos::Always => true,
284            AddBos::Never => false,
285        };
286
287        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
288        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
289
290        let c_string = CString::new(str)?;
291        let buffer_capacity =
292            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
293
294        let size = unsafe {
295            llama_cpp_sys_2::llama_tokenize(
296                self.vocab_ptr(),
297                c_string.as_ptr(),
298                c_int::try_from(c_string.as_bytes().len())?,
299                buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
300                buffer_capacity,
301                add_bos,
302                true,
303            )
304        };
305
306        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
307        // as a result - size is guaranteed to be positive here.
308        let size = if size.is_negative() {
309            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
310            unsafe {
311                llama_cpp_sys_2::llama_tokenize(
312                    self.vocab_ptr(),
313                    c_string.as_ptr(),
314                    c_int::try_from(c_string.as_bytes().len())?,
315                    buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
316                    -size,
317                    add_bos,
318                    true,
319                )
320            }
321        } else {
322            size
323        };
324
325        let size = usize::try_from(size).expect("size is positive and usize ");
326
327        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
328        unsafe { buffer.set_len(size) }
329        Ok(buffer)
330    }
331
332    /// Get the type of a token.
333    ///
334    /// # Panics
335    ///
336    /// If the token type is not known to this library.
337    #[must_use]
338    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
339        let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
340        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
341    }
342
343    /// Convert a token to a string with a specified buffer size.
344    ///
345    /// Generally you should use [`LlamaModel::token_to_str`] as it is able to decode tokens with
346    /// any length.
347    ///
348    /// # Errors
349    ///
350    /// - if the token type is unknown
351    /// - the resultant token is larger than `buffer_size`.
352    /// - the string returend by llama-cpp is not valid utf8.
353    ///
354    /// # Panics
355    ///
356    /// - if `buffer_size` does not fit into a [`c_int`].
357    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
358    pub fn token_to_str_with_size(
359        &self,
360        token: LlamaToken,
361        buffer_size: usize,
362        special: Special,
363    ) -> Result<String, TokenToStringError> {
364        let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
365        Ok(String::from_utf8(bytes)?)
366    }
367
368    /// Convert a token to bytes with a specified buffer size.
369    ///
370    /// Generally you should use [`LlamaModel::token_to_bytes`] as it is able to handle tokens of
371    /// any length.
372    ///
373    /// # Errors
374    ///
375    /// - if the token type is unknown
376    /// - the resultant token is larger than `buffer_size`.
377    ///
378    /// # Panics
379    ///
380    /// - if `buffer_size` does not fit into a [`c_int`].
381    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
382    pub fn token_to_bytes_with_size(
383        &self,
384        token: LlamaToken,
385        buffer_size: usize,
386        special: Special,
387        lstrip: Option<NonZeroU16>,
388    ) -> Result<Vec<u8>, TokenToStringError> {
389        if token == self.token_nl() {
390            return Ok(b"\n".to_vec());
391        }
392
393        // unsure what to do with this in the face of the 'special' arg + attr changes
394        let attrs = self.token_attr(token);
395        if attrs.is_empty()
396            || attrs
397                .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
398            || attrs.contains(LlamaTokenAttr::Control)
399                && (token == self.token_bos() || token == self.token_eos())
400        {
401            return Ok(Vec::new());
402        }
403
404        let special = match special {
405            Special::Tokenize => true,
406            Special::Plaintext => false,
407        };
408
409        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
410        let len = string.as_bytes().len();
411        let len = c_int::try_from(len).expect("length fits into c_int");
412        let buf = string.into_raw();
413        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
414        let size = unsafe {
415            llama_cpp_sys_2::llama_token_to_piece(
416                self.vocab_ptr(),
417                token.0,
418                buf,
419                len,
420                lstrip,
421                special,
422            )
423        };
424
425        match size {
426            0 => Err(TokenToStringError::UnknownTokenType),
427            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
428            size => {
429                let string = unsafe { CString::from_raw(buf) };
430                let mut bytes = string.into_bytes();
431                let len = usize::try_from(size).expect("size is positive and fits into usize");
432                bytes.truncate(len);
433                Ok(bytes)
434            }
435        }
436    }
437    /// The number of tokens the model was trained on.
438    ///
439    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
440    /// without issue.
441    #[must_use]
442    pub fn n_vocab(&self) -> i32 {
443        unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
444    }
445
446    /// The type of vocab the model was trained on.
447    ///
448    /// # Panics
449    ///
450    /// If llama-cpp emits a vocab type that is not known to this library.
451    #[must_use]
452    pub fn vocab_type(&self) -> VocabType {
453        // llama_cpp_sys_2::llama_model_get_vocab
454        let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
455        VocabType::try_from(vocab_type).expect("invalid vocab type")
456    }
457
458    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
459    /// without issue.
460    #[must_use]
461    pub fn n_embd(&self) -> c_int {
462        unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
463    }
464
465    /// Returns the total size of all the tensors in the model in bytes.
466    pub fn size(&self) -> u64 {
467        unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
468    }
469
470    /// Returns the number of parameters in the model.
471    pub fn n_params(&self) -> u64 {
472        unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
473    }
474
475    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
476    pub fn is_recurrent(&self) -> bool {
477        unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
478    }
479
480    /// Returns the number of layers within the model.
481    pub fn n_layer(&self) -> u32 {
482        // It's never possible for this to panic because while the API interface is defined as an int32_t,
483        // the field it's accessing is a uint32_t.
484        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
485    }
486
487    /// Returns the number of attention heads within the model.
488    pub fn n_head(&self) -> u32 {
489        // It's never possible for this to panic because while the API interface is defined as an int32_t,
490        // the field it's accessing is a uint32_t.
491        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
492    }
493
494    /// Returns the number of KV attention heads.
495    pub fn n_head_kv(&self) -> u32 {
496        // It's never possible for this to panic because while the API interface is defined as an int32_t,
497        // the field it's accessing is a uint32_t.
498        u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
499            .unwrap()
500    }
501
502    /// Get metadata value as a string by key name
503    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
504        let key_cstring = CString::new(key)?;
505        let key_ptr = key_cstring.as_ptr();
506
507        extract_meta_string(
508            |buf_ptr, buf_len| unsafe {
509                llama_cpp_sys_2::llama_model_meta_val_str(
510                    self.model.as_ptr(),
511                    key_ptr,
512                    buf_ptr,
513                    buf_len,
514                )
515            },
516            256,
517        )
518    }
519
520    /// Get the number of metadata key/value pairs
521    pub fn meta_count(&self) -> i32 {
522        unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
523    }
524
525    /// Get metadata key name by index
526    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
527        extract_meta_string(
528            |buf_ptr, buf_len| unsafe {
529                llama_cpp_sys_2::llama_model_meta_key_by_index(
530                    self.model.as_ptr(),
531                    index,
532                    buf_ptr,
533                    buf_len,
534                )
535            },
536            256,
537        )
538    }
539
540    /// Get metadata value as a string by index
541    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
542        extract_meta_string(
543            |buf_ptr, buf_len| unsafe {
544                llama_cpp_sys_2::llama_model_meta_val_str_by_index(
545                    self.model.as_ptr(),
546                    index,
547                    buf_ptr,
548                    buf_len,
549                )
550            },
551            256,
552        )
553    }
554
555    /// Returns the rope type of the model.
556    pub fn rope_type(&self) -> Option<RopeType> {
557        match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
558            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
559            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
560            llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
561            llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
562            llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
563            rope_type => {
564                tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
565                None
566            }
567        }
568    }
569
570    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
571    ///
572    /// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
573    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
574    /// the chat.
575    ///
576    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
577    /// to parse jinja templates not supported by the llama.cpp template engine.
578    ///
579    /// # Errors
580    ///
581    /// * If the model has no chat template by that name
582    /// * If the chat template is not a valid [`CString`].
583    pub fn chat_template(
584        &self,
585        name: Option<&str>,
586    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
587        let name_cstr = name.map(CString::new);
588        let name_ptr = match name_cstr {
589            Some(Ok(name)) => name.as_ptr(),
590            _ => std::ptr::null(),
591        };
592        let result =
593            unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
594
595        // Convert result to Rust String if not null
596        if result.is_null() {
597            Err(ChatTemplateError::MissingTemplate)
598        } else {
599            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
600            let chat_template = CString::new(chat_template_cstr.to_bytes())?;
601            Ok(LlamaChatTemplate(chat_template))
602        }
603    }
604
605    /// Loads a model from a file.
606    ///
607    /// # Errors
608    ///
609    /// See [`LlamaModelLoadError`] for more information.
610    #[tracing::instrument(skip_all, fields(params))]
611    pub fn load_from_file(
612        _: &LlamaBackend,
613        path: impl AsRef<Path>,
614        params: &LlamaModelParams,
615    ) -> Result<Self, LlamaModelLoadError> {
616        let path = path.as_ref();
617        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
618        let path = path
619            .to_str()
620            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
621
622        let cstr = CString::new(path)?;
623        let llama_model =
624            unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
625
626        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
627
628        tracing::debug!(?path, "Loaded model");
629        Ok(LlamaModel { model })
630    }
631
632    /// Initializes a lora adapter from a file.
633    ///
634    /// # Errors
635    ///
636    /// See [`LlamaLoraAdapterInitError`] for more information.
637    pub fn lora_adapter_init(
638        &self,
639        path: impl AsRef<Path>,
640    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
641        let path = path.as_ref();
642        debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
643
644        let path = path
645            .to_str()
646            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
647                path.to_path_buf(),
648            ))?;
649
650        let cstr = CString::new(path)?;
651        let adapter =
652            unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
653
654        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
655
656        tracing::debug!(?path, "Initialized lora adapter");
657        Ok(LlamaLoraAdapter {
658            lora_adapter: adapter,
659        })
660    }
661
662    /// Create a new context from this model.
663    ///
664    /// # Errors
665    ///
666    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
667    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
668    #[allow(clippy::needless_pass_by_value)]
669    pub fn new_context(
670        &self,
671        _: &LlamaBackend,
672        params: LlamaContextParams,
673    ) -> Result<LlamaContext, LlamaContextLoadError> {
674        let context_params = params.context_params;
675        let context = unsafe {
676            llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
677        };
678        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
679
680        Ok(LlamaContext::new(self, context, params.embeddings()))
681    }
682
683    /// Apply the models chat template to some messages.
684    /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
685    ///
686    /// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
687    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
688    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
689    /// string.
690    ///
691    /// Use [Self::chat_template] to retrieve the template baked into the model (this is the preferred
692    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
693    ///
694    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
695    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
696    /// one into the output and the output may also have unexpected output aside from that.
697    ///
698    /// # Errors
699    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
700    #[tracing::instrument(skip_all)]
701    pub fn apply_chat_template(
702        &self,
703        tmpl: &LlamaChatTemplate,
704        chat: &[LlamaChatMessage],
705        add_ass: bool,
706    ) -> Result<String, ApplyChatTemplateError> {
707        // Buffer is twice the length of messages per their recommendation
708        let message_length = chat.iter().fold(0, |acc, c| {
709            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
710        });
711        let mut buff: Vec<u8> = vec![0; message_length * 2];
712
713        // Build our llama_cpp_sys_2 chat messages
714        let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
715            .iter()
716            .map(|c| llama_cpp_sys_2::llama_chat_message {
717                role: c.role.as_ptr(),
718                content: c.content.as_ptr(),
719            })
720            .collect();
721
722        let tmpl_ptr = tmpl.0.as_ptr();
723
724        let res = unsafe {
725            llama_cpp_sys_2::llama_chat_apply_template(
726                tmpl_ptr,
727                chat.as_ptr(),
728                chat.len(),
729                add_ass,
730                buff.as_mut_ptr().cast::<c_char>(),
731                buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
732            )
733        };
734
735        if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
736            buff.resize(res.try_into().expect("res is negative"), 0);
737
738            let res = unsafe {
739                llama_cpp_sys_2::llama_chat_apply_template(
740                    tmpl_ptr,
741                    chat.as_ptr(),
742                    chat.len(),
743                    add_ass,
744                    buff.as_mut_ptr().cast::<c_char>(),
745                    buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
746                )
747            };
748            assert_eq!(Ok(res), buff.len().try_into());
749        }
750        buff.truncate(res.try_into().expect("res is negative"));
751        Ok(String::from_utf8(buff)?)
752    }
753}
754
755/// Generic helper function for extracting string values from the C API
756/// This are specifically useful for the the metadata functions, where we pass in a buffer
757/// to be populated by a string, not yet knowing if the buffer is large enough.
758/// If the buffer was not large enough, we get the correct length back, which can be used to
759/// construct a buffer of appropriate size.
760fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
761where
762    F: Fn(*mut c_char, usize) -> i32,
763{
764    let mut buffer = vec![0u8; capacity];
765
766    // call the foreign function
767    let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
768    if result < 0 {
769        return Err(MetaValError::NegativeReturn(result));
770    }
771
772    // check if the response fit in our buffer
773    let returned_len = result as usize;
774    if returned_len >= capacity {
775        // buffer wasn't large enough, try again with the correct capacity.
776        return extract_meta_string(c_function, returned_len + 1);
777    }
778
779    // verify null termination
780    debug_assert_eq!(
781        buffer.get(returned_len),
782        Some(&0),
783        "should end with null byte"
784    );
785
786    // resize, convert, and return
787    buffer.truncate(returned_len);
788    Ok(String::from_utf8(buffer)?)
789}
790
791impl Drop for LlamaModel {
792    fn drop(&mut self) {
793        unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
794    }
795}
796
797/// a rusty equivalent of `llama_vocab_type`
798#[repr(u32)]
799#[derive(Debug, Eq, Copy, Clone, PartialEq)]
800pub enum VocabType {
801    /// Byte Pair Encoding
802    BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
803    /// Sentence Piece Tokenizer
804    SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
805}
806
807/// There was an error converting a `llama_vocab_type` to a `VocabType`.
808#[derive(thiserror::Error, Debug, Eq, PartialEq)]
809pub enum LlamaTokenTypeFromIntError {
810    /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
811    #[error("Unknown Value {0}")]
812    UnknownValue(llama_cpp_sys_2::llama_vocab_type),
813}
814
815impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
816    type Error = LlamaTokenTypeFromIntError;
817
818    fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
819        match value {
820            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
821            llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
822            unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
823        }
824    }
825}