1use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6
7fn truncated_buffer_to_string(
8 mut buffer: Vec<u8>,
9 length: usize,
10) -> Result<String, ApplyChatTemplateError> {
11 buffer.truncate(length);
12
13 Ok(String::from_utf8(buffer)?)
14}
15
16fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
17 Ok(c_int::try_from(length)?)
18}
19
20fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
21 let c_string = CString::new(str)?;
22 let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
23 Ok((c_string, len))
24}
25use std::ptr::{self, NonNull};
26
27use crate::context::LlamaContext;
28use crate::context::params::LlamaContextParams;
29use crate::llama_backend::LlamaBackend;
30use crate::openai::OpenAIChatTemplateParams;
31use crate::token::LlamaToken;
32use crate::token_type::LlamaTokenAttrs;
33use crate::{
34 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
35 LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
36};
37
38pub mod add_bos;
39pub mod chat_template_result;
40pub mod grammar_trigger;
41pub mod llama_chat_message;
42pub mod llama_chat_template;
43pub mod llama_lora_adapter;
44pub mod params;
45pub mod rope_type;
46pub mod split_mode;
47pub mod vocab_type;
48
49pub use add_bos::AddBos;
50pub use chat_template_result::ChatTemplateResult;
51pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
52pub use llama_chat_message::LlamaChatMessage;
53pub use llama_chat_template::LlamaChatTemplate;
54pub use llama_lora_adapter::LlamaLoraAdapter;
55pub use rope_type::RopeType;
56pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
57
58use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
59use params::LlamaModelParams;
60
61#[derive(Debug)]
63#[repr(transparent)]
64pub struct LlamaModel {
65 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
67}
68
69unsafe impl Send for LlamaModel {}
70
71unsafe impl Sync for LlamaModel {}
72
73impl LlamaModel {
74 #[must_use]
76 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
77 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
78 }
79
80 #[must_use]
87 pub fn n_ctx_train(&self) -> u32 {
88 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
89 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
90 }
91
92 pub fn tokens(
94 &self,
95 decode_special: bool,
96 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
97 (0..self.n_vocab())
98 .map(LlamaToken::new)
99 .map(move |llama_token| {
100 let mut decoder = encoding_rs::UTF_8.new_decoder();
101 (
102 llama_token,
103 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
104 )
105 })
106 }
107
108 #[must_use]
110 pub fn token_bos(&self) -> LlamaToken {
111 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
112 LlamaToken(token)
113 }
114
115 #[must_use]
117 pub fn token_eos(&self) -> LlamaToken {
118 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
119 LlamaToken(token)
120 }
121
122 #[must_use]
124 pub fn token_nl(&self) -> LlamaToken {
125 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
126 LlamaToken(token)
127 }
128
129 #[must_use]
131 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
132 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
133 }
134
135 #[must_use]
137 pub fn decode_start_token(&self) -> LlamaToken {
138 let token =
139 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
140 LlamaToken(token)
141 }
142
143 #[must_use]
145 pub fn token_sep(&self) -> LlamaToken {
146 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
147 LlamaToken(token)
148 }
149
150 pub fn str_to_token(
173 &self,
174 str: &str,
175 add_bos: AddBos,
176 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
177 let add_bos = match add_bos {
178 AddBos::Always => true,
179 AddBos::Never => false,
180 };
181
182 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
183 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
184
185 let (c_string, c_string_len) = cstring_with_validated_len(str)?;
186 let buffer_capacity = c_int::try_from(buffer.capacity())?;
187
188 let size = unsafe {
189 llama_cpp_bindings_sys::llama_tokenize(
190 self.vocab_ptr(),
191 c_string.as_ptr(),
192 c_string_len,
193 buffer
194 .as_mut_ptr()
195 .cast::<llama_cpp_bindings_sys::llama_token>(),
196 buffer_capacity,
197 add_bos,
198 true,
199 )
200 };
201
202 let size = if size.is_negative() {
203 buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize"));
204 unsafe {
205 llama_cpp_bindings_sys::llama_tokenize(
206 self.vocab_ptr(),
207 c_string.as_ptr(),
208 c_string_len,
209 buffer
210 .as_mut_ptr()
211 .cast::<llama_cpp_bindings_sys::llama_token>(),
212 -size,
213 add_bos,
214 true,
215 )
216 }
217 } else {
218 size
219 };
220
221 let size = usize::try_from(size)?;
222
223 unsafe { buffer.set_len(size) }
225
226 Ok(buffer)
227 }
228
229 #[must_use]
235 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
236 let token_type =
237 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
238 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
239 }
240
241 pub fn token_to_piece(
259 &self,
260 token: LlamaToken,
261 decoder: &mut encoding_rs::Decoder,
262 special: bool,
263 lstrip: Option<NonZeroU16>,
264 ) -> Result<String, TokenToStringError> {
265 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
266 Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self
267 .token_to_piece_bytes(
268 token,
269 (-required_size)
270 .try_into()
271 .expect("Error buffer size is positive"),
272 special,
273 lstrip,
274 ),
275 other => other,
276 }?;
277
278 let mut output_piece = String::with_capacity(bytes.len());
279 let (_result, _decoded_size, _had_replacements) =
280 decoder.decode_to_string(&bytes, &mut output_piece, false);
281
282 Ok(output_piece)
283 }
284
285 #[allow(clippy::missing_panics_doc)]
296 pub fn token_to_piece_bytes(
297 &self,
298 token: LlamaToken,
299 buffer_size: usize,
300 special: bool,
301 lstrip: Option<NonZeroU16>,
302 ) -> Result<Vec<u8>, TokenToStringError> {
303 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
305 let len = string.as_bytes().len();
306 let len = c_int::try_from(len)?;
307 let buf = string.into_raw();
308 let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
309 let size = unsafe {
310 llama_cpp_bindings_sys::llama_token_to_piece(
311 self.vocab_ptr(),
312 token.0,
313 buf,
314 len,
315 lstrip,
316 special,
317 )
318 };
319
320 match size {
321 0 => Err(TokenToStringError::UnknownTokenType),
322 error_code if error_code.is_negative() => {
323 Err(TokenToStringError::InsufficientBufferSpace(error_code))
324 }
325 size => {
326 let string = unsafe { CString::from_raw(buf) };
327 let mut bytes = string.into_bytes();
328 let len = usize::try_from(size).expect("size is positive and fits into usize");
329 bytes.truncate(len);
330
331 Ok(bytes)
332 }
333 }
334 }
335
336 #[must_use]
341 pub fn n_vocab(&self) -> i32 {
342 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
343 }
344
345 #[must_use]
351 pub fn vocab_type(&self) -> VocabType {
352 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
353 VocabType::try_from(vocab_type).expect("invalid vocab type")
354 }
355
356 #[must_use]
359 pub fn n_embd(&self) -> c_int {
360 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
361 }
362
363 #[must_use]
365 pub fn size(&self) -> u64 {
366 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
367 }
368
369 #[must_use]
371 pub fn n_params(&self) -> u64 {
372 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
373 }
374
375 #[must_use]
377 pub fn is_recurrent(&self) -> bool {
378 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
379 }
380
381 #[must_use]
386 pub fn n_layer(&self) -> u32 {
387 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
389 .expect("llama.cpp returns a positive value for n_layer")
390 }
391
392 #[must_use]
397 pub fn n_head(&self) -> u32 {
398 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
400 .expect("llama.cpp returns a positive value for n_head")
401 }
402
403 #[must_use]
408 pub fn n_head_kv(&self) -> u32 {
409 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
411 .expect("llama.cpp returns a positive value for n_head_kv")
412 }
413
414 #[must_use]
418 pub fn is_hybrid(&self) -> bool {
419 unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
420 }
421
422 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
427 let key_cstring = CString::new(key)?;
428 let key_ptr = key_cstring.as_ptr();
429
430 extract_meta_string(
431 |buf_ptr, buf_len| unsafe {
432 llama_cpp_bindings_sys::llama_model_meta_val_str(
433 self.model.as_ptr(),
434 key_ptr,
435 buf_ptr,
436 buf_len,
437 )
438 },
439 256,
440 )
441 }
442
443 #[must_use]
445 pub fn meta_count(&self) -> i32 {
446 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
447 }
448
449 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
454 extract_meta_string(
455 |buf_ptr, buf_len| unsafe {
456 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
457 self.model.as_ptr(),
458 index,
459 buf_ptr,
460 buf_len,
461 )
462 },
463 256,
464 )
465 }
466
467 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
472 extract_meta_string(
473 |buf_ptr, buf_len| unsafe {
474 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
475 self.model.as_ptr(),
476 index,
477 buf_ptr,
478 buf_len,
479 )
480 },
481 256,
482 )
483 }
484
485 #[must_use]
487 pub fn rope_type(&self) -> Option<RopeType> {
488 let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
489
490 rope_type::rope_type_from_raw(raw)
491 }
492
493 pub fn chat_template(
511 &self,
512 name: Option<&str>,
513 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
514 let name_cstr = name.map(CString::new);
515 let name_ptr = match name_cstr {
516 Some(Ok(name)) => name.as_ptr(),
517 _ => ptr::null(),
518 };
519 let result = unsafe {
520 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
521 };
522
523 if result.is_null() {
524 Err(ChatTemplateError::MissingTemplate)
525 } else {
526 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
527 let chat_template = CString::new(chat_template_cstr.to_bytes())
528 .expect("CStr bytes cannot contain interior null bytes");
529
530 Ok(LlamaChatTemplate(chat_template))
531 }
532 }
533
534 #[tracing::instrument(skip_all, fields(params))]
544 pub fn load_from_file(
545 _: &LlamaBackend,
546 path: impl AsRef<Path>,
547 params: &LlamaModelParams,
548 ) -> Result<Self, LlamaModelLoadError> {
549 let path = path.as_ref();
550
551 let path_str = path
552 .to_str()
553 .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
554
555 if !path.exists() {
556 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
557 }
558
559 let cstr = CString::new(path_str)?;
560 let llama_model = unsafe {
561 llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
562 };
563
564 let model = match NonNull::new(llama_model) {
565 Some(ptr) => ptr,
566 None if !path.exists() => {
567 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
568 }
569 None => return Err(LlamaModelLoadError::NullResult),
570 };
571
572 Ok(Self { model })
573 }
574
575 pub fn lora_adapter_init(
581 &self,
582 path: impl AsRef<Path>,
583 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
584 let path = path.as_ref();
585
586 let path_str = path
587 .to_str()
588 .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
589
590 if !path.exists() {
591 return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
592 }
593
594 let cstr = CString::new(path_str)?;
595 let raw_adapter = unsafe {
596 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
597 };
598
599 let Some(adapter) = NonNull::new(raw_adapter) else {
600 return Err(LlamaLoraAdapterInitError::NullResult);
601 };
602
603 Ok(LlamaLoraAdapter {
604 lora_adapter: adapter,
605 })
606 }
607
608 #[allow(clippy::needless_pass_by_value)]
615 pub fn new_context<'model>(
616 &'model self,
617 _: &LlamaBackend,
618 params: LlamaContextParams,
619 ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
620 let context_params = params.context_params;
621 let context = unsafe {
622 llama_cpp_bindings_sys::llama_new_context_with_model(
623 self.model.as_ptr(),
624 context_params,
625 )
626 };
627 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
628
629 Ok(LlamaContext::new(self, context, params.embeddings()))
630 }
631
632 #[tracing::instrument(skip_all)]
650 pub fn apply_chat_template(
651 &self,
652 tmpl: &LlamaChatTemplate,
653 chat: &[LlamaChatMessage],
654 add_ass: bool,
655 ) -> Result<String, ApplyChatTemplateError> {
656 let message_length = chat.iter().fold(0, |acc, chat_message| {
657 acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
658 });
659 let mut buff: Vec<u8> = vec![0; message_length * 2];
660
661 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
662 .iter()
663 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
664 role: chat_message.role.as_ptr(),
665 content: chat_message.content.as_ptr(),
666 })
667 .collect();
668
669 let tmpl_ptr = tmpl.0.as_ptr();
670
671 let buff_len: i32 = buff.len().try_into()?;
672
673 let res = unsafe {
674 llama_cpp_bindings_sys::llama_chat_apply_template(
675 tmpl_ptr,
676 chat.as_ptr(),
677 chat.len(),
678 add_ass,
679 buff.as_mut_ptr().cast::<c_char>(),
680 buff_len,
681 )
682 };
683
684 if res > buff_len {
685 let required_size: usize = res.try_into()?;
686 buff.resize(required_size, 0);
687
688 let new_buff_len: i32 = buff.len().try_into()?;
689
690 let res = unsafe {
691 llama_cpp_bindings_sys::llama_chat_apply_template(
692 tmpl_ptr,
693 chat.as_ptr(),
694 chat.len(),
695 add_ass,
696 buff.as_mut_ptr().cast::<c_char>(),
697 new_buff_len,
698 )
699 };
700 let final_size: usize = res.try_into()?;
701
702 return truncated_buffer_to_string(buff, final_size);
703 }
704
705 let final_size: usize = res.try_into()?;
706
707 truncated_buffer_to_string(buff, final_size)
708 }
709
710 #[tracing::instrument(skip_all)]
717 pub fn apply_chat_template_with_tools_oaicompat(
718 &self,
719 tmpl: &LlamaChatTemplate,
720 messages: &[LlamaChatMessage],
721 tools_json: Option<&str>,
722 json_schema: Option<&str>,
723 add_generation_prompt: bool,
724 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
725 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
726 .iter()
727 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
728 role: chat_message.role.as_ptr(),
729 content: chat_message.content.as_ptr(),
730 })
731 .collect();
732
733 let tools_cstr = tools_json.map(CString::new).transpose()?;
734 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
735
736 let mut raw_result = new_empty_chat_template_raw_result();
737
738 let rc = unsafe {
739 llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
740 self.model.as_ptr(),
741 tmpl.0.as_ptr(),
742 chat.as_ptr(),
743 chat.len(),
744 tools_cstr
745 .as_ref()
746 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
747 json_schema_cstr
748 .as_ref()
749 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
750 add_generation_prompt,
751 &raw mut raw_result,
752 )
753 };
754
755 let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
756
757 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
758 }
759
760 #[tracing::instrument(skip_all)]
765 pub fn apply_chat_template_oaicompat(
766 &self,
767 tmpl: &LlamaChatTemplate,
768 params: &OpenAIChatTemplateParams<'_>,
769 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
770 let parse_tool_calls = params.parse_tool_calls;
771 let messages_cstr = CString::new(params.messages_json)?;
772 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
773 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
774 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
775 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
776 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
777 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
778
779 let mut raw_result = new_empty_chat_template_raw_result();
780
781 let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
782 messages: messages_cstr.as_ptr(),
783 tools: tools_cstr
784 .as_ref()
785 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
786 tool_choice: tool_choice_cstr
787 .as_ref()
788 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
789 json_schema: json_schema_cstr
790 .as_ref()
791 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
792 grammar: grammar_cstr
793 .as_ref()
794 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
795 reasoning_format: reasoning_cstr
796 .as_ref()
797 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
798 chat_template_kwargs: kwargs_cstr
799 .as_ref()
800 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
801 add_generation_prompt: params.add_generation_prompt,
802 use_jinja: params.use_jinja,
803 parallel_tool_calls: params.parallel_tool_calls,
804 enable_thinking: params.enable_thinking,
805 add_bos: params.add_bos,
806 add_eos: params.add_eos,
807 };
808
809 let rc = unsafe {
810 llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
811 self.model.as_ptr(),
812 tmpl.0.as_ptr(),
813 &raw const ffi_params,
814 &raw mut raw_result,
815 )
816 };
817
818 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
819 }
820}
821
822fn extract_meta_string<TCFunction>(
823 c_function: TCFunction,
824 capacity: usize,
825) -> Result<String, MetaValError>
826where
827 TCFunction: Fn(*mut c_char, usize) -> i32,
828{
829 let mut buffer = vec![0u8; capacity];
830 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
831
832 if result < 0 {
833 return Err(MetaValError::NegativeReturn(result));
834 }
835
836 let returned_len = result.cast_unsigned() as usize;
837
838 if returned_len >= capacity {
839 return extract_meta_string(c_function, returned_len + 1);
840 }
841
842 if buffer.get(returned_len) != Some(&0) {
843 return Err(MetaValError::NegativeReturn(-1));
844 }
845
846 buffer.truncate(returned_len);
847
848 Ok(String::from_utf8(buffer)?)
849}
850
851impl Drop for LlamaModel {
852 fn drop(&mut self) {
853 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
854 }
855}
856
857#[cfg(test)]
858mod extract_meta_string_tests {
859 use super::extract_meta_string;
860 use crate::MetaValError;
861
862 #[test]
863 fn returns_error_when_null_terminator_missing() {
864 let result = extract_meta_string(
865 |buf_ptr, buf_len| {
866 let buffer =
867 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
868 buffer[0] = b'a';
869 buffer[1] = b'b';
870 buffer[2] = b'c';
872 2
873 },
874 4,
875 );
876
877 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
878 }
879
880 #[test]
881 fn returns_error_for_negative_return_value() {
882 let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
883
884 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
885 }
886
887 #[test]
888 fn returns_error_for_invalid_utf8_data() {
889 let result = extract_meta_string(
890 |buf_ptr, buf_len| {
891 let buffer =
892 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
893 buffer[0] = 0xFF;
894 buffer[1] = 0xFE;
895 buffer[2] = 0;
896 2
897 },
898 4,
899 );
900
901 assert!(result.is_err());
902 assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
903 }
904
905 #[test]
906 fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
907 let call_count = std::cell::Cell::new(0);
908 let result = extract_meta_string(
909 |buf_ptr, buf_len| {
910 let count = call_count.get();
911 call_count.set(count + 1);
912 if count == 0 {
913 10
915 } else {
916 let buffer =
918 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
919 buffer[0] = b'h';
920 buffer[1] = b'i';
921 buffer[2] = 0;
922 2
923 }
924 },
925 4,
926 );
927
928 assert_eq!(result.unwrap(), "hi");
929 }
930
931 #[test]
932 fn cstring_with_validated_len_null_byte_returns_error() {
933 let result = super::cstring_with_validated_len("null\0byte");
934
935 assert!(result.is_err());
936 }
937
938 #[test]
939 fn validate_string_length_overflow_returns_error() {
940 let result = super::validate_string_length_for_tokenizer(usize::MAX);
941
942 assert!(result.is_err());
943 }
944
945 #[test]
946 fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
947 let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
948 let result = super::truncated_buffer_to_string(invalid_utf8, 3);
949
950 assert!(result.is_err());
951 }
952}
953
954#[cfg(test)]
955#[cfg(feature = "tests_that_use_llms")]
956mod tests {
957 use serial_test::serial;
958
959 use super::LlamaModel;
960 use crate::llama_backend::LlamaBackend;
961 use crate::model::AddBos;
962 use crate::model::params::LlamaModelParams;
963 use crate::test_model;
964
965 #[test]
966 #[serial]
967 fn model_loads_with_valid_metadata() {
968 let (_backend, model) = test_model::load_default_model().unwrap();
969 assert!(model.n_vocab() > 0);
970 assert!(model.n_embd() > 0);
971 assert!(model.n_params() > 0);
972 assert!(model.n_ctx_train() > 0);
973 }
974
975 #[test]
976 #[serial]
977 fn special_tokens_exist() {
978 let (_backend, model) = test_model::load_default_model().unwrap();
979 let bos = model.token_bos();
980 let eos = model.token_eos();
981 assert_ne!(bos, eos);
982 assert!(model.is_eog_token(eos));
983 assert!(!model.is_eog_token(bos));
984 }
985
986 #[test]
987 #[serial]
988 fn str_to_token_roundtrip() {
989 let (_backend, model) = test_model::load_default_model().unwrap();
990 let tokens = model.str_to_token("hello world", AddBos::Never).unwrap();
991 assert!(!tokens.is_empty());
992 let mut decoder = encoding_rs::UTF_8.new_decoder();
993 let piece = model
994 .token_to_piece(tokens[0], &mut decoder, false, None)
995 .unwrap();
996 assert!(!piece.is_empty());
997 }
998
999 #[test]
1000 #[serial]
1001 fn chat_template_returns_non_empty() {
1002 let (_backend, model) = test_model::load_default_model().unwrap();
1003 let template = model.chat_template(None);
1004 assert!(template.is_ok());
1005 }
1006
1007 #[test]
1008 #[serial]
1009 fn apply_chat_template_produces_prompt() {
1010 let (_backend, model) = test_model::load_default_model().unwrap();
1011 let template = model.chat_template(None).unwrap();
1012 let message =
1013 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1014 let prompt = model.apply_chat_template(&template, &[message], true);
1015 assert!(prompt.is_ok());
1016 assert!(!prompt.unwrap().is_empty());
1017 }
1018
1019 #[test]
1020 #[serial]
1021 fn apply_chat_template_oaicompat_produces_result() {
1022 let (_backend, model) = test_model::load_default_model().unwrap();
1023 let template = model.chat_template(None).unwrap();
1024 let params = crate::openai::OpenAIChatTemplateParams {
1025 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1026 tools_json: None,
1027 tool_choice: None,
1028 json_schema: None,
1029 grammar: None,
1030 reasoning_format: Some("none"),
1031 chat_template_kwargs: None,
1032 add_generation_prompt: true,
1033 use_jinja: true,
1034 parallel_tool_calls: false,
1035 enable_thinking: false,
1036 add_bos: false,
1037 add_eos: false,
1038 parse_tool_calls: false,
1039 };
1040 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1041 assert!(result.is_ok());
1042 assert!(!result.unwrap().prompt.is_empty());
1043 }
1044
1045 #[test]
1046 #[serial]
1047 fn meta_count_returns_positive() {
1048 let (_backend, model) = test_model::load_default_model().unwrap();
1049 assert!(model.meta_count() > 0);
1050 }
1051
1052 #[test]
1053 #[serial]
1054 fn tokens_iterator_produces_valid_entries() {
1055 let (_backend, model) = test_model::load_default_model().unwrap();
1056 let mut count = 0;
1057
1058 for (token, piece_result) in model.tokens(false) {
1059 assert!(token.0 >= 0);
1060 let _ = piece_result;
1062 count += 1;
1063
1064 if count >= 100 {
1065 break;
1066 }
1067 }
1068
1069 assert_eq!(count, 100);
1070 }
1071
1072 #[test]
1073 #[serial]
1074 fn token_to_piece_bytes_returns_bytes_for_known_token() {
1075 let (_backend, model) = test_model::load_default_model().unwrap();
1076 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1077 let bytes = model
1078 .token_to_piece_bytes(tokens[0], 32, false, None)
1079 .unwrap();
1080
1081 assert!(!bytes.is_empty());
1082 }
1083
1084 #[test]
1085 #[serial]
1086 fn n_layer_returns_positive() {
1087 let (_backend, model) = test_model::load_default_model().unwrap();
1088
1089 assert!(model.n_layer() > 0);
1090 }
1091
1092 #[test]
1093 #[serial]
1094 fn n_head_returns_positive() {
1095 let (_backend, model) = test_model::load_default_model().unwrap();
1096
1097 assert!(model.n_head() > 0);
1098 }
1099
1100 #[test]
1101 #[serial]
1102 fn n_head_kv_returns_positive() {
1103 let (_backend, model) = test_model::load_default_model().unwrap();
1104
1105 assert!(model.n_head_kv() > 0);
1106 }
1107
1108 #[test]
1109 #[serial]
1110 fn meta_key_by_index_returns_valid_key() {
1111 let (_backend, model) = test_model::load_default_model().unwrap();
1112 let key = model.meta_key_by_index(0).unwrap();
1113
1114 assert!(!key.is_empty());
1115 }
1116
1117 #[test]
1118 #[serial]
1119 fn meta_val_str_by_index_returns_valid_value() {
1120 let (_backend, model) = test_model::load_default_model().unwrap();
1121 let value = model.meta_val_str_by_index(0).unwrap();
1122
1123 assert!(!value.is_empty());
1124 }
1125
1126 #[test]
1127 #[serial]
1128 fn meta_key_by_index_out_of_range_returns_error() {
1129 let (_backend, model) = test_model::load_default_model().unwrap();
1130 let result = model.meta_key_by_index(999_999);
1131
1132 assert!(result.is_err());
1133 }
1134
1135 #[test]
1136 #[serial]
1137 fn meta_val_str_by_index_out_of_range_returns_error() {
1138 let (_backend, model) = test_model::load_default_model().unwrap();
1139 let result = model.meta_val_str_by_index(999_999);
1140
1141 assert!(result.is_err());
1142 }
1143
1144 #[test]
1145 #[serial]
1146 fn meta_val_str_returns_value_for_known_key() {
1147 let (_backend, model) = test_model::load_default_model().unwrap();
1148 let first_key = model.meta_key_by_index(0).unwrap();
1149 let value = model.meta_val_str(&first_key).unwrap();
1150
1151 assert!(!value.is_empty());
1152 }
1153
1154 #[test]
1155 #[serial]
1156 fn model_size_returns_nonzero() {
1157 let (_backend, model) = test_model::load_default_model().unwrap();
1158
1159 assert!(model.size() > 0);
1160 }
1161
1162 #[test]
1163 #[serial]
1164 fn is_recurrent_returns_false_for_transformer() {
1165 let (_backend, model) = test_model::load_default_model().unwrap();
1166
1167 assert!(!model.is_recurrent());
1168 }
1169
1170 #[test]
1171 #[serial]
1172 fn rope_type_does_not_panic() {
1173 let (_backend, model) = test_model::load_default_model().unwrap();
1174 let _rope_type = model.rope_type();
1175 }
1176
1177 #[test]
1178 #[serial]
1179 fn load_model_with_invalid_path_returns_error() {
1180 let backend = LlamaBackend::init().unwrap();
1181 let model_params = LlamaModelParams::default();
1182 let result = LlamaModel::load_from_file(&backend, "/nonexistent/model.gguf", &model_params);
1183
1184 assert_eq!(
1185 result.unwrap_err(),
1186 crate::LlamaModelLoadError::FileNotFound(std::path::PathBuf::from(
1187 "/nonexistent/model.gguf"
1188 ))
1189 );
1190 }
1191
1192 #[test]
1193 #[serial]
1194 fn load_model_with_invalid_file_content_returns_null_result() {
1195 let backend = LlamaBackend::init().unwrap();
1196 let model_params = LlamaModelParams::default();
1197 let dummy_path = std::env::temp_dir().join("llama_test_invalid_model.gguf");
1198 std::fs::write(&dummy_path, b"not a valid gguf model file").unwrap();
1199
1200 let result = LlamaModel::load_from_file(&backend, &dummy_path, &model_params);
1201
1202 assert_eq!(result.unwrap_err(), crate::LlamaModelLoadError::NullResult);
1203 let _ = std::fs::remove_file(&dummy_path);
1204 }
1205
1206 #[cfg(unix)]
1207 #[test]
1208 #[serial]
1209 fn load_model_with_non_utf8_path_returns_path_to_str_error() {
1210 use std::ffi::OsStr;
1211 use std::os::unix::ffi::OsStrExt;
1212
1213 let backend = LlamaBackend::init().unwrap();
1214 let model_params = LlamaModelParams::default();
1215 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1216
1217 let result = LlamaModel::load_from_file(&backend, non_utf8_path, &model_params);
1218
1219 assert_eq!(
1220 result.unwrap_err(),
1221 crate::LlamaModelLoadError::PathToStrError(non_utf8_path.to_path_buf())
1222 );
1223 }
1224
1225 #[cfg(unix)]
1226 #[test]
1227 #[serial]
1228 fn lora_adapter_init_with_non_utf8_path_returns_error() {
1229 use std::ffi::OsStr;
1230 use std::os::unix::ffi::OsStrExt;
1231
1232 let (_backend, model) = test_model::load_default_model().unwrap();
1233 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
1234
1235 let result = model.lora_adapter_init(non_utf8_path);
1236
1237 assert_eq!(
1238 result.unwrap_err(),
1239 crate::LlamaLoraAdapterInitError::PathToStrError(non_utf8_path.to_path_buf())
1240 );
1241 }
1242
1243 #[test]
1244 #[serial]
1245 fn lora_adapter_init_with_invalid_path_returns_error() {
1246 let (_backend, model) = test_model::load_default_model().unwrap();
1247 let result = model.lora_adapter_init("/nonexistent/path/lora.gguf");
1248
1249 assert_eq!(
1250 result.unwrap_err(),
1251 crate::LlamaLoraAdapterInitError::FileNotFound(std::path::PathBuf::from(
1252 "/nonexistent/path/lora.gguf"
1253 ))
1254 );
1255 }
1256
1257 #[test]
1258 #[serial]
1259 fn new_context_returns_valid_context() {
1260 let (backend, model) = test_model::load_default_model().unwrap();
1261 let ctx_params = crate::context::params::LlamaContextParams::default()
1262 .with_n_ctx(std::num::NonZeroU32::new(256));
1263 let context = model.new_context(&backend, ctx_params).unwrap();
1264
1265 assert!(context.n_ctx() > 0);
1266 }
1267
1268 #[test]
1269 #[serial]
1270 fn token_nl_returns_valid_token() {
1271 let (_backend, model) = test_model::load_default_model().unwrap();
1272 let nl_token = model.token_nl();
1273
1274 assert!(nl_token.0 >= 0);
1275 }
1276
1277 #[test]
1278 #[serial]
1279 fn decode_start_token_returns_valid_token() {
1280 let (_backend, model) = test_model::load_default_model().unwrap();
1281 let _decode_start = model.decode_start_token();
1282 }
1283
1284 #[test]
1285 #[serial]
1286 fn token_sep_returns_valid_token() {
1287 let (_backend, model) = test_model::load_default_model().unwrap();
1288 let _sep_token = model.token_sep();
1289 }
1290
1291 #[test]
1292 #[serial]
1293 fn token_to_piece_handles_large_token_requiring_buffer_resize() {
1294 let (_backend, model) = test_model::load_default_model().unwrap();
1295 let mut decoder = encoding_rs::UTF_8.new_decoder();
1296
1297 for (token, _) in model.tokens(true).take(200) {
1298 let result = model.token_to_piece(token, &mut decoder, true, None);
1299 assert!(result.is_ok());
1300 }
1301 }
1302
1303 #[test]
1304 #[serial]
1305 fn token_to_piece_bytes_insufficient_buffer_returns_error() {
1306 let (_backend, model) = test_model::load_default_model().unwrap();
1307 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1308 let result = model.token_to_piece_bytes(tokens[0], 1, false, None);
1309
1310 assert!(
1311 result
1312 .unwrap_err()
1313 .to_string()
1314 .contains("Insufficient Buffer Space")
1315 );
1316 }
1317
1318 #[test]
1319 #[serial]
1320 fn token_to_piece_with_lstrip() {
1321 let (_backend, model) = test_model::load_default_model().unwrap();
1322 let mut decoder = encoding_rs::UTF_8.new_decoder();
1323 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
1324 let result =
1325 model.token_to_piece(tokens[0], &mut decoder, false, std::num::NonZeroU16::new(1));
1326
1327 assert!(result.is_ok());
1328 }
1329
1330 #[test]
1331 #[serial]
1332 fn n_vocab_matches_tokens_iterator_count() {
1333 let (_backend, model) = test_model::load_default_model().unwrap();
1334 let n_vocab = model.n_vocab();
1335 let count = model.tokens(false).count();
1336
1337 assert_eq!(count, n_vocab as usize);
1338 }
1339
1340 #[test]
1341 #[serial]
1342 fn token_attr_returns_valid_attr() {
1343 let (_backend, model) = test_model::load_default_model().unwrap();
1344 let bos = model.token_bos();
1345 let _attr = model.token_attr(bos);
1346 }
1347
1348 #[test]
1349 #[serial]
1350 fn vocab_type_returns_valid_type() {
1351 let (_backend, model) = test_model::load_default_model().unwrap();
1352 let _vocab_type = model.vocab_type();
1353 }
1354
1355 #[test]
1356 #[serial]
1357 fn apply_chat_template_buffer_resize_with_long_messages() {
1358 let (_backend, model) = test_model::load_default_model().unwrap();
1359 let template = model.chat_template(None).unwrap();
1360 let long_content = "a".repeat(2000);
1361 let message =
1362 crate::model::LlamaChatMessage::new("user".to_string(), long_content).unwrap();
1363 let prompt = model.apply_chat_template(&template, &[message], true);
1364
1365 assert!(prompt.is_ok());
1366 assert!(!prompt.unwrap().is_empty());
1367 }
1368
1369 #[test]
1370 #[serial]
1371 fn meta_val_str_with_long_value_triggers_buffer_resize() {
1372 let (_backend, model) = test_model::load_default_model().unwrap();
1373 let count = model.meta_count();
1374
1375 for index in 0..count {
1376 let key = model.meta_key_by_index(index);
1377 let value = model.meta_val_str_by_index(index);
1378 assert!(key.is_ok());
1379 assert!(value.is_ok());
1380 }
1381 }
1382
1383 #[test]
1384 #[serial]
1385 fn str_to_token_with_add_bos_never() {
1386 let (_backend, model) = test_model::load_default_model().unwrap();
1387 let tokens_with_bos = model.str_to_token("hello", AddBos::Always).unwrap();
1388 let tokens_without_bos = model.str_to_token("hello", AddBos::Never).unwrap();
1389
1390 assert!(tokens_with_bos.len() >= tokens_without_bos.len());
1391 }
1392
1393 #[test]
1394 #[serial]
1395 fn apply_chat_template_with_tools_oaicompat_produces_result() {
1396 let (_backend, model) = test_model::load_default_model().unwrap();
1397 let template = model.chat_template(None).unwrap();
1398 let message =
1399 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1400 let result =
1401 model.apply_chat_template_with_tools_oaicompat(&template, &[message], None, None, true);
1402
1403 assert!(result.is_ok());
1404 assert!(!result.unwrap().prompt.is_empty());
1405 }
1406
1407 #[test]
1408 #[serial]
1409 fn apply_chat_template_with_tools_oaicompat_with_tools_json() {
1410 let (_backend, model) = test_model::load_default_model().unwrap();
1411 let template = model.chat_template(None).unwrap();
1412 let message =
1413 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1414 let tools =
1415 r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object"}}}]"#;
1416 let result = model.apply_chat_template_with_tools_oaicompat(
1417 &template,
1418 &[message],
1419 Some(tools),
1420 None,
1421 true,
1422 );
1423
1424 assert!(result.is_ok());
1425 }
1426
1427 #[test]
1428 #[serial]
1429 fn apply_chat_template_with_tools_oaicompat_with_json_schema() {
1430 let (_backend, model) = test_model::load_default_model().unwrap();
1431 let template = model.chat_template(None).unwrap();
1432 let message =
1433 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1434 let schema = r#"{"type":"object","properties":{"name":{"type":"string"}}}"#;
1435 let result = model.apply_chat_template_with_tools_oaicompat(
1436 &template,
1437 &[message],
1438 None,
1439 Some(schema),
1440 true,
1441 );
1442
1443 assert!(result.is_ok());
1444 }
1445
1446 #[test]
1447 #[serial]
1448 fn apply_chat_template_oaicompat_with_tools_and_tool_choice() {
1449 let (_backend, model) = test_model::load_default_model().unwrap();
1450 let template = model.chat_template(None).unwrap();
1451 let params = crate::openai::OpenAIChatTemplateParams {
1452 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1453 tools_json: Some(
1454 r#"[{"type":"function","function":{"name":"test","parameters":{"type":"object","properties":{}}}}]"#,
1455 ),
1456 tool_choice: Some("auto"),
1457 json_schema: None,
1458 grammar: None,
1459 reasoning_format: Some("none"),
1460 chat_template_kwargs: None,
1461 add_generation_prompt: true,
1462 use_jinja: true,
1463 parallel_tool_calls: false,
1464 enable_thinking: false,
1465 add_bos: false,
1466 add_eos: false,
1467 parse_tool_calls: true,
1468 };
1469 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1470
1471 assert!(result.is_ok());
1472 }
1473
1474 #[test]
1475 #[serial]
1476 fn apply_chat_template_oaicompat_with_json_schema_field() {
1477 let (_backend, model) = test_model::load_default_model().unwrap();
1478 let template = model.chat_template(None).unwrap();
1479 let params = crate::openai::OpenAIChatTemplateParams {
1480 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1481 tools_json: None,
1482 tool_choice: None,
1483 json_schema: Some(r#"{"type":"object","properties":{"name":{"type":"string"}}}"#),
1484 grammar: None,
1485 reasoning_format: Some("none"),
1486 chat_template_kwargs: None,
1487 add_generation_prompt: true,
1488 use_jinja: true,
1489 parallel_tool_calls: false,
1490 enable_thinking: false,
1491 add_bos: false,
1492 add_eos: false,
1493 parse_tool_calls: false,
1494 };
1495 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1496
1497 assert!(result.is_ok());
1498 }
1499
1500 #[test]
1501 #[serial]
1502 fn apply_chat_template_oaicompat_with_grammar_field() {
1503 let (_backend, model) = test_model::load_default_model().unwrap();
1504 let template = model.chat_template(None).unwrap();
1505 let params = crate::openai::OpenAIChatTemplateParams {
1506 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1507 tools_json: None,
1508 tool_choice: None,
1509 json_schema: None,
1510 grammar: Some("root ::= \"hello\""),
1511 reasoning_format: Some("none"),
1512 chat_template_kwargs: None,
1513 add_generation_prompt: true,
1514 use_jinja: true,
1515 parallel_tool_calls: false,
1516 enable_thinking: false,
1517 add_bos: false,
1518 add_eos: false,
1519 parse_tool_calls: false,
1520 };
1521 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1522
1523 assert!(result.is_ok());
1524 }
1525
1526 #[test]
1527 #[serial]
1528 fn apply_chat_template_oaicompat_with_kwargs_field() {
1529 let (_backend, model) = test_model::load_default_model().unwrap();
1530 let template = model.chat_template(None).unwrap();
1531 let params = crate::openai::OpenAIChatTemplateParams {
1532 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1533 tools_json: None,
1534 tool_choice: None,
1535 json_schema: None,
1536 grammar: None,
1537 reasoning_format: Some("none"),
1538 chat_template_kwargs: Some(r#"{"bos_token": "<|im_start|>"}"#),
1539 add_generation_prompt: true,
1540 use_jinja: true,
1541 parallel_tool_calls: false,
1542 enable_thinking: false,
1543 add_bos: false,
1544 add_eos: false,
1545 parse_tool_calls: false,
1546 };
1547 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1548
1549 assert!(result.is_ok());
1550 }
1551
1552 #[test]
1553 #[serial]
1554 fn chat_template_with_nonexistent_name_returns_error() {
1555 let (_backend, model) = test_model::load_default_model().unwrap();
1556
1557 let result = model.chat_template(Some("nonexistent_template_name_xyz"));
1558
1559 assert_eq!(
1560 result.unwrap_err(),
1561 crate::ChatTemplateError::MissingTemplate
1562 );
1563 }
1564
1565 #[test]
1566 #[serial]
1567 fn lora_adapter_init_with_invalid_gguf_returns_null_result() {
1568 let (_backend, model) = test_model::load_default_model().unwrap();
1569 let dummy_path = std::env::temp_dir().join("llama_test_dummy_lora.gguf");
1570 std::fs::write(&dummy_path, b"not a valid gguf").unwrap();
1571
1572 let result = model.lora_adapter_init(&dummy_path);
1573
1574 assert_eq!(
1575 result.unwrap_err(),
1576 crate::LlamaLoraAdapterInitError::NullResult
1577 );
1578 let _ = std::fs::remove_file(&dummy_path);
1579 }
1580
1581 #[test]
1582 #[serial]
1583 fn str_to_token_with_many_tokens_triggers_buffer_resize() {
1584 let (_backend, model) = test_model::load_default_model().unwrap();
1585 let many_numbers: String = (0..2000).map(|number| format!("{number} ")).collect();
1589
1590 let tokens = model.str_to_token(&many_numbers, AddBos::Always).unwrap();
1591
1592 assert!(tokens.len() > many_numbers.len() / 2);
1593 }
1594
1595 #[test]
1596 #[serial]
1597 fn rope_type_returns_valid_result_for_test_model() {
1598 let (_backend, model) = test_model::load_default_model().unwrap();
1599
1600 let _rope_type = model.rope_type();
1601 }
1602
1603 #[test]
1604 #[serial]
1605 fn meta_val_str_with_null_byte_in_key_returns_error() {
1606 let (_backend, model) = test_model::load_default_model().unwrap();
1607 let result = model.meta_val_str("key\0with_null");
1608
1609 assert!(result.is_err());
1610 }
1611
1612 #[test]
1613 #[serial]
1614 fn apply_chat_template_with_tools_null_byte_in_tools_returns_error() {
1615 let (_backend, model) = test_model::load_default_model().unwrap();
1616 let template = model.chat_template(None).unwrap();
1617 let message =
1618 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1619 let result = model.apply_chat_template_with_tools_oaicompat(
1620 &template,
1621 &[message],
1622 Some("tools\0null"),
1623 None,
1624 true,
1625 );
1626
1627 assert!(result.is_err());
1628 }
1629
1630 #[test]
1631 #[serial]
1632 fn apply_chat_template_with_tools_null_byte_in_json_schema_returns_error() {
1633 let (_backend, model) = test_model::load_default_model().unwrap();
1634 let template = model.chat_template(None).unwrap();
1635 let message =
1636 crate::model::LlamaChatMessage::new("user".to_string(), "hello".to_string()).unwrap();
1637 let result = model.apply_chat_template_with_tools_oaicompat(
1638 &template,
1639 &[message],
1640 None,
1641 Some("schema\0null"),
1642 true,
1643 );
1644
1645 assert!(result.is_err());
1646 }
1647
1648 #[test]
1649 #[serial]
1650 fn apply_chat_template_oaicompat_with_null_byte_in_messages_returns_error() {
1651 let (_backend, model) = test_model::load_default_model().unwrap();
1652 let template = model.chat_template(None).unwrap();
1653 let params = crate::openai::OpenAIChatTemplateParams {
1654 messages_json: "messages\0null",
1655 tools_json: None,
1656 tool_choice: None,
1657 json_schema: None,
1658 grammar: None,
1659 reasoning_format: None,
1660 chat_template_kwargs: None,
1661 add_generation_prompt: true,
1662 use_jinja: true,
1663 parallel_tool_calls: false,
1664 enable_thinking: false,
1665 add_bos: false,
1666 add_eos: false,
1667 parse_tool_calls: false,
1668 };
1669 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1670
1671 assert!(result.is_err());
1672 }
1673
1674 #[test]
1675 #[serial]
1676 fn apply_chat_template_oaicompat_with_null_byte_in_tools_returns_error() {
1677 let (_backend, model) = test_model::load_default_model().unwrap();
1678 let template = model.chat_template(None).unwrap();
1679 let params = crate::openai::OpenAIChatTemplateParams {
1680 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1681 tools_json: Some("tools\0null"),
1682 tool_choice: None,
1683 json_schema: None,
1684 grammar: None,
1685 reasoning_format: None,
1686 chat_template_kwargs: None,
1687 add_generation_prompt: true,
1688 use_jinja: true,
1689 parallel_tool_calls: false,
1690 enable_thinking: false,
1691 add_bos: false,
1692 add_eos: false,
1693 parse_tool_calls: false,
1694 };
1695 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1696
1697 assert!(result.is_err());
1698 }
1699
1700 #[test]
1701 #[serial]
1702 fn apply_chat_template_oaicompat_with_null_byte_in_tool_choice_returns_error() {
1703 let (_backend, model) = test_model::load_default_model().unwrap();
1704 let template = model.chat_template(None).unwrap();
1705 let params = crate::openai::OpenAIChatTemplateParams {
1706 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1707 tools_json: None,
1708 tool_choice: Some("choice\0null"),
1709 json_schema: None,
1710 grammar: None,
1711 reasoning_format: None,
1712 chat_template_kwargs: None,
1713 add_generation_prompt: true,
1714 use_jinja: true,
1715 parallel_tool_calls: false,
1716 enable_thinking: false,
1717 add_bos: false,
1718 add_eos: false,
1719 parse_tool_calls: false,
1720 };
1721 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1722
1723 assert!(result.is_err());
1724 }
1725
1726 #[test]
1727 #[serial]
1728 fn apply_chat_template_oaicompat_with_null_byte_in_json_schema_returns_error() {
1729 let (_backend, model) = test_model::load_default_model().unwrap();
1730 let template = model.chat_template(None).unwrap();
1731 let params = crate::openai::OpenAIChatTemplateParams {
1732 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1733 tools_json: None,
1734 tool_choice: None,
1735 json_schema: Some("schema\0null"),
1736 grammar: None,
1737 reasoning_format: None,
1738 chat_template_kwargs: None,
1739 add_generation_prompt: true,
1740 use_jinja: true,
1741 parallel_tool_calls: false,
1742 enable_thinking: false,
1743 add_bos: false,
1744 add_eos: false,
1745 parse_tool_calls: false,
1746 };
1747 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1748
1749 assert!(result.is_err());
1750 }
1751
1752 #[test]
1753 #[serial]
1754 fn apply_chat_template_oaicompat_with_null_byte_in_grammar_returns_error() {
1755 let (_backend, model) = test_model::load_default_model().unwrap();
1756 let template = model.chat_template(None).unwrap();
1757 let params = crate::openai::OpenAIChatTemplateParams {
1758 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1759 tools_json: None,
1760 tool_choice: None,
1761 json_schema: None,
1762 grammar: Some("grammar\0null"),
1763 reasoning_format: None,
1764 chat_template_kwargs: None,
1765 add_generation_prompt: true,
1766 use_jinja: true,
1767 parallel_tool_calls: false,
1768 enable_thinking: false,
1769 add_bos: false,
1770 add_eos: false,
1771 parse_tool_calls: false,
1772 };
1773 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1774
1775 assert!(result.is_err());
1776 }
1777
1778 #[test]
1779 #[serial]
1780 fn apply_chat_template_oaicompat_with_null_byte_in_reasoning_format_returns_error() {
1781 let (_backend, model) = test_model::load_default_model().unwrap();
1782 let template = model.chat_template(None).unwrap();
1783 let params = crate::openai::OpenAIChatTemplateParams {
1784 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1785 tools_json: None,
1786 tool_choice: None,
1787 json_schema: None,
1788 grammar: None,
1789 reasoning_format: Some("format\0null"),
1790 chat_template_kwargs: None,
1791 add_generation_prompt: true,
1792 use_jinja: true,
1793 parallel_tool_calls: false,
1794 enable_thinking: false,
1795 add_bos: false,
1796 add_eos: false,
1797 parse_tool_calls: false,
1798 };
1799 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1800
1801 assert!(result.is_err());
1802 }
1803
1804 #[test]
1805 #[serial]
1806 fn apply_chat_template_oaicompat_with_null_byte_in_kwargs_returns_error() {
1807 let (_backend, model) = test_model::load_default_model().unwrap();
1808 let template = model.chat_template(None).unwrap();
1809 let params = crate::openai::OpenAIChatTemplateParams {
1810 messages_json: r#"[{"role":"user","content":"hello"}]"#,
1811 tools_json: None,
1812 tool_choice: None,
1813 json_schema: None,
1814 grammar: None,
1815 reasoning_format: None,
1816 chat_template_kwargs: Some("kwargs\0null"),
1817 add_generation_prompt: true,
1818 use_jinja: true,
1819 parallel_tool_calls: false,
1820 enable_thinking: false,
1821 add_bos: false,
1822 add_eos: false,
1823 parse_tool_calls: false,
1824 };
1825 let result = model.apply_chat_template_oaicompat(&template, ¶ms);
1826
1827 assert!(result.is_err());
1828 }
1829
1830 #[test]
1831 #[serial]
1832 fn new_context_with_huge_ctx_returns_null_error() {
1833 let (_backend, model) = test_model::load_default_model().unwrap();
1834 let ctx_params = crate::context::params::LlamaContextParams::default()
1835 .with_n_ctx(std::num::NonZeroU32::new(u32::MAX));
1836
1837 let result = model.new_context(&_backend, ctx_params);
1838
1839 assert!(result.is_err());
1840 }
1841}