Skip to main content

llama_cpp_bindings/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7
8use crate::context::LlamaContext;
9use crate::context::params::LlamaContextParams;
10use crate::llama_backend::LlamaBackend;
11use crate::openai::OpenAIChatTemplateParams;
12use crate::token::LlamaToken;
13use crate::token_type::LlamaTokenAttrs;
14use crate::{
15    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
16    LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
17};
18
19pub mod add_bos;
20pub mod chat_template_result;
21pub mod grammar_trigger;
22pub mod llama_chat_message;
23pub mod llama_chat_template;
24pub mod llama_lora_adapter;
25pub mod params;
26pub mod rope_type;
27pub mod split_mode;
28pub mod vocab_type;
29
30pub use add_bos::AddBos;
31pub use chat_template_result::ChatTemplateResult;
32pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
33pub use llama_chat_message::LlamaChatMessage;
34pub use llama_chat_template::LlamaChatTemplate;
35pub use llama_lora_adapter::LlamaLoraAdapter;
36pub use rope_type::RopeType;
37pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
38
39use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
40use params::LlamaModelParams;
41
42/// A safe wrapper around `llama_model`.
43#[derive(Debug)]
44#[repr(transparent)]
45pub struct LlamaModel {
46    /// Raw pointer to the underlying `llama_model`.
47    pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
48}
49
50unsafe impl Send for LlamaModel {}
51
52unsafe impl Sync for LlamaModel {}
53
54impl LlamaModel {
55    /// Returns a raw pointer to the model's vocabulary.
56    #[must_use]
57    pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
58        unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
59    }
60
61    /// get the number of tokens the model was trained on
62    ///
63    /// # Panics
64    ///
65    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
66    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
67    #[must_use]
68    pub fn n_ctx_train(&self) -> u32 {
69        let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
70        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
71    }
72
73    /// Get all tokens in the model.
74    pub fn tokens(
75        &self,
76        decode_special: bool,
77    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
78        (0..self.n_vocab())
79            .map(LlamaToken::new)
80            .map(move |llama_token| {
81                let mut decoder = encoding_rs::UTF_8.new_decoder();
82                (
83                    llama_token,
84                    self.token_to_piece(llama_token, &mut decoder, decode_special, None),
85                )
86            })
87    }
88
89    /// Get the beginning of stream token.
90    #[must_use]
91    pub fn token_bos(&self) -> LlamaToken {
92        let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
93        LlamaToken(token)
94    }
95
96    /// Get the end of stream token.
97    #[must_use]
98    pub fn token_eos(&self) -> LlamaToken {
99        let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
100        LlamaToken(token)
101    }
102
103    /// Get the newline token.
104    #[must_use]
105    pub fn token_nl(&self) -> LlamaToken {
106        let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
107        LlamaToken(token)
108    }
109
110    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
111    #[must_use]
112    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
113        unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
114    }
115
116    /// Get the decoder start token.
117    #[must_use]
118    pub fn decode_start_token(&self) -> LlamaToken {
119        let token =
120            unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
121        LlamaToken(token)
122    }
123
124    /// Get the separator token (SEP).
125    #[must_use]
126    pub fn token_sep(&self) -> LlamaToken {
127        let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
128        LlamaToken(token)
129    }
130
131    /// Convert a string to a Vector of tokens.
132    ///
133    /// # Errors
134    ///
135    /// - if [`str`] contains a null byte.
136    ///
137    /// # Panics
138    ///
139    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
140    ///
141    ///
142    /// ```no_run
143    /// use llama_cpp_bindings::model::LlamaModel;
144    ///
145    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
146    /// use std::path::Path;
147    /// use llama_cpp_bindings::model::AddBos;
148    /// let backend = llama_cpp_bindings::llama_backend::LlamaBackend::init()?;
149    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
150    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
151    /// # Ok(())
152    /// # }
153    pub fn str_to_token(
154        &self,
155        str: &str,
156        add_bos: AddBos,
157    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
158        let add_bos = match add_bos {
159            AddBos::Always => true,
160            AddBos::Never => false,
161        };
162
163        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
164        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
165
166        let c_string = CString::new(str)?;
167        let buffer_capacity =
168            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
169
170        let size = unsafe {
171            llama_cpp_bindings_sys::llama_tokenize(
172                self.vocab_ptr(),
173                c_string.as_ptr(),
174                c_int::try_from(c_string.as_bytes().len())?,
175                buffer
176                    .as_mut_ptr()
177                    .cast::<llama_cpp_bindings_sys::llama_token>(),
178                buffer_capacity,
179                add_bos,
180                true,
181            )
182        };
183
184        // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
185        // as a result - size is guaranteed to be positive here.
186        let size = if size.is_negative() {
187            buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
188            unsafe {
189                llama_cpp_bindings_sys::llama_tokenize(
190                    self.vocab_ptr(),
191                    c_string.as_ptr(),
192                    c_int::try_from(c_string.as_bytes().len())?,
193                    buffer
194                        .as_mut_ptr()
195                        .cast::<llama_cpp_bindings_sys::llama_token>(),
196                    -size,
197                    add_bos,
198                    true,
199                )
200            }
201        } else {
202            size
203        };
204
205        let size = usize::try_from(size).expect("size is positive and usize ");
206
207        // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
208        unsafe { buffer.set_len(size) }
209
210        Ok(buffer)
211    }
212
213    /// Get the type of a token.
214    ///
215    /// # Panics
216    ///
217    /// If the token type is not known to this library.
218    #[must_use]
219    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
220        let token_type =
221            unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
222        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
223    }
224
225    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
226    ///
227    /// This is the new default function for token decoding and provides direct access to
228    /// the llama.cpp token decoding functionality without any special logic or filtering.
229    ///
230    /// Decoding raw string requires using an decoder, tokens from language models may not always map
231    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
232    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
233    /// garbage.
234    ///
235    /// # Errors
236    ///
237    /// - if the token type is unknown
238    ///
239    /// # Panics
240    ///
241    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
242    pub fn token_to_piece(
243        &self,
244        token: LlamaToken,
245        decoder: &mut encoding_rs::Decoder,
246        special: bool,
247        lstrip: Option<NonZeroU16>,
248    ) -> Result<String, TokenToStringError> {
249        let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
250            // when there is insufficient space `token_to_piece` will return a negative number with the size that would have been returned
251            // https://github.com/abetlen/llama-cpp-python/blob/c37132bac860fcc333255c36313f89c4f49d4c8d/llama_cpp/llama_cpp.py#L3461
252            Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
253                token,
254                (-i).try_into().expect("Error buffer size is positive"),
255                special,
256                lstrip,
257            ),
258            x => x,
259        }?;
260        // here the assumption is that each byte from the output may map to at most one output charakter
261        let mut output_piece = String::with_capacity(bytes.len());
262        // _result only tells if there is nothing more in the input, or if the output was full
263        // but further decoding will happen on the next interation anyway
264        let (_result, _somesize, _truthy) =
265            decoder.decode_to_string(&bytes, &mut output_piece, false);
266
267        Ok(output_piece)
268    }
269
270    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
271    ///
272    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
273    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
274    /// letting the caller handle errors. For a safer inteface returing rust strings directly use `token_to_piece` instead!
275    ///
276    /// # Errors
277    ///
278    /// - if the token type is unknown
279    /// - the resultant token is larger than `buffer_size`.
280    ///
281    /// # Panics
282    ///
283    /// - if `buffer_size` does not fit into a [`c_int`].
284    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
285    pub fn token_to_piece_bytes(
286        &self,
287        token: LlamaToken,
288        buffer_size: usize,
289        special: bool,
290        lstrip: Option<NonZeroU16>,
291    ) -> Result<Vec<u8>, TokenToStringError> {
292        // SAFETY: `*` (0x2A) is never `\0`, so CString::new cannot fail here
293        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
294        let len = string.as_bytes().len();
295        let len = c_int::try_from(len).expect("length fits into c_int");
296        let buf = string.into_raw();
297        let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
298        let size = unsafe {
299            llama_cpp_bindings_sys::llama_token_to_piece(
300                self.vocab_ptr(),
301                token.0,
302                buf,
303                len,
304                lstrip,
305                special,
306            )
307        };
308
309        match size {
310            0 => Err(TokenToStringError::UnknownTokenType),
311            i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
312            size => {
313                let string = unsafe { CString::from_raw(buf) };
314                let mut bytes = string.into_bytes();
315                let len = usize::try_from(size).expect("size is positive and fits into usize");
316                bytes.truncate(len);
317
318                Ok(bytes)
319            }
320        }
321    }
322
323    /// The number of tokens the model was trained on.
324    ///
325    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
326    /// without issue.
327    #[must_use]
328    pub fn n_vocab(&self) -> i32 {
329        unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
330    }
331
332    /// The type of vocab the model was trained on.
333    ///
334    /// # Panics
335    ///
336    /// If llama-cpp emits a vocab type that is not known to this library.
337    #[must_use]
338    pub fn vocab_type(&self) -> VocabType {
339        let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
340        VocabType::try_from(vocab_type).expect("invalid vocab type")
341    }
342
343    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
344    /// without issue.
345    #[must_use]
346    pub fn n_embd(&self) -> c_int {
347        unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
348    }
349
350    /// Returns the total size of all the tensors in the model in bytes.
351    #[must_use]
352    pub fn size(&self) -> u64 {
353        unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
354    }
355
356    /// Returns the number of parameters in the model.
357    #[must_use]
358    pub fn n_params(&self) -> u64 {
359        unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
360    }
361
362    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
363    #[must_use]
364    pub fn is_recurrent(&self) -> bool {
365        unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
366    }
367
368    /// Returns the number of layers within the model.
369    ///
370    /// # Panics
371    /// Panics if the layer count returned by llama.cpp is negative.
372    #[must_use]
373    pub fn n_layer(&self) -> u32 {
374        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
375        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
376            .expect("llama.cpp returns a positive value for n_layer")
377    }
378
379    /// Returns the number of attention heads within the model.
380    ///
381    /// # Panics
382    /// Panics if the head count returned by llama.cpp is negative.
383    #[must_use]
384    pub fn n_head(&self) -> u32 {
385        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
386        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
387            .expect("llama.cpp returns a positive value for n_head")
388    }
389
390    /// Returns the number of KV attention heads.
391    ///
392    /// # Panics
393    /// Panics if the KV head count returned by llama.cpp is negative.
394    #[must_use]
395    pub fn n_head_kv(&self) -> u32 {
396        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
397        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
398            .expect("llama.cpp returns a positive value for n_head_kv")
399    }
400
401    /// Get metadata value as a string by key name
402    ///
403    /// # Errors
404    /// Returns an error if the key is not found or the value is not valid UTF-8.
405    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
406        let key_cstring = CString::new(key)?;
407        let key_ptr = key_cstring.as_ptr();
408
409        extract_meta_string(
410            |buf_ptr, buf_len| unsafe {
411                llama_cpp_bindings_sys::llama_model_meta_val_str(
412                    self.model.as_ptr(),
413                    key_ptr,
414                    buf_ptr,
415                    buf_len,
416                )
417            },
418            256,
419        )
420    }
421
422    /// Get the number of metadata key/value pairs
423    #[must_use]
424    pub fn meta_count(&self) -> i32 {
425        unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
426    }
427
428    /// Get metadata key name by index
429    ///
430    /// # Errors
431    /// Returns an error if the index is out of range or the key is not valid UTF-8.
432    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
433        extract_meta_string(
434            |buf_ptr, buf_len| unsafe {
435                llama_cpp_bindings_sys::llama_model_meta_key_by_index(
436                    self.model.as_ptr(),
437                    index,
438                    buf_ptr,
439                    buf_len,
440                )
441            },
442            256,
443        )
444    }
445
446    /// Get metadata value as a string by index
447    ///
448    /// # Errors
449    /// Returns an error if the index is out of range or the value is not valid UTF-8.
450    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
451        extract_meta_string(
452            |buf_ptr, buf_len| unsafe {
453                llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
454                    self.model.as_ptr(),
455                    index,
456                    buf_ptr,
457                    buf_len,
458                )
459            },
460            256,
461        )
462    }
463
464    /// Returns the rope type of the model.
465    pub fn rope_type(&self) -> Option<RopeType> {
466        match unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) } {
467            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NONE => None,
468            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
469            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
470            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
471            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
472            rope_type => {
473                tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
474                None
475            }
476        }
477    }
478
479    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
480    ///
481    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
482    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
483    /// the chat.
484    ///
485    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
486    /// to parse jinja templates not supported by the llama.cpp template engine.
487    ///
488    /// # Errors
489    ///
490    /// * If the model has no chat template by that name
491    /// * If the chat template is not a valid [`CString`].
492    pub fn chat_template(
493        &self,
494        name: Option<&str>,
495    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
496        let name_cstr = name.map(CString::new);
497        let name_ptr = match name_cstr {
498            Some(Ok(name)) => name.as_ptr(),
499            _ => std::ptr::null(),
500        };
501        let result = unsafe {
502            llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
503        };
504
505        // Convert result to Rust String if not null
506        if result.is_null() {
507            Err(ChatTemplateError::MissingTemplate)
508        } else {
509            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
510            let chat_template = CString::new(chat_template_cstr.to_bytes())?;
511
512            Ok(LlamaChatTemplate(chat_template))
513        }
514    }
515
516    /// Loads a model from a file.
517    ///
518    /// # Errors
519    ///
520    /// See [`LlamaModelLoadError`] for more information.
521    #[tracing::instrument(skip_all, fields(params))]
522    pub fn load_from_file(
523        _: &LlamaBackend,
524        path: impl AsRef<Path>,
525        params: &LlamaModelParams,
526    ) -> Result<Self, LlamaModelLoadError> {
527        let path = path.as_ref();
528        debug_assert!(
529            Path::new(path).exists(),
530            "{} does not exist",
531            path.display()
532        );
533        let path = path
534            .to_str()
535            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
536
537        let cstr = CString::new(path)?;
538        let llama_model = unsafe {
539            llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
540        };
541
542        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
543
544        tracing::debug!(?path, "Loaded model");
545
546        Ok(LlamaModel { model })
547    }
548
549    /// Initializes a lora adapter from a file.
550    ///
551    /// # Errors
552    ///
553    /// See [`LlamaLoraAdapterInitError`] for more information.
554    pub fn lora_adapter_init(
555        &self,
556        path: impl AsRef<Path>,
557    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
558        let path = path.as_ref();
559        debug_assert!(
560            Path::new(path).exists(),
561            "{} does not exist",
562            path.display()
563        );
564
565        let path = path
566            .to_str()
567            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
568                path.to_path_buf(),
569            ))?;
570
571        let cstr = CString::new(path)?;
572        let adapter = unsafe {
573            llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
574        };
575
576        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
577
578        tracing::debug!(?path, "Initialized lora adapter");
579
580        Ok(LlamaLoraAdapter {
581            lora_adapter: adapter,
582        })
583    }
584
585    /// Create a new context from this model.
586    ///
587    /// # Errors
588    ///
589    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
590    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
591    pub fn new_context<'model>(
592        &'model self,
593        _: &LlamaBackend,
594        params: LlamaContextParams,
595    ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
596        let context_params = params.context_params;
597        let context = unsafe {
598            llama_cpp_bindings_sys::llama_new_context_with_model(
599                self.model.as_ptr(),
600                context_params,
601            )
602        };
603        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
604
605        Ok(LlamaContext::new(self, context, params.embeddings()))
606    }
607
608    /// Apply the models chat template to some messages.
609    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
610    ///
611    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the `ChatML` template when given
612    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
613    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
614    /// string.
615    ///
616    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
617    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
618    ///
619    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
620    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
621    /// one into the output and the output may also have unexpected output aside from that.
622    ///
623    /// # Errors
624    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
625    ///
626    /// # Panics
627    /// Panics if the buffer size exceeds `i32::MAX`.
628    #[tracing::instrument(skip_all)]
629    pub fn apply_chat_template(
630        &self,
631        tmpl: &LlamaChatTemplate,
632        chat: &[LlamaChatMessage],
633        add_ass: bool,
634    ) -> Result<String, ApplyChatTemplateError> {
635        // Buffer is twice the length of messages per their recommendation
636        let message_length = chat.iter().fold(0, |acc, c| {
637            acc + c.role.to_bytes().len() + c.content.to_bytes().len()
638        });
639        let mut buff: Vec<u8> = vec![0; message_length * 2];
640
641        // Build our llama_cpp_bindings_sys chat messages
642        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
643            .iter()
644            .map(|c| llama_cpp_bindings_sys::llama_chat_message {
645                role: c.role.as_ptr(),
646                content: c.content.as_ptr(),
647            })
648            .collect();
649
650        let tmpl_ptr = tmpl.0.as_ptr();
651
652        let res = unsafe {
653            llama_cpp_bindings_sys::llama_chat_apply_template(
654                tmpl_ptr,
655                chat.as_ptr(),
656                chat.len(),
657                add_ass,
658                buff.as_mut_ptr().cast::<c_char>(),
659                buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
660            )
661        };
662
663        if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
664            buff.resize(res.try_into().expect("res is negative"), 0);
665
666            let res = unsafe {
667                llama_cpp_bindings_sys::llama_chat_apply_template(
668                    tmpl_ptr,
669                    chat.as_ptr(),
670                    chat.len(),
671                    add_ass,
672                    buff.as_mut_ptr().cast::<c_char>(),
673                    buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
674                )
675            };
676            assert_eq!(Ok(res), buff.len().try_into());
677        }
678        buff.truncate(res.try_into().expect("res is negative"));
679
680        Ok(String::from_utf8(buff)?)
681    }
682
683    /// Apply the models chat template to some messages and return an optional tool grammar.
684    /// `tools_json` should be an OpenAI-compatible tool definition JSON array string.
685    /// `json_schema` should be a JSON schema string.
686    ///
687    /// # Errors
688    /// Returns an error if the FFI call fails or the result contains invalid data.
689    #[tracing::instrument(skip_all)]
690    pub fn apply_chat_template_with_tools_oaicompat(
691        &self,
692        tmpl: &LlamaChatTemplate,
693        messages: &[LlamaChatMessage],
694        tools_json: Option<&str>,
695        json_schema: Option<&str>,
696        add_generation_prompt: bool,
697    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
698        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
699            .iter()
700            .map(|c| llama_cpp_bindings_sys::llama_chat_message {
701                role: c.role.as_ptr(),
702                content: c.content.as_ptr(),
703            })
704            .collect();
705
706        let tools_cstr = tools_json.map(CString::new).transpose()?;
707        let json_schema_cstr = json_schema.map(CString::new).transpose()?;
708
709        let mut raw_result = new_empty_chat_template_raw_result();
710
711        let rc = unsafe {
712            llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
713                self.model.as_ptr(),
714                tmpl.0.as_ptr(),
715                chat.as_ptr(),
716                chat.len(),
717                tools_cstr
718                    .as_ref()
719                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
720                json_schema_cstr
721                    .as_ref()
722                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
723                add_generation_prompt,
724                &raw mut raw_result,
725            )
726        };
727
728        let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
729
730        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
731    }
732
733    /// Apply the model chat template using OpenAI-compatible JSON messages.
734    ///
735    /// # Errors
736    /// Returns an error if the FFI call fails or the result contains invalid data.
737    #[tracing::instrument(skip_all)]
738    pub fn apply_chat_template_oaicompat(
739        &self,
740        tmpl: &LlamaChatTemplate,
741        params: &OpenAIChatTemplateParams<'_>,
742    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
743        let parse_tool_calls = params.parse_tool_calls;
744        let messages_cstr = CString::new(params.messages_json)?;
745        let tools_cstr = params.tools_json.map(CString::new).transpose()?;
746        let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
747        let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
748        let grammar_cstr = params.grammar.map(CString::new).transpose()?;
749        let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
750        let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
751
752        let mut raw_result = new_empty_chat_template_raw_result();
753
754        let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
755            messages: messages_cstr.as_ptr(),
756            tools: tools_cstr
757                .as_ref()
758                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
759            tool_choice: tool_choice_cstr
760                .as_ref()
761                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
762            json_schema: json_schema_cstr
763                .as_ref()
764                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
765            grammar: grammar_cstr
766                .as_ref()
767                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
768            reasoning_format: reasoning_cstr
769                .as_ref()
770                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
771            chat_template_kwargs: kwargs_cstr
772                .as_ref()
773                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
774            add_generation_prompt: params.add_generation_prompt,
775            use_jinja: params.use_jinja,
776            parallel_tool_calls: params.parallel_tool_calls,
777            enable_thinking: params.enable_thinking,
778            add_bos: params.add_bos,
779            add_eos: params.add_eos,
780        };
781
782        let rc = unsafe {
783            llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
784                self.model.as_ptr(),
785                tmpl.0.as_ptr(),
786                &raw const ffi_params,
787                &raw mut raw_result,
788            )
789        };
790
791        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
792    }
793}
794
795/// Generic helper function for extracting string values from the C API
796/// These are specifically useful for the metadata functions, where we pass in a buffer
797/// to be populated by a string, not yet knowing if the buffer is large enough.
798/// If the buffer was not large enough, we get the correct length back, which can be used to
799/// construct a buffer of appropriate size.
800fn extract_meta_string<TCFunction>(
801    c_function: TCFunction,
802    capacity: usize,
803) -> Result<String, MetaValError>
804where
805    TCFunction: Fn(*mut c_char, usize) -> i32,
806{
807    let mut buffer = vec![0u8; capacity];
808
809    // call the foreign function
810    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
811    if result < 0 {
812        return Err(MetaValError::NegativeReturn(result));
813    }
814
815    // check if the response fit in our buffer
816    let returned_len = result as usize;
817    if returned_len >= capacity {
818        // buffer wasn't large enough, try again with the correct capacity.
819        return extract_meta_string(c_function, returned_len + 1);
820    }
821
822    // verify null termination
823    debug_assert_eq!(
824        buffer.get(returned_len),
825        Some(&0),
826        "should end with null byte"
827    );
828
829    // resize, convert, and return
830    buffer.truncate(returned_len);
831
832    Ok(String::from_utf8(buffer)?)
833}
834
835impl Drop for LlamaModel {
836    fn drop(&mut self) {
837        unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
838    }
839}