Skip to main content

llama_cpp_bindings/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6
7fn truncated_buffer_to_string(
8    mut buffer: Vec<u8>,
9    length: usize,
10) -> Result<String, ApplyChatTemplateError> {
11    buffer.truncate(length);
12
13    Ok(String::from_utf8(buffer)?)
14}
15
16fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
17    Ok(c_int::try_from(length)?)
18}
19
20fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
21    let c_string = CString::new(str)?;
22    let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
23    Ok((c_string, len))
24}
25use std::ptr::{self, NonNull};
26
27use crate::context::LlamaContext;
28use crate::context::params::LlamaContextParams;
29use crate::llama_backend::LlamaBackend;
30use crate::openai::OpenAIChatTemplateParams;
31use crate::token::LlamaToken;
32use crate::token_type::LlamaTokenAttrs;
33use crate::{
34    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
35    LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
36};
37
38pub mod add_bos;
39pub mod chat_template_result;
40pub mod grammar_trigger;
41pub mod llama_chat_message;
42pub mod llama_chat_template;
43pub mod llama_lora_adapter;
44pub mod params;
45pub mod rope_type;
46pub mod split_mode;
47pub mod vocab_type;
48
49pub use add_bos::AddBos;
50pub use chat_template_result::ChatTemplateResult;
51pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
52pub use llama_chat_message::LlamaChatMessage;
53pub use llama_chat_template::LlamaChatTemplate;
54pub use llama_lora_adapter::LlamaLoraAdapter;
55pub use rope_type::RopeType;
56pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
57
58use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
59use params::LlamaModelParams;
60
61/// A safe wrapper around `llama_model`.
62#[derive(Debug)]
63#[repr(transparent)]
64pub struct LlamaModel {
65    /// Raw pointer to the underlying `llama_model`.
66    pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
67}
68
69unsafe impl Send for LlamaModel {}
70
71unsafe impl Sync for LlamaModel {}
72
73impl LlamaModel {
74    /// Returns a raw pointer to the model's vocabulary.
75    #[must_use]
76    pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
77        unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
78    }
79
80    /// get the number of tokens the model was trained on
81    ///
82    /// # Panics
83    ///
84    /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
85    /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
86    #[must_use]
87    pub fn n_ctx_train(&self) -> u32 {
88        let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
89        u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
90    }
91
92    /// Get all tokens in the model.
93    pub fn tokens(
94        &self,
95        decode_special: bool,
96    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
97        (0..self.n_vocab())
98            .map(LlamaToken::new)
99            .map(move |llama_token| {
100                let mut decoder = encoding_rs::UTF_8.new_decoder();
101                (
102                    llama_token,
103                    self.token_to_piece(llama_token, &mut decoder, decode_special, None),
104                )
105            })
106    }
107
108    /// Get the beginning of stream token.
109    #[must_use]
110    pub fn token_bos(&self) -> LlamaToken {
111        let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
112        LlamaToken(token)
113    }
114
115    /// Get the end of stream token.
116    #[must_use]
117    pub fn token_eos(&self) -> LlamaToken {
118        let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
119        LlamaToken(token)
120    }
121
122    /// Get the newline token.
123    #[must_use]
124    pub fn token_nl(&self) -> LlamaToken {
125        let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
126        LlamaToken(token)
127    }
128
129    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
130    #[must_use]
131    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
132        unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
133    }
134
135    /// Get the decoder start token.
136    #[must_use]
137    pub fn decode_start_token(&self) -> LlamaToken {
138        let token =
139            unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
140        LlamaToken(token)
141    }
142
143    /// Get the separator token (SEP).
144    #[must_use]
145    pub fn token_sep(&self) -> LlamaToken {
146        let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
147        LlamaToken(token)
148    }
149
150    /// Convert a string to a Vector of tokens.
151    ///
152    /// # Errors
153    ///
154    /// - if [`str`] contains a null byte.
155    ///
156    /// # Panics
157    ///
158    /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`].
159    ///
160    ///
161    /// ```no_run
162    /// use llama_cpp_bindings::model::LlamaModel;
163    ///
164    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
165    /// use std::path::Path;
166    /// use llama_cpp_bindings::model::AddBos;
167    /// let backend = llama_cpp_bindings::llama_backend::LlamaBackend::init()?;
168    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
169    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
170    /// # Ok(())
171    /// # }
172    pub fn str_to_token(
173        &self,
174        str: &str,
175        add_bos: AddBos,
176    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
177        let add_bos = match add_bos {
178            AddBos::Always => true,
179            AddBos::Never => false,
180        };
181
182        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
183        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
184
185        let (c_string, c_string_len) = cstring_with_validated_len(str)?;
186        let buffer_capacity = c_int::try_from(buffer.capacity())?;
187
188        let size = unsafe {
189            llama_cpp_bindings_sys::llama_tokenize(
190                self.vocab_ptr(),
191                c_string.as_ptr(),
192                c_string_len,
193                buffer
194                    .as_mut_ptr()
195                    .cast::<llama_cpp_bindings_sys::llama_token>(),
196                buffer_capacity,
197                add_bos,
198                true,
199            )
200        };
201
202        let size = if size.is_negative() {
203            buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize"));
204            unsafe {
205                llama_cpp_bindings_sys::llama_tokenize(
206                    self.vocab_ptr(),
207                    c_string.as_ptr(),
208                    c_string_len,
209                    buffer
210                        .as_mut_ptr()
211                        .cast::<llama_cpp_bindings_sys::llama_token>(),
212                    -size,
213                    add_bos,
214                    true,
215                )
216            }
217        } else {
218            size
219        };
220
221        let size = usize::try_from(size)?;
222
223        // SAFETY: `size` < `capacity` and llama-cpp has initialized elements up to `size`
224        unsafe { buffer.set_len(size) }
225
226        Ok(buffer)
227    }
228
229    /// Get the type of a token.
230    ///
231    /// # Panics
232    ///
233    /// If the token type is not known to this library.
234    #[must_use]
235    pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
236        let token_type =
237            unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
238        LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
239    }
240
241    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
242    ///
243    /// This is the new default function for token decoding and provides direct access to
244    /// the llama.cpp token decoding functionality without any special logic or filtering.
245    ///
246    /// Decoding raw string requires using an decoder, tokens from language models may not always map
247    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
248    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
249    /// garbage.
250    ///
251    /// # Errors
252    ///
253    /// - if the token type is unknown
254    ///
255    /// # Panics
256    ///
257    /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
258    pub fn token_to_piece(
259        &self,
260        token: LlamaToken,
261        decoder: &mut encoding_rs::Decoder,
262        special: bool,
263        lstrip: Option<NonZeroU16>,
264    ) -> Result<String, TokenToStringError> {
265        let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
266            Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self
267                .token_to_piece_bytes(
268                    token,
269                    (-required_size)
270                        .try_into()
271                        .expect("Error buffer size is positive"),
272                    special,
273                    lstrip,
274                ),
275            other => other,
276        }?;
277
278        let mut output_piece = String::with_capacity(bytes.len());
279        let (_result, _decoded_size, _had_replacements) =
280            decoder.decode_to_string(&bytes, &mut output_piece, false);
281
282        Ok(output_piece)
283    }
284
285    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
286    ///
287    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
288    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
289    /// letting the caller handle errors. For a safer interface returning rust strings directly use `token_to_piece` instead!
290    ///
291    /// # Errors
292    ///
293    /// - if the token type is unknown
294    /// - the resultant token is larger than `buffer_size`.
295    #[allow(clippy::missing_panics_doc)]
296    pub fn token_to_piece_bytes(
297        &self,
298        token: LlamaToken,
299        buffer_size: usize,
300        special: bool,
301        lstrip: Option<NonZeroU16>,
302    ) -> Result<Vec<u8>, TokenToStringError> {
303        // SAFETY: `*` (0x2A) is never `\0`, so CString::new cannot fail here
304        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
305        let len = string.as_bytes().len();
306        let len = c_int::try_from(len)?;
307        let buf = string.into_raw();
308        let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
309        let size = unsafe {
310            llama_cpp_bindings_sys::llama_token_to_piece(
311                self.vocab_ptr(),
312                token.0,
313                buf,
314                len,
315                lstrip,
316                special,
317            )
318        };
319
320        match size {
321            0 => Err(TokenToStringError::UnknownTokenType),
322            error_code if error_code.is_negative() => {
323                Err(TokenToStringError::InsufficientBufferSpace(error_code))
324            }
325            size => {
326                let string = unsafe { CString::from_raw(buf) };
327                let mut bytes = string.into_bytes();
328                let len = usize::try_from(size).expect("size is positive and fits into usize");
329                bytes.truncate(len);
330
331                Ok(bytes)
332            }
333        }
334    }
335
336    /// The number of tokens the model was trained on.
337    ///
338    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
339    /// without issue.
340    #[must_use]
341    pub fn n_vocab(&self) -> i32 {
342        unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
343    }
344
345    /// The type of vocab the model was trained on.
346    ///
347    /// # Panics
348    ///
349    /// If llama-cpp emits a vocab type that is not known to this library.
350    #[must_use]
351    pub fn vocab_type(&self) -> VocabType {
352        let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
353        VocabType::try_from(vocab_type).expect("invalid vocab type")
354    }
355
356    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
357    /// without issue.
358    #[must_use]
359    pub fn n_embd(&self) -> c_int {
360        unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
361    }
362
363    /// Returns the total size of all the tensors in the model in bytes.
364    #[must_use]
365    pub fn size(&self) -> u64 {
366        unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
367    }
368
369    /// Returns the number of parameters in the model.
370    #[must_use]
371    pub fn n_params(&self) -> u64 {
372        unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
373    }
374
375    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
376    #[must_use]
377    pub fn is_recurrent(&self) -> bool {
378        unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
379    }
380
381    /// Returns the number of layers within the model.
382    ///
383    /// # Panics
384    /// Panics if the layer count returned by llama.cpp is negative.
385    #[must_use]
386    pub fn n_layer(&self) -> u32 {
387        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
388        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
389            .expect("llama.cpp returns a positive value for n_layer")
390    }
391
392    /// Returns the number of attention heads within the model.
393    ///
394    /// # Panics
395    /// Panics if the head count returned by llama.cpp is negative.
396    #[must_use]
397    pub fn n_head(&self) -> u32 {
398        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
399        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
400            .expect("llama.cpp returns a positive value for n_head")
401    }
402
403    /// Returns the number of KV attention heads.
404    ///
405    /// # Panics
406    /// Panics if the KV head count returned by llama.cpp is negative.
407    #[must_use]
408    pub fn n_head_kv(&self) -> u32 {
409        // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe
410        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
411            .expect("llama.cpp returns a positive value for n_head_kv")
412    }
413
414    /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc.)
415    ///
416    /// Hybrid models have both attention layers and recurrent/SSM layers.
417    #[must_use]
418    pub fn is_hybrid(&self) -> bool {
419        unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
420    }
421
422    /// Get metadata value as a string by key name
423    ///
424    /// # Errors
425    /// Returns an error if the key is not found or the value is not valid UTF-8.
426    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
427        let key_cstring = CString::new(key)?;
428        let key_ptr = key_cstring.as_ptr();
429
430        extract_meta_string(
431            |buf_ptr, buf_len| unsafe {
432                llama_cpp_bindings_sys::llama_model_meta_val_str(
433                    self.model.as_ptr(),
434                    key_ptr,
435                    buf_ptr,
436                    buf_len,
437                )
438            },
439            256,
440        )
441    }
442
443    /// Get the number of metadata key/value pairs
444    #[must_use]
445    pub fn meta_count(&self) -> i32 {
446        unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
447    }
448
449    /// Get metadata key name by index
450    ///
451    /// # Errors
452    /// Returns an error if the index is out of range or the key is not valid UTF-8.
453    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
454        extract_meta_string(
455            |buf_ptr, buf_len| unsafe {
456                llama_cpp_bindings_sys::llama_model_meta_key_by_index(
457                    self.model.as_ptr(),
458                    index,
459                    buf_ptr,
460                    buf_len,
461                )
462            },
463            256,
464        )
465    }
466
467    /// Get metadata value as a string by index
468    ///
469    /// # Errors
470    /// Returns an error if the index is out of range or the value is not valid UTF-8.
471    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
472        extract_meta_string(
473            |buf_ptr, buf_len| unsafe {
474                llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
475                    self.model.as_ptr(),
476                    index,
477                    buf_ptr,
478                    buf_len,
479                )
480            },
481            256,
482        )
483    }
484
485    /// Returns the rope type of the model.
486    #[must_use]
487    pub fn rope_type(&self) -> Option<RopeType> {
488        let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
489
490        rope_type::rope_type_from_raw(raw)
491    }
492
493    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
494    ///
495    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
496    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
497    /// the chat.
498    ///
499    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
500    /// to parse jinja templates not supported by the llama.cpp template engine.
501    ///
502    /// # Errors
503    ///
504    /// * If the model has no chat template by that name
505    ///
506    /// # Panics
507    ///
508    /// Panics if the C-returned chat template string contains interior null bytes
509    /// (should never happen with valid model data).
510    pub fn chat_template(
511        &self,
512        name: Option<&str>,
513    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
514        let name_cstr = name.map(CString::new);
515        let name_ptr = match name_cstr {
516            Some(Ok(name)) => name.as_ptr(),
517            _ => ptr::null(),
518        };
519        let result = unsafe {
520            llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
521        };
522
523        if result.is_null() {
524            Err(ChatTemplateError::MissingTemplate)
525        } else {
526            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
527            let chat_template = CString::new(chat_template_cstr.to_bytes())
528                .expect("CStr bytes cannot contain interior null bytes");
529
530            Ok(LlamaChatTemplate(chat_template))
531        }
532    }
533
534    /// Loads a model from a file.
535    ///
536    /// # Errors
537    ///
538    /// See [`LlamaModelLoadError`] for more information.
539    ///
540    /// # Panics
541    ///
542    /// Panics if a valid UTF-8 path somehow contains interior null bytes (should never happen).
543    #[tracing::instrument(skip_all, fields(params))]
544    pub fn load_from_file(
545        _: &LlamaBackend,
546        path: impl AsRef<Path>,
547        params: &LlamaModelParams,
548    ) -> Result<Self, LlamaModelLoadError> {
549        let path = path.as_ref();
550
551        let path_str = path
552            .to_str()
553            .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
554
555        if !path.exists() {
556            return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
557        }
558
559        let cstr = CString::new(path_str)?;
560        let llama_model = unsafe {
561            llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
562        };
563
564        let model = match NonNull::new(llama_model) {
565            Some(ptr) => ptr,
566            None if !path.exists() => {
567                return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
568            }
569            None => return Err(LlamaModelLoadError::NullResult),
570        };
571
572        Ok(Self { model })
573    }
574
575    /// Initializes a lora adapter from a file.
576    ///
577    /// # Errors
578    ///
579    /// See [`LlamaLoraAdapterInitError`] for more information.
580    pub fn lora_adapter_init(
581        &self,
582        path: impl AsRef<Path>,
583    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
584        let path = path.as_ref();
585
586        let path_str = path
587            .to_str()
588            .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
589
590        if !path.exists() {
591            return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
592        }
593
594        let cstr = CString::new(path_str)?;
595        let raw_adapter = unsafe {
596            llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
597        };
598
599        let Some(adapter) = NonNull::new(raw_adapter) else {
600            return Err(LlamaLoraAdapterInitError::NullResult);
601        };
602
603        Ok(LlamaLoraAdapter {
604            lora_adapter: adapter,
605        })
606    }
607
608    /// Create a new context from this model.
609    ///
610    /// # Errors
611    ///
612    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
613    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
614    #[allow(clippy::needless_pass_by_value)]
615    pub fn new_context<'model>(
616        &'model self,
617        _: &LlamaBackend,
618        params: LlamaContextParams,
619    ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
620        let context_params = params.context_params;
621        let context = unsafe {
622            llama_cpp_bindings_sys::llama_new_context_with_model(
623                self.model.as_ptr(),
624                context_params,
625            )
626        };
627        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
628
629        Ok(LlamaContext::new(self, context, params.embeddings()))
630    }
631
632    /// Apply the models chat template to some messages.
633    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
634    ///
635    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the `ChatML` template when given
636    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
637    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
638    /// string.
639    ///
640    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
641    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
642    ///
643    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
644    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
645    /// one into the output and the output may also have unexpected output aside from that.
646    ///
647    /// # Errors
648    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
649    #[tracing::instrument(skip_all)]
650    pub fn apply_chat_template(
651        &self,
652        tmpl: &LlamaChatTemplate,
653        chat: &[LlamaChatMessage],
654        add_ass: bool,
655    ) -> Result<String, ApplyChatTemplateError> {
656        let message_length = chat.iter().fold(0, |acc, chat_message| {
657            acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
658        });
659        let mut buff: Vec<u8> = vec![0; message_length * 2];
660
661        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
662            .iter()
663            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
664                role: chat_message.role.as_ptr(),
665                content: chat_message.content.as_ptr(),
666            })
667            .collect();
668
669        let tmpl_ptr = tmpl.0.as_ptr();
670
671        let buff_len: i32 = buff.len().try_into()?;
672
673        let res = unsafe {
674            llama_cpp_bindings_sys::llama_chat_apply_template(
675                tmpl_ptr,
676                chat.as_ptr(),
677                chat.len(),
678                add_ass,
679                buff.as_mut_ptr().cast::<c_char>(),
680                buff_len,
681            )
682        };
683
684        if res > buff_len {
685            let required_size: usize = res.try_into()?;
686            buff.resize(required_size, 0);
687
688            let new_buff_len: i32 = buff.len().try_into()?;
689
690            let res = unsafe {
691                llama_cpp_bindings_sys::llama_chat_apply_template(
692                    tmpl_ptr,
693                    chat.as_ptr(),
694                    chat.len(),
695                    add_ass,
696                    buff.as_mut_ptr().cast::<c_char>(),
697                    new_buff_len,
698                )
699            };
700            let final_size: usize = res.try_into()?;
701
702            return truncated_buffer_to_string(buff, final_size);
703        }
704
705        let final_size: usize = res.try_into()?;
706
707        truncated_buffer_to_string(buff, final_size)
708    }
709
710    /// Apply the models chat template to some messages and return an optional tool grammar.
711    /// `tools_json` should be an OpenAI-compatible tool definition JSON array string.
712    /// `json_schema` should be a JSON schema string.
713    ///
714    /// # Errors
715    /// Returns an error if the FFI call fails or the result contains invalid data.
716    #[tracing::instrument(skip_all)]
717    pub fn apply_chat_template_with_tools_oaicompat(
718        &self,
719        tmpl: &LlamaChatTemplate,
720        messages: &[LlamaChatMessage],
721        tools_json: Option<&str>,
722        json_schema: Option<&str>,
723        add_generation_prompt: bool,
724    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
725        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
726            .iter()
727            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
728                role: chat_message.role.as_ptr(),
729                content: chat_message.content.as_ptr(),
730            })
731            .collect();
732
733        let tools_cstr = tools_json.map(CString::new).transpose()?;
734        let json_schema_cstr = json_schema.map(CString::new).transpose()?;
735
736        let mut raw_result = new_empty_chat_template_raw_result();
737
738        let rc = unsafe {
739            llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
740                self.model.as_ptr(),
741                tmpl.0.as_ptr(),
742                chat.as_ptr(),
743                chat.len(),
744                tools_cstr
745                    .as_ref()
746                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
747                json_schema_cstr
748                    .as_ref()
749                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
750                add_generation_prompt,
751                &raw mut raw_result,
752            )
753        };
754
755        let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
756
757        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
758    }
759
760    /// Apply the model chat template using OpenAI-compatible JSON messages.
761    ///
762    /// # Errors
763    /// Returns an error if the FFI call fails or the result contains invalid data.
764    #[tracing::instrument(skip_all)]
765    pub fn apply_chat_template_oaicompat(
766        &self,
767        tmpl: &LlamaChatTemplate,
768        params: &OpenAIChatTemplateParams<'_>,
769    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
770        let parse_tool_calls = params.parse_tool_calls;
771        let messages_cstr = CString::new(params.messages_json)?;
772        let tools_cstr = params.tools_json.map(CString::new).transpose()?;
773        let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
774        let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
775        let grammar_cstr = params.grammar.map(CString::new).transpose()?;
776        let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
777        let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
778
779        let mut raw_result = new_empty_chat_template_raw_result();
780
781        let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
782            messages: messages_cstr.as_ptr(),
783            tools: tools_cstr
784                .as_ref()
785                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
786            tool_choice: tool_choice_cstr
787                .as_ref()
788                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
789            json_schema: json_schema_cstr
790                .as_ref()
791                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
792            grammar: grammar_cstr
793                .as_ref()
794                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
795            reasoning_format: reasoning_cstr
796                .as_ref()
797                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
798            chat_template_kwargs: kwargs_cstr
799                .as_ref()
800                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
801            add_generation_prompt: params.add_generation_prompt,
802            use_jinja: params.use_jinja,
803            parallel_tool_calls: params.parallel_tool_calls,
804            enable_thinking: params.enable_thinking,
805            add_bos: params.add_bos,
806            add_eos: params.add_eos,
807        };
808
809        let rc = unsafe {
810            llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
811                self.model.as_ptr(),
812                tmpl.0.as_ptr(),
813                &raw const ffi_params,
814                &raw mut raw_result,
815            )
816        };
817
818        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
819    }
820}
821
822fn extract_meta_string<TCFunction>(
823    c_function: TCFunction,
824    capacity: usize,
825) -> Result<String, MetaValError>
826where
827    TCFunction: Fn(*mut c_char, usize) -> i32,
828{
829    let mut buffer = vec![0u8; capacity];
830    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
831
832    if result < 0 {
833        return Err(MetaValError::NegativeReturn(result));
834    }
835
836    let returned_len = result.cast_unsigned() as usize;
837
838    if returned_len >= capacity {
839        return extract_meta_string(c_function, returned_len + 1);
840    }
841
842    if buffer.get(returned_len) != Some(&0) {
843        return Err(MetaValError::NegativeReturn(-1));
844    }
845
846    buffer.truncate(returned_len);
847
848    Ok(String::from_utf8(buffer)?)
849}
850
851impl Drop for LlamaModel {
852    fn drop(&mut self) {
853        unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
854    }
855}
856
857#[cfg(test)]
858mod extract_meta_string_tests {
859    use super::extract_meta_string;
860    use crate::MetaValError;
861
862    #[test]
863    fn returns_error_when_null_terminator_missing() {
864        let result = extract_meta_string(
865            |buf_ptr, buf_len| {
866                let buffer =
867                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
868                buffer[0] = b'a';
869                buffer[1] = b'b';
870                // Intentionally do NOT write a null terminator at position 2
871                buffer[2] = b'c';
872                2
873            },
874            4,
875        );
876
877        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
878    }
879
880    #[test]
881    fn returns_error_for_negative_return_value() {
882        let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
883
884        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
885    }
886
887    #[test]
888    fn returns_error_for_invalid_utf8_data() {
889        let result = extract_meta_string(
890            |buf_ptr, buf_len| {
891                let buffer =
892                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
893                buffer[0] = 0xFF;
894                buffer[1] = 0xFE;
895                buffer[2] = 0;
896                2
897            },
898            4,
899        );
900
901        assert!(result.is_err());
902        assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
903    }
904
905    #[test]
906    fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
907        let call_count = std::cell::Cell::new(0);
908        let result = extract_meta_string(
909            |buf_ptr, buf_len| {
910                let count = call_count.get();
911                call_count.set(count + 1);
912                if count == 0 {
913                    // First call: return length larger than capacity to trigger resize
914                    10
915                } else {
916                    // Second call with larger buffer: write valid data
917                    let buffer =
918                        unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
919                    buffer[0] = b'h';
920                    buffer[1] = b'i';
921                    buffer[2] = 0;
922                    2
923                }
924            },
925            4,
926        );
927
928        assert_eq!(result.unwrap(), "hi");
929    }
930
931    #[test]
932    fn cstring_with_validated_len_null_byte_returns_error() {
933        let result = super::cstring_with_validated_len("null\0byte");
934
935        assert!(result.is_err());
936    }
937
938    #[test]
939    fn validate_string_length_overflow_returns_error() {
940        let result = super::validate_string_length_for_tokenizer(usize::MAX);
941
942        assert!(result.is_err());
943    }
944
945    #[test]
946    fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
947        let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
948        let result = super::truncated_buffer_to_string(invalid_utf8, 3);
949
950        assert!(result.is_err());
951    }
952}
953
954#[cfg(test)]
955#[cfg(feature = "tests_that_use_llms")]
956mod tests {
957    use serial_test::serial;
958
959    use super::LlamaModel;
960    use crate::llama_backend::LlamaBackend;
961    use crate::model::AddBos;
962    use crate::model::params::LlamaModelParams;
963    use crate::test_model;
964
965    #[test]
966    #[serial]
967    fn model_loads_with_valid_metadata() {
968        let (_backend, model) = test_model::load_default_model().unwrap();
969        assert!(model.n_vocab() > 0);
970        assert!(model.n_embd() > 0);
971        assert!(model.n_params() > 0);
972        assert!(model.n_ctx_train() > 0);
973    }
974
975    #[test]
976    #[serial]
977    fn special_tokens_exist() {
978        let (_backend, model) = test_model::load_default_model().unwrap();
979        let bos = model.token_bos();
980        let eos = model.token_eos();
981        assert_ne!(bos, eos);
982        assert!(model.is_eog_token(eos));
983        assert!(!model.is_eog_token(bos));
984    }
985
986    #[test]
987    #[serial]
988    fn str_to_token_roundtrip() {
989        let (_backend, model) = test_model::load_default_model().unwrap();
990        let tokens = model.str_to_token("hello world", AddBos::Never).unwrap();
991        assert!(!tokens.is_empty());
992        let mut decoder = encoding_rs::UTF_8.new_decoder();
993        let piece = model
994            .token_to_piece(tokens[0], &mut decoder, false, None)
995            .unwrap();
996        assert!(!piece.is_empty());
997    }
998
999    #[test]
1000    #[serial]
1001    fn chat_template_returns_non_empty() {
1002        let (_backend, model) = test_model::load_default_model().unwrap();
1003        let template = model.chat_template(None);
1004        assert!(template.is_ok());
1005    }
1006
1007    #[test]
1008    #[serial]
1009    fn apply_chat_template_produces_prompt() {
1010        let (_backend, model) = test_model::load_default_model().unwrap();
1011        let template = model.chat_template(None).unwrap();
1012        let message =
1013            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1014        let prompt = model.apply_chat_template(&template, &[message], true);
1015        assert!(prompt.is_ok());
1016        assert!(!prompt.unwrap().is_empty());
1017    }
1018
1019    #[test]
1020    #[serial]
1021    fn apply_chat_template_oaicompat_produces_result() {
1022        let (_backend, model) = test_model::load_default_model().unwrap();
1023        let template = model.chat_template(None).unwrap();
1024        let params = crate::openai::OpenAIChatTemplateParams {
1025            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1026            tools_json: None,
1027            tool_choice: None,
1028            json_schema: None,
1029            grammar: None,
1030            reasoning_format: Some("none"),
1031            chat_template_kwargs: None,
1032            add_generation_prompt: true,
1033            use_jinja: true,
1034            parallel_tool_calls: false,
1035            enable_thinking: false,
1036            add_bos: false,
1037            add_eos: false,
1038            parse_tool_calls: false,
1039        };
1040        let result = model.apply_chat_template_oaicompat(&template, &params);
1041        assert!(result.is_ok());
1042        assert!(!result.unwrap().prompt.is_empty());
1043    }
1044
1045    #[test]
1046    #[serial]
1047    fn meta_count_returns_positive() {
1048        let (_backend, model) = test_model::load_default_model().unwrap();
1049        assert!(model.meta_count() > 0);
1050    }
1051
1052    #[test]
1053    #[serial]
1054    fn tokens_iterator_produces_valid_entries() {
1055        let (_backend, model) = test_model::load_default_model().unwrap();
1056        let mut count = 0;
1057
1058        for (token, piece_result) in model.tokens(false) {
1059            assert!(token.0 >= 0);
1060            // Not all tokens decode successfully, but the iterator should not panic
1061            let _ = piece_result;
1062            count += 1;
1063
1064            if count >= 100 {
1065                break;
1066            }
1067        }
1068
1069        assert_eq!(count, 100);
1070    }
1071
1072    #[test]
1073    #[serial]
1074    fn token_to_piece_bytes_returns_bytes_for_known_token() {
1075        let (_backend, model) = test_model::load_default_model().unwrap();
1076        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1077        let bytes = model
1078            .token_to_piece_bytes(tokens[0], 32, false, None)
1079            .unwrap();
1080
1081        assert!(!bytes.is_empty());
1082    }
1083
1084    #[test]
1085    #[serial]
1086    fn n_layer_returns_positive() {
1087        let (_backend, model) = test_model::load_default_model().unwrap();
1088
1089        assert!(model.n_layer() > 0);
1090    }
1091
1092    #[test]
1093    #[serial]
1094    fn n_head_returns_positive() {
1095        let (_backend, model) = test_model::load_default_model().unwrap();
1096
1097        assert!(model.n_head() > 0);
1098    }
1099
1100    #[test]
1101    #[serial]
1102    fn n_head_kv_returns_positive() {
1103        let (_backend, model) = test_model::load_default_model().unwrap();
1104
1105        assert!(model.n_head_kv() > 0);
1106    }
1107
1108    #[test]
1109    #[serial]
1110    fn meta_key_by_index_returns_valid_key() {
1111        let (_backend, model) = test_model::load_default_model().unwrap();
1112        let key = model.meta_key_by_index(0).unwrap();
1113
1114        assert!(!key.is_empty());
1115    }
1116
1117    #[test]
1118    #[serial]
1119    fn meta_val_str_by_index_returns_valid_value() {
1120        let (_backend, model) = test_model::load_default_model().unwrap();
1121        let value = model.meta_val_str_by_index(0).unwrap();
1122
1123        assert!(!value.is_empty());
1124    }
1125
1126    #[test]
1127    #[serial]
1128    fn meta_key_by_index_out_of_range_returns_error() {
1129        let (_backend, model) = test_model::load_default_model().unwrap();
1130        let result = model.meta_key_by_index(999_999);
1131
1132        assert!(result.is_err());
1133    }
1134
1135    #[test]
1136    #[serial]
1137    fn meta_val_str_by_index_out_of_range_returns_error() {
1138        let (_backend, model) = test_model::load_default_model().unwrap();
1139        let result = model.meta_val_str_by_index(999_999);
1140
1141        assert!(result.is_err());
1142    }
1143
1144    #[test]
1145    #[serial]
1146    fn meta_val_str_returns_value_for_known_key() {
1147        let (_backend, model) = test_model::load_default_model().unwrap();
1148        let first_key = model.meta_key_by_index(0).unwrap();
1149        let value = model.meta_val_str(&first_key).unwrap();
1150
1151        assert!(!value.is_empty());
1152    }
1153
1154    #[test]
1155    #[serial]
1156    fn model_size_returns_nonzero() {
1157        let (_backend, model) = test_model::load_default_model().unwrap();
1158
1159        assert!(model.size() > 0);
1160    }
1161
1162    #[test]
1163    #[serial]
1164    fn is_recurrent_returns_false_for_transformer() {
1165        let (_backend, model) = test_model::load_default_model().unwrap();
1166
1167        assert!(!model.is_recurrent());
1168    }
1169
1170    #[test]
1171    #[serial]
1172    fn rope_type_does_not_panic() {
1173        let (_backend, model) = test_model::load_default_model().unwrap();
1174        let _rope_type = model.rope_type();
1175    }
1176
1177    #[test]
1178    #[serial]
1179    fn load_model_with_invalid_path_returns_error() {
1180        let backend = LlamaBackend::init().unwrap();
1181        let model_params = LlamaModelParams::default();
1182        let result = LlamaModel::load_from_file(&backend, "/nonexistent/model.gguf", &model_params);
1183
1184        assert_eq!(
1185            result.unwrap_err(),
1186            crate::LlamaModelLoadError::FileNotFound(std::path::PathBuf::from(
1187                "/nonexistent/model.gguf"
1188            ))
1189        );
1190    }
1191
1192    #[test]
1193    #[serial]
1194    fn load_model_with_invalid_file_content_returns_null_result() {
1195        let backend = LlamaBackend::init().unwrap();
1196        let model_params = LlamaModelParams::default();
1197        let dummy_path = std::env::temp_dir().join("llama_test_invalid_model.gguf");
1198        std::fs::write(&dummy_path, b"not a valid gguf model file").unwrap();
1199
1200        let result = LlamaModel::load_from_file(&backend, &dummy_path, &model_params);
1201
1202        assert_eq!(result.unwrap_err(), crate::LlamaModelLoadError::NullResult);
1203        let _ = std::fs::remove_file(&dummy_path);
1204    }
1205
1206    #[cfg(unix)]
1207    #[test]
1208    #[serial]
1209    fn load_model_with_non_utf8_path_returns_path_to_str_error() {
1210        use std::ffi::OsStr;
1211        use std::os::unix::ffi::OsStrExt;
1212
1213        let backend = LlamaBackend::init().unwrap();
1214        let model_params = LlamaModelParams::default();
1215        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1216
1217        let result = LlamaModel::load_from_file(&backend, non_utf8_path, &model_params);
1218
1219        assert_eq!(
1220            result.unwrap_err(),
1221            crate::LlamaModelLoadError::PathToStrError(non_utf8_path.to_path_buf())
1222        );
1223    }
1224
1225    #[cfg(unix)]
1226    #[test]
1227    #[serial]
1228    fn lora_adapter_init_with_non_utf8_path_returns_error() {
1229        use std::ffi::OsStr;
1230        use std::os::unix::ffi::OsStrExt;
1231
1232        let (_backend, model) = test_model::load_default_model().unwrap();
1233        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1234
1235        let result = model.lora_adapter_init(non_utf8_path);
1236
1237        assert_eq!(
1238            result.unwrap_err(),
1239            crate::LlamaLoraAdapterInitError::PathToStrError(non_utf8_path.to_path_buf())
1240        );
1241    }
1242
1243    #[test]
1244    #[serial]
1245    fn lora_adapter_init_with_invalid_path_returns_error() {
1246        let (_backend, model) = test_model::load_default_model().unwrap();
1247        let result = model.lora_adapter_init("/nonexistent/path/lora.gguf");
1248
1249        assert_eq!(
1250            result.unwrap_err(),
1251            crate::LlamaLoraAdapterInitError::FileNotFound(std::path::PathBuf::from(
1252                "/nonexistent/path/lora.gguf"
1253            ))
1254        );
1255    }
1256
1257    #[test]
1258    #[serial]
1259    fn new_context_returns_valid_context() {
1260        let (backend, model) = test_model::load_default_model().unwrap();
1261        let ctx_params = crate::context::params::LlamaContextParams::default()
1262            .with_n_ctx(std::num::NonZeroU32::new(256));
1263        let context = model.new_context(&backend, ctx_params).unwrap();
1264
1265        assert!(context.n_ctx() > 0);
1266    }
1267
1268    #[test]
1269    #[serial]
1270    fn token_nl_returns_valid_token() {
1271        let (_backend, model) = test_model::load_default_model().unwrap();
1272        let nl_token = model.token_nl();
1273
1274        assert!(nl_token.0 >= 0);
1275    }
1276
1277    #[test]
1278    #[serial]
1279    fn decode_start_token_returns_valid_token() {
1280        let (_backend, model) = test_model::load_default_model().unwrap();
1281        let _decode_start = model.decode_start_token();
1282    }
1283
1284    #[test]
1285    #[serial]
1286    fn token_sep_returns_valid_token() {
1287        let (_backend, model) = test_model::load_default_model().unwrap();
1288        let _sep_token = model.token_sep();
1289    }
1290
1291    #[test]
1292    #[serial]
1293    fn token_to_piece_handles_large_token_requiring_buffer_resize() {
1294        let (_backend, model) = test_model::load_default_model().unwrap();
1295        let mut decoder = encoding_rs::UTF_8.new_decoder();
1296
1297        for (token, _) in model.tokens(true).take(200) {
1298            let result = model.token_to_piece(token, &mut decoder, true, None);
1299            assert!(result.is_ok());
1300        }
1301    }
1302
1303    #[test]
1304    #[serial]
1305    fn token_to_piece_bytes_insufficient_buffer_returns_error() {
1306        let (_backend, model) = test_model::load_default_model().unwrap();
1307        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1308        let result = model.token_to_piece_bytes(tokens[0], 1, false, None);
1309
1310        assert!(
1311            result
1312                .unwrap_err()
1313                .to_string()
1314                .contains("Insufficient Buffer Space")
1315        );
1316    }
1317
1318    #[test]
1319    #[serial]
1320    fn token_to_piece_with_lstrip() {
1321        let (_backend, model) = test_model::load_default_model().unwrap();
1322        let mut decoder = encoding_rs::UTF_8.new_decoder();
1323        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1324        let result =
1325            model.token_to_piece(tokens[0], &mut decoder, false, std::num::NonZeroU16::new(1));
1326
1327        assert!(result.is_ok());
1328    }
1329
1330    #[test]
1331    #[serial]
1332    fn n_vocab_matches_tokens_iterator_count() {
1333        let (_backend, model) = test_model::load_default_model().unwrap();
1334        let n_vocab = model.n_vocab();
1335        let count = model.tokens(false).count();
1336
1337        assert_eq!(count, n_vocab as usize);
1338    }
1339
1340    #[test]
1341    #[serial]
1342    fn token_attr_returns_valid_attr() {
1343        let (_backend, model) = test_model::load_default_model().unwrap();
1344        let bos = model.token_bos();
1345        let _attr = model.token_attr(bos);
1346    }
1347
1348    #[test]
1349    #[serial]
1350    fn vocab_type_returns_valid_type() {
1351        let (_backend, model) = test_model::load_default_model().unwrap();
1352        let _vocab_type = model.vocab_type();
1353    }
1354
1355    #[test]
1356    #[serial]
1357    fn apply_chat_template_buffer_resize_with_long_messages() {
1358        let (_backend, model) = test_model::load_default_model().unwrap();
1359        let template = model.chat_template(None).unwrap();
1360        let long_content = "a".repeat(2000);
1361        let message =
1362            crate::model::LlamaChatMessage::new("user".to_string(), long_content).unwrap();
1363        let prompt = model.apply_chat_template(&template, &[message], true);
1364
1365        assert!(prompt.is_ok());
1366        assert!(!prompt.unwrap().is_empty());
1367    }
1368
1369    #[test]
1370    #[serial]
1371    fn meta_val_str_with_long_value_triggers_buffer_resize() {
1372        let (_backend, model) = test_model::load_default_model().unwrap();
1373        let count = model.meta_count();
1374
1375        for index in 0..count {
1376            let key = model.meta_key_by_index(index);
1377            let value = model.meta_val_str_by_index(index);
1378            assert!(key.is_ok());
1379            assert!(value.is_ok());
1380        }
1381    }
1382
1383    #[test]
1384    #[serial]
1385    fn str_to_token_with_add_bos_never() {
1386        let (_backend, model) = test_model::load_default_model().unwrap();
1387        let tokens_with_bos = model.str_to_token("hello", AddBos::Always).unwrap();
1388        let tokens_without_bos = model.str_to_token("hello", AddBos::Never).unwrap();
1389
1390        assert!(tokens_with_bos.len() >= tokens_without_bos.len());
1391    }
1392
1393    #[test]
1394    #[serial]
1395    fn apply_chat_template_with_tools_oaicompat_produces_result() {
1396        let (_backend, model) = test_model::load_default_model().unwrap();
1397        let template = model.chat_template(None).unwrap();
1398        let message =
1399            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1400        let result =
1401            model.apply_chat_template_with_tools_oaicompat(&template, &[message], None, None, true);
1402
1403        assert!(result.is_ok());
1404        assert!(!result.unwrap().prompt.is_empty());
1405    }
1406
1407    #[test]
1408    #[serial]
1409    fn apply_chat_template_with_tools_oaicompat_with_tools_json() {
1410        let (_backend, model) = test_model::load_default_model().unwrap();
1411        let template = model.chat_template(None).unwrap();
1412        let message =
1413            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1414        let tools =
1415            r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object"}}}]"#;
1416        let result = model.apply_chat_template_with_tools_oaicompat(
1417            &template,
1418            &[message],
1419            Some(tools),
1420            None,
1421            true,
1422        );
1423
1424        assert!(result.is_ok());
1425    }
1426
1427    #[test]
1428    #[serial]
1429    fn apply_chat_template_with_tools_oaicompat_with_json_schema() {
1430        let (_backend, model) = test_model::load_default_model().unwrap();
1431        let template = model.chat_template(None).unwrap();
1432        let message =
1433            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1434        let schema = r#"{"type":"object","properties":{"name":{"type":"string"}}}"#;
1435        let result = model.apply_chat_template_with_tools_oaicompat(
1436            &template,
1437            &[message],
1438            None,
1439            Some(schema),
1440            true,
1441        );
1442
1443        assert!(result.is_ok());
1444    }
1445
1446    #[test]
1447    #[serial]
1448    fn apply_chat_template_oaicompat_with_tools_and_tool_choice() {
1449        let (_backend, model) = test_model::load_default_model().unwrap();
1450        let template = model.chat_template(None).unwrap();
1451        let params = crate::openai::OpenAIChatTemplateParams {
1452            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1453            tools_json: Some(
1454                r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object","properties":{}}}}]"#,
1455            ),
1456            tool_choice: Some("auto"),
1457            json_schema: None,
1458            grammar: None,
1459            reasoning_format: Some("none"),
1460            chat_template_kwargs: None,
1461            add_generation_prompt: true,
1462            use_jinja: true,
1463            parallel_tool_calls: false,
1464            enable_thinking: false,
1465            add_bos: false,
1466            add_eos: false,
1467            parse_tool_calls: true,
1468        };
1469        let result = model.apply_chat_template_oaicompat(&template, &params);
1470
1471        assert!(result.is_ok());
1472    }
1473
1474    #[test]
1475    #[serial]
1476    fn apply_chat_template_oaicompat_with_json_schema_field() {
1477        let (_backend, model) = test_model::load_default_model().unwrap();
1478        let template = model.chat_template(None).unwrap();
1479        let params = crate::openai::OpenAIChatTemplateParams {
1480            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1481            tools_json: None,
1482            tool_choice: None,
1483            json_schema: Some(r#"{"type":"object","properties":{"name":{"type":"string"}}}"#),
1484            grammar: None,
1485            reasoning_format: Some("none"),
1486            chat_template_kwargs: None,
1487            add_generation_prompt: true,
1488            use_jinja: true,
1489            parallel_tool_calls: false,
1490            enable_thinking: false,
1491            add_bos: false,
1492            add_eos: false,
1493            parse_tool_calls: false,
1494        };
1495        let result = model.apply_chat_template_oaicompat(&template, &params);
1496
1497        assert!(result.is_ok());
1498    }
1499
1500    #[test]
1501    #[serial]
1502    fn apply_chat_template_oaicompat_with_grammar_field() {
1503        let (_backend, model) = test_model::load_default_model().unwrap();
1504        let template = model.chat_template(None).unwrap();
1505        let params = crate::openai::OpenAIChatTemplateParams {
1506            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1507            tools_json: None,
1508            tool_choice: None,
1509            json_schema: None,
1510            grammar: Some("root ::= \"hello\""),
1511            reasoning_format: Some("none"),
1512            chat_template_kwargs: None,
1513            add_generation_prompt: true,
1514            use_jinja: true,
1515            parallel_tool_calls: false,
1516            enable_thinking: false,
1517            add_bos: false,
1518            add_eos: false,
1519            parse_tool_calls: false,
1520        };
1521        let result = model.apply_chat_template_oaicompat(&template, &params);
1522
1523        assert!(result.is_ok());
1524    }
1525
1526    #[test]
1527    #[serial]
1528    fn apply_chat_template_oaicompat_with_kwargs_field() {
1529        let (_backend, model) = test_model::load_default_model().unwrap();
1530        let template = model.chat_template(None).unwrap();
1531        let params = crate::openai::OpenAIChatTemplateParams {
1532            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1533            tools_json: None,
1534            tool_choice: None,
1535            json_schema: None,
1536            grammar: None,
1537            reasoning_format: Some("none"),
1538            chat_template_kwargs: Some(r#"{"bos_token": "<|im_start|>"}"#),
1539            add_generation_prompt: true,
1540            use_jinja: true,
1541            parallel_tool_calls: false,
1542            enable_thinking: false,
1543            add_bos: false,
1544            add_eos: false,
1545            parse_tool_calls: false,
1546        };
1547        let result = model.apply_chat_template_oaicompat(&template, &params);
1548
1549        assert!(result.is_ok());
1550    }
1551
1552    #[test]
1553    #[serial]
1554    fn chat_template_with_nonexistent_name_returns_error() {
1555        let (_backend, model) = test_model::load_default_model().unwrap();
1556
1557        let result = model.chat_template(Some("nonexistent_template_name_xyz"));
1558
1559        assert_eq!(
1560            result.unwrap_err(),
1561            crate::ChatTemplateError::MissingTemplate
1562        );
1563    }
1564
1565    #[test]
1566    #[serial]
1567    fn lora_adapter_init_with_invalid_gguf_returns_null_result() {
1568        let (_backend, model) = test_model::load_default_model().unwrap();
1569        let dummy_path = std::env::temp_dir().join("llama_test_dummy_lora.gguf");
1570        std::fs::write(&dummy_path, b"not a valid gguf").unwrap();
1571
1572        let result = model.lora_adapter_init(&dummy_path);
1573
1574        assert_eq!(
1575            result.unwrap_err(),
1576            crate::LlamaLoraAdapterInitError::NullResult
1577        );
1578        let _ = std::fs::remove_file(&dummy_path);
1579    }
1580
1581    #[test]
1582    #[serial]
1583    fn str_to_token_with_many_tokens_triggers_buffer_resize() {
1584        let (_backend, model) = test_model::load_default_model().unwrap();
1585        // Each digit typically becomes its own token, but the buffer estimate
1586        // is len/2 which is smaller than the actual token count for
1587        // single-char-token strings like "1 2 3 4 ..."
1588        let many_numbers: String = (0..2000).map(|number| format!("{number} ")).collect();
1589
1590        let tokens = model.str_to_token(&many_numbers, AddBos::Always).unwrap();
1591
1592        assert!(tokens.len() > many_numbers.len() / 2);
1593    }
1594
1595    #[test]
1596    #[serial]
1597    fn rope_type_returns_valid_result_for_test_model() {
1598        let (_backend, model) = test_model::load_default_model().unwrap();
1599
1600        let _rope_type = model.rope_type();
1601    }
1602
1603    #[test]
1604    #[serial]
1605    fn meta_val_str_with_null_byte_in_key_returns_error() {
1606        let (_backend, model) = test_model::load_default_model().unwrap();
1607        let result = model.meta_val_str("key\0with_null");
1608
1609        assert!(result.is_err());
1610    }
1611
1612    #[test]
1613    #[serial]
1614    fn apply_chat_template_with_tools_null_byte_in_tools_returns_error() {
1615        let (_backend, model) = test_model::load_default_model().unwrap();
1616        let template = model.chat_template(None).unwrap();
1617        let message =
1618            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1619        let result = model.apply_chat_template_with_tools_oaicompat(
1620            &template,
1621            &[message],
1622            Some("tools\0null"),
1623            None,
1624            true,
1625        );
1626
1627        assert!(result.is_err());
1628    }
1629
1630    #[test]
1631    #[serial]
1632    fn apply_chat_template_with_tools_null_byte_in_json_schema_returns_error() {
1633        let (_backend, model) = test_model::load_default_model().unwrap();
1634        let template = model.chat_template(None).unwrap();
1635        let message =
1636            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1637        let result = model.apply_chat_template_with_tools_oaicompat(
1638            &template,
1639            &[message],
1640            None,
1641            Some("schema\0null"),
1642            true,
1643        );
1644
1645        assert!(result.is_err());
1646    }
1647
1648    #[test]
1649    #[serial]
1650    fn apply_chat_template_oaicompat_with_null_byte_in_messages_returns_error() {
1651        let (_backend, model) = test_model::load_default_model().unwrap();
1652        let template = model.chat_template(None).unwrap();
1653        let params = crate::openai::OpenAIChatTemplateParams {
1654            messages_json: "messages\0null",
1655            tools_json: None,
1656            tool_choice: None,
1657            json_schema: None,
1658            grammar: None,
1659            reasoning_format: None,
1660            chat_template_kwargs: None,
1661            add_generation_prompt: true,
1662            use_jinja: true,
1663            parallel_tool_calls: false,
1664            enable_thinking: false,
1665            add_bos: false,
1666            add_eos: false,
1667            parse_tool_calls: false,
1668        };
1669        let result = model.apply_chat_template_oaicompat(&template, &params);
1670
1671        assert!(result.is_err());
1672    }
1673
1674    #[test]
1675    #[serial]
1676    fn apply_chat_template_oaicompat_with_null_byte_in_tools_returns_error() {
1677        let (_backend, model) = test_model::load_default_model().unwrap();
1678        let template = model.chat_template(None).unwrap();
1679        let params = crate::openai::OpenAIChatTemplateParams {
1680            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1681            tools_json: Some("tools\0null"),
1682            tool_choice: None,
1683            json_schema: None,
1684            grammar: None,
1685            reasoning_format: None,
1686            chat_template_kwargs: None,
1687            add_generation_prompt: true,
1688            use_jinja: true,
1689            parallel_tool_calls: false,
1690            enable_thinking: false,
1691            add_bos: false,
1692            add_eos: false,
1693            parse_tool_calls: false,
1694        };
1695        let result = model.apply_chat_template_oaicompat(&template, &params);
1696
1697        assert!(result.is_err());
1698    }
1699
1700    #[test]
1701    #[serial]
1702    fn apply_chat_template_oaicompat_with_null_byte_in_tool_choice_returns_error() {
1703        let (_backend, model) = test_model::load_default_model().unwrap();
1704        let template = model.chat_template(None).unwrap();
1705        let params = crate::openai::OpenAIChatTemplateParams {
1706            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1707            tools_json: None,
1708            tool_choice: Some("choice\0null"),
1709            json_schema: None,
1710            grammar: None,
1711            reasoning_format: None,
1712            chat_template_kwargs: None,
1713            add_generation_prompt: true,
1714            use_jinja: true,
1715            parallel_tool_calls: false,
1716            enable_thinking: false,
1717            add_bos: false,
1718            add_eos: false,
1719            parse_tool_calls: false,
1720        };
1721        let result = model.apply_chat_template_oaicompat(&template, &params);
1722
1723        assert!(result.is_err());
1724    }
1725
1726    #[test]
1727    #[serial]
1728    fn apply_chat_template_oaicompat_with_null_byte_in_json_schema_returns_error() {
1729        let (_backend, model) = test_model::load_default_model().unwrap();
1730        let template = model.chat_template(None).unwrap();
1731        let params = crate::openai::OpenAIChatTemplateParams {
1732            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1733            tools_json: None,
1734            tool_choice: None,
1735            json_schema: Some("schema\0null"),
1736            grammar: None,
1737            reasoning_format: None,
1738            chat_template_kwargs: None,
1739            add_generation_prompt: true,
1740            use_jinja: true,
1741            parallel_tool_calls: false,
1742            enable_thinking: false,
1743            add_bos: false,
1744            add_eos: false,
1745            parse_tool_calls: false,
1746        };
1747        let result = model.apply_chat_template_oaicompat(&template, &params);
1748
1749        assert!(result.is_err());
1750    }
1751
1752    #[test]
1753    #[serial]
1754    fn apply_chat_template_oaicompat_with_null_byte_in_grammar_returns_error() {
1755        let (_backend, model) = test_model::load_default_model().unwrap();
1756        let template = model.chat_template(None).unwrap();
1757        let params = crate::openai::OpenAIChatTemplateParams {
1758            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1759            tools_json: None,
1760            tool_choice: None,
1761            json_schema: None,
1762            grammar: Some("grammar\0null"),
1763            reasoning_format: None,
1764            chat_template_kwargs: None,
1765            add_generation_prompt: true,
1766            use_jinja: true,
1767            parallel_tool_calls: false,
1768            enable_thinking: false,
1769            add_bos: false,
1770            add_eos: false,
1771            parse_tool_calls: false,
1772        };
1773        let result = model.apply_chat_template_oaicompat(&template, &params);
1774
1775        assert!(result.is_err());
1776    }
1777
1778    #[test]
1779    #[serial]
1780    fn apply_chat_template_oaicompat_with_null_byte_in_reasoning_format_returns_error() {
1781        let (_backend, model) = test_model::load_default_model().unwrap();
1782        let template = model.chat_template(None).unwrap();
1783        let params = crate::openai::OpenAIChatTemplateParams {
1784            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1785            tools_json: None,
1786            tool_choice: None,
1787            json_schema: None,
1788            grammar: None,
1789            reasoning_format: Some("format\0null"),
1790            chat_template_kwargs: None,
1791            add_generation_prompt: true,
1792            use_jinja: true,
1793            parallel_tool_calls: false,
1794            enable_thinking: false,
1795            add_bos: false,
1796            add_eos: false,
1797            parse_tool_calls: false,
1798        };
1799        let result = model.apply_chat_template_oaicompat(&template, &params);
1800
1801        assert!(result.is_err());
1802    }
1803
1804    #[test]
1805    #[serial]
1806    fn apply_chat_template_oaicompat_with_null_byte_in_kwargs_returns_error() {
1807        let (_backend, model) = test_model::load_default_model().unwrap();
1808        let template = model.chat_template(None).unwrap();
1809        let params = crate::openai::OpenAIChatTemplateParams {
1810            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1811            tools_json: None,
1812            tool_choice: None,
1813            json_schema: None,
1814            grammar: None,
1815            reasoning_format: None,
1816            chat_template_kwargs: Some("kwargs\0null"),
1817            add_generation_prompt: true,
1818            use_jinja: true,
1819            parallel_tool_calls: false,
1820            enable_thinking: false,
1821            add_bos: false,
1822            add_eos: false,
1823            parse_tool_calls: false,
1824        };
1825        let result = model.apply_chat_template_oaicompat(&template, &params);
1826
1827        assert!(result.is_err());
1828    }
1829
1830    #[test]
1831    #[serial]
1832    fn new_context_with_huge_ctx_returns_null_error() {
1833        let (_backend, model) = test_model::load_default_model().unwrap();
1834        let ctx_params = crate::context::params::LlamaContextParams::default()
1835            .with_n_ctx(std::num::NonZeroU32::new(u32::MAX));
1836
1837        let result = model.new_context(&_backend, ctx_params);
1838
1839        assert!(result.is_err());
1840    }
1841}