1use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6
7fn truncated_buffer_to_string(
8 mut buffer: Vec<u8>,
9 length: usize,
10) -> Result<String, ApplyChatTemplateError> {
11 buffer.truncate(length);
12
13 Ok(String::from_utf8(buffer)?)
14}
15
16fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
17 Ok(c_int::try_from(length)?)
18}
19
20fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
21 let c_string = CString::new(str)?;
22 let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
23 Ok((c_string, len))
24}
25use std::ptr::{self, NonNull};
26
27use crate::context::LlamaContext;
28use crate::context::params::LlamaContextParams;
29use crate::llama_backend::LlamaBackend;
30use crate::openai::OpenAIChatTemplateParams;
31use crate::token::LlamaToken;
32use crate::token_type::LlamaTokenAttrs;
33use crate::{
34 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
35 LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
36};
37
38pub mod add_bos;
39pub mod chat_template_result;
40pub mod grammar_trigger;
41pub mod llama_chat_message;
42pub mod llama_chat_template;
43pub mod llama_lora_adapter;
44pub mod params;
45pub mod rope_type;
46pub mod split_mode;
47pub mod vocab_type;
48
49pub use add_bos::AddBos;
50pub use chat_template_result::ChatTemplateResult;
51pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
52pub use llama_chat_message::LlamaChatMessage;
53pub use llama_chat_template::LlamaChatTemplate;
54pub use llama_lora_adapter::LlamaLoraAdapter;
55pub use rope_type::RopeType;
56pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
57
58use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
59use params::LlamaModelParams;
60
61#[derive(Debug)]
63#[repr(transparent)]
64pub struct LlamaModel {
65 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
67}
68
69unsafe impl Send for LlamaModel {}
70
71unsafe impl Sync for LlamaModel {}
72
73impl LlamaModel {
74 #[must_use]
76 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
77 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
78 }
79
80 pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
86 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
87
88 u32::try_from(n_ctx_train)
89 }
90
91 pub fn tokens(
93 &self,
94 decode_special: bool,
95 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
96 (0..self.n_vocab())
97 .map(LlamaToken::new)
98 .map(move |llama_token| {
99 let mut decoder = encoding_rs::UTF_8.new_decoder();
100 (
101 llama_token,
102 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
103 )
104 })
105 }
106
107 #[must_use]
109 pub fn token_bos(&self) -> LlamaToken {
110 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
111 LlamaToken(token)
112 }
113
114 #[must_use]
116 pub fn token_eos(&self) -> LlamaToken {
117 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
118 LlamaToken(token)
119 }
120
121 #[must_use]
123 pub fn token_nl(&self) -> LlamaToken {
124 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
125 LlamaToken(token)
126 }
127
128 #[must_use]
130 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
131 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
132 }
133
134 #[must_use]
136 pub fn decode_start_token(&self) -> LlamaToken {
137 let token =
138 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
139 LlamaToken(token)
140 }
141
142 #[must_use]
144 pub fn token_sep(&self) -> LlamaToken {
145 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
146 LlamaToken(token)
147 }
148
149 pub fn str_to_token(
169 &self,
170 str: &str,
171 add_bos: AddBos,
172 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
173 let add_bos = match add_bos {
174 AddBos::Always => true,
175 AddBos::Never => false,
176 };
177
178 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
179 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
180
181 let (c_string, c_string_len) = cstring_with_validated_len(str)?;
182 let buffer_capacity = c_int::try_from(buffer.capacity())?;
183
184 let size = unsafe {
185 llama_cpp_bindings_sys::llama_tokenize(
186 self.vocab_ptr(),
187 c_string.as_ptr(),
188 c_string_len,
189 buffer
190 .as_mut_ptr()
191 .cast::<llama_cpp_bindings_sys::llama_token>(),
192 buffer_capacity,
193 add_bos,
194 true,
195 )
196 };
197
198 let size = if size.is_negative() {
199 buffer.reserve_exact(usize::try_from(-size)?);
200 unsafe {
201 llama_cpp_bindings_sys::llama_tokenize(
202 self.vocab_ptr(),
203 c_string.as_ptr(),
204 c_string_len,
205 buffer
206 .as_mut_ptr()
207 .cast::<llama_cpp_bindings_sys::llama_token>(),
208 -size,
209 add_bos,
210 true,
211 )
212 }
213 } else {
214 size
215 };
216
217 let size = usize::try_from(size)?;
218
219 unsafe { buffer.set_len(size) }
221
222 Ok(buffer)
223 }
224
225 pub fn token_attr(
231 &self,
232 LlamaToken(id): LlamaToken,
233 ) -> Result<LlamaTokenAttrs, crate::token_type::LlamaTokenTypeFromIntError> {
234 let token_type =
235 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
236
237 LlamaTokenAttrs::try_from(token_type)
238 }
239
240 pub fn token_to_piece(
256 &self,
257 token: LlamaToken,
258 decoder: &mut encoding_rs::Decoder,
259 special: bool,
260 lstrip: Option<NonZeroU16>,
261 ) -> Result<String, TokenToStringError> {
262 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
263 Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
264 let buffer_size: usize = (-required_size).try_into()?;
265
266 self.token_to_piece_bytes(token, buffer_size, special, lstrip)
267 }
268 other => other,
269 }?;
270
271 let mut output_piece = String::with_capacity(bytes.len());
272 let (_result, _decoded_size, _had_replacements) =
273 decoder.decode_to_string(&bytes, &mut output_piece, false);
274
275 Ok(output_piece)
276 }
277
278 #[allow(clippy::missing_panics_doc)]
290 pub fn token_to_piece_bytes(
291 &self,
292 token: LlamaToken,
293 buffer_size: usize,
294 special: bool,
295 lstrip: Option<NonZeroU16>,
296 ) -> Result<Vec<u8>, TokenToStringError> {
297 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
299 let len = string.as_bytes().len();
300 let len = c_int::try_from(len)?;
301 let buf = string.into_raw();
302 let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
303 let size = unsafe {
304 llama_cpp_bindings_sys::llama_token_to_piece(
305 self.vocab_ptr(),
306 token.0,
307 buf,
308 len,
309 lstrip,
310 special,
311 )
312 };
313
314 match size {
315 0 => Err(TokenToStringError::UnknownTokenType),
316 error_code if error_code.is_negative() => {
317 Err(TokenToStringError::InsufficientBufferSpace(error_code))
318 }
319 size => {
320 let string = unsafe { CString::from_raw(buf) };
321 let mut bytes = string.into_bytes();
322 let len = usize::try_from(size)?;
323 bytes.truncate(len);
324
325 Ok(bytes)
326 }
327 }
328 }
329
330 #[must_use]
335 pub fn n_vocab(&self) -> i32 {
336 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
337 }
338
339 pub fn vocab_type(&self) -> Result<VocabType, LlamaTokenTypeFromIntError> {
345 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
346
347 VocabType::try_from(vocab_type)
348 }
349
350 #[must_use]
353 pub fn n_embd(&self) -> c_int {
354 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
355 }
356
357 #[must_use]
359 pub fn size(&self) -> u64 {
360 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
361 }
362
363 #[must_use]
365 pub fn n_params(&self) -> u64 {
366 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
367 }
368
369 #[must_use]
371 pub fn is_recurrent(&self) -> bool {
372 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
373 }
374
375 pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
381 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
382 }
383
384 pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
390 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
391 }
392
393 pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
399 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
400 }
401
402 #[must_use]
406 pub fn is_hybrid(&self) -> bool {
407 unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
408 }
409
410 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
415 let key_cstring = CString::new(key)?;
416 let key_ptr = key_cstring.as_ptr();
417
418 extract_meta_string(
419 |buf_ptr, buf_len| unsafe {
420 llama_cpp_bindings_sys::llama_model_meta_val_str(
421 self.model.as_ptr(),
422 key_ptr,
423 buf_ptr,
424 buf_len,
425 )
426 },
427 256,
428 )
429 }
430
431 #[must_use]
433 pub fn meta_count(&self) -> i32 {
434 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
435 }
436
437 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
442 extract_meta_string(
443 |buf_ptr, buf_len| unsafe {
444 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
445 self.model.as_ptr(),
446 index,
447 buf_ptr,
448 buf_len,
449 )
450 },
451 256,
452 )
453 }
454
455 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
460 extract_meta_string(
461 |buf_ptr, buf_len| unsafe {
462 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
463 self.model.as_ptr(),
464 index,
465 buf_ptr,
466 buf_len,
467 )
468 },
469 256,
470 )
471 }
472
473 #[must_use]
475 pub fn rope_type(&self) -> Option<RopeType> {
476 let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
477
478 rope_type::rope_type_from_raw(raw)
479 }
480
481 pub fn chat_template(
499 &self,
500 name: Option<&str>,
501 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
502 let name_cstr = name.map(CString::new);
503 let name_ptr = match name_cstr {
504 Some(Ok(name)) => name.as_ptr(),
505 _ => ptr::null(),
506 };
507 let result = unsafe {
508 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
509 };
510
511 if result.is_null() {
512 Err(ChatTemplateError::MissingTemplate)
513 } else {
514 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
515 let chat_template = CString::new(chat_template_cstr.to_bytes())
516 .expect("CStr bytes cannot contain interior null bytes");
517
518 Ok(LlamaChatTemplate(chat_template))
519 }
520 }
521
522 #[tracing::instrument(skip_all, fields(params))]
532 pub fn load_from_file(
533 _: &LlamaBackend,
534 path: impl AsRef<Path>,
535 params: &LlamaModelParams,
536 ) -> Result<Self, LlamaModelLoadError> {
537 let path = path.as_ref();
538
539 let path_str = path
540 .to_str()
541 .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
542
543 if !path.exists() {
544 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
545 }
546
547 let cstr = CString::new(path_str)?;
548 let llama_model = unsafe {
549 llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
550 };
551
552 let model = match NonNull::new(llama_model) {
553 Some(ptr) => ptr,
554 None if !path.exists() => {
555 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
556 }
557 None => return Err(LlamaModelLoadError::NullResult),
558 };
559
560 Ok(Self { model })
561 }
562
563 pub fn lora_adapter_init(
569 &self,
570 path: impl AsRef<Path>,
571 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
572 let path = path.as_ref();
573
574 let path_str = path
575 .to_str()
576 .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
577
578 if !path.exists() {
579 return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
580 }
581
582 let cstr = CString::new(path_str)?;
583 let raw_adapter = unsafe {
584 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
585 };
586
587 let Some(adapter) = NonNull::new(raw_adapter) else {
588 return Err(LlamaLoraAdapterInitError::NullResult);
589 };
590
591 Ok(LlamaLoraAdapter {
592 lora_adapter: adapter,
593 })
594 }
595
596 #[allow(clippy::needless_pass_by_value)]
603 pub fn new_context<'model>(
604 &'model self,
605 _: &LlamaBackend,
606 params: LlamaContextParams,
607 ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
608 let context_params = params.context_params;
609 let context = unsafe {
610 llama_cpp_bindings_sys::llama_new_context_with_model(
611 self.model.as_ptr(),
612 context_params,
613 )
614 };
615 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
616
617 Ok(LlamaContext::new(self, context, params.embeddings()))
618 }
619
620 #[tracing::instrument(skip_all)]
638 pub fn apply_chat_template(
639 &self,
640 tmpl: &LlamaChatTemplate,
641 chat: &[LlamaChatMessage],
642 add_ass: bool,
643 ) -> Result<String, ApplyChatTemplateError> {
644 let message_length = chat.iter().fold(0, |acc, chat_message| {
645 acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
646 });
647 let mut buff: Vec<u8> = vec![0; message_length * 2];
648
649 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
650 .iter()
651 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
652 role: chat_message.role.as_ptr(),
653 content: chat_message.content.as_ptr(),
654 })
655 .collect();
656
657 let tmpl_ptr = tmpl.0.as_ptr();
658
659 let buff_len: i32 = buff.len().try_into()?;
660
661 let res = unsafe {
662 llama_cpp_bindings_sys::llama_chat_apply_template(
663 tmpl_ptr,
664 chat.as_ptr(),
665 chat.len(),
666 add_ass,
667 buff.as_mut_ptr().cast::<c_char>(),
668 buff_len,
669 )
670 };
671
672 if res > buff_len {
673 let required_size: usize = res.try_into()?;
674 buff.resize(required_size, 0);
675
676 let new_buff_len: i32 = buff.len().try_into()?;
677
678 let res = unsafe {
679 llama_cpp_bindings_sys::llama_chat_apply_template(
680 tmpl_ptr,
681 chat.as_ptr(),
682 chat.len(),
683 add_ass,
684 buff.as_mut_ptr().cast::<c_char>(),
685 new_buff_len,
686 )
687 };
688 let final_size: usize = res.try_into()?;
689
690 return truncated_buffer_to_string(buff, final_size);
691 }
692
693 let final_size: usize = res.try_into()?;
694
695 truncated_buffer_to_string(buff, final_size)
696 }
697
698 #[tracing::instrument(skip_all)]
705 pub fn apply_chat_template_with_tools_oaicompat(
706 &self,
707 tmpl: &LlamaChatTemplate,
708 messages: &[LlamaChatMessage],
709 tools_json: Option<&str>,
710 json_schema: Option<&str>,
711 add_generation_prompt: bool,
712 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
713 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
714 .iter()
715 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
716 role: chat_message.role.as_ptr(),
717 content: chat_message.content.as_ptr(),
718 })
719 .collect();
720
721 let tools_cstr = tools_json.map(CString::new).transpose()?;
722 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
723
724 let mut raw_result = new_empty_chat_template_raw_result();
725
726 let rc = unsafe {
727 llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
728 self.model.as_ptr(),
729 tmpl.0.as_ptr(),
730 chat.as_ptr(),
731 chat.len(),
732 tools_cstr
733 .as_ref()
734 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
735 json_schema_cstr
736 .as_ref()
737 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
738 add_generation_prompt,
739 &raw mut raw_result,
740 )
741 };
742
743 let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
744
745 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
746 }
747
748 #[tracing::instrument(skip_all)]
753 pub fn apply_chat_template_oaicompat(
754 &self,
755 tmpl: &LlamaChatTemplate,
756 params: &OpenAIChatTemplateParams<'_>,
757 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
758 let parse_tool_calls = params.parse_tool_calls;
759 let messages_cstr = CString::new(params.messages_json)?;
760 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
761 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
762 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
763 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
764 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
765 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
766
767 let mut raw_result = new_empty_chat_template_raw_result();
768
769 let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
770 messages: messages_cstr.as_ptr(),
771 tools: tools_cstr
772 .as_ref()
773 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
774 tool_choice: tool_choice_cstr
775 .as_ref()
776 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
777 json_schema: json_schema_cstr
778 .as_ref()
779 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
780 grammar: grammar_cstr
781 .as_ref()
782 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
783 reasoning_format: reasoning_cstr
784 .as_ref()
785 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
786 chat_template_kwargs: kwargs_cstr
787 .as_ref()
788 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
789 add_generation_prompt: params.add_generation_prompt,
790 use_jinja: params.use_jinja,
791 parallel_tool_calls: params.parallel_tool_calls,
792 enable_thinking: params.enable_thinking,
793 add_bos: params.add_bos,
794 add_eos: params.add_eos,
795 };
796
797 let rc = unsafe {
798 llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
799 self.model.as_ptr(),
800 tmpl.0.as_ptr(),
801 &raw const ffi_params,
802 &raw mut raw_result,
803 )
804 };
805
806 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
807 }
808}
809
810fn extract_meta_string<TCFunction>(
811 c_function: TCFunction,
812 capacity: usize,
813) -> Result<String, MetaValError>
814where
815 TCFunction: Fn(*mut c_char, usize) -> i32,
816{
817 let mut buffer = vec![0u8; capacity];
818 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
819
820 if result < 0 {
821 return Err(MetaValError::NegativeReturn(result));
822 }
823
824 let returned_len = result.cast_unsigned() as usize;
825
826 if returned_len >= capacity {
827 return extract_meta_string(c_function, returned_len + 1);
828 }
829
830 if buffer.get(returned_len) != Some(&0) {
831 return Err(MetaValError::NegativeReturn(-1));
832 }
833
834 buffer.truncate(returned_len);
835
836 Ok(String::from_utf8(buffer)?)
837}
838
839impl Drop for LlamaModel {
840 fn drop(&mut self) {
841 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
842 }
843}
844
845#[cfg(test)]
846mod extract_meta_string_tests {
847 use super::extract_meta_string;
848 use crate::MetaValError;
849
850 #[test]
851 fn returns_error_when_null_terminator_missing() {
852 let result = extract_meta_string(
853 |buf_ptr, buf_len| {
854 let buffer =
855 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
856 buffer[0] = b'a';
857 buffer[1] = b'b';
858 buffer[2] = b'c';
860 2
861 },
862 4,
863 );
864
865 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
866 }
867
868 #[test]
869 fn returns_error_for_negative_return_value() {
870 let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
871
872 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
873 }
874
875 #[test]
876 fn returns_error_for_invalid_utf8_data() {
877 let result = extract_meta_string(
878 |buf_ptr, buf_len| {
879 let buffer =
880 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
881 buffer[0] = 0xFF;
882 buffer[1] = 0xFE;
883 buffer[2] = 0;
884 2
885 },
886 4,
887 );
888
889 assert!(result.is_err());
890 assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
891 }
892
893 #[test]
894 fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
895 let call_count = std::cell::Cell::new(0);
896 let result = extract_meta_string(
897 |buf_ptr, buf_len| {
898 let count = call_count.get();
899 call_count.set(count + 1);
900 if count == 0 {
901 10
903 } else {
904 let buffer =
906 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
907 buffer[0] = b'h';
908 buffer[1] = b'i';
909 buffer[2] = 0;
910 2
911 }
912 },
913 4,
914 );
915
916 assert_eq!(result.unwrap(), "hi");
917 }
918
919 #[test]
920 fn cstring_with_validated_len_null_byte_returns_error() {
921 let result = super::cstring_with_validated_len("null\0byte");
922
923 assert!(result.is_err());
924 }
925
926 #[test]
927 fn validate_string_length_overflow_returns_error() {
928 let result = super::validate_string_length_for_tokenizer(usize::MAX);
929
930 assert!(result.is_err());
931 }
932
933 #[test]
934 fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
935 let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
936 let result = super::truncated_buffer_to_string(invalid_utf8, 3);
937
938 assert!(result.is_err());
939 }
940}
941
942#[cfg(test)]
943#[cfg(feature = "tests_that_use_llms")]
944mod tests {
945 use serial_test::serial;
946
947 use super::LlamaModel;
948 use crate::llama_backend::LlamaBackend;
949 use crate::model::AddBos;
950 use crate::model::params::LlamaModelParams;
951 use crate::test_model;
952
953 #[test]
954 #[serial]
955 fn model_loads_with_valid_metadata() {
956 let (_backend, model) = test_model::load_default_model().unwrap();
957 assert!(model.n_vocab() > 0);
958 assert!(model.n_embd() > 0);
959 assert!(model.n_params() > 0);
960 assert!(model.n_ctx_train().unwrap() > 0);
961 }
962
963 #[test]
964 #[serial]
965 fn special_tokens_exist() {
966 let (_backend, model) = test_model::load_default_model().unwrap();
967 let bos = model.token_bos();
968 let eos = model.token_eos();
969 assert_ne!(bos, eos);
970 assert!(model.is_eog_token(eos));
971 assert!(!model.is_eog_token(bos));
972 }
973
974 #[test]
975 #[serial]
976 fn str_to_token_roundtrip() {
977 let (_backend, model) = test_model::load_default_model().unwrap();
978 let tokens = model.str_to_token("hello world", AddBos::Never).unwrap();
979 assert!(!tokens.is_empty());
980 let mut decoder = encoding_rs::UTF_8.new_decoder();
981 let piece = model
982 .token_to_piece(tokens[0], &mut decoder, false, None)
983 .unwrap();
984 assert!(!piece.is_empty());
985 }
986
987 #[test]
988 #[serial]
989 fn chat_template_returns_non_empty() {
990 let (_backend, model) = test_model::load_default_model().unwrap();
991 let template = model.chat_template(None);
992 assert!(template.is_ok());
993 }
994
995 #[test]
996 #[serial]
997 fn apply_chat_template_produces_prompt() {
998 let (_backend, model) = test_model::load_default_model().unwrap();
999 let template = model.chat_template(None).unwrap();
1000 let message =
1001 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1002 let prompt = model.apply_chat_template(&template, &[message], true);
1003 assert!(prompt.is_ok());
1004 assert!(!prompt.unwrap().is_empty());
1005 }
1006
1007 #[test]
1008 #[serial]
1009 fn apply_chat_template_oaicompat_produces_result() {
1010 let (_backend, model) = test_model::load_default_model().unwrap();
1011 let template = model.chat_template(None).unwrap();
1012 let params = crate::openai::OpenAIChatTemplateParams {
1013 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1014 tools_json: None,
1015 tool_choice: None,
1016 json_schema: None,
1017 grammar: None,
1018 reasoning_format: Some("none"),
1019 chat_template_kwargs: None,
1020 add_generation_prompt: true,
1021 use_jinja: true,
1022 parallel_tool_calls: false,
1023 enable_thinking: false,
1024 add_bos: false,
1025 add_eos: false,
1026 parse_tool_calls: false,
1027 };
1028 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1029 assert!(result.is_ok());
1030 assert!(!result.unwrap().prompt.is_empty());
1031 }
1032
1033 #[test]
1034 #[serial]
1035 fn meta_count_returns_positive() {
1036 let (_backend, model) = test_model::load_default_model().unwrap();
1037 assert!(model.meta_count() > 0);
1038 }
1039
1040 #[test]
1041 #[serial]
1042 fn tokens_iterator_produces_valid_entries() {
1043 let (_backend, model) = test_model::load_default_model().unwrap();
1044 let mut count = 0;
1045
1046 for (token, piece_result) in model.tokens(false) {
1047 assert!(token.0 >= 0);
1048 let _ = piece_result;
1050 count += 1;
1051
1052 if count >= 100 {
1053 break;
1054 }
1055 }
1056
1057 assert_eq!(count, 100);
1058 }
1059
1060 #[test]
1061 #[serial]
1062 fn token_to_piece_bytes_returns_bytes_for_known_token() {
1063 let (_backend, model) = test_model::load_default_model().unwrap();
1064 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1065 let bytes = model
1066 .token_to_piece_bytes(tokens[0], 32, false, None)
1067 .unwrap();
1068
1069 assert!(!bytes.is_empty());
1070 }
1071
1072 #[test]
1073 #[serial]
1074 fn n_layer_returns_positive() {
1075 let (_backend, model) = test_model::load_default_model().unwrap();
1076
1077 assert!(model.n_layer().unwrap() > 0);
1078 }
1079
1080 #[test]
1081 #[serial]
1082 fn n_head_returns_positive() {
1083 let (_backend, model) = test_model::load_default_model().unwrap();
1084
1085 assert!(model.n_head().unwrap() > 0);
1086 }
1087
1088 #[test]
1089 #[serial]
1090 fn n_head_kv_returns_positive() {
1091 let (_backend, model) = test_model::load_default_model().unwrap();
1092
1093 assert!(model.n_head_kv().unwrap() > 0);
1094 }
1095
1096 #[test]
1097 #[serial]
1098 fn is_hybrid_returns_bool_for_test_model() {
1099 let (_backend, model) = test_model::load_default_model().unwrap();
1100
1101 let _ = model.is_hybrid();
1102 }
1103
1104 #[test]
1105 #[serial]
1106 fn meta_key_by_index_returns_valid_key() {
1107 let (_backend, model) = test_model::load_default_model().unwrap();
1108 let key = model.meta_key_by_index(0).unwrap();
1109
1110 assert!(!key.is_empty());
1111 }
1112
1113 #[test]
1114 #[serial]
1115 fn meta_val_str_by_index_returns_valid_value() {
1116 let (_backend, model) = test_model::load_default_model().unwrap();
1117 let value = model.meta_val_str_by_index(0).unwrap();
1118
1119 assert!(!value.is_empty());
1120 }
1121
1122 #[test]
1123 #[serial]
1124 fn meta_key_by_index_out_of_range_returns_error() {
1125 let (_backend, model) = test_model::load_default_model().unwrap();
1126 let result = model.meta_key_by_index(999_999);
1127
1128 assert!(result.is_err());
1129 }
1130
1131 #[test]
1132 #[serial]
1133 fn meta_val_str_by_index_out_of_range_returns_error() {
1134 let (_backend, model) = test_model::load_default_model().unwrap();
1135 let result = model.meta_val_str_by_index(999_999);
1136
1137 assert!(result.is_err());
1138 }
1139
1140 #[test]
1141 #[serial]
1142 fn meta_val_str_returns_value_for_known_key() {
1143 let (_backend, model) = test_model::load_default_model().unwrap();
1144 let first_key = model.meta_key_by_index(0).unwrap();
1145 let value = model.meta_val_str(&first_key).unwrap();
1146
1147 assert!(!value.is_empty());
1148 }
1149
1150 #[test]
1151 #[serial]
1152 fn model_size_returns_nonzero() {
1153 let (_backend, model) = test_model::load_default_model().unwrap();
1154
1155 assert!(model.size() > 0);
1156 }
1157
1158 #[test]
1159 #[serial]
1160 fn is_recurrent_returns_false_for_transformer() {
1161 let (_backend, model) = test_model::load_default_model().unwrap();
1162
1163 assert!(!model.is_recurrent());
1164 }
1165
1166 #[test]
1167 #[serial]
1168 fn rope_type_does_not_panic() {
1169 let (_backend, model) = test_model::load_default_model().unwrap();
1170 let _rope_type = model.rope_type();
1171 }
1172
1173 #[test]
1174 #[serial]
1175 fn load_model_with_invalid_path_returns_error() {
1176 let backend = LlamaBackend::init().unwrap();
1177 let model_params = LlamaModelParams::default();
1178 let result = LlamaModel::load_from_file(&backend, "/nonexistent/model.gguf", &model_params);
1179
1180 assert_eq!(
1181 result.unwrap_err(),
1182 crate::LlamaModelLoadError::FileNotFound(std::path::PathBuf::from(
1183 "/nonexistent/model.gguf"
1184 ))
1185 );
1186 }
1187
1188 #[test]
1189 #[serial]
1190 fn load_model_with_invalid_file_content_returns_null_result() {
1191 let backend = LlamaBackend::init().unwrap();
1192 let model_params = LlamaModelParams::default();
1193 let dummy_path = std::env::temp_dir().join("llama_test_invalid_model.gguf");
1194 std::fs::write(&dummy_path, b"not a valid gguf model file").unwrap();
1195
1196 let result = LlamaModel::load_from_file(&backend, &dummy_path, &model_params);
1197
1198 assert_eq!(result.unwrap_err(), crate::LlamaModelLoadError::NullResult);
1199 let _ = std::fs::remove_file(&dummy_path);
1200 }
1201
1202 #[cfg(unix)]
1203 #[test]
1204 #[serial]
1205 fn load_model_with_non_utf8_path_returns_path_to_str_error() {
1206 use std::ffi::OsStr;
1207 use std::os::unix::ffi::OsStrExt;
1208
1209 let backend = LlamaBackend::init().unwrap();
1210 let model_params = LlamaModelParams::default();
1211 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1212
1213 let result = LlamaModel::load_from_file(&backend, non_utf8_path, &model_params);
1214
1215 assert_eq!(
1216 result.unwrap_err(),
1217 crate::LlamaModelLoadError::PathToStrError(non_utf8_path.to_path_buf())
1218 );
1219 }
1220
1221 #[cfg(unix)]
1222 #[test]
1223 #[serial]
1224 fn lora_adapter_init_with_non_utf8_path_returns_error() {
1225 use std::ffi::OsStr;
1226 use std::os::unix::ffi::OsStrExt;
1227
1228 let (_backend, model) = test_model::load_default_model().unwrap();
1229 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1230
1231 let result = model.lora_adapter_init(non_utf8_path);
1232
1233 assert_eq!(
1234 result.unwrap_err(),
1235 crate::LlamaLoraAdapterInitError::PathToStrError(non_utf8_path.to_path_buf())
1236 );
1237 }
1238
1239 #[test]
1240 #[serial]
1241 fn lora_adapter_init_with_invalid_path_returns_error() {
1242 let (_backend, model) = test_model::load_default_model().unwrap();
1243 let result = model.lora_adapter_init("/nonexistent/path/lora.gguf");
1244
1245 assert_eq!(
1246 result.unwrap_err(),
1247 crate::LlamaLoraAdapterInitError::FileNotFound(std::path::PathBuf::from(
1248 "/nonexistent/path/lora.gguf"
1249 ))
1250 );
1251 }
1252
1253 #[test]
1254 #[serial]
1255 fn new_context_returns_valid_context() {
1256 let (backend, model) = test_model::load_default_model().unwrap();
1257 let ctx_params = crate::context::params::LlamaContextParams::default()
1258 .with_n_ctx(std::num::NonZeroU32::new(256));
1259 let context = model.new_context(&backend, ctx_params).unwrap();
1260
1261 assert!(context.n_ctx() > 0);
1262 }
1263
1264 #[test]
1265 #[serial]
1266 fn token_nl_returns_valid_token() {
1267 let (_backend, model) = test_model::load_default_model().unwrap();
1268 let nl_token = model.token_nl();
1269
1270 assert!(nl_token.0 >= 0);
1271 }
1272
1273 #[test]
1274 #[serial]
1275 fn decode_start_token_returns_valid_token() {
1276 let (_backend, model) = test_model::load_default_model().unwrap();
1277 let _decode_start = model.decode_start_token();
1278 }
1279
1280 #[test]
1281 #[serial]
1282 fn token_sep_returns_valid_token() {
1283 let (_backend, model) = test_model::load_default_model().unwrap();
1284 let _sep_token = model.token_sep();
1285 }
1286
1287 #[test]
1288 #[serial]
1289 fn token_to_piece_handles_large_token_requiring_buffer_resize() {
1290 let (_backend, model) = test_model::load_default_model().unwrap();
1291 let mut decoder = encoding_rs::UTF_8.new_decoder();
1292
1293 for (token, _) in model.tokens(true).take(200) {
1294 let result = model.token_to_piece(token, &mut decoder, true, None);
1295 assert!(result.is_ok());
1296 }
1297 }
1298
1299 #[test]
1300 #[serial]
1301 fn token_to_piece_bytes_insufficient_buffer_returns_error() {
1302 let (_backend, model) = test_model::load_default_model().unwrap();
1303 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1304 let result = model.token_to_piece_bytes(tokens[0], 1, false, None);
1305
1306 assert!(
1307 result
1308 .unwrap_err()
1309 .to_string()
1310 .contains("Insufficient Buffer Space")
1311 );
1312 }
1313
1314 #[test]
1315 #[serial]
1316 fn token_to_piece_with_lstrip() {
1317 let (_backend, model) = test_model::load_default_model().unwrap();
1318 let mut decoder = encoding_rs::UTF_8.new_decoder();
1319 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1320 let result =
1321 model.token_to_piece(tokens[0], &mut decoder, false, std::num::NonZeroU16::new(1));
1322
1323 assert!(result.is_ok());
1324 }
1325
1326 #[test]
1327 #[serial]
1328 fn n_vocab_matches_tokens_iterator_count() {
1329 let (_backend, model) = test_model::load_default_model().unwrap();
1330 let n_vocab = model.n_vocab();
1331 let count = model.tokens(false).count();
1332
1333 assert_eq!(count, n_vocab as usize);
1334 }
1335
1336 #[test]
1337 #[serial]
1338 fn token_attr_returns_valid_attr() {
1339 let (_backend, model) = test_model::load_default_model().unwrap();
1340 let bos = model.token_bos();
1341 let _attr = model.token_attr(bos).unwrap();
1342 }
1343
1344 #[test]
1345 #[serial]
1346 fn vocab_type_returns_valid_type() {
1347 let (_backend, model) = test_model::load_default_model().unwrap();
1348 let _vocab_type = model.vocab_type().unwrap();
1349 }
1350
1351 #[test]
1352 #[serial]
1353 fn apply_chat_template_buffer_resize_with_long_messages() {
1354 let (_backend, model) = test_model::load_default_model().unwrap();
1355 let template = model.chat_template(None).unwrap();
1356 let long_content = "a".repeat(2000);
1357 let message =
1358 crate::model::LlamaChatMessage::new("user".to_string(), long_content).unwrap();
1359 let prompt = model.apply_chat_template(&template, &[message], true);
1360
1361 assert!(prompt.is_ok());
1362 assert!(!prompt.unwrap().is_empty());
1363 }
1364
1365 #[test]
1366 #[serial]
1367 fn meta_val_str_with_long_value_triggers_buffer_resize() {
1368 let (_backend, model) = test_model::load_default_model().unwrap();
1369 let count = model.meta_count();
1370
1371 for index in 0..count {
1372 let key = model.meta_key_by_index(index);
1373 let value = model.meta_val_str_by_index(index);
1374 assert!(key.is_ok());
1375 assert!(value.is_ok());
1376 }
1377 }
1378
1379 #[test]
1380 #[serial]
1381 fn str_to_token_with_add_bos_never() {
1382 let (_backend, model) = test_model::load_default_model().unwrap();
1383 let tokens_with_bos = model.str_to_token("hello", AddBos::Always).unwrap();
1384 let tokens_without_bos = model.str_to_token("hello", AddBos::Never).unwrap();
1385
1386 assert!(tokens_with_bos.len() >= tokens_without_bos.len());
1387 }
1388
1389 #[test]
1390 #[serial]
1391 fn apply_chat_template_with_tools_oaicompat_produces_result() {
1392 let (_backend, model) = test_model::load_default_model().unwrap();
1393 let template = model.chat_template(None).unwrap();
1394 let message =
1395 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1396 let result =
1397 model.apply_chat_template_with_tools_oaicompat(&template, &[message], None, None, true);
1398
1399 assert!(result.is_ok());
1400 assert!(!result.unwrap().prompt.is_empty());
1401 }
1402
1403 #[test]
1404 #[serial]
1405 fn apply_chat_template_with_tools_oaicompat_with_tools_json() {
1406 let (_backend, model) = test_model::load_default_model().unwrap();
1407 let template = model.chat_template(None).unwrap();
1408 let message =
1409 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1410 let tools =
1411 r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object"}}}]"#;
1412 let result = model.apply_chat_template_with_tools_oaicompat(
1413 &template,
1414 &[message],
1415 Some(tools),
1416 None,
1417 true,
1418 );
1419
1420 assert!(result.is_ok());
1421 }
1422
1423 #[test]
1424 #[serial]
1425 fn apply_chat_template_with_tools_oaicompat_with_json_schema() {
1426 let (_backend, model) = test_model::load_default_model().unwrap();
1427 let template = model.chat_template(None).unwrap();
1428 let message =
1429 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1430 let schema = r#"{"type":"object","properties":{"name":{"type":"string"}}}"#;
1431 let result = model.apply_chat_template_with_tools_oaicompat(
1432 &template,
1433 &[message],
1434 None,
1435 Some(schema),
1436 true,
1437 );
1438
1439 assert!(result.is_ok());
1440 }
1441
1442 #[test]
1443 #[serial]
1444 fn apply_chat_template_oaicompat_with_tools_and_tool_choice() {
1445 let (_backend, model) = test_model::load_default_model().unwrap();
1446 let template = model.chat_template(None).unwrap();
1447 let params = crate::openai::OpenAIChatTemplateParams {
1448 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1449 tools_json: Some(
1450 r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object","properties":{}}}}]"#,
1451 ),
1452 tool_choice: Some("auto"),
1453 json_schema: None,
1454 grammar: None,
1455 reasoning_format: Some("none"),
1456 chat_template_kwargs: None,
1457 add_generation_prompt: true,
1458 use_jinja: true,
1459 parallel_tool_calls: false,
1460 enable_thinking: false,
1461 add_bos: false,
1462 add_eos: false,
1463 parse_tool_calls: true,
1464 };
1465 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1466
1467 assert!(result.is_ok());
1468 }
1469
1470 #[test]
1471 #[serial]
1472 fn apply_chat_template_oaicompat_with_json_schema_field() {
1473 let (_backend, model) = test_model::load_default_model().unwrap();
1474 let template = model.chat_template(None).unwrap();
1475 let params = crate::openai::OpenAIChatTemplateParams {
1476 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1477 tools_json: None,
1478 tool_choice: None,
1479 json_schema: Some(r#"{"type":"object","properties":{"name":{"type":"string"}}}"#),
1480 grammar: None,
1481 reasoning_format: Some("none"),
1482 chat_template_kwargs: None,
1483 add_generation_prompt: true,
1484 use_jinja: true,
1485 parallel_tool_calls: false,
1486 enable_thinking: false,
1487 add_bos: false,
1488 add_eos: false,
1489 parse_tool_calls: false,
1490 };
1491 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1492
1493 assert!(result.is_ok());
1494 }
1495
1496 #[test]
1497 #[serial]
1498 fn apply_chat_template_oaicompat_with_grammar_field() {
1499 let (_backend, model) = test_model::load_default_model().unwrap();
1500 let template = model.chat_template(None).unwrap();
1501 let params = crate::openai::OpenAIChatTemplateParams {
1502 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1503 tools_json: None,
1504 tool_choice: None,
1505 json_schema: None,
1506 grammar: Some("root ::= \"hello\""),
1507 reasoning_format: Some("none"),
1508 chat_template_kwargs: None,
1509 add_generation_prompt: true,
1510 use_jinja: true,
1511 parallel_tool_calls: false,
1512 enable_thinking: false,
1513 add_bos: false,
1514 add_eos: false,
1515 parse_tool_calls: false,
1516 };
1517 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1518
1519 assert!(result.is_ok());
1520 }
1521
1522 #[test]
1523 #[serial]
1524 fn apply_chat_template_oaicompat_with_kwargs_field() {
1525 let (_backend, model) = test_model::load_default_model().unwrap();
1526 let template = model.chat_template(None).unwrap();
1527 let params = crate::openai::OpenAIChatTemplateParams {
1528 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1529 tools_json: None,
1530 tool_choice: None,
1531 json_schema: None,
1532 grammar: None,
1533 reasoning_format: Some("none"),
1534 chat_template_kwargs: Some(r#"{"bos_token": "<|im_start|>"}"#),
1535 add_generation_prompt: true,
1536 use_jinja: true,
1537 parallel_tool_calls: false,
1538 enable_thinking: false,
1539 add_bos: false,
1540 add_eos: false,
1541 parse_tool_calls: false,
1542 };
1543 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1544
1545 assert!(result.is_ok());
1546 }
1547
1548 #[test]
1549 #[serial]
1550 fn chat_template_with_nonexistent_name_returns_error() {
1551 let (_backend, model) = test_model::load_default_model().unwrap();
1552
1553 let result = model.chat_template(Some("nonexistent_template_name_xyz"));
1554
1555 assert_eq!(
1556 result.unwrap_err(),
1557 crate::ChatTemplateError::MissingTemplate
1558 );
1559 }
1560
1561 #[test]
1562 #[serial]
1563 fn lora_adapter_init_with_invalid_gguf_returns_null_result() {
1564 let (_backend, model) = test_model::load_default_model().unwrap();
1565 let dummy_path = std::env::temp_dir().join("llama_test_dummy_lora.gguf");
1566 std::fs::write(&dummy_path, b"not a valid gguf").unwrap();
1567
1568 let result = model.lora_adapter_init(&dummy_path);
1569
1570 assert_eq!(
1571 result.unwrap_err(),
1572 crate::LlamaLoraAdapterInitError::NullResult
1573 );
1574 let _ = std::fs::remove_file(&dummy_path);
1575 }
1576
1577 #[test]
1578 #[serial]
1579 fn str_to_token_with_many_tokens_triggers_buffer_resize() {
1580 let (_backend, model) = test_model::load_default_model().unwrap();
1581 let many_numbers: String = (0..2000).map(|number| format!("{number} ")).collect();
1585
1586 let tokens = model.str_to_token(&many_numbers, AddBos::Always).unwrap();
1587
1588 assert!(tokens.len() > many_numbers.len() / 2);
1589 }
1590
1591 #[test]
1592 #[serial]
1593 fn rope_type_returns_valid_result_for_test_model() {
1594 let (_backend, model) = test_model::load_default_model().unwrap();
1595
1596 let _rope_type = model.rope_type();
1597 }
1598
1599 #[test]
1600 #[serial]
1601 fn meta_val_str_with_null_byte_in_key_returns_error() {
1602 let (_backend, model) = test_model::load_default_model().unwrap();
1603 let result = model.meta_val_str("key\0with_null");
1604
1605 assert!(result.is_err());
1606 }
1607
1608 #[test]
1609 #[serial]
1610 fn apply_chat_template_with_tools_null_byte_in_tools_returns_error() {
1611 let (_backend, model) = test_model::load_default_model().unwrap();
1612 let template = model.chat_template(None).unwrap();
1613 let message =
1614 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1615 let result = model.apply_chat_template_with_tools_oaicompat(
1616 &template,
1617 &[message],
1618 Some("tools\0null"),
1619 None,
1620 true,
1621 );
1622
1623 assert!(result.is_err());
1624 }
1625
1626 #[test]
1627 #[serial]
1628 fn apply_chat_template_with_tools_null_byte_in_json_schema_returns_error() {
1629 let (_backend, model) = test_model::load_default_model().unwrap();
1630 let template = model.chat_template(None).unwrap();
1631 let message =
1632 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1633 let result = model.apply_chat_template_with_tools_oaicompat(
1634 &template,
1635 &[message],
1636 None,
1637 Some("schema\0null"),
1638 true,
1639 );
1640
1641 assert!(result.is_err());
1642 }
1643
1644 #[test]
1645 #[serial]
1646 fn apply_chat_template_oaicompat_with_null_byte_in_messages_returns_error() {
1647 let (_backend, model) = test_model::load_default_model().unwrap();
1648 let template = model.chat_template(None).unwrap();
1649 let params = crate::openai::OpenAIChatTemplateParams {
1650 messages_json: "messages\0null",
1651 tools_json: None,
1652 tool_choice: None,
1653 json_schema: None,
1654 grammar: None,
1655 reasoning_format: None,
1656 chat_template_kwargs: None,
1657 add_generation_prompt: true,
1658 use_jinja: true,
1659 parallel_tool_calls: false,
1660 enable_thinking: false,
1661 add_bos: false,
1662 add_eos: false,
1663 parse_tool_calls: false,
1664 };
1665 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1666
1667 assert!(result.is_err());
1668 }
1669
1670 #[test]
1671 #[serial]
1672 fn apply_chat_template_oaicompat_with_null_byte_in_tools_returns_error() {
1673 let (_backend, model) = test_model::load_default_model().unwrap();
1674 let template = model.chat_template(None).unwrap();
1675 let params = crate::openai::OpenAIChatTemplateParams {
1676 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1677 tools_json: Some("tools\0null"),
1678 tool_choice: None,
1679 json_schema: None,
1680 grammar: None,
1681 reasoning_format: None,
1682 chat_template_kwargs: None,
1683 add_generation_prompt: true,
1684 use_jinja: true,
1685 parallel_tool_calls: false,
1686 enable_thinking: false,
1687 add_bos: false,
1688 add_eos: false,
1689 parse_tool_calls: false,
1690 };
1691 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1692
1693 assert!(result.is_err());
1694 }
1695
1696 #[test]
1697 #[serial]
1698 fn apply_chat_template_oaicompat_with_null_byte_in_tool_choice_returns_error() {
1699 let (_backend, model) = test_model::load_default_model().unwrap();
1700 let template = model.chat_template(None).unwrap();
1701 let params = crate::openai::OpenAIChatTemplateParams {
1702 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1703 tools_json: None,
1704 tool_choice: Some("choice\0null"),
1705 json_schema: None,
1706 grammar: None,
1707 reasoning_format: None,
1708 chat_template_kwargs: None,
1709 add_generation_prompt: true,
1710 use_jinja: true,
1711 parallel_tool_calls: false,
1712 enable_thinking: false,
1713 add_bos: false,
1714 add_eos: false,
1715 parse_tool_calls: false,
1716 };
1717 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1718
1719 assert!(result.is_err());
1720 }
1721
1722 #[test]
1723 #[serial]
1724 fn apply_chat_template_oaicompat_with_null_byte_in_json_schema_returns_error() {
1725 let (_backend, model) = test_model::load_default_model().unwrap();
1726 let template = model.chat_template(None).unwrap();
1727 let params = crate::openai::OpenAIChatTemplateParams {
1728 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1729 tools_json: None,
1730 tool_choice: None,
1731 json_schema: Some("schema\0null"),
1732 grammar: None,
1733 reasoning_format: None,
1734 chat_template_kwargs: None,
1735 add_generation_prompt: true,
1736 use_jinja: true,
1737 parallel_tool_calls: false,
1738 enable_thinking: false,
1739 add_bos: false,
1740 add_eos: false,
1741 parse_tool_calls: false,
1742 };
1743 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1744
1745 assert!(result.is_err());
1746 }
1747
1748 #[test]
1749 #[serial]
1750 fn apply_chat_template_oaicompat_with_null_byte_in_grammar_returns_error() {
1751 let (_backend, model) = test_model::load_default_model().unwrap();
1752 let template = model.chat_template(None).unwrap();
1753 let params = crate::openai::OpenAIChatTemplateParams {
1754 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1755 tools_json: None,
1756 tool_choice: None,
1757 json_schema: None,
1758 grammar: Some("grammar\0null"),
1759 reasoning_format: None,
1760 chat_template_kwargs: None,
1761 add_generation_prompt: true,
1762 use_jinja: true,
1763 parallel_tool_calls: false,
1764 enable_thinking: false,
1765 add_bos: false,
1766 add_eos: false,
1767 parse_tool_calls: false,
1768 };
1769 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1770
1771 assert!(result.is_err());
1772 }
1773
1774 #[test]
1775 #[serial]
1776 fn apply_chat_template_oaicompat_with_null_byte_in_reasoning_format_returns_error() {
1777 let (_backend, model) = test_model::load_default_model().unwrap();
1778 let template = model.chat_template(None).unwrap();
1779 let params = crate::openai::OpenAIChatTemplateParams {
1780 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1781 tools_json: None,
1782 tool_choice: None,
1783 json_schema: None,
1784 grammar: None,
1785 reasoning_format: Some("format\0null"),
1786 chat_template_kwargs: None,
1787 add_generation_prompt: true,
1788 use_jinja: true,
1789 parallel_tool_calls: false,
1790 enable_thinking: false,
1791 add_bos: false,
1792 add_eos: false,
1793 parse_tool_calls: false,
1794 };
1795 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1796
1797 assert!(result.is_err());
1798 }
1799
1800 #[test]
1801 #[serial]
1802 fn apply_chat_template_oaicompat_with_null_byte_in_kwargs_returns_error() {
1803 let (_backend, model) = test_model::load_default_model().unwrap();
1804 let template = model.chat_template(None).unwrap();
1805 let params = crate::openai::OpenAIChatTemplateParams {
1806 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1807 tools_json: None,
1808 tool_choice: None,
1809 json_schema: None,
1810 grammar: None,
1811 reasoning_format: None,
1812 chat_template_kwargs: Some("kwargs\0null"),
1813 add_generation_prompt: true,
1814 use_jinja: true,
1815 parallel_tool_calls: false,
1816 enable_thinking: false,
1817 add_bos: false,
1818 add_eos: false,
1819 parse_tool_calls: false,
1820 };
1821 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1822
1823 assert!(result.is_err());
1824 }
1825
1826 #[test]
1827 #[serial]
1828 fn new_context_with_huge_ctx_returns_null_error() {
1829 let (_backend, model) = test_model::load_default_model().unwrap();
1830 let ctx_params = crate::context::params::LlamaContextParams::default()
1831 .with_n_ctx(std::num::NonZeroU32::new(u32::MAX));
1832
1833 let result = model.new_context(&_backend, ctx_params);
1834
1835 assert!(result.is_err());
1836 }
1837
1838 #[test]
1839 #[serial]
1840 fn sample_returns_result_and_succeeds_with_valid_index() {
1841 use crate::sampling::LlamaSampler;
1842 use crate::token::LlamaToken;
1843
1844 let (backend, model) = test_model::load_default_model().unwrap();
1845 let ctx_params = crate::context::params::LlamaContextParams::default()
1846 .with_n_ctx(std::num::NonZeroU32::new(256));
1847 let mut context = model.new_context(&backend, ctx_params).unwrap();
1848
1849 let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
1850 let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1851
1852 batch.add_sequence(&tokens, 0, false).unwrap();
1853
1854 context.decode(&mut batch).unwrap();
1855
1856 let mut sampler =
1857 LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
1858
1859 let result = sampler.sample(&context, batch.n_tokens() - 1);
1862
1863 assert!(result.is_ok());
1864 }
1865
1866 #[test]
1867 #[serial]
1868 fn grammar_sampler_constrains_output_to_yes_or_no() {
1869 use crate::sampling::LlamaSampler;
1870 use std::sync::Arc;
1871
1872 let (backend, model) = test_model::load_default_model().unwrap();
1873
1874 let ctx_params = crate::context::params::LlamaContextParams::default()
1875 .with_n_ctx(std::num::NonZeroU32::new(512));
1876 let mut context = model.new_context(&backend, ctx_params).unwrap();
1877
1878 let prompt = "<|im_start|>user\nIs the sky blue? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1879 let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1880 let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1881
1882 batch.add_sequence(&tokens, 0, false).unwrap();
1883
1884 context.decode(&mut batch).unwrap();
1885
1886 let mut sampler = LlamaSampler::chain_simple([
1887 LlamaSampler::grammar(&model, r#"root ::= [Yy] [Ee] [Ss] | [Nn] [Oo]"#, "root")
1888 .unwrap(),
1889 LlamaSampler::temp(0.8),
1890 LlamaSampler::greedy(),
1891 ]);
1892
1893 let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap();
1894
1895 assert!(
1896 !model.is_eog_token(token),
1897 "Grammar sampler should not allow EOS as first token"
1898 );
1899
1900 let mut decoder = encoding_rs::UTF_8.new_decoder();
1901 let piece = model
1902 .token_to_piece(token, &mut decoder, true, None)
1903 .unwrap();
1904 let first_char = piece.chars().next().unwrap().to_lowercase().next().unwrap();
1905
1906 assert!(
1907 first_char == 'y' || first_char == 'n',
1908 "Grammar should constrain first token to start with y/n, got: '{piece}'"
1909 );
1910 }
1911
1912 #[test]
1913 #[serial]
1914 fn json_schema_grammar_sampler_constrains_output_to_json() {
1915 use crate::sampling::LlamaSampler;
1916 use std::sync::Arc;
1917
1918 let (backend, model) = test_model::load_default_model().unwrap();
1919
1920 let ctx_params = crate::context::params::LlamaContextParams::default()
1921 .with_n_ctx(std::num::NonZeroU32::new(512));
1922 let mut context = model.new_context(&backend, ctx_params).unwrap();
1923
1924 let prompt = "<|im_start|>user\nWhat is 2+2? Respond with a JSON object.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1925 let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1926 let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1927
1928 batch.add_sequence(&tokens, 0, false).unwrap();
1929
1930 context.decode(&mut batch).unwrap();
1931
1932 let grammar_str = crate::json_schema_to_grammar(
1933 r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#
1934 ).unwrap();
1935
1936 let mut sampler = LlamaSampler::chain_simple([
1937 LlamaSampler::grammar(&model, &grammar_str, "root").unwrap(),
1938 LlamaSampler::temp(0.8),
1939 LlamaSampler::greedy(),
1940 ]);
1941
1942 let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap();
1943
1944 assert!(
1945 !model.is_eog_token(token),
1946 "Grammar sampler should not allow EOS as first token"
1947 );
1948
1949 let mut decoder = encoding_rs::UTF_8.new_decoder();
1950 let piece = model
1951 .token_to_piece(token, &mut decoder, true, None)
1952 .unwrap();
1953
1954 assert!(
1955 piece.starts_with('{'),
1956 "JSON schema grammar should constrain first token to start with '{{', got: '{piece}'"
1957 );
1958 }
1959
1960 #[test]
1961 #[serial]
1962 fn sample_with_grammar_produces_constrained_output_in_loop() {
1963 use crate::sampling::LlamaSampler;
1964 use std::sync::Arc;
1965
1966 let (backend, model) = test_model::load_default_model().unwrap();
1967
1968 let ctx_params = crate::context::params::LlamaContextParams::default()
1969 .with_n_ctx(std::num::NonZeroU32::new(512));
1970 let mut context = model.new_context(&backend, ctx_params).unwrap();
1971
1972 let prompt = "<|im_start|>user\nIs the sky blue? yes or no<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1973 let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
1974 let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
1975
1976 batch.add_sequence(&tokens, 0, false).unwrap();
1977
1978 context.decode(&mut batch).unwrap();
1979
1980 let mut sampler = LlamaSampler::chain_simple([
1981 LlamaSampler::grammar(&model, r#"root ::= "yes" | "no""#, "root").unwrap(),
1982 LlamaSampler::temp(0.8),
1983 LlamaSampler::greedy(),
1984 ]);
1985
1986 let mut generated = String::new();
1987 let mut decoder = encoding_rs::UTF_8.new_decoder();
1988 let mut position = batch.n_tokens();
1989
1990 for iteration in 0..10 {
1991 let token = sampler.sample(&context, -1).unwrap();
1992 let is_eog = model.is_eog_token(token);
1993
1994 eprintln!(" iteration={iteration} token={} eog={is_eog}", token.0);
1995
1996 if is_eog {
1997 break;
1998 }
1999
2000 let piece = model
2001 .token_to_piece(token, &mut decoder, true, None)
2002 .unwrap();
2003
2004 eprintln!(" piece='{piece}'");
2005
2006 generated.push_str(&piece);
2007
2008 batch.clear();
2009 batch.add(token, position, &[0], true).unwrap();
2010 position += 1;
2011
2012 context.decode(&mut batch).unwrap();
2013 }
2014
2015 let lowercase = generated.to_lowercase();
2016
2017 assert!(
2018 lowercase == "yes" || lowercase == "no",
2019 "Grammar loop should produce 'yes' or 'no', got: '{generated}'"
2020 );
2021 }
2022
2023 #[test]
2024 #[serial]
2025 fn sample_without_grammar_produces_multiple_tokens() {
2026 use crate::sampling::LlamaSampler;
2027 use std::sync::Arc;
2028
2029 let (backend, model) = test_model::load_default_model().unwrap();
2030
2031 let ctx_params = crate::context::params::LlamaContextParams::default()
2032 .with_n_ctx(std::num::NonZeroU32::new(512));
2033 let mut context = model.new_context(&backend, ctx_params).unwrap();
2034
2035 let prompt =
2036 "<|im_start|>user\nSay hello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
2037 let tokens = model.str_to_token(prompt, AddBos::Always).unwrap();
2038 let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap();
2039
2040 batch.add_sequence(&tokens, 0, false).unwrap();
2041
2042 context.decode(&mut batch).unwrap();
2043
2044 let mut sampler =
2045 LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
2046
2047 let mut token_count = 0;
2048 let mut position = batch.n_tokens();
2049
2050 for _ in 0..5 {
2051 let token = sampler.sample(&context, -1).unwrap();
2052
2053 if model.is_eog_token(token) {
2054 break;
2055 }
2056
2057 token_count += 1;
2058
2059 batch.clear();
2060 batch.add(token, position, &[0], true).unwrap();
2061 position += 1;
2062
2063 context.decode(&mut batch).unwrap();
2064 }
2065
2066 assert!(
2067 token_count > 0,
2068 "Should produce at least one token without grammar"
2069 );
2070 }
2071}