Skip to main content

llama_cpp_bindings/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7
8use crate::context::LlamaContext;
9use crate::context::params::LlamaContextParams;
10use crate::llama_backend::LlamaBackend;
11use crate::openai::OpenAIChatTemplateParams;
12use crate::token::LlamaToken;
13use crate::token_type::LlamaTokenAttrs;
14use crate::{
15    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
16    LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
17};
18
19pub mod add_bos;
20pub mod chat_template_result;
21pub mod grammar_trigger;
22pub mod llama_chat_message;
23pub mod llama_chat_template;
24pub mod llama_lora_adapter;
25pub mod params;
26pub mod rope_type;
27pub mod split_mode;
28pub mod vocab_type;
29
30pub use add_bos::AddBos;
31pub use chat_template_result::ChatTemplateResult;
32pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
33pub use llama_chat_message::LlamaChatMessage;
34pub use llama_chat_template::LlamaChatTemplate;
35pub use llama_lora_adapter::LlamaLoraAdapter;
36pub use rope_type::RopeType;
37pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
38
39use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
40use params::LlamaModelParams;
41
42/// A safe wrapper around `llama_model`.
43#[derive(Debug)]
44#[repr(transparent)]
45pub struct LlamaModel {
46    /// Raw pointer to the underlying `llama_model`.
47    pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
48}
49
50unsafe impl Send for LlamaModel {}
51
52unsafe impl Sync for LlamaModel {}
53
54impl LlamaModel {
55    /// Returns a raw pointer to the model's vocabulary.
56    #[must_use]
57    pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
58        unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
59    }
60
61    /// get the number of tokens the model was trained on
62    ///
63    /// # Panics
64    ///
65    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
66    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
67    #[must_use]
68    pub fn n_ctx_train(&self) -> u32 {
69        let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
70        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
71    }
72
73    /// Get all tokens in the model.
74    pub fn tokens(
75        &self,
76        decode_special: bool,
77    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
78        (0..self.n_vocab())
79            .map(LlamaToken::new)
80            .map(move |llama_token| {
81                let mut decoder = encoding_rs::UTF_8.new_decoder();
82                (
83                    llama_token,
84                    self.token_to_piece(llama_token, &mut decoder, decode_special, None),
85                )
86            })
87    }
88
89    /// Get the beginning of stream token.
90    #[must_use]
91    pub fn token_bos(&self) -> LlamaToken {
92        let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
93        LlamaToken(token)
94    }
95
96    /// Get the end of stream token.
97    #[must_use]
98    pub fn token_eos(&self) -> LlamaToken {
99        let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
100        LlamaToken(token)
101    }
102
103    /// Get the newline token.
104    #[must_use]
105    pub fn token_nl(&self) -> LlamaToken {
106        let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
107        LlamaToken(token)
108    }
109
110    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
111    #[must_use]
112    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
113        unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
114    }
115
116    /// Get the decoder start token.
117    #[must_use]
118    pub fn decode_start_token(&self) -> LlamaToken {
119        let token =
120            unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
121        LlamaToken(token)
122    }
123
124    /// Get the separator token (SEP).
125    #[must_use]
126    pub fn token_sep(&self) -> LlamaToken {
127        let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
128        LlamaToken(token)
129    }
130
131    /// Convert a string to a Vector of tokens.
132    ///
133    /// # Errors
134    ///
135    /// - if [`str`] contains a null byte.
136    ///
137    /// # Panics
138    ///
139    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
140    ///
141    ///
142    /// ```no_run
143    /// use llama_cpp_bindings::model::LlamaModel;
144    ///
145    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
146    /// use std::path::Path;
147    /// use llama_cpp_bindings::model::AddBos;
148    /// let backend = llama_cpp_bindings::llama_backend::LlamaBackend::init()?;
149    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
150    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
151    /// # Ok(())
152    /// # }
153    pub fn str_to_token(
154        &self,
155        str: &str,
156        add_bos: AddBos,
157    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
158        let add_bos = match add_bos {
159            AddBos::Always => true,
160            AddBos::Never => false,
161        };
162
163        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
164        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
165
166        let c_string = CString::new(str)?;
167        let buffer_capacity =
168            c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
169
170        let size = unsafe {
171            llama_cpp_bindings_sys::llama_tokenize(
172                self.vocab_ptr(),
173                c_string.as_ptr(),
174                c_int::try_from(c_string.as_bytes().len())?,
175                buffer
176                    .as_mut_ptr()
177                    .cast::<llama_cpp_bindings_sys::llama_token>(),
178                buffer_capacity,
179                add_bos,
180                true,
181            )
182        };
183
184        let size = if size.is_negative() {
185            buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize"));
186            unsafe {
187                llama_cpp_bindings_sys::llama_tokenize(
188                    self.vocab_ptr(),
189                    c_string.as_ptr(),
190                    c_int::try_from(c_string.as_bytes().len())?,
191                    buffer
192                        .as_mut_ptr()
193                        .cast::<llama_cpp_bindings_sys::llama_token>(),
194                    -size,
195                    add_bos,
196                    true,
197                )
198            }
199        } else {
200            size
201        };
202
203        let size = usize::try_from(size).expect("size is positive and fits into usize");
204
205        // SAFETY: `size` < `capacity` and llama-cpp has initialized elements up to `size`
206        unsafe { buffer.set_len(size) }
207
208        Ok(buffer)
209    }
210
211    /// Get the type of a token.
212    ///
213    /// # Panics
214    ///
215    /// If the token type is not known to this library.
216    #[must_use]
217    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
218        let token_type =
219            unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
220        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
221    }
222
223    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
224    ///
225    /// This is the new default function for token decoding and provides direct access to
226    /// the llama.cpp token decoding functionality without any special logic or filtering.
227    ///
228    /// Decoding raw string requires using an decoder, tokens from language models may not always map
229    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
230    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
231    /// garbage.
232    ///
233    /// # Errors
234    ///
235    /// - if the token type is unknown
236    ///
237    /// # Panics
238    ///
239    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
240    pub fn token_to_piece(
241        &self,
242        token: LlamaToken,
243        decoder: &mut encoding_rs::Decoder,
244        special: bool,
245        lstrip: Option<NonZeroU16>,
246    ) -> Result<String, TokenToStringError> {
247        let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
248            Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self
249                .token_to_piece_bytes(
250                    token,
251                    (-required_size)
252                        .try_into()
253                        .expect("Error buffer size is positive"),
254                    special,
255                    lstrip,
256                ),
257            other => other,
258        }?;
259
260        let mut output_piece = String::with_capacity(bytes.len());
261        let (_result, _decoded_size, _had_replacements) =
262            decoder.decode_to_string(&bytes, &mut output_piece, false);
263
264        Ok(output_piece)
265    }
266
267    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
268    ///
269    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
270    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
271    /// letting the caller handle errors. For a safer interface returning rust strings directly use `token_to_piece` instead!
272    ///
273    /// # Errors
274    ///
275    /// - if the token type is unknown
276    /// - the resultant token is larger than `buffer_size`.
277    ///
278    /// # Panics
279    ///
280    /// - if `buffer_size` does not fit into a [`c_int`].
281    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
282    pub fn token_to_piece_bytes(
283        &self,
284        token: LlamaToken,
285        buffer_size: usize,
286        special: bool,
287        lstrip: Option<NonZeroU16>,
288    ) -> Result<Vec<u8>, TokenToStringError> {
289        // SAFETY: `*` (0x2A) is never `\0`, so CString::new cannot fail here
290        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
291        let len = string.as_bytes().len();
292        let len = c_int::try_from(len).expect("length fits into c_int");
293        let buf = string.into_raw();
294        let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
295        let size = unsafe {
296            llama_cpp_bindings_sys::llama_token_to_piece(
297                self.vocab_ptr(),
298                token.0,
299                buf,
300                len,
301                lstrip,
302                special,
303            )
304        };
305
306        match size {
307            0 => Err(TokenToStringError::UnknownTokenType),
308            error_code if error_code.is_negative() => {
309                Err(TokenToStringError::InsufficientBufferSpace(error_code))
310            }
311            size => {
312                let string = unsafe { CString::from_raw(buf) };
313                let mut bytes = string.into_bytes();
314                let len = usize::try_from(size).expect("size is positive and fits into usize");
315                bytes.truncate(len);
316
317                Ok(bytes)
318            }
319        }
320    }
321
322    /// The number of tokens the model was trained on.
323    ///
324    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
325    /// without issue.
326    #[must_use]
327    pub fn n_vocab(&self) -> i32 {
328        unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
329    }
330
331    /// The type of vocab the model was trained on.
332    ///
333    /// # Panics
334    ///
335    /// If llama-cpp emits a vocab type that is not known to this library.
336    #[must_use]
337    pub fn vocab_type(&self) -> VocabType {
338        let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
339        VocabType::try_from(vocab_type).expect("invalid vocab type")
340    }
341
342    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
343    /// without issue.
344    #[must_use]
345    pub fn n_embd(&self) -> c_int {
346        unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
347    }
348
349    /// Returns the total size of all the tensors in the model in bytes.
350    #[must_use]
351    pub fn size(&self) -> u64 {
352        unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
353    }
354
355    /// Returns the number of parameters in the model.
356    #[must_use]
357    pub fn n_params(&self) -> u64 {
358        unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
359    }
360
361    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
362    #[must_use]
363    pub fn is_recurrent(&self) -> bool {
364        unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
365    }
366
367    /// Returns the number of layers within the model.
368    ///
369    /// # Panics
370    /// Panics if the layer count returned by llama.cpp is negative.
371    #[must_use]
372    pub fn n_layer(&self) -> u32 {
373        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
374        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
375            .expect("llama.cpp returns a positive value for n_layer")
376    }
377
378    /// Returns the number of attention heads within the model.
379    ///
380    /// # Panics
381    /// Panics if the head count returned by llama.cpp is negative.
382    #[must_use]
383    pub fn n_head(&self) -> u32 {
384        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
385        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
386            .expect("llama.cpp returns a positive value for n_head")
387    }
388
389    /// Returns the number of KV attention heads.
390    ///
391    /// # Panics
392    /// Panics if the KV head count returned by llama.cpp is negative.
393    #[must_use]
394    pub fn n_head_kv(&self) -> u32 {
395        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
396        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
397            .expect("llama.cpp returns a positive value for n_head_kv")
398    }
399
400    /// Get metadata value as a string by key name
401    ///
402    /// # Errors
403    /// Returns an error if the key is not found or the value is not valid UTF-8.
404    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
405        let key_cstring = CString::new(key)?;
406        let key_ptr = key_cstring.as_ptr();
407
408        extract_meta_string(
409            |buf_ptr, buf_len| unsafe {
410                llama_cpp_bindings_sys::llama_model_meta_val_str(
411                    self.model.as_ptr(),
412                    key_ptr,
413                    buf_ptr,
414                    buf_len,
415                )
416            },
417            256,
418        )
419    }
420
421    /// Get the number of metadata key/value pairs
422    #[must_use]
423    pub fn meta_count(&self) -> i32 {
424        unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
425    }
426
427    /// Get metadata key name by index
428    ///
429    /// # Errors
430    /// Returns an error if the index is out of range or the key is not valid UTF-8.
431    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
432        extract_meta_string(
433            |buf_ptr, buf_len| unsafe {
434                llama_cpp_bindings_sys::llama_model_meta_key_by_index(
435                    self.model.as_ptr(),
436                    index,
437                    buf_ptr,
438                    buf_len,
439                )
440            },
441            256,
442        )
443    }
444
445    /// Get metadata value as a string by index
446    ///
447    /// # Errors
448    /// Returns an error if the index is out of range or the value is not valid UTF-8.
449    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
450        extract_meta_string(
451            |buf_ptr, buf_len| unsafe {
452                llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
453                    self.model.as_ptr(),
454                    index,
455                    buf_ptr,
456                    buf_len,
457                )
458            },
459            256,
460        )
461    }
462
463    /// Returns the rope type of the model.
464    pub fn rope_type(&self) -> Option<RopeType> {
465        match unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) } {
466            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NONE => None,
467            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
468            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
469            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
470            llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
471            rope_type => {
472                tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
473                None
474            }
475        }
476    }
477
478    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
479    ///
480    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
481    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
482    /// the chat.
483    ///
484    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
485    /// to parse jinja templates not supported by the llama.cpp template engine.
486    ///
487    /// # Errors
488    ///
489    /// * If the model has no chat template by that name
490    /// * If the chat template is not a valid [`CString`].
491    pub fn chat_template(
492        &self,
493        name: Option<&str>,
494    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
495        let name_cstr = name.map(CString::new);
496        let name_ptr = match name_cstr {
497            Some(Ok(name)) => name.as_ptr(),
498            _ => std::ptr::null(),
499        };
500        let result = unsafe {
501            llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
502        };
503
504        if result.is_null() {
505            Err(ChatTemplateError::MissingTemplate)
506        } else {
507            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
508            let chat_template = CString::new(chat_template_cstr.to_bytes())?;
509
510            Ok(LlamaChatTemplate(chat_template))
511        }
512    }
513
514    /// Loads a model from a file.
515    ///
516    /// # Errors
517    ///
518    /// See [`LlamaModelLoadError`] for more information.
519    #[tracing::instrument(skip_all, fields(params))]
520    pub fn load_from_file(
521        _: &LlamaBackend,
522        path: impl AsRef<Path>,
523        params: &LlamaModelParams,
524    ) -> Result<Self, LlamaModelLoadError> {
525        let path = path.as_ref();
526        debug_assert!(
527            Path::new(path).exists(),
528            "{} does not exist",
529            path.display()
530        );
531        let path = path
532            .to_str()
533            .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
534
535        let cstr = CString::new(path)?;
536        let llama_model = unsafe {
537            llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
538        };
539
540        let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
541
542        tracing::debug!(?path, "Loaded model");
543
544        Ok(LlamaModel { model })
545    }
546
547    /// Initializes a lora adapter from a file.
548    ///
549    /// # Errors
550    ///
551    /// See [`LlamaLoraAdapterInitError`] for more information.
552    pub fn lora_adapter_init(
553        &self,
554        path: impl AsRef<Path>,
555    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
556        let path = path.as_ref();
557        debug_assert!(
558            Path::new(path).exists(),
559            "{} does not exist",
560            path.display()
561        );
562
563        let path = path
564            .to_str()
565            .ok_or(LlamaLoraAdapterInitError::PathToStrError(
566                path.to_path_buf(),
567            ))?;
568
569        let cstr = CString::new(path)?;
570        let adapter = unsafe {
571            llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
572        };
573
574        let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
575
576        tracing::debug!(?path, "Initialized lora adapter");
577
578        Ok(LlamaLoraAdapter {
579            lora_adapter: adapter,
580        })
581    }
582
583    /// Create a new context from this model.
584    ///
585    /// # Errors
586    ///
587    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
588    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
589    pub fn new_context<'model>(
590        &'model self,
591        _: &LlamaBackend,
592        params: LlamaContextParams,
593    ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
594        let context_params = params.context_params;
595        let context = unsafe {
596            llama_cpp_bindings_sys::llama_new_context_with_model(
597                self.model.as_ptr(),
598                context_params,
599            )
600        };
601        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
602
603        Ok(LlamaContext::new(self, context, params.embeddings()))
604    }
605
606    /// Apply the models chat template to some messages.
607    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
608    ///
609    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the `ChatML` template when given
610    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
611    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
612    /// string.
613    ///
614    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
615    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
616    ///
617    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
618    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
619    /// one into the output and the output may also have unexpected output aside from that.
620    ///
621    /// # Errors
622    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
623    ///
624    /// # Panics
625    /// Panics if the buffer size exceeds `i32::MAX`.
626    #[tracing::instrument(skip_all)]
627    pub fn apply_chat_template(
628        &self,
629        tmpl: &LlamaChatTemplate,
630        chat: &[LlamaChatMessage],
631        add_ass: bool,
632    ) -> Result<String, ApplyChatTemplateError> {
633        let message_length = chat.iter().fold(0, |acc, chat_message| {
634            acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
635        });
636        let mut buff: Vec<u8> = vec![0; message_length * 2];
637
638        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
639            .iter()
640            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
641                role: chat_message.role.as_ptr(),
642                content: chat_message.content.as_ptr(),
643            })
644            .collect();
645
646        let tmpl_ptr = tmpl.0.as_ptr();
647
648        let res = unsafe {
649            llama_cpp_bindings_sys::llama_chat_apply_template(
650                tmpl_ptr,
651                chat.as_ptr(),
652                chat.len(),
653                add_ass,
654                buff.as_mut_ptr().cast::<c_char>(),
655                buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
656            )
657        };
658
659        if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
660            buff.resize(res.try_into().expect("res is negative"), 0);
661
662            let res = unsafe {
663                llama_cpp_bindings_sys::llama_chat_apply_template(
664                    tmpl_ptr,
665                    chat.as_ptr(),
666                    chat.len(),
667                    add_ass,
668                    buff.as_mut_ptr().cast::<c_char>(),
669                    buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
670                )
671            };
672            assert_eq!(Ok(res), buff.len().try_into());
673        }
674        buff.truncate(res.try_into().expect("res is negative"));
675
676        Ok(String::from_utf8(buff)?)
677    }
678
679    /// Apply the models chat template to some messages and return an optional tool grammar.
680    /// `tools_json` should be an OpenAI-compatible tool definition JSON array string.
681    /// `json_schema` should be a JSON schema string.
682    ///
683    /// # Errors
684    /// Returns an error if the FFI call fails or the result contains invalid data.
685    #[tracing::instrument(skip_all)]
686    pub fn apply_chat_template_with_tools_oaicompat(
687        &self,
688        tmpl: &LlamaChatTemplate,
689        messages: &[LlamaChatMessage],
690        tools_json: Option<&str>,
691        json_schema: Option<&str>,
692        add_generation_prompt: bool,
693    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
694        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
695            .iter()
696            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
697                role: chat_message.role.as_ptr(),
698                content: chat_message.content.as_ptr(),
699            })
700            .collect();
701
702        let tools_cstr = tools_json.map(CString::new).transpose()?;
703        let json_schema_cstr = json_schema.map(CString::new).transpose()?;
704
705        let mut raw_result = new_empty_chat_template_raw_result();
706
707        let rc = unsafe {
708            llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
709                self.model.as_ptr(),
710                tmpl.0.as_ptr(),
711                chat.as_ptr(),
712                chat.len(),
713                tools_cstr
714                    .as_ref()
715                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
716                json_schema_cstr
717                    .as_ref()
718                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
719                add_generation_prompt,
720                &raw mut raw_result,
721            )
722        };
723
724        let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
725
726        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
727    }
728
729    /// Apply the model chat template using OpenAI-compatible JSON messages.
730    ///
731    /// # Errors
732    /// Returns an error if the FFI call fails or the result contains invalid data.
733    #[tracing::instrument(skip_all)]
734    pub fn apply_chat_template_oaicompat(
735        &self,
736        tmpl: &LlamaChatTemplate,
737        params: &OpenAIChatTemplateParams<'_>,
738    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
739        let parse_tool_calls = params.parse_tool_calls;
740        let messages_cstr = CString::new(params.messages_json)?;
741        let tools_cstr = params.tools_json.map(CString::new).transpose()?;
742        let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
743        let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
744        let grammar_cstr = params.grammar.map(CString::new).transpose()?;
745        let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
746        let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
747
748        let mut raw_result = new_empty_chat_template_raw_result();
749
750        let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
751            messages: messages_cstr.as_ptr(),
752            tools: tools_cstr
753                .as_ref()
754                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
755            tool_choice: tool_choice_cstr
756                .as_ref()
757                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
758            json_schema: json_schema_cstr
759                .as_ref()
760                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
761            grammar: grammar_cstr
762                .as_ref()
763                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
764            reasoning_format: reasoning_cstr
765                .as_ref()
766                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
767            chat_template_kwargs: kwargs_cstr
768                .as_ref()
769                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
770            add_generation_prompt: params.add_generation_prompt,
771            use_jinja: params.use_jinja,
772            parallel_tool_calls: params.parallel_tool_calls,
773            enable_thinking: params.enable_thinking,
774            add_bos: params.add_bos,
775            add_eos: params.add_eos,
776        };
777
778        let rc = unsafe {
779            llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
780                self.model.as_ptr(),
781                tmpl.0.as_ptr(),
782                &raw const ffi_params,
783                &raw mut raw_result,
784            )
785        };
786
787        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
788    }
789}
790
791fn extract_meta_string<TCFunction>(
792    c_function: TCFunction,
793    capacity: usize,
794) -> Result<String, MetaValError>
795where
796    TCFunction: Fn(*mut c_char, usize) -> i32,
797{
798    let mut buffer = vec![0u8; capacity];
799    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
800
801    if result < 0 {
802        return Err(MetaValError::NegativeReturn(result));
803    }
804
805    let returned_len = result.cast_unsigned() as usize;
806
807    if returned_len >= capacity {
808        return extract_meta_string(c_function, returned_len + 1);
809    }
810
811    debug_assert_eq!(
812        buffer.get(returned_len),
813        Some(&0),
814        "should end with null byte"
815    );
816
817    buffer.truncate(returned_len);
818
819    Ok(String::from_utf8(buffer)?)
820}
821
822impl Drop for LlamaModel {
823    fn drop(&mut self) {
824        unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
825    }
826}