1use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::ffi_error_reader::read_and_free_cpp_error;
9use crate::model::LlamaModel;
10use crate::token::LlamaToken;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
14
15fn check_sampler_accept_status(
16 status: llama_cpp_bindings_sys::llama_rs_status,
17 error_ptr: *mut c_char,
18) -> Result<(), SamplerAcceptError> {
19 match status {
20 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(()),
21 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
22 Err(SamplerAcceptError::InvalidArgument)
23 }
24 _ => Err(SamplerAcceptError::CppException(unsafe {
25 read_and_free_cpp_error(error_ptr)
26 })),
27 }
28}
29
30fn check_sampler_not_null(
31 sampler: *mut llama_cpp_bindings_sys::llama_sampler,
32 error_ptr: *mut c_char,
33) -> Result<LlamaSampler, GrammarError> {
34 if sampler.is_null() {
35 Err(GrammarError::NullGrammar(unsafe {
36 read_and_free_cpp_error(error_ptr)
37 }))
38 } else {
39 Ok(LlamaSampler { sampler })
40 }
41}
42
43fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
44 i32::try_from(value).map_err(|convert_error| {
45 GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
46 })
47}
48
49fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
50 i32::try_from(value).map_err(|convert_error| {
51 SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
52 })
53}
54
55pub struct LlamaSampler {
57 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
59}
60
61impl Debug for LlamaSampler {
62 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("LlamaSamplerChain").finish()
64 }
65}
66
67impl LlamaSampler {
68 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
74 let mut token: i32 = -1;
75 let mut error_ptr: *mut c_char = std::ptr::null_mut();
76
77 let status = unsafe {
78 llama_cpp_bindings_sys::llama_rs_sampler_sample(
79 self.sampler,
80 ctx.context.as_ptr(),
81 idx,
82 &raw mut token,
83 &raw mut error_ptr,
84 )
85 };
86
87 match status {
88 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(LlamaToken(token)),
89 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
90 Err(SampleError::InvalidArgument)
91 }
92 _ => Err(SampleError::CppException(unsafe {
93 read_and_free_cpp_error(error_ptr)
94 })),
95 }
96 }
97
98 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
100 data_array.apply_sampler(self);
101 }
102
103 pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
109 self.try_accept(token)
110 }
111
112 pub fn accept_many(
118 &mut self,
119 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
120 ) -> Result<(), SamplerAcceptError> {
121 for token in tokens {
122 self.try_accept(*token.borrow())?;
123 }
124
125 Ok(())
126 }
127
128 pub fn with_tokens(
134 mut self,
135 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
136 ) -> Result<Self, SamplerAcceptError> {
137 self.accept_many(tokens)?;
138
139 Ok(self)
140 }
141
142 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
147 let mut error_ptr: *mut c_char = std::ptr::null_mut();
148
149 let status = unsafe {
150 llama_cpp_bindings_sys::llama_rs_sampler_accept(
151 self.sampler,
152 token.0,
153 &raw mut error_ptr,
154 )
155 };
156
157 check_sampler_accept_status(status, error_ptr)
158 }
159
160 pub fn reset(&mut self) {
164 unsafe {
165 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
166 }
167 }
168
169 #[must_use]
176 pub fn get_seed(&self) -> u32 {
177 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
178 }
179
180 #[must_use]
187 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
188 unsafe {
189 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
190 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
191 );
192
193 for sampler in samplers {
194 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
195
196 std::mem::forget(sampler);
199 }
200
201 Self { sampler: chain }
202 }
203 }
204
205 #[must_use]
237 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
238 Self::chain(samplers, false)
239 }
240
241 #[must_use]
266 pub fn temp(t: f32) -> Self {
267 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
268 Self { sampler }
269 }
270
271 #[must_use]
274 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
275 let sampler =
276 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
277 Self { sampler }
278 }
279
280 #[must_use]
306 pub fn top_k(k: i32) -> Self {
307 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
308 Self { sampler }
309 }
310
311 #[must_use]
337 pub fn top_n_sigma(n: f32) -> Self {
338 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
339 Self { sampler }
340 }
341
342 #[must_use]
344 pub fn typical(p: f32, min_keep: usize) -> Self {
345 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
346 Self { sampler }
347 }
348
349 #[must_use]
352 pub fn top_p(p: f32, min_keep: usize) -> Self {
353 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
354 Self { sampler }
355 }
356
357 #[must_use]
359 pub fn min_p(p: f32, min_keep: usize) -> Self {
360 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
361 Self { sampler }
362 }
363
364 #[must_use]
366 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
367 let sampler =
368 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
369 Self { sampler }
370 }
371
372 pub fn grammar(
377 model: &LlamaModel,
378 grammar_str: &str,
379 grammar_root: &str,
380 ) -> Result<Self, GrammarError> {
381 let (grammar_str, grammar_root) =
382 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
383 let mut error_ptr: *mut c_char = std::ptr::null_mut();
384
385 let sampler = unsafe {
386 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
387 model.vocab_ptr(),
388 grammar_str.as_ptr(),
389 grammar_root.as_ptr(),
390 &raw mut error_ptr,
391 )
392 };
393
394 check_sampler_not_null(sampler, error_ptr)
395 }
396
397 pub fn grammar_lazy(
404 model: &LlamaModel,
405 grammar_str: &str,
406 grammar_root: &str,
407 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
408 trigger_tokens: &[LlamaToken],
409 ) -> Result<Self, GrammarError> {
410 let (grammar_str, grammar_root) =
411 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
412 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
413 let mut error_ptr: *mut c_char = std::ptr::null_mut();
414
415 let mut trigger_word_ptrs: Vec<*const c_char> =
416 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
417
418 let sampler = unsafe {
419 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
420 model.vocab_ptr(),
421 grammar_str.as_ptr(),
422 grammar_root.as_ptr(),
423 trigger_word_ptrs.as_mut_ptr(),
424 trigger_word_ptrs.len(),
425 trigger_tokens.as_ptr().cast(),
426 trigger_tokens.len(),
427 &raw mut error_ptr,
428 )
429 };
430
431 check_sampler_not_null(sampler, error_ptr)
432 }
433
434 pub fn grammar_lazy_patterns(
443 model: &LlamaModel,
444 grammar_str: &str,
445 grammar_root: &str,
446 trigger_patterns: &[String],
447 trigger_tokens: &[LlamaToken],
448 ) -> Result<Self, GrammarError> {
449 let (grammar_str, grammar_root) =
450 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
451 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
452 let mut error_ptr: *mut c_char = std::ptr::null_mut();
453
454 let mut trigger_pattern_ptrs: Vec<*const c_char> =
455 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
456
457 let sampler = unsafe {
458 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
459 model.vocab_ptr(),
460 grammar_str.as_ptr(),
461 grammar_root.as_ptr(),
462 trigger_pattern_ptrs.as_mut_ptr(),
463 trigger_pattern_ptrs.len(),
464 trigger_tokens.as_ptr().cast(),
465 trigger_tokens.len(),
466 &raw mut error_ptr,
467 )
468 };
469
470 check_sampler_not_null(sampler, error_ptr)
471 }
472
473 #[cfg(feature = "llguidance")]
482 pub fn llguidance(
483 model: &LlamaModel,
484 grammar_kind: &str,
485 grammar_data: &str,
486 ) -> Result<Self, GrammarError> {
487 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
488 }
489
490 fn sanitize_grammar_strings(
491 grammar_str: &str,
492 grammar_root: &str,
493 ) -> Result<(CString, CString), GrammarError> {
494 if !grammar_str.contains(grammar_root) {
495 return Err(GrammarError::RootNotFound);
496 }
497
498 let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
499 let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
500
501 Ok((grammar, root))
502 }
503
504 fn sanitize_trigger_words(
505 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
506 ) -> Result<Vec<CString>, GrammarError> {
507 trigger_words
508 .into_iter()
509 .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
510 .collect()
511 }
512
513 fn sanitize_trigger_patterns(
514 trigger_patterns: &[String],
515 ) -> Result<Vec<CString>, GrammarError> {
516 trigger_patterns
517 .iter()
518 .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
519 .collect()
520 }
521
522 #[allow(missing_docs)]
529 pub fn dry(
530 model: &LlamaModel,
531 multiplier: f32,
532 base: f32,
533 allowed_length: i32,
534 penalty_last_n: i32,
535 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
536 ) -> Result<Self, GrammarError> {
537 let seq_breakers: Vec<CString> = seq_breakers
538 .into_iter()
539 .map(|s| CString::new(s.as_ref()))
540 .collect::<Result<Vec<_>, _>>()?;
541 let mut seq_breaker_pointers: Vec<*const c_char> =
542 seq_breakers.iter().map(|s| s.as_ptr()).collect();
543
544 let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
545 GrammarError::IntegerOverflow(format!(
546 "n_ctx_train does not fit into u32: {convert_error}"
547 ))
548 })?;
549 let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
550 let sampler = unsafe {
551 llama_cpp_bindings_sys::llama_sampler_init_dry(
552 model.vocab_ptr(),
553 n_ctx_train,
554 multiplier,
555 base,
556 allowed_length,
557 penalty_last_n,
558 seq_breaker_pointers.as_mut_ptr(),
559 seq_breaker_pointers.len(),
560 )
561 };
562
563 Ok(Self { sampler })
564 }
565
566 #[must_use]
574 pub fn penalties(
575 penalty_last_n: i32,
576 penalty_repeat: f32,
577 penalty_freq: f32,
578 penalty_present: f32,
579 ) -> Self {
580 let sampler = unsafe {
581 llama_cpp_bindings_sys::llama_sampler_init_penalties(
582 penalty_last_n,
583 penalty_repeat,
584 penalty_freq,
585 penalty_present,
586 )
587 };
588 Self { sampler }
589 }
590
591 #[must_use]
607 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
608 let sampler = unsafe {
609 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
610 };
611 Self { sampler }
612 }
613
614 #[must_use]
625 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
626 let sampler =
627 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
628 Self { sampler }
629 }
630
631 #[must_use]
633 pub fn dist(seed: u32) -> Self {
634 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
635 Self { sampler }
636 }
637
638 #[must_use]
660 pub fn greedy() -> Self {
661 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
662 Self { sampler }
663 }
664
665 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
688 let bias_count = checked_usize_as_i32_sampling(biases.len())?;
689 let data = biases
690 .as_ptr()
691 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
692
693 let sampler = unsafe {
694 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
695 };
696
697 Ok(Self { sampler })
698 }
699}
700
701impl Drop for LlamaSampler {
702 fn drop(&mut self) {
703 unsafe {
704 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
705 }
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::LlamaSampler;
712 use crate::GrammarError;
713
714 #[test]
715 fn sanitize_grammar_strings_valid() {
716 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
717
718 assert!(result.is_ok());
719 }
720
721 #[test]
722 fn sanitize_grammar_strings_root_not_found() {
723 let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
724
725 assert_eq!(result.err(), Some(GrammarError::RootNotFound));
726 }
727
728 #[test]
729 fn sanitize_grammar_strings_null_byte_in_grammar() {
730 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
731
732 assert!(matches!(
733 result.err(),
734 Some(GrammarError::GrammarNullBytes(_))
735 ));
736 }
737
738 #[test]
739 fn sanitize_grammar_strings_null_byte_in_root() {
740 let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
741
742 assert!(matches!(
743 result.err(),
744 Some(GrammarError::GrammarNullBytes(_))
745 ));
746 }
747
748 #[test]
749 fn sanitize_trigger_words_valid() {
750 let words: Vec<&[u8]> = vec![b"hello", b"world"];
751 let result = LlamaSampler::sanitize_trigger_words(words);
752
753 assert!(result.is_ok());
754 assert_eq!(result.expect("valid trigger words").len(), 2);
755 }
756
757 #[test]
758 fn sanitize_trigger_words_empty_list() {
759 let words: Vec<&[u8]> = vec![];
760 let result = LlamaSampler::sanitize_trigger_words(words);
761
762 assert!(result.is_ok());
763 assert!(result.expect("valid trigger words").is_empty());
764 }
765
766 #[test]
767 fn sanitize_trigger_words_null_byte() {
768 let words: Vec<&[u8]> = vec![b"hel\0lo"];
769 let result = LlamaSampler::sanitize_trigger_words(words);
770
771 assert!(matches!(
772 result.err(),
773 Some(GrammarError::TriggerWordNullBytes(_))
774 ));
775 }
776
777 #[test]
778 fn sanitize_trigger_patterns_valid() {
779 let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
780 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
781
782 assert!(result.is_ok());
783 assert_eq!(result.expect("valid trigger patterns").len(), 2);
784 }
785
786 #[test]
787 fn sanitize_trigger_patterns_empty_list() {
788 let patterns: Vec<String> = vec![];
789 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
790
791 assert!(result.is_ok());
792 assert!(result.expect("valid trigger patterns").is_empty());
793 }
794
795 #[test]
796 fn sanitize_trigger_patterns_null_byte() {
797 let patterns = vec!["hel\0lo".to_string()];
798 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
799
800 assert!(matches!(
801 result.err(),
802 Some(GrammarError::GrammarNullBytes(_))
803 ));
804 }
805
806 #[test]
807 fn apply_modifies_data_array() {
808 use crate::token::LlamaToken;
809 use crate::token::data::LlamaTokenData;
810 use crate::token::data_array::LlamaTokenDataArray;
811
812 let sampler = LlamaSampler::greedy();
813 let mut data_array = LlamaTokenDataArray::new(
814 vec![
815 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
816 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
817 ],
818 false,
819 );
820
821 sampler.apply(&mut data_array);
822
823 assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
824 }
825
826 #[test]
827 fn accept_succeeds() {
828 let mut sampler = LlamaSampler::chain_simple([
829 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
830 LlamaSampler::greedy(),
831 ]);
832
833 sampler
834 .accept(crate::token::LlamaToken::new(1))
835 .expect("test: accept should succeed");
836 }
837
838 #[test]
839 fn try_accept_succeeds_on_penalties_sampler() {
840 let mut sampler = LlamaSampler::chain_simple([
841 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
842 LlamaSampler::greedy(),
843 ]);
844
845 let result = sampler.try_accept(crate::token::LlamaToken::new(42));
846
847 assert!(result.is_ok());
848 }
849
850 #[test]
851 fn accept_many_multiple_tokens() {
852 use crate::token::LlamaToken;
853
854 let mut sampler = LlamaSampler::chain_simple([
855 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
856 LlamaSampler::greedy(),
857 ]);
858
859 sampler
860 .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
861 .expect("test: accept_many should succeed");
862 }
863
864 #[test]
865 fn with_tokens_builder_pattern() {
866 use crate::token::LlamaToken;
867
868 let _sampler = LlamaSampler::chain_simple([
869 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
870 LlamaSampler::greedy(),
871 ])
872 .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
873 .expect("test: with_tokens should succeed");
874 }
875
876 #[test]
877 fn all_sampler_constructors() {
878 use crate::token::LlamaToken;
879 use crate::token::logit_bias::LlamaLogitBias;
880
881 let _temp = LlamaSampler::temp(0.8);
882 let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
883 let _top_k = LlamaSampler::top_k(40);
884 let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
885 let _top_p = LlamaSampler::top_p(0.9, 1);
886 let _min_p = LlamaSampler::min_p(0.05, 1);
887 let _typical = LlamaSampler::typical(0.9, 1);
888 let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
889 let _dist = LlamaSampler::dist(42);
890 let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
891 let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
892 let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
893 let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
894 let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
895 }
896
897 #[test]
898 fn reset_and_get_seed() {
899 let mut sampler = LlamaSampler::dist(42);
900 sampler.reset();
901 let _seed = sampler.get_seed();
902 }
903
904 #[test]
905 fn debug_formatting() {
906 let sampler = LlamaSampler::greedy();
907 let debug_output = format!("{sampler:?}");
908 assert!(debug_output.contains("LlamaSampler"));
909 }
910
911 #[cfg(feature = "tests_that_use_llms")]
912 #[test]
913 #[serial_test::serial]
914 fn dry_sampler_with_model() {
915 let (_backend, model) = crate::test_model::load_default_model().unwrap();
916 let breakers: Vec<&[u8]> = vec![b"\n", b"\t"];
917 let _sampler = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, &breakers);
918 }
919
920 #[cfg(feature = "tests_that_use_llms")]
921 #[test]
922 #[serial_test::serial]
923 fn dry_sampler_with_null_byte_in_seq_breakers_returns_error() {
924 let (_backend, model) = crate::test_model::load_default_model().unwrap();
925 let breakers: Vec<&[u8]> = vec![b"hello\0world"];
926 let result = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, breakers);
927
928 assert!(result.is_err());
929 }
930
931 #[cfg(feature = "tests_that_use_llms")]
932 #[test]
933 #[serial_test::serial]
934 fn grammar_returns_sampler_for_valid_grammar() {
935 let (_backend, model) = crate::test_model::load_default_model().unwrap();
936 let sampler = LlamaSampler::grammar(&model, "root ::= \"hello\"", "root");
937
938 assert!(sampler.is_ok());
939 }
940
941 #[cfg(feature = "tests_that_use_llms")]
942 #[test]
943 #[serial_test::serial]
944 fn grammar_lazy_returns_sampler_for_valid_grammar_with_triggers() {
945 let (_backend, model) = crate::test_model::load_default_model().unwrap();
946 let trigger_words: Vec<&[u8]> = vec![b"function"];
947 let sampler =
948 LlamaSampler::grammar_lazy(&model, "root ::= \"hello\"", "root", trigger_words, &[]);
949
950 assert!(sampler.is_ok());
951 }
952
953 #[cfg(feature = "tests_that_use_llms")]
954 #[test]
955 #[serial_test::serial]
956 fn grammar_lazy_patterns_returns_sampler_for_valid_grammar_with_patterns() {
957 let (_backend, model) = crate::test_model::load_default_model().unwrap();
958 let patterns = vec!["\\{.*".to_string()];
959 let sampler = LlamaSampler::grammar_lazy_patterns(
960 &model,
961 "root ::= \"hello\"",
962 "root",
963 &patterns,
964 &[],
965 );
966
967 assert!(sampler.is_ok());
968 }
969
970 #[cfg(feature = "tests_that_use_llms")]
971 #[test]
972 #[serial_test::serial]
973 fn sample_returns_token_after_decode() {
974 use crate::context::params::LlamaContextParams;
975 use crate::llama_batch::LlamaBatch;
976 use crate::model::AddBos;
977 use crate::token::LlamaToken;
978
979 let (backend, model) = crate::test_model::load_default_model().unwrap();
980 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
981 let mut context = model.new_context(&backend, ctx_params).unwrap();
982 let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
983 let mut batch = LlamaBatch::new(512, 1).unwrap();
984 batch.add_sequence(&tokens, 0, false).unwrap();
985 context.decode(&mut batch).unwrap();
986 let mut sampler =
987 LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
988 let result = sampler.sample(&context, batch.n_tokens() - 1);
989
990 assert!(result.is_ok());
991 }
992
993 #[test]
994 fn checked_u32_as_i32_overflow() {
995 let result = super::checked_u32_as_i32(u32::MAX);
996 assert!(result.is_err());
997 }
998
999 #[test]
1000 fn checked_usize_as_i32_sampling_overflow() {
1001 let result = super::checked_usize_as_i32_sampling(usize::MAX);
1002 assert!(result.is_err());
1003 }
1004
1005 #[test]
1006 fn check_sampler_accept_status_ok() {
1007 let result = super::check_sampler_accept_status(
1008 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1009 std::ptr::null_mut(),
1010 );
1011
1012 assert!(result.is_ok());
1013 }
1014
1015 #[test]
1016 fn check_sampler_accept_status_invalid_argument() {
1017 let result = super::check_sampler_accept_status(
1018 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
1019 std::ptr::null_mut(),
1020 );
1021
1022 assert!(matches!(
1023 result,
1024 Err(crate::SamplerAcceptError::InvalidArgument)
1025 ));
1026 }
1027
1028 #[test]
1029 fn check_sampler_accept_status_exception() {
1030 let result = super::check_sampler_accept_status(
1031 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
1032 std::ptr::null_mut(),
1033 );
1034
1035 assert!(matches!(
1036 result,
1037 Err(crate::SamplerAcceptError::CppException(_))
1038 ));
1039 }
1040
1041 #[test]
1042 fn check_sampler_not_null_returns_error() {
1043 let result = super::check_sampler_not_null(std::ptr::null_mut(), std::ptr::null_mut());
1044
1045 assert!(result.is_err());
1046 }
1047}