1use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32};
13
14const fn check_sampler_accept_status(
15 status: llama_cpp_bindings_sys::llama_rs_status,
16) -> Result<(), SamplerAcceptError> {
17 if status_is_ok(status) {
18 Ok(())
19 } else {
20 Err(SamplerAcceptError::FfiError(status_to_i32(status)))
21 }
22}
23
24const fn check_sampler_not_null(
25 sampler: *mut llama_cpp_bindings_sys::llama_sampler,
26) -> Result<LlamaSampler, GrammarError> {
27 if sampler.is_null() {
28 Err(GrammarError::NullGrammar)
29 } else {
30 Ok(LlamaSampler { sampler })
31 }
32}
33
34fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
35 i32::try_from(value).map_err(|convert_error| {
36 GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
37 })
38}
39
40fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
41 i32::try_from(value).map_err(|convert_error| {
42 SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
43 })
44}
45
46pub struct LlamaSampler {
48 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
50}
51
52impl Debug for LlamaSampler {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("LlamaSamplerChain").finish()
55 }
56}
57
58impl LlamaSampler {
59 #[must_use]
61 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
62 let token = unsafe {
63 llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
64 };
65
66 LlamaToken(token)
67 }
68
69 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
71 data_array.apply_sampler(self);
72 }
73
74 pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
80 self.try_accept(token)
81 }
82
83 pub fn accept_many(
89 &mut self,
90 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
91 ) -> Result<(), SamplerAcceptError> {
92 for token in tokens {
93 self.try_accept(*token.borrow())?;
94 }
95
96 Ok(())
97 }
98
99 pub fn with_tokens(
105 mut self,
106 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
107 ) -> Result<Self, SamplerAcceptError> {
108 self.accept_many(tokens)?;
109
110 Ok(self)
111 }
112
113 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
118 let sampler_result =
119 unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
120
121 check_sampler_accept_status(sampler_result)
122 }
123
124 pub fn reset(&mut self) {
128 unsafe {
129 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
130 }
131 }
132
133 #[must_use]
140 pub fn get_seed(&self) -> u32 {
141 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
142 }
143
144 #[must_use]
151 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
152 unsafe {
153 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
154 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
155 );
156
157 for sampler in samplers {
158 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
159
160 std::mem::forget(sampler);
163 }
164
165 Self { sampler: chain }
166 }
167 }
168
169 #[must_use]
201 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
202 Self::chain(samplers, false)
203 }
204
205 #[must_use]
230 pub fn temp(t: f32) -> Self {
231 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
232 Self { sampler }
233 }
234
235 #[must_use]
238 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
239 let sampler =
240 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
241 Self { sampler }
242 }
243
244 #[must_use]
270 pub fn top_k(k: i32) -> Self {
271 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
272 Self { sampler }
273 }
274
275 #[must_use]
301 pub fn top_n_sigma(n: f32) -> Self {
302 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
303 Self { sampler }
304 }
305
306 #[must_use]
308 pub fn typical(p: f32, min_keep: usize) -> Self {
309 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
310 Self { sampler }
311 }
312
313 #[must_use]
316 pub fn top_p(p: f32, min_keep: usize) -> Self {
317 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
318 Self { sampler }
319 }
320
321 #[must_use]
323 pub fn min_p(p: f32, min_keep: usize) -> Self {
324 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
325 Self { sampler }
326 }
327
328 #[must_use]
330 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
331 let sampler =
332 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
333 Self { sampler }
334 }
335
336 pub fn grammar(
341 model: &LlamaModel,
342 grammar_str: &str,
343 grammar_root: &str,
344 ) -> Result<Self, GrammarError> {
345 let (grammar_str, grammar_root) =
346 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
347
348 let sampler = unsafe {
349 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
350 model.vocab_ptr(),
351 grammar_str.as_ptr(),
352 grammar_root.as_ptr(),
353 )
354 };
355
356 check_sampler_not_null(sampler)
357 }
358
359 pub fn grammar_lazy(
366 model: &LlamaModel,
367 grammar_str: &str,
368 grammar_root: &str,
369 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
370 trigger_tokens: &[LlamaToken],
371 ) -> Result<Self, GrammarError> {
372 let (grammar_str, grammar_root) =
373 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
374 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
375
376 let mut trigger_word_ptrs: Vec<*const c_char> =
377 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
378
379 let sampler = unsafe {
380 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
381 model.vocab_ptr(),
382 grammar_str.as_ptr(),
383 grammar_root.as_ptr(),
384 trigger_word_ptrs.as_mut_ptr(),
385 trigger_word_ptrs.len(),
386 trigger_tokens.as_ptr().cast(),
387 trigger_tokens.len(),
388 )
389 };
390
391 check_sampler_not_null(sampler)
392 }
393
394 pub fn grammar_lazy_patterns(
403 model: &LlamaModel,
404 grammar_str: &str,
405 grammar_root: &str,
406 trigger_patterns: &[String],
407 trigger_tokens: &[LlamaToken],
408 ) -> Result<Self, GrammarError> {
409 let (grammar_str, grammar_root) =
410 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
411 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
412
413 let mut trigger_pattern_ptrs: Vec<*const c_char> =
414 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
415
416 let sampler = unsafe {
417 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
418 model.vocab_ptr(),
419 grammar_str.as_ptr(),
420 grammar_root.as_ptr(),
421 trigger_pattern_ptrs.as_mut_ptr(),
422 trigger_pattern_ptrs.len(),
423 trigger_tokens.as_ptr().cast(),
424 trigger_tokens.len(),
425 )
426 };
427
428 check_sampler_not_null(sampler)
429 }
430
431 #[cfg(feature = "llguidance")]
440 pub fn llguidance(
441 model: &LlamaModel,
442 grammar_kind: &str,
443 grammar_data: &str,
444 ) -> Result<Self, GrammarError> {
445 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
446 }
447
448 fn sanitize_grammar_strings(
449 grammar_str: &str,
450 grammar_root: &str,
451 ) -> Result<(CString, CString), GrammarError> {
452 if !grammar_str.contains(grammar_root) {
453 return Err(GrammarError::RootNotFound);
454 }
455
456 let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
457 let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
458
459 Ok((grammar, root))
460 }
461
462 fn sanitize_trigger_words(
463 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
464 ) -> Result<Vec<CString>, GrammarError> {
465 trigger_words
466 .into_iter()
467 .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
468 .collect()
469 }
470
471 fn sanitize_trigger_patterns(
472 trigger_patterns: &[String],
473 ) -> Result<Vec<CString>, GrammarError> {
474 trigger_patterns
475 .iter()
476 .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
477 .collect()
478 }
479
480 #[allow(missing_docs)]
487 pub fn dry(
488 model: &LlamaModel,
489 multiplier: f32,
490 base: f32,
491 allowed_length: i32,
492 penalty_last_n: i32,
493 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
494 ) -> Result<Self, GrammarError> {
495 let seq_breakers: Vec<CString> = seq_breakers
496 .into_iter()
497 .map(|s| CString::new(s.as_ref()))
498 .collect::<Result<Vec<_>, _>>()?;
499 let mut seq_breaker_pointers: Vec<*const c_char> =
500 seq_breakers.iter().map(|s| s.as_ptr()).collect();
501
502 let n_ctx_train = checked_u32_as_i32(model.n_ctx_train())?;
503 let sampler = unsafe {
504 llama_cpp_bindings_sys::llama_sampler_init_dry(
505 model.vocab_ptr(),
506 n_ctx_train,
507 multiplier,
508 base,
509 allowed_length,
510 penalty_last_n,
511 seq_breaker_pointers.as_mut_ptr(),
512 seq_breaker_pointers.len(),
513 )
514 };
515
516 Ok(Self { sampler })
517 }
518
519 #[must_use]
527 pub fn penalties(
528 penalty_last_n: i32,
529 penalty_repeat: f32,
530 penalty_freq: f32,
531 penalty_present: f32,
532 ) -> Self {
533 let sampler = unsafe {
534 llama_cpp_bindings_sys::llama_sampler_init_penalties(
535 penalty_last_n,
536 penalty_repeat,
537 penalty_freq,
538 penalty_present,
539 )
540 };
541 Self { sampler }
542 }
543
544 #[must_use]
560 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
561 let sampler = unsafe {
562 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
563 };
564 Self { sampler }
565 }
566
567 #[must_use]
578 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
579 let sampler =
580 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
581 Self { sampler }
582 }
583
584 #[must_use]
586 pub fn dist(seed: u32) -> Self {
587 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
588 Self { sampler }
589 }
590
591 #[must_use]
613 pub fn greedy() -> Self {
614 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
615 Self { sampler }
616 }
617
618 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
641 let bias_count = checked_usize_as_i32_sampling(biases.len())?;
642 let data = biases
643 .as_ptr()
644 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
645
646 let sampler = unsafe {
647 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
648 };
649
650 Ok(Self { sampler })
651 }
652}
653
654impl Drop for LlamaSampler {
655 fn drop(&mut self) {
656 unsafe {
657 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::LlamaSampler;
665 use crate::GrammarError;
666
667 #[test]
668 fn sanitize_grammar_strings_valid() {
669 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
670
671 assert!(result.is_ok());
672 }
673
674 #[test]
675 fn sanitize_grammar_strings_root_not_found() {
676 let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
677
678 assert_eq!(result.err(), Some(GrammarError::RootNotFound));
679 }
680
681 #[test]
682 fn sanitize_grammar_strings_null_byte_in_grammar() {
683 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
684
685 assert!(matches!(
686 result.err(),
687 Some(GrammarError::GrammarNullBytes(_))
688 ));
689 }
690
691 #[test]
692 fn sanitize_grammar_strings_null_byte_in_root() {
693 let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
694
695 assert!(matches!(
696 result.err(),
697 Some(GrammarError::GrammarNullBytes(_))
698 ));
699 }
700
701 #[test]
702 fn sanitize_trigger_words_valid() {
703 let words: Vec<&[u8]> = vec![b"hello", b"world"];
704 let result = LlamaSampler::sanitize_trigger_words(words);
705
706 assert!(result.is_ok());
707 assert_eq!(result.expect("valid trigger words").len(), 2);
708 }
709
710 #[test]
711 fn sanitize_trigger_words_empty_list() {
712 let words: Vec<&[u8]> = vec![];
713 let result = LlamaSampler::sanitize_trigger_words(words);
714
715 assert!(result.is_ok());
716 assert!(result.expect("valid trigger words").is_empty());
717 }
718
719 #[test]
720 fn sanitize_trigger_words_null_byte() {
721 let words: Vec<&[u8]> = vec![b"hel\0lo"];
722 let result = LlamaSampler::sanitize_trigger_words(words);
723
724 assert!(matches!(
725 result.err(),
726 Some(GrammarError::TriggerWordNullBytes(_))
727 ));
728 }
729
730 #[test]
731 fn sanitize_trigger_patterns_valid() {
732 let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
733 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
734
735 assert!(result.is_ok());
736 assert_eq!(result.expect("valid trigger patterns").len(), 2);
737 }
738
739 #[test]
740 fn sanitize_trigger_patterns_empty_list() {
741 let patterns: Vec<String> = vec![];
742 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
743
744 assert!(result.is_ok());
745 assert!(result.expect("valid trigger patterns").is_empty());
746 }
747
748 #[test]
749 fn sanitize_trigger_patterns_null_byte() {
750 let patterns = vec!["hel\0lo".to_string()];
751 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
752
753 assert!(matches!(
754 result.err(),
755 Some(GrammarError::GrammarNullBytes(_))
756 ));
757 }
758
759 #[test]
760 fn apply_modifies_data_array() {
761 use crate::token::LlamaToken;
762 use crate::token::data::LlamaTokenData;
763 use crate::token::data_array::LlamaTokenDataArray;
764
765 let sampler = LlamaSampler::greedy();
766 let mut data_array = LlamaTokenDataArray::new(
767 vec![
768 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
769 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
770 ],
771 false,
772 );
773
774 sampler.apply(&mut data_array);
775
776 assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
777 }
778
779 #[test]
780 fn accept_succeeds() {
781 let mut sampler = LlamaSampler::chain_simple([
782 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
783 LlamaSampler::greedy(),
784 ]);
785
786 sampler
787 .accept(crate::token::LlamaToken::new(1))
788 .expect("test: accept should succeed");
789 }
790
791 #[test]
792 fn try_accept_succeeds_on_penalties_sampler() {
793 let mut sampler = LlamaSampler::chain_simple([
794 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
795 LlamaSampler::greedy(),
796 ]);
797
798 let result = sampler.try_accept(crate::token::LlamaToken::new(42));
799
800 assert!(result.is_ok());
801 }
802
803 #[test]
804 fn accept_many_multiple_tokens() {
805 use crate::token::LlamaToken;
806
807 let mut sampler = LlamaSampler::chain_simple([
808 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
809 LlamaSampler::greedy(),
810 ]);
811
812 sampler
813 .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
814 .expect("test: accept_many should succeed");
815 }
816
817 #[test]
818 fn with_tokens_builder_pattern() {
819 use crate::token::LlamaToken;
820
821 let _sampler = LlamaSampler::chain_simple([
822 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
823 LlamaSampler::greedy(),
824 ])
825 .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
826 .expect("test: with_tokens should succeed");
827 }
828
829 #[test]
830 fn all_sampler_constructors() {
831 use crate::token::LlamaToken;
832 use crate::token::logit_bias::LlamaLogitBias;
833
834 let _temp = LlamaSampler::temp(0.8);
835 let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
836 let _top_k = LlamaSampler::top_k(40);
837 let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
838 let _top_p = LlamaSampler::top_p(0.9, 1);
839 let _min_p = LlamaSampler::min_p(0.05, 1);
840 let _typical = LlamaSampler::typical(0.9, 1);
841 let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
842 let _dist = LlamaSampler::dist(42);
843 let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
844 let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
845 let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
846 let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
847 let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
848 }
849
850 #[test]
851 fn reset_and_get_seed() {
852 let mut sampler = LlamaSampler::dist(42);
853 sampler.reset();
854 let _seed = sampler.get_seed();
855 }
856
857 #[test]
858 fn debug_formatting() {
859 let sampler = LlamaSampler::greedy();
860 let debug_output = format!("{sampler:?}");
861 assert!(debug_output.contains("LlamaSampler"));
862 }
863
864 #[cfg(feature = "tests_that_use_llms")]
865 #[test]
866 #[serial_test::serial]
867 fn dry_sampler_with_model() {
868 let (_backend, model) = crate::test_model::load_default_model().unwrap();
869 let breakers: Vec<&[u8]> = vec![b"\n", b"\t"];
870 let _sampler = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, &breakers);
871 }
872
873 #[cfg(feature = "tests_that_use_llms")]
874 #[test]
875 #[serial_test::serial]
876 fn dry_sampler_with_null_byte_in_seq_breakers_returns_error() {
877 let (_backend, model) = crate::test_model::load_default_model().unwrap();
878 let breakers: Vec<&[u8]> = vec![b"hello\0world"];
879 let result = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, breakers);
880
881 assert!(result.is_err());
882 }
883
884 #[cfg(feature = "tests_that_use_llms")]
885 #[test]
886 #[serial_test::serial]
887 fn grammar_returns_sampler_for_valid_grammar() {
888 let (_backend, model) = crate::test_model::load_default_model().unwrap();
889 let sampler = LlamaSampler::grammar(&model, "root ::= \"hello\"", "root");
890
891 assert!(sampler.is_ok());
892 }
893
894 #[cfg(feature = "tests_that_use_llms")]
895 #[test]
896 #[serial_test::serial]
897 fn grammar_lazy_returns_sampler_for_valid_grammar_with_triggers() {
898 let (_backend, model) = crate::test_model::load_default_model().unwrap();
899 let trigger_words: Vec<&[u8]> = vec![b"function"];
900 let sampler =
901 LlamaSampler::grammar_lazy(&model, "root ::= \"hello\"", "root", trigger_words, &[]);
902
903 assert!(sampler.is_ok());
904 }
905
906 #[cfg(feature = "tests_that_use_llms")]
907 #[test]
908 #[serial_test::serial]
909 fn grammar_lazy_patterns_returns_sampler_for_valid_grammar_with_patterns() {
910 let (_backend, model) = crate::test_model::load_default_model().unwrap();
911 let patterns = vec!["\\{.*".to_string()];
912 let sampler = LlamaSampler::grammar_lazy_patterns(
913 &model,
914 "root ::= \"hello\"",
915 "root",
916 &patterns,
917 &[],
918 );
919
920 assert!(sampler.is_ok());
921 }
922
923 #[cfg(feature = "tests_that_use_llms")]
924 #[test]
925 #[serial_test::serial]
926 fn sample_returns_token_after_decode() {
927 use crate::context::params::LlamaContextParams;
928 use crate::llama_batch::LlamaBatch;
929 use crate::model::AddBos;
930 use crate::token::LlamaToken;
931
932 let (backend, model) = crate::test_model::load_default_model().unwrap();
933 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
934 let mut context = model.new_context(&backend, ctx_params).unwrap();
935 let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
936 let mut batch = LlamaBatch::new(512, 1).unwrap();
937 batch.add_sequence(&tokens, 0, false).unwrap();
938 context.decode(&mut batch).unwrap();
939 let mut sampler =
940 LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
941 let token = sampler.sample(&context, batch.n_tokens() - 1);
942
943 assert_ne!(token, LlamaToken::new(-1));
944 }
945
946 #[test]
947 fn checked_u32_as_i32_overflow() {
948 let result = super::checked_u32_as_i32(u32::MAX);
949 assert!(result.is_err());
950 }
951
952 #[test]
953 fn checked_usize_as_i32_sampling_overflow() {
954 let result = super::checked_usize_as_i32_sampling(usize::MAX);
955 assert!(result.is_err());
956 }
957
958 #[test]
959 fn check_sampler_accept_status_error() {
960 let result =
961 super::check_sampler_accept_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION);
962 assert!(result.is_err());
963 }
964
965 #[test]
966 fn check_sampler_not_null_returns_error() {
967 let result = super::check_sampler_not_null(std::ptr::null_mut());
968 assert!(result.is_err());
969 }
970}