Skip to main content

llama_cpp_bindings/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6
7fn truncated_buffer_to_string(
8    mut buffer: Vec<u8>,
9    length: usize,
10) -> Result<String, ApplyChatTemplateError> {
11    buffer.truncate(length);
12
13    Ok(String::from_utf8(buffer)?)
14}
15
16fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
17    Ok(c_int::try_from(length)?)
18}
19
20fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
21    let c_string = CString::new(str)?;
22    let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
23    Ok((c_string, len))
24}
25use std::ptr::{self, NonNull};
26
27use crate::context::LlamaContext;
28use crate::context::params::LlamaContextParams;
29use crate::llama_backend::LlamaBackend;
30use crate::openai::OpenAIChatTemplateParams;
31use crate::token::LlamaToken;
32use crate::token_type::LlamaTokenAttrs;
33use crate::{
34    ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
35    LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
36};
37
38pub mod add_bos;
39pub mod chat_template_result;
40pub mod grammar_trigger;
41pub mod llama_chat_message;
42pub mod llama_chat_template;
43pub mod llama_lora_adapter;
44pub mod params;
45pub mod rope_type;
46pub mod split_mode;
47pub mod vocab_type;
48
49pub use add_bos::AddBos;
50pub use chat_template_result::ChatTemplateResult;
51pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
52pub use llama_chat_message::LlamaChatMessage;
53pub use llama_chat_template::LlamaChatTemplate;
54pub use llama_lora_adapter::LlamaLoraAdapter;
55pub use rope_type::RopeType;
56pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
57
58use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
59use params::LlamaModelParams;
60
61/// A safe wrapper around `llama_model`.
62#[derive(Debug)]
63#[repr(transparent)]
64pub struct LlamaModel {
65    /// Raw pointer to the underlying `llama_model`.
66    pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
67}
68
69unsafe impl Send for LlamaModel {}
70
71unsafe impl Sync for LlamaModel {}
72
73impl LlamaModel {
74    /// Returns a raw pointer to the model's vocabulary.
75    #[must_use]
76    pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
77        unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
78    }
79
80    /// Get the number of tokens the model was trained on.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if the value returned by llama.cpp does not fit into a `u32`.
85    pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
86        let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
87
88        u32::try_from(n_ctx_train)
89    }
90
91    /// Get all tokens in the model.
92    pub fn tokens(
93        &self,
94        decode_special: bool,
95    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
96        (0..self.n_vocab())
97            .map(LlamaToken::new)
98            .map(move |llama_token| {
99                let mut decoder = encoding_rs::UTF_8.new_decoder();
100                (
101                    llama_token,
102                    self.token_to_piece(llama_token, &mut decoder, decode_special, None),
103                )
104            })
105    }
106
107    /// Get the beginning of stream token.
108    #[must_use]
109    pub fn token_bos(&self) -> LlamaToken {
110        let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
111        LlamaToken(token)
112    }
113
114    /// Get the end of stream token.
115    #[must_use]
116    pub fn token_eos(&self) -> LlamaToken {
117        let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
118        LlamaToken(token)
119    }
120
121    /// Get the newline token.
122    #[must_use]
123    pub fn token_nl(&self) -> LlamaToken {
124        let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
125        LlamaToken(token)
126    }
127
128    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
129    #[must_use]
130    pub fn is_eog_token(&self, token: LlamaToken) -> bool {
131        unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
132    }
133
134    /// Get the decoder start token.
135    #[must_use]
136    pub fn decode_start_token(&self) -> LlamaToken {
137        let token =
138            unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
139        LlamaToken(token)
140    }
141
142    /// Get the separator token (SEP).
143    #[must_use]
144    pub fn token_sep(&self) -> LlamaToken {
145        let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
146        LlamaToken(token)
147    }
148
149    /// Convert a string to a Vector of tokens.
150    ///
151    /// # Errors
152    ///
153    /// - if [`str`] contains a null byte
154    /// - if an integer conversion fails during tokenization
155    ///
156    ///
157    /// ```no_run
158    /// use llama_cpp_bindings::model::LlamaModel;
159    ///
160    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
161    /// use std::path::Path;
162    /// use llama_cpp_bindings::model::AddBos;
163    /// let backend = llama_cpp_bindings::llama_backend::LlamaBackend::init()?;
164    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
165    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
166    /// # Ok(())
167    /// # }
168    pub fn str_to_token(
169        &self,
170        str: &str,
171        add_bos: AddBos,
172    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
173        let add_bos = match add_bos {
174            AddBos::Always => true,
175            AddBos::Never => false,
176        };
177
178        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
179        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
180
181        let (c_string, c_string_len) = cstring_with_validated_len(str)?;
182        let buffer_capacity = c_int::try_from(buffer.capacity())?;
183
184        let size = unsafe {
185            llama_cpp_bindings_sys::llama_tokenize(
186                self.vocab_ptr(),
187                c_string.as_ptr(),
188                c_string_len,
189                buffer
190                    .as_mut_ptr()
191                    .cast::<llama_cpp_bindings_sys::llama_token>(),
192                buffer_capacity,
193                add_bos,
194                true,
195            )
196        };
197
198        let size = if size.is_negative() {
199            buffer.reserve_exact(usize::try_from(-size)?);
200            unsafe {
201                llama_cpp_bindings_sys::llama_tokenize(
202                    self.vocab_ptr(),
203                    c_string.as_ptr(),
204                    c_string_len,
205                    buffer
206                        .as_mut_ptr()
207                        .cast::<llama_cpp_bindings_sys::llama_token>(),
208                    -size,
209                    add_bos,
210                    true,
211                )
212            }
213        } else {
214            size
215        };
216
217        let size = usize::try_from(size)?;
218
219        // SAFETY: `size` < `capacity` and llama-cpp has initialized elements up to `size`
220        unsafe { buffer.set_len(size) }
221
222        Ok(buffer)
223    }
224
225    /// Get the type of a token.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the token type is not known to this library.
230    pub fn token_attr(
231        &self,
232        LlamaToken(id): LlamaToken,
233    ) -> Result<LlamaTokenAttrs, crate::token_type::LlamaTokenTypeFromIntError> {
234        let token_type =
235            unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
236
237        LlamaTokenAttrs::try_from(token_type)
238    }
239
240    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
241    ///
242    /// This is the new default function for token decoding and provides direct access to
243    /// the llama.cpp token decoding functionality without any special logic or filtering.
244    ///
245    /// Decoding raw string requires using an decoder, tokens from language models may not always map
246    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
247    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
248    /// garbage.
249    ///
250    /// # Errors
251    ///
252    /// - if the token type is unknown
253    ///
254    /// - if the returned size from llama.cpp does not fit into a `usize`
255    pub fn token_to_piece(
256        &self,
257        token: LlamaToken,
258        decoder: &mut encoding_rs::Decoder,
259        special: bool,
260        lstrip: Option<NonZeroU16>,
261    ) -> Result<String, TokenToStringError> {
262        let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
263            Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
264                let buffer_size: usize = (-required_size).try_into()?;
265
266                self.token_to_piece_bytes(token, buffer_size, special, lstrip)
267            }
268            other => other,
269        }?;
270
271        let mut output_piece = String::with_capacity(bytes.len());
272        let (_result, _decoded_size, _had_replacements) =
273            decoder.decode_to_string(&bytes, &mut output_piece, false);
274
275        Ok(output_piece)
276    }
277
278    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
279    ///
280    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
281    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
282    /// letting the caller handle errors. For a safer interface returning rust strings directly use `token_to_piece` instead!
283    ///
284    /// # Errors
285    ///
286    /// - if the token type is unknown
287    /// - the resultant token is larger than `buffer_size`.
288    /// - if an integer conversion fails
289    #[allow(clippy::missing_panics_doc)]
290    pub fn token_to_piece_bytes(
291        &self,
292        token: LlamaToken,
293        buffer_size: usize,
294        special: bool,
295        lstrip: Option<NonZeroU16>,
296    ) -> Result<Vec<u8>, TokenToStringError> {
297        // SAFETY: `*` (0x2A) is never `\0`, so CString::new cannot fail here
298        let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
299        let len = string.as_bytes().len();
300        let len = c_int::try_from(len)?;
301        let buf = string.into_raw();
302        let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
303        let size = unsafe {
304            llama_cpp_bindings_sys::llama_token_to_piece(
305                self.vocab_ptr(),
306                token.0,
307                buf,
308                len,
309                lstrip,
310                special,
311            )
312        };
313
314        match size {
315            0 => Err(TokenToStringError::UnknownTokenType),
316            error_code if error_code.is_negative() => {
317                Err(TokenToStringError::InsufficientBufferSpace(error_code))
318            }
319            size => {
320                let string = unsafe { CString::from_raw(buf) };
321                let mut bytes = string.into_bytes();
322                let len = usize::try_from(size)?;
323                bytes.truncate(len);
324
325                Ok(bytes)
326            }
327        }
328    }
329
330    /// The number of tokens the model was trained on.
331    ///
332    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
333    /// without issue.
334    #[must_use]
335    pub fn n_vocab(&self) -> i32 {
336        unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
337    }
338
339    /// The type of vocab the model was trained on.
340    ///
341    /// # Errors
342    ///
343    /// Returns an error if llama.cpp emits a vocab type that is not known to this library.
344    pub fn vocab_type(&self) -> Result<VocabType, LlamaTokenTypeFromIntError> {
345        let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
346
347        VocabType::try_from(vocab_type)
348    }
349
350    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
351    /// without issue.
352    #[must_use]
353    pub fn n_embd(&self) -> c_int {
354        unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
355    }
356
357    /// Returns the total size of all the tensors in the model in bytes.
358    #[must_use]
359    pub fn size(&self) -> u64 {
360        unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
361    }
362
363    /// Returns the number of parameters in the model.
364    #[must_use]
365    pub fn n_params(&self) -> u64 {
366        unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
367    }
368
369    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
370    #[must_use]
371    pub fn is_recurrent(&self) -> bool {
372        unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
373    }
374
375    /// Returns the number of layers within the model.
376    ///
377    /// # Errors
378    ///
379    /// Returns an error if the layer count returned by llama.cpp does not fit into a `u32`.
380    pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
381        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
382    }
383
384    /// Returns the number of attention heads within the model.
385    ///
386    /// # Errors
387    ///
388    /// Returns an error if the head count returned by llama.cpp does not fit into a `u32`.
389    pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
390        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
391    }
392
393    /// Returns the number of KV attention heads.
394    ///
395    /// # Errors
396    ///
397    /// Returns an error if the KV head count returned by llama.cpp does not fit into a `u32`.
398    pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
399        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
400    }
401
402    /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc.)
403    ///
404    /// Hybrid models have both attention layers and recurrent/SSM layers.
405    #[must_use]
406    pub fn is_hybrid(&self) -> bool {
407        unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
408    }
409
410    /// Get metadata value as a string by key name
411    ///
412    /// # Errors
413    /// Returns an error if the key is not found or the value is not valid UTF-8.
414    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
415        let key_cstring = CString::new(key)?;
416        let key_ptr = key_cstring.as_ptr();
417
418        extract_meta_string(
419            |buf_ptr, buf_len| unsafe {
420                llama_cpp_bindings_sys::llama_model_meta_val_str(
421                    self.model.as_ptr(),
422                    key_ptr,
423                    buf_ptr,
424                    buf_len,
425                )
426            },
427            256,
428        )
429    }
430
431    /// Get the number of metadata key/value pairs
432    #[must_use]
433    pub fn meta_count(&self) -> i32 {
434        unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
435    }
436
437    /// Get metadata key name by index
438    ///
439    /// # Errors
440    /// Returns an error if the index is out of range or the key is not valid UTF-8.
441    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
442        extract_meta_string(
443            |buf_ptr, buf_len| unsafe {
444                llama_cpp_bindings_sys::llama_model_meta_key_by_index(
445                    self.model.as_ptr(),
446                    index,
447                    buf_ptr,
448                    buf_len,
449                )
450            },
451            256,
452        )
453    }
454
455    /// Get metadata value as a string by index
456    ///
457    /// # Errors
458    /// Returns an error if the index is out of range or the value is not valid UTF-8.
459    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
460        extract_meta_string(
461            |buf_ptr, buf_len| unsafe {
462                llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
463                    self.model.as_ptr(),
464                    index,
465                    buf_ptr,
466                    buf_len,
467                )
468            },
469            256,
470        )
471    }
472
473    /// Returns the rope type of the model.
474    #[must_use]
475    pub fn rope_type(&self) -> Option<RopeType> {
476        let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
477
478        rope_type::rope_type_from_raw(raw)
479    }
480
481    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
482    ///
483    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
484    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
485    /// the chat.
486    ///
487    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
488    /// to parse jinja templates not supported by the llama.cpp template engine.
489    ///
490    /// # Errors
491    ///
492    /// * If the model has no chat template by that name
493    ///
494    /// # Panics
495    ///
496    /// Panics if the C-returned chat template string contains interior null bytes
497    /// (should never happen with valid model data).
498    pub fn chat_template(
499        &self,
500        name: Option<&str>,
501    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
502        let name_cstr = name.map(CString::new);
503        let name_ptr = match name_cstr {
504            Some(Ok(name)) => name.as_ptr(),
505            _ => ptr::null(),
506        };
507        let result = unsafe {
508            llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
509        };
510
511        if result.is_null() {
512            Err(ChatTemplateError::MissingTemplate)
513        } else {
514            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
515            let chat_template = CString::new(chat_template_cstr.to_bytes())
516                .expect("CStr bytes cannot contain interior null bytes");
517
518            Ok(LlamaChatTemplate(chat_template))
519        }
520    }
521
522    /// Loads a model from a file.
523    ///
524    /// # Errors
525    ///
526    /// See [`LlamaModelLoadError`] for more information.
527    ///
528    /// # Panics
529    ///
530    /// Panics if a valid UTF-8 path somehow contains interior null bytes (should never happen).
531    #[tracing::instrument(skip_all, fields(params))]
532    pub fn load_from_file(
533        _: &LlamaBackend,
534        path: impl AsRef<Path>,
535        params: &LlamaModelParams,
536    ) -> Result<Self, LlamaModelLoadError> {
537        let path = path.as_ref();
538
539        let path_str = path
540            .to_str()
541            .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
542
543        if !path.exists() {
544            return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
545        }
546
547        let cstr = CString::new(path_str)?;
548        let llama_model = unsafe {
549            llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
550        };
551
552        let model = match NonNull::new(llama_model) {
553            Some(ptr) => ptr,
554            None if !path.exists() => {
555                return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
556            }
557            None => return Err(LlamaModelLoadError::NullResult),
558        };
559
560        Ok(Self { model })
561    }
562
563    /// Initializes a lora adapter from a file.
564    ///
565    /// # Errors
566    ///
567    /// See [`LlamaLoraAdapterInitError`] for more information.
568    pub fn lora_adapter_init(
569        &self,
570        path: impl AsRef<Path>,
571    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
572        let path = path.as_ref();
573
574        let path_str = path
575            .to_str()
576            .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
577
578        if !path.exists() {
579            return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
580        }
581
582        let cstr = CString::new(path_str)?;
583        let raw_adapter = unsafe {
584            llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
585        };
586
587        let Some(adapter) = NonNull::new(raw_adapter) else {
588            return Err(LlamaLoraAdapterInitError::NullResult);
589        };
590
591        Ok(LlamaLoraAdapter {
592            lora_adapter: adapter,
593        })
594    }
595
596    /// Create a new context from this model.
597    ///
598    /// # Errors
599    ///
600    /// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
601    // we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
602    #[allow(clippy::needless_pass_by_value)]
603    pub fn new_context<'model>(
604        &'model self,
605        _: &LlamaBackend,
606        params: LlamaContextParams,
607    ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
608        let context_params = params.context_params;
609        let context = unsafe {
610            llama_cpp_bindings_sys::llama_new_context_with_model(
611                self.model.as_ptr(),
612                context_params,
613            )
614        };
615        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
616
617        Ok(LlamaContext::new(self, context, params.embeddings()))
618    }
619
620    /// Apply the models chat template to some messages.
621    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
622    ///
623    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the `ChatML` template when given
624    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
625    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
626    /// string.
627    ///
628    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
629    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
630    ///
631    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
632    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
633    /// one into the output and the output may also have unexpected output aside from that.
634    ///
635    /// # Errors
636    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
637    #[tracing::instrument(skip_all)]
638    pub fn apply_chat_template(
639        &self,
640        tmpl: &LlamaChatTemplate,
641        chat: &[LlamaChatMessage],
642        add_ass: bool,
643    ) -> Result<String, ApplyChatTemplateError> {
644        let message_length = chat.iter().fold(0, |acc, chat_message| {
645            acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
646        });
647        let mut buff: Vec<u8> = vec![0; message_length * 2];
648
649        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
650            .iter()
651            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
652                role: chat_message.role.as_ptr(),
653                content: chat_message.content.as_ptr(),
654            })
655            .collect();
656
657        let tmpl_ptr = tmpl.0.as_ptr();
658
659        let buff_len: i32 = buff.len().try_into()?;
660
661        let res = unsafe {
662            llama_cpp_bindings_sys::llama_chat_apply_template(
663                tmpl_ptr,
664                chat.as_ptr(),
665                chat.len(),
666                add_ass,
667                buff.as_mut_ptr().cast::<c_char>(),
668                buff_len,
669            )
670        };
671
672        if res > buff_len {
673            let required_size: usize = res.try_into()?;
674            buff.resize(required_size, 0);
675
676            let new_buff_len: i32 = buff.len().try_into()?;
677
678            let res = unsafe {
679                llama_cpp_bindings_sys::llama_chat_apply_template(
680                    tmpl_ptr,
681                    chat.as_ptr(),
682                    chat.len(),
683                    add_ass,
684                    buff.as_mut_ptr().cast::<c_char>(),
685                    new_buff_len,
686                )
687            };
688            let final_size: usize = res.try_into()?;
689
690            return truncated_buffer_to_string(buff, final_size);
691        }
692
693        let final_size: usize = res.try_into()?;
694
695        truncated_buffer_to_string(buff, final_size)
696    }
697
698    /// Apply the models chat template to some messages and return an optional tool grammar.
699    /// `tools_json` should be an OpenAI-compatible tool definition JSON array string.
700    /// `json_schema` should be a JSON schema string.
701    ///
702    /// # Errors
703    /// Returns an error if the FFI call fails or the result contains invalid data.
704    #[tracing::instrument(skip_all)]
705    pub fn apply_chat_template_with_tools_oaicompat(
706        &self,
707        tmpl: &LlamaChatTemplate,
708        messages: &[LlamaChatMessage],
709        tools_json: Option<&str>,
710        json_schema: Option<&str>,
711        add_generation_prompt: bool,
712    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
713        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
714            .iter()
715            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
716                role: chat_message.role.as_ptr(),
717                content: chat_message.content.as_ptr(),
718            })
719            .collect();
720
721        let tools_cstr = tools_json.map(CString::new).transpose()?;
722        let json_schema_cstr = json_schema.map(CString::new).transpose()?;
723
724        let mut raw_result = new_empty_chat_template_raw_result();
725
726        let rc = unsafe {
727            llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
728                self.model.as_ptr(),
729                tmpl.0.as_ptr(),
730                chat.as_ptr(),
731                chat.len(),
732                tools_cstr
733                    .as_ref()
734                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
735                json_schema_cstr
736                    .as_ref()
737                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
738                add_generation_prompt,
739                &raw mut raw_result,
740            )
741        };
742
743        let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
744
745        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
746    }
747
748    /// Apply the model chat template using OpenAI-compatible JSON messages.
749    ///
750    /// # Errors
751    /// Returns an error if the FFI call fails or the result contains invalid data.
752    #[tracing::instrument(skip_all)]
753    pub fn apply_chat_template_oaicompat(
754        &self,
755        tmpl: &LlamaChatTemplate,
756        params: &OpenAIChatTemplateParams<'_>,
757    ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
758        let parse_tool_calls = params.parse_tool_calls;
759        let messages_cstr = CString::new(params.messages_json)?;
760        let tools_cstr = params.tools_json.map(CString::new).transpose()?;
761        let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
762        let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
763        let grammar_cstr = params.grammar.map(CString::new).transpose()?;
764        let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
765        let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
766
767        let mut raw_result = new_empty_chat_template_raw_result();
768
769        let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
770            messages: messages_cstr.as_ptr(),
771            tools: tools_cstr
772                .as_ref()
773                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
774            tool_choice: tool_choice_cstr
775                .as_ref()
776                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
777            json_schema: json_schema_cstr
778                .as_ref()
779                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
780            grammar: grammar_cstr
781                .as_ref()
782                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
783            reasoning_format: reasoning_cstr
784                .as_ref()
785                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
786            chat_template_kwargs: kwargs_cstr
787                .as_ref()
788                .map_or(ptr::null(), |cstr| cstr.as_ptr()),
789            add_generation_prompt: params.add_generation_prompt,
790            use_jinja: params.use_jinja,
791            parallel_tool_calls: params.parallel_tool_calls,
792            enable_thinking: params.enable_thinking,
793            add_bos: params.add_bos,
794            add_eos: params.add_eos,
795        };
796
797        let rc = unsafe {
798            llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
799                self.model.as_ptr(),
800                tmpl.0.as_ptr(),
801                &raw const ffi_params,
802                &raw mut raw_result,
803            )
804        };
805
806        unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
807    }
808}
809
810fn extract_meta_string<TCFunction>(
811    c_function: TCFunction,
812    capacity: usize,
813) -> Result<String, MetaValError>
814where
815    TCFunction: Fn(*mut c_char, usize) -> i32,
816{
817    let mut buffer = vec![0u8; capacity];
818    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
819
820    if result < 0 {
821        return Err(MetaValError::NegativeReturn(result));
822    }
823
824    let returned_len = result.cast_unsigned() as usize;
825
826    if returned_len >= capacity {
827        return extract_meta_string(c_function, returned_len + 1);
828    }
829
830    if buffer.get(returned_len) != Some(&0) {
831        return Err(MetaValError::NegativeReturn(-1));
832    }
833
834    buffer.truncate(returned_len);
835
836    Ok(String::from_utf8(buffer)?)
837}
838
839impl Drop for LlamaModel {
840    fn drop(&mut self) {
841        unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
842    }
843}
844
845#[cfg(test)]
846mod extract_meta_string_tests {
847    use super::extract_meta_string;
848    use crate::MetaValError;
849
850    #[test]
851    fn returns_error_when_null_terminator_missing() {
852        let result = extract_meta_string(
853            |buf_ptr, buf_len| {
854                let buffer =
855                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
856                buffer[0] = b'a';
857                buffer[1] = b'b';
858                // Intentionally do NOT write a null terminator at position 2
859                buffer[2] = b'c';
860                2
861            },
862            4,
863        );
864
865        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
866    }
867
868    #[test]
869    fn returns_error_for_negative_return_value() {
870        let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
871
872        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
873    }
874
875    #[test]
876    fn returns_error_for_invalid_utf8_data() {
877        let result = extract_meta_string(
878            |buf_ptr, buf_len| {
879                let buffer =
880                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
881                buffer[0] = 0xFF;
882                buffer[1] = 0xFE;
883                buffer[2] = 0;
884                2
885            },
886            4,
887        );
888
889        assert!(result.is_err());
890        assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
891    }
892
893    #[test]
894    fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
895        let call_count = std::cell::Cell::new(0);
896        let result = extract_meta_string(
897            |buf_ptr, buf_len| {
898                let count = call_count.get();
899                call_count.set(count + 1);
900                if count == 0 {
901                    // First call: return length larger than capacity to trigger resize
902                    10
903                } else {
904                    // Second call with larger buffer: write valid data
905                    let buffer =
906                        unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
907                    buffer[0] = b'h';
908                    buffer[1] = b'i';
909                    buffer[2] = 0;
910                    2
911                }
912            },
913            4,
914        );
915
916        assert_eq!(result.unwrap(), "hi");
917    }
918
919    #[test]
920    fn cstring_with_validated_len_null_byte_returns_error() {
921        let result = super::cstring_with_validated_len("null\0byte");
922
923        assert!(result.is_err());
924    }
925
926    #[test]
927    fn validate_string_length_overflow_returns_error() {
928        let result = super::validate_string_length_for_tokenizer(usize::MAX);
929
930        assert!(result.is_err());
931    }
932
933    #[test]
934    fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
935        let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
936        let result = super::truncated_buffer_to_string(invalid_utf8, 3);
937
938        assert!(result.is_err());
939    }
940}
941
942#[cfg(test)]
943#[cfg(feature = "tests_that_use_llms")]
944mod tests {
945    use serial_test::serial;
946
947    use super::LlamaModel;
948    use crate::llama_backend::LlamaBackend;
949    use crate::model::AddBos;
950    use crate::model::params::LlamaModelParams;
951    use crate::test_model;
952
953    #[test]
954    #[serial]
955    fn model_loads_with_valid_metadata() {
956        let (_backend, model) = test_model::load_default_model().unwrap();
957        assert!(model.n_vocab() > 0);
958        assert!(model.n_embd() > 0);
959        assert!(model.n_params() > 0);
960        assert!(model.n_ctx_train().unwrap() > 0);
961    }
962
963    #[test]
964    #[serial]
965    fn special_tokens_exist() {
966        let (_backend, model) = test_model::load_default_model().unwrap();
967        let bos = model.token_bos();
968        let eos = model.token_eos();
969        assert_ne!(bos, eos);
970        assert!(model.is_eog_token(eos));
971        assert!(!model.is_eog_token(bos));
972    }
973
974    #[test]
975    #[serial]
976    fn str_to_token_roundtrip() {
977        let (_backend, model) = test_model::load_default_model().unwrap();
978        let tokens = model.str_to_token("hello world", AddBos::Never).unwrap();
979        assert!(!tokens.is_empty());
980        let mut decoder = encoding_rs::UTF_8.new_decoder();
981        let piece = model
982            .token_to_piece(tokens[0], &mut decoder, false, None)
983            .unwrap();
984        assert!(!piece.is_empty());
985    }
986
987    #[test]
988    #[serial]
989    fn chat_template_returns_non_empty() {
990        let (_backend, model) = test_model::load_default_model().unwrap();
991        let template = model.chat_template(None);
992        assert!(template.is_ok());
993    }
994
995    #[test]
996    #[serial]
997    fn apply_chat_template_produces_prompt() {
998        let (_backend, model) = test_model::load_default_model().unwrap();
999        let template = model.chat_template(None).unwrap();
1000        let message =
1001            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1002        let prompt = model.apply_chat_template(&template, &[message], true);
1003        assert!(prompt.is_ok());
1004        assert!(!prompt.unwrap().is_empty());
1005    }
1006
1007    #[test]
1008    #[serial]
1009    fn apply_chat_template_oaicompat_produces_result() {
1010        let (_backend, model) = test_model::load_default_model().unwrap();
1011        let template = model.chat_template(None).unwrap();
1012        let params = crate::openai::OpenAIChatTemplateParams {
1013            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1014            tools_json: None,
1015            tool_choice: None,
1016            json_schema: None,
1017            grammar: None,
1018            reasoning_format: Some("none"),
1019            chat_template_kwargs: None,
1020            add_generation_prompt: true,
1021            use_jinja: true,
1022            parallel_tool_calls: false,
1023            enable_thinking: false,
1024            add_bos: false,
1025            add_eos: false,
1026            parse_tool_calls: false,
1027        };
1028        let result = model.apply_chat_template_oaicompat(&template, &params);
1029        assert!(result.is_ok());
1030        assert!(!result.unwrap().prompt.is_empty());
1031    }
1032
1033    #[test]
1034    #[serial]
1035    fn meta_count_returns_positive() {
1036        let (_backend, model) = test_model::load_default_model().unwrap();
1037        assert!(model.meta_count() > 0);
1038    }
1039
1040    #[test]
1041    #[serial]
1042    fn tokens_iterator_produces_valid_entries() {
1043        let (_backend, model) = test_model::load_default_model().unwrap();
1044        let mut count = 0;
1045
1046        for (token, piece_result) in model.tokens(false) {
1047            assert!(token.0 >= 0);
1048            // Not all tokens decode successfully, but the iterator should not panic
1049            let _ = piece_result;
1050            count += 1;
1051
1052            if count >= 100 {
1053                break;
1054            }
1055        }
1056
1057        assert_eq!(count, 100);
1058    }
1059
1060    #[test]
1061    #[serial]
1062    fn token_to_piece_bytes_returns_bytes_for_known_token() {
1063        let (_backend, model) = test_model::load_default_model().unwrap();
1064        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1065        let bytes = model
1066            .token_to_piece_bytes(tokens[0], 32, false, None)
1067            .unwrap();
1068
1069        assert!(!bytes.is_empty());
1070    }
1071
1072    #[test]
1073    #[serial]
1074    fn n_layer_returns_positive() {
1075        let (_backend, model) = test_model::load_default_model().unwrap();
1076
1077        assert!(model.n_layer().unwrap() > 0);
1078    }
1079
1080    #[test]
1081    #[serial]
1082    fn n_head_returns_positive() {
1083        let (_backend, model) = test_model::load_default_model().unwrap();
1084
1085        assert!(model.n_head().unwrap() > 0);
1086    }
1087
1088    #[test]
1089    #[serial]
1090    fn n_head_kv_returns_positive() {
1091        let (_backend, model) = test_model::load_default_model().unwrap();
1092
1093        assert!(model.n_head_kv().unwrap() > 0);
1094    }
1095
1096    #[test]
1097    #[serial]
1098    fn is_hybrid_returns_bool_for_test_model() {
1099        let (_backend, model) = test_model::load_default_model().unwrap();
1100
1101        let _ = model.is_hybrid();
1102    }
1103
1104    #[test]
1105    #[serial]
1106    fn meta_key_by_index_returns_valid_key() {
1107        let (_backend, model) = test_model::load_default_model().unwrap();
1108        let key = model.meta_key_by_index(0).unwrap();
1109
1110        assert!(!key.is_empty());
1111    }
1112
1113    #[test]
1114    #[serial]
1115    fn meta_val_str_by_index_returns_valid_value() {
1116        let (_backend, model) = test_model::load_default_model().unwrap();
1117        let value = model.meta_val_str_by_index(0).unwrap();
1118
1119        assert!(!value.is_empty());
1120    }
1121
1122    #[test]
1123    #[serial]
1124    fn meta_key_by_index_out_of_range_returns_error() {
1125        let (_backend, model) = test_model::load_default_model().unwrap();
1126        let result = model.meta_key_by_index(999_999);
1127
1128        assert!(result.is_err());
1129    }
1130
1131    #[test]
1132    #[serial]
1133    fn meta_val_str_by_index_out_of_range_returns_error() {
1134        let (_backend, model) = test_model::load_default_model().unwrap();
1135        let result = model.meta_val_str_by_index(999_999);
1136
1137        assert!(result.is_err());
1138    }
1139
1140    #[test]
1141    #[serial]
1142    fn meta_val_str_returns_value_for_known_key() {
1143        let (_backend, model) = test_model::load_default_model().unwrap();
1144        let first_key = model.meta_key_by_index(0).unwrap();
1145        let value = model.meta_val_str(&first_key).unwrap();
1146
1147        assert!(!value.is_empty());
1148    }
1149
1150    #[test]
1151    #[serial]
1152    fn model_size_returns_nonzero() {
1153        let (_backend, model) = test_model::load_default_model().unwrap();
1154
1155        assert!(model.size() > 0);
1156    }
1157
1158    #[test]
1159    #[serial]
1160    fn is_recurrent_returns_false_for_transformer() {
1161        let (_backend, model) = test_model::load_default_model().unwrap();
1162
1163        assert!(!model.is_recurrent());
1164    }
1165
1166    #[test]
1167    #[serial]
1168    fn rope_type_does_not_panic() {
1169        let (_backend, model) = test_model::load_default_model().unwrap();
1170        let _rope_type = model.rope_type();
1171    }
1172
1173    #[test]
1174    #[serial]
1175    fn load_model_with_invalid_path_returns_error() {
1176        let backend = LlamaBackend::init().unwrap();
1177        let model_params = LlamaModelParams::default();
1178        let result = LlamaModel::load_from_file(&backend, "/nonexistent/model.gguf", &model_params);
1179
1180        assert_eq!(
1181            result.unwrap_err(),
1182            crate::LlamaModelLoadError::FileNotFound(std::path::PathBuf::from(
1183                "/nonexistent/model.gguf"
1184            ))
1185        );
1186    }
1187
1188    #[test]
1189    #[serial]
1190    fn load_model_with_invalid_file_content_returns_null_result() {
1191        let backend = LlamaBackend::init().unwrap();
1192        let model_params = LlamaModelParams::default();
1193        let dummy_path = std::env::temp_dir().join("llama_test_invalid_model.gguf");
1194        std::fs::write(&dummy_path, b"not a valid gguf model file").unwrap();
1195
1196        let result = LlamaModel::load_from_file(&backend, &dummy_path, &model_params);
1197
1198        assert_eq!(result.unwrap_err(), crate::LlamaModelLoadError::NullResult);
1199        let _ = std::fs::remove_file(&dummy_path);
1200    }
1201
1202    #[cfg(unix)]
1203    #[test]
1204    #[serial]
1205    fn load_model_with_non_utf8_path_returns_path_to_str_error() {
1206        use std::ffi::OsStr;
1207        use std::os::unix::ffi::OsStrExt;
1208
1209        let backend = LlamaBackend::init().unwrap();
1210        let model_params = LlamaModelParams::default();
1211        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1212
1213        let result = LlamaModel::load_from_file(&backend, non_utf8_path, &model_params);
1214
1215        assert_eq!(
1216            result.unwrap_err(),
1217            crate::LlamaModelLoadError::PathToStrError(non_utf8_path.to_path_buf())
1218        );
1219    }
1220
1221    #[cfg(unix)]
1222    #[test]
1223    #[serial]
1224    fn lora_adapter_init_with_non_utf8_path_returns_error() {
1225        use std::ffi::OsStr;
1226        use std::os::unix::ffi::OsStrExt;
1227
1228        let (_backend, model) = test_model::load_default_model().unwrap();
1229        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1230
1231        let result = model.lora_adapter_init(non_utf8_path);
1232
1233        assert_eq!(
1234            result.unwrap_err(),
1235            crate::LlamaLoraAdapterInitError::PathToStrError(non_utf8_path.to_path_buf())
1236        );
1237    }
1238
1239    #[test]
1240    #[serial]
1241    fn lora_adapter_init_with_invalid_path_returns_error() {
1242        let (_backend, model) = test_model::load_default_model().unwrap();
1243        let result = model.lora_adapter_init("/nonexistent/path/lora.gguf");
1244
1245        assert_eq!(
1246            result.unwrap_err(),
1247            crate::LlamaLoraAdapterInitError::FileNotFound(std::path::PathBuf::from(
1248                "/nonexistent/path/lora.gguf"
1249            ))
1250        );
1251    }
1252
1253    #[test]
1254    #[serial]
1255    fn new_context_returns_valid_context() {
1256        let (backend, model) = test_model::load_default_model().unwrap();
1257        let ctx_params = crate::context::params::LlamaContextParams::default()
1258            .with_n_ctx(std::num::NonZeroU32::new(256));
1259        let context = model.new_context(&backend, ctx_params).unwrap();
1260
1261        assert!(context.n_ctx() > 0);
1262    }
1263
1264    #[test]
1265    #[serial]
1266    fn token_nl_returns_valid_token() {
1267        let (_backend, model) = test_model::load_default_model().unwrap();
1268        let nl_token = model.token_nl();
1269
1270        assert!(nl_token.0 >= 0);
1271    }
1272
1273    #[test]
1274    #[serial]
1275    fn decode_start_token_returns_valid_token() {
1276        let (_backend, model) = test_model::load_default_model().unwrap();
1277        let _decode_start = model.decode_start_token();
1278    }
1279
1280    #[test]
1281    #[serial]
1282    fn token_sep_returns_valid_token() {
1283        let (_backend, model) = test_model::load_default_model().unwrap();
1284        let _sep_token = model.token_sep();
1285    }
1286
1287    #[test]
1288    #[serial]
1289    fn token_to_piece_handles_large_token_requiring_buffer_resize() {
1290        let (_backend, model) = test_model::load_default_model().unwrap();
1291        let mut decoder = encoding_rs::UTF_8.new_decoder();
1292
1293        for (token, _) in model.tokens(true).take(200) {
1294            let result = model.token_to_piece(token, &mut decoder, true, None);
1295            assert!(result.is_ok());
1296        }
1297    }
1298
1299    #[test]
1300    #[serial]
1301    fn token_to_piece_bytes_insufficient_buffer_returns_error() {
1302        let (_backend, model) = test_model::load_default_model().unwrap();
1303        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1304        let result = model.token_to_piece_bytes(tokens[0], 1, false, None);
1305
1306        assert!(
1307            result
1308                .unwrap_err()
1309                .to_string()
1310                .contains("Insufficient Buffer Space")
1311        );
1312    }
1313
1314    #[test]
1315    #[serial]
1316    fn token_to_piece_with_lstrip() {
1317        let (_backend, model) = test_model::load_default_model().unwrap();
1318        let mut decoder = encoding_rs::UTF_8.new_decoder();
1319        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1320        let result =
1321            model.token_to_piece(tokens[0], &mut decoder, false, std::num::NonZeroU16::new(1));
1322
1323        assert!(result.is_ok());
1324    }
1325
1326    #[test]
1327    #[serial]
1328    fn n_vocab_matches_tokens_iterator_count() {
1329        let (_backend, model) = test_model::load_default_model().unwrap();
1330        let n_vocab = model.n_vocab();
1331        let count = model.tokens(false).count();
1332
1333        assert_eq!(count, n_vocab as usize);
1334    }
1335
1336    #[test]
1337    #[serial]
1338    fn token_attr_returns_valid_attr() {
1339        let (_backend, model) = test_model::load_default_model().unwrap();
1340        let bos = model.token_bos();
1341        let _attr = model.token_attr(bos).unwrap();
1342    }
1343
1344    #[test]
1345    #[serial]
1346    fn vocab_type_returns_valid_type() {
1347        let (_backend, model) = test_model::load_default_model().unwrap();
1348        let _vocab_type = model.vocab_type().unwrap();
1349    }
1350
1351    #[test]
1352    #[serial]
1353    fn apply_chat_template_buffer_resize_with_long_messages() {
1354        let (_backend, model) = test_model::load_default_model().unwrap();
1355        let template = model.chat_template(None).unwrap();
1356        let long_content = "a".repeat(2000);
1357        let message =
1358            crate::model::LlamaChatMessage::new("user".to_string(), long_content).unwrap();
1359        let prompt = model.apply_chat_template(&template, &[message], true);
1360
1361        assert!(prompt.is_ok());
1362        assert!(!prompt.unwrap().is_empty());
1363    }
1364
1365    #[test]
1366    #[serial]
1367    fn meta_val_str_with_long_value_triggers_buffer_resize() {
1368        let (_backend, model) = test_model::load_default_model().unwrap();
1369        let count = model.meta_count();
1370
1371        for index in 0..count {
1372            let key = model.meta_key_by_index(index);
1373            let value = model.meta_val_str_by_index(index);
1374            assert!(key.is_ok());
1375            assert!(value.is_ok());
1376        }
1377    }
1378
1379    #[test]
1380    #[serial]
1381    fn str_to_token_with_add_bos_never() {
1382        let (_backend, model) = test_model::load_default_model().unwrap();
1383        let tokens_with_bos = model.str_to_token("hello", AddBos::Always).unwrap();
1384        let tokens_without_bos = model.str_to_token("hello", AddBos::Never).unwrap();
1385
1386        assert!(tokens_with_bos.len() >= tokens_without_bos.len());
1387    }
1388
1389    #[test]
1390    #[serial]
1391    fn apply_chat_template_with_tools_oaicompat_produces_result() {
1392        let (_backend, model) = test_model::load_default_model().unwrap();
1393        let template = model.chat_template(None).unwrap();
1394        let message =
1395            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1396        let result =
1397            model.apply_chat_template_with_tools_oaicompat(&template, &[message], None, None, true);
1398
1399        assert!(result.is_ok());
1400        assert!(!result.unwrap().prompt.is_empty());
1401    }
1402
1403    #[test]
1404    #[serial]
1405    fn apply_chat_template_with_tools_oaicompat_with_tools_json() {
1406        let (_backend, model) = test_model::load_default_model().unwrap();
1407        let template = model.chat_template(None).unwrap();
1408        let message =
1409            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1410        let tools =
1411            r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object"}}}]"#;
1412        let result = model.apply_chat_template_with_tools_oaicompat(
1413            &template,
1414            &[message],
1415            Some(tools),
1416            None,
1417            true,
1418        );
1419
1420        assert!(result.is_ok());
1421    }
1422
1423    #[test]
1424    #[serial]
1425    fn apply_chat_template_with_tools_oaicompat_with_json_schema() {
1426        let (_backend, model) = test_model::load_default_model().unwrap();
1427        let template = model.chat_template(None).unwrap();
1428        let message =
1429            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1430        let schema = r#"{"type":"object","properties":{"name":{"type":"string"}}}"#;
1431        let result = model.apply_chat_template_with_tools_oaicompat(
1432            &template,
1433            &[message],
1434            None,
1435            Some(schema),
1436            true,
1437        );
1438
1439        assert!(result.is_ok());
1440    }
1441
1442    #[test]
1443    #[serial]
1444    fn apply_chat_template_oaicompat_with_tools_and_tool_choice() {
1445        let (_backend, model) = test_model::load_default_model().unwrap();
1446        let template = model.chat_template(None).unwrap();
1447        let params = crate::openai::OpenAIChatTemplateParams {
1448            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1449            tools_json: Some(
1450                r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object","properties":{}}}}]"#,
1451            ),
1452            tool_choice: Some("auto"),
1453            json_schema: None,
1454            grammar: None,
1455            reasoning_format: Some("none"),
1456            chat_template_kwargs: None,
1457            add_generation_prompt: true,
1458            use_jinja: true,
1459            parallel_tool_calls: false,
1460            enable_thinking: false,
1461            add_bos: false,
1462            add_eos: false,
1463            parse_tool_calls: true,
1464        };
1465        let result = model.apply_chat_template_oaicompat(&template, &params);
1466
1467        assert!(result.is_ok());
1468    }
1469
1470    #[test]
1471    #[serial]
1472    fn apply_chat_template_oaicompat_with_json_schema_field() {
1473        let (_backend, model) = test_model::load_default_model().unwrap();
1474        let template = model.chat_template(None).unwrap();
1475        let params = crate::openai::OpenAIChatTemplateParams {
1476            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1477            tools_json: None,
1478            tool_choice: None,
1479            json_schema: Some(r#"{"type":"object","properties":{"name":{"type":"string"}}}"#),
1480            grammar: None,
1481            reasoning_format: Some("none"),
1482            chat_template_kwargs: None,
1483            add_generation_prompt: true,
1484            use_jinja: true,
1485            parallel_tool_calls: false,
1486            enable_thinking: false,
1487            add_bos: false,
1488            add_eos: false,
1489            parse_tool_calls: false,
1490        };
1491        let result = model.apply_chat_template_oaicompat(&template, &params);
1492
1493        assert!(result.is_ok());
1494    }
1495
1496    #[test]
1497    #[serial]
1498    fn apply_chat_template_oaicompat_with_grammar_field() {
1499        let (_backend, model) = test_model::load_default_model().unwrap();
1500        let template = model.chat_template(None).unwrap();
1501        let params = crate::openai::OpenAIChatTemplateParams {
1502            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1503            tools_json: None,
1504            tool_choice: None,
1505            json_schema: None,
1506            grammar: Some("root ::= \"hello\""),
1507            reasoning_format: Some("none"),
1508            chat_template_kwargs: None,
1509            add_generation_prompt: true,
1510            use_jinja: true,
1511            parallel_tool_calls: false,
1512            enable_thinking: false,
1513            add_bos: false,
1514            add_eos: false,
1515            parse_tool_calls: false,
1516        };
1517        let result = model.apply_chat_template_oaicompat(&template, &params);
1518
1519        assert!(result.is_ok());
1520    }
1521
1522    #[test]
1523    #[serial]
1524    fn apply_chat_template_oaicompat_with_kwargs_field() {
1525        let (_backend, model) = test_model::load_default_model().unwrap();
1526        let template = model.chat_template(None).unwrap();
1527        let params = crate::openai::OpenAIChatTemplateParams {
1528            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1529            tools_json: None,
1530            tool_choice: None,
1531            json_schema: None,
1532            grammar: None,
1533            reasoning_format: Some("none"),
1534            chat_template_kwargs: Some(r#"{"bos_token": "<|im_start|>"}"#),
1535            add_generation_prompt: true,
1536            use_jinja: true,
1537            parallel_tool_calls: false,
1538            enable_thinking: false,
1539            add_bos: false,
1540            add_eos: false,
1541            parse_tool_calls: false,
1542        };
1543        let result = model.apply_chat_template_oaicompat(&template, &params);
1544
1545        assert!(result.is_ok());
1546    }
1547
1548    #[test]
1549    #[serial]
1550    fn chat_template_with_nonexistent_name_returns_error() {
1551        let (_backend, model) = test_model::load_default_model().unwrap();
1552
1553        let result = model.chat_template(Some("nonexistent_template_name_xyz"));
1554
1555        assert_eq!(
1556            result.unwrap_err(),
1557            crate::ChatTemplateError::MissingTemplate
1558        );
1559    }
1560
1561    #[test]
1562    #[serial]
1563    fn lora_adapter_init_with_invalid_gguf_returns_null_result() {
1564        let (_backend, model) = test_model::load_default_model().unwrap();
1565        let dummy_path = std::env::temp_dir().join("llama_test_dummy_lora.gguf");
1566        std::fs::write(&dummy_path, b"not a valid gguf").unwrap();
1567
1568        let result = model.lora_adapter_init(&dummy_path);
1569
1570        assert_eq!(
1571            result.unwrap_err(),
1572            crate::LlamaLoraAdapterInitError::NullResult
1573        );
1574        let _ = std::fs::remove_file(&dummy_path);
1575    }
1576
1577    #[test]
1578    #[serial]
1579    fn str_to_token_with_many_tokens_triggers_buffer_resize() {
1580        let (_backend, model) = test_model::load_default_model().unwrap();
1581        // Each digit typically becomes its own token, but the buffer estimate
1582        // is len/2 which is smaller than the actual token count for
1583        // single-char-token strings like "1 2 3 4 ..."
1584        let many_numbers: String = (0..2000).map(|number| format!("{number} ")).collect();
1585
1586        let tokens = model.str_to_token(&many_numbers, AddBos::Always).unwrap();
1587
1588        assert!(tokens.len() > many_numbers.len() / 2);
1589    }
1590
1591    #[test]
1592    #[serial]
1593    fn rope_type_returns_valid_result_for_test_model() {
1594        let (_backend, model) = test_model::load_default_model().unwrap();
1595
1596        let _rope_type = model.rope_type();
1597    }
1598
1599    #[test]
1600    #[serial]
1601    fn meta_val_str_with_null_byte_in_key_returns_error() {
1602        let (_backend, model) = test_model::load_default_model().unwrap();
1603        let result = model.meta_val_str("key\0with_null");
1604
1605        assert!(result.is_err());
1606    }
1607
1608    #[test]
1609    #[serial]
1610    fn apply_chat_template_with_tools_null_byte_in_tools_returns_error() {
1611        let (_backend, model) = test_model::load_default_model().unwrap();
1612        let template = model.chat_template(None).unwrap();
1613        let message =
1614            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1615        let result = model.apply_chat_template_with_tools_oaicompat(
1616            &template,
1617            &[message],
1618            Some("tools\0null"),
1619            None,
1620            true,
1621        );
1622
1623        assert!(result.is_err());
1624    }
1625
1626    #[test]
1627    #[serial]
1628    fn apply_chat_template_with_tools_null_byte_in_json_schema_returns_error() {
1629        let (_backend, model) = test_model::load_default_model().unwrap();
1630        let template = model.chat_template(None).unwrap();
1631        let message =
1632            crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1633        let result = model.apply_chat_template_with_tools_oaicompat(
1634            &template,
1635            &[message],
1636            None,
1637            Some("schema\0null"),
1638            true,
1639        );
1640
1641        assert!(result.is_err());
1642    }
1643
1644    #[test]
1645    #[serial]
1646    fn apply_chat_template_oaicompat_with_null_byte_in_messages_returns_error() {
1647        let (_backend, model) = test_model::load_default_model().unwrap();
1648        let template = model.chat_template(None).unwrap();
1649        let params = crate::openai::OpenAIChatTemplateParams {
1650            messages_json: "messages\0null",
1651            tools_json: None,
1652            tool_choice: None,
1653            json_schema: None,
1654            grammar: None,
1655            reasoning_format: None,
1656            chat_template_kwargs: None,
1657            add_generation_prompt: true,
1658            use_jinja: true,
1659            parallel_tool_calls: false,
1660            enable_thinking: false,
1661            add_bos: false,
1662            add_eos: false,
1663            parse_tool_calls: false,
1664        };
1665        let result = model.apply_chat_template_oaicompat(&template, &params);
1666
1667        assert!(result.is_err());
1668    }
1669
1670    #[test]
1671    #[serial]
1672    fn apply_chat_template_oaicompat_with_null_byte_in_tools_returns_error() {
1673        let (_backend, model) = test_model::load_default_model().unwrap();
1674        let template = model.chat_template(None).unwrap();
1675        let params = crate::openai::OpenAIChatTemplateParams {
1676            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1677            tools_json: Some("tools\0null"),
1678            tool_choice: None,
1679            json_schema: None,
1680            grammar: None,
1681            reasoning_format: None,
1682            chat_template_kwargs: None,
1683            add_generation_prompt: true,
1684            use_jinja: true,
1685            parallel_tool_calls: false,
1686            enable_thinking: false,
1687            add_bos: false,
1688            add_eos: false,
1689            parse_tool_calls: false,
1690        };
1691        let result = model.apply_chat_template_oaicompat(&template, &params);
1692
1693        assert!(result.is_err());
1694    }
1695
1696    #[test]
1697    #[serial]
1698    fn apply_chat_template_oaicompat_with_null_byte_in_tool_choice_returns_error() {
1699        let (_backend, model) = test_model::load_default_model().unwrap();
1700        let template = model.chat_template(None).unwrap();
1701        let params = crate::openai::OpenAIChatTemplateParams {
1702            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1703            tools_json: None,
1704            tool_choice: Some("choice\0null"),
1705            json_schema: None,
1706            grammar: None,
1707            reasoning_format: None,
1708            chat_template_kwargs: None,
1709            add_generation_prompt: true,
1710            use_jinja: true,
1711            parallel_tool_calls: false,
1712            enable_thinking: false,
1713            add_bos: false,
1714            add_eos: false,
1715            parse_tool_calls: false,
1716        };
1717        let result = model.apply_chat_template_oaicompat(&template, &params);
1718
1719        assert!(result.is_err());
1720    }
1721
1722    #[test]
1723    #[serial]
1724    fn apply_chat_template_oaicompat_with_null_byte_in_json_schema_returns_error() {
1725        let (_backend, model) = test_model::load_default_model().unwrap();
1726        let template = model.chat_template(None).unwrap();
1727        let params = crate::openai::OpenAIChatTemplateParams {
1728            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1729            tools_json: None,
1730            tool_choice: None,
1731            json_schema: Some("schema\0null"),
1732            grammar: None,
1733            reasoning_format: None,
1734            chat_template_kwargs: None,
1735            add_generation_prompt: true,
1736            use_jinja: true,
1737            parallel_tool_calls: false,
1738            enable_thinking: false,
1739            add_bos: false,
1740            add_eos: false,
1741            parse_tool_calls: false,
1742        };
1743        let result = model.apply_chat_template_oaicompat(&template, &params);
1744
1745        assert!(result.is_err());
1746    }
1747
1748    #[test]
1749    #[serial]
1750    fn apply_chat_template_oaicompat_with_null_byte_in_grammar_returns_error() {
1751        let (_backend, model) = test_model::load_default_model().unwrap();
1752        let template = model.chat_template(None).unwrap();
1753        let params = crate::openai::OpenAIChatTemplateParams {
1754            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1755            tools_json: None,
1756            tool_choice: None,
1757            json_schema: None,
1758            grammar: Some("grammar\0null"),
1759            reasoning_format: None,
1760            chat_template_kwargs: None,
1761            add_generation_prompt: true,
1762            use_jinja: true,
1763            parallel_tool_calls: false,
1764            enable_thinking: false,
1765            add_bos: false,
1766            add_eos: false,
1767            parse_tool_calls: false,
1768        };
1769        let result = model.apply_chat_template_oaicompat(&template, &params);
1770
1771        assert!(result.is_err());
1772    }
1773
1774    #[test]
1775    #[serial]
1776    fn apply_chat_template_oaicompat_with_null_byte_in_reasoning_format_returns_error() {
1777        let (_backend, model) = test_model::load_default_model().unwrap();
1778        let template = model.chat_template(None).unwrap();
1779        let params = crate::openai::OpenAIChatTemplateParams {
1780            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1781            tools_json: None,
1782            tool_choice: None,
1783            json_schema: None,
1784            grammar: None,
1785            reasoning_format: Some("format\0null"),
1786            chat_template_kwargs: None,
1787            add_generation_prompt: true,
1788            use_jinja: true,
1789            parallel_tool_calls: false,
1790            enable_thinking: false,
1791            add_bos: false,
1792            add_eos: false,
1793            parse_tool_calls: false,
1794        };
1795        let result = model.apply_chat_template_oaicompat(&template, &params);
1796
1797        assert!(result.is_err());
1798    }
1799
1800    #[test]
1801    #[serial]
1802    fn apply_chat_template_oaicompat_with_null_byte_in_kwargs_returns_error() {
1803        let (_backend, model) = test_model::load_default_model().unwrap();
1804        let template = model.chat_template(None).unwrap();
1805        let params = crate::openai::OpenAIChatTemplateParams {
1806            messages_json: r#"[{"role":"user","content":"hello"}]"#,
1807            tools_json: None,
1808            tool_choice: None,
1809            json_schema: None,
1810            grammar: None,
1811            reasoning_format: None,
1812            chat_template_kwargs: Some("kwargs\0null"),
1813            add_generation_prompt: true,
1814            use_jinja: true,
1815            parallel_tool_calls: false,
1816            enable_thinking: false,
1817            add_bos: false,
1818            add_eos: false,
1819            parse_tool_calls: false,
1820        };
1821        let result = model.apply_chat_template_oaicompat(&template, &params);
1822
1823        assert!(result.is_err());
1824    }
1825
1826    #[test]
1827    #[serial]
1828    fn new_context_with_huge_ctx_returns_null_error() {
1829        let (_backend, model) = test_model::load_default_model().unwrap();
1830        let ctx_params = crate::context::params::LlamaContextParams::default()
1831            .with_n_ctx(std::num::NonZeroU32::new(u32::MAX));
1832
1833        let result = model.new_context(&_backend, ctx_params);
1834
1835        assert!(result.is_err());
1836    }
1837
1838    #[test]
1839    #[serial]
1840    fn sample_returns_result_and_succeeds_with_valid_index() {
1841        use crate::sampling::LlamaSampler;
1842        use crate::token::LlamaToken;
1843
1844        let (backend, model) = test_model::load_default_model().unwrap();
1845        let ctx_params = crate::context::params::LlamaContextParams::default()
1846            .with_n_ctx(std::num::NonZeroU32::new(256));
1847        let mut context = model.new_context(&backend, ctx_params).unwrap();
1848
1849        let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
1850        let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1851
1852        batch.add_sequence(&tokens, 0, false).unwrap();
1853
1854        context.decode(&mut batch).unwrap();
1855
1856        let mut sampler =
1857            LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
1858
1859        // sample() now returns Result to catch C++ exceptions at the FFI
1860        // boundary instead of aborting the process.
1861        let result = sampler.sample(&context, batch.n_tokens() - 1);
1862
1863        assert!(result.is_ok());
1864    }
1865
1866    #[test]
1867    #[serial]
1868    fn grammar_sampler_constrains_output_to_yes_or_no() {
1869        use crate::sampling::LlamaSampler;
1870        use std::sync::Arc;
1871
1872        let (backend, model) = test_model::load_default_model().unwrap();
1873
1874        let ctx_params = crate::context::params::LlamaContextParams::default()
1875            .with_n_ctx(std::num::NonZeroU32::new(512));
1876        let mut context = model.new_context(&backend, ctx_params).unwrap();
1877
1878        let prompt = "<|im_start|>user\nIs the sky blue? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1879        let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1880        let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1881
1882        batch.add_sequence(&tokens, 0, false).unwrap();
1883
1884        context.decode(&mut batch).unwrap();
1885
1886        let mut sampler = LlamaSampler::chain_simple([
1887            LlamaSampler::grammar(&model, r#"root ::= [Yy] [Ee] [Ss] | [Nn] [Oo]"#, "root")
1888                .unwrap(),
1889            LlamaSampler::temp(0.8),
1890            LlamaSampler::greedy(),
1891        ]);
1892
1893        let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap();
1894
1895        assert!(
1896            !model.is_eog_token(token),
1897            "Grammar sampler should not allow EOS as first token"
1898        );
1899
1900        let mut decoder = encoding_rs::UTF_8.new_decoder();
1901        let piece = model
1902            .token_to_piece(token, &mut decoder, true, None)
1903            .unwrap();
1904        let first_char = piece.chars().next().unwrap().to_lowercase().next().unwrap();
1905
1906        assert!(
1907            first_char == 'y' || first_char == 'n',
1908            "Grammar should constrain first token to start with y/n, got: '{piece}'"
1909        );
1910    }
1911
1912    #[test]
1913    #[serial]
1914    fn json_schema_grammar_sampler_constrains_output_to_json() {
1915        use crate::sampling::LlamaSampler;
1916        use std::sync::Arc;
1917
1918        let (backend, model) = test_model::load_default_model().unwrap();
1919
1920        let ctx_params = crate::context::params::LlamaContextParams::default()
1921            .with_n_ctx(std::num::NonZeroU32::new(512));
1922        let mut context = model.new_context(&backend, ctx_params).unwrap();
1923
1924        let prompt = "<|im_start|>user\nWhat is 2+2? Respond with a JSON object.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1925        let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1926        let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1927
1928        batch.add_sequence(&tokens, 0, false).unwrap();
1929
1930        context.decode(&mut batch).unwrap();
1931
1932        let grammar_str = crate::json_schema_to_grammar(
1933            r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#
1934        ).unwrap();
1935
1936        let mut sampler = LlamaSampler::chain_simple([
1937            LlamaSampler::grammar(&model, &grammar_str, "root").unwrap(),
1938            LlamaSampler::temp(0.8),
1939            LlamaSampler::greedy(),
1940        ]);
1941
1942        let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap();
1943
1944        assert!(
1945            !model.is_eog_token(token),
1946            "Grammar sampler should not allow EOS as first token"
1947        );
1948
1949        let mut decoder = encoding_rs::UTF_8.new_decoder();
1950        let piece = model
1951            .token_to_piece(token, &mut decoder, true, None)
1952            .unwrap();
1953
1954        assert!(
1955            piece.starts_with('{'),
1956            "JSON schema grammar should constrain first token to start with '{{', got: '{piece}'"
1957        );
1958    }
1959
1960    #[test]
1961    #[serial]
1962    fn sample_with_grammar_produces_constrained_output_in_loop() {
1963        use crate::sampling::LlamaSampler;
1964        use std::sync::Arc;
1965
1966        let (backend, model) = test_model::load_default_model().unwrap();
1967
1968        let ctx_params = crate::context::params::LlamaContextParams::default()
1969            .with_n_ctx(std::num::NonZeroU32::new(512));
1970        let mut context = model.new_context(&backend, ctx_params).unwrap();
1971
1972        let prompt = "<|im_start|>user\nIs the sky blue? yes or no<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1973        let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1974        let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1975
1976        batch.add_sequence(&tokens, 0, false).unwrap();
1977
1978        context.decode(&mut batch).unwrap();
1979
1980        let mut sampler = LlamaSampler::chain_simple([
1981            LlamaSampler::grammar(&model, r#"root ::= "yes" | "no""#, "root").unwrap(),
1982            LlamaSampler::temp(0.8),
1983            LlamaSampler::greedy(),
1984        ]);
1985
1986        let mut generated = String::new();
1987        let mut decoder = encoding_rs::UTF_8.new_decoder();
1988        let mut position = batch.n_tokens();
1989
1990        for iteration in 0..10 {
1991            let token = sampler.sample(&context, -1).unwrap();
1992            let is_eog = model.is_eog_token(token);
1993
1994            eprintln!("  iteration={iteration} token={} eog={is_eog}", token.0);
1995
1996            if is_eog {
1997                break;
1998            }
1999
2000            let piece = model
2001                .token_to_piece(token, &mut decoder, true, None)
2002                .unwrap();
2003
2004            eprintln!("  piece='{piece}'");
2005
2006            generated.push_str(&piece);
2007
2008            batch.clear();
2009            batch.add(token, position, &[0], true).unwrap();
2010            position += 1;
2011
2012            context.decode(&mut batch).unwrap();
2013        }
2014
2015        let lowercase = generated.to_lowercase();
2016
2017        assert!(
2018            lowercase == "yes" || lowercase == "no",
2019            "Grammar loop should produce 'yes' or 'no', got: '{generated}'"
2020        );
2021    }
2022
2023    #[test]
2024    #[serial]
2025    fn sample_without_grammar_produces_multiple_tokens() {
2026        use crate::sampling::LlamaSampler;
2027        use std::sync::Arc;
2028
2029        let (backend, model) = test_model::load_default_model().unwrap();
2030
2031        let ctx_params = crate::context::params::LlamaContextParams::default()
2032            .with_n_ctx(std::num::NonZeroU32::new(512));
2033        let mut context = model.new_context(&backend, ctx_params).unwrap();
2034
2035        let prompt =
2036            "<|im_start|>user\nSay hello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
2037        let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
2038        let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
2039
2040        batch.add_sequence(&tokens, 0, false).unwrap();
2041
2042        context.decode(&mut batch).unwrap();
2043
2044        let mut sampler =
2045            LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
2046
2047        let mut token_count = 0;
2048        let mut position = batch.n_tokens();
2049
2050        for _ in 0..5 {
2051            let token = sampler.sample(&context, -1).unwrap();
2052
2053            if model.is_eog_token(token) {
2054                break;
2055            }
2056
2057            token_count += 1;
2058
2059            batch.clear();
2060            batch.add(token, position, &[0], true).unwrap();
2061            position += 1;
2062
2063            context.decode(&mut batch).unwrap();
2064        }
2065
2066        assert!(
2067            token_count > 0,
2068            "Should produce at least one token without grammar"
2069        );
2070    }
2071}