1use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::ffi_error_reader::read_and_free_cpp_error;
9use crate::model::LlamaModel;
10use crate::token::LlamaToken;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
14
15fn check_sampler_accept_status(
16 status: llama_cpp_bindings_sys::llama_rs_sampler_accept_status,
17 error_ptr: *mut c_char,
18) -> Result<(), SamplerAcceptError> {
19 match status {
20 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK => Ok(()),
21 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_ERROR_STRING_ALLOCATION_FAILED => {
22 Err(SamplerAcceptError::NotEnoughMemory)
23 }
24 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION => {
25 let message = unsafe { read_and_free_cpp_error(error_ptr) };
26 Err(SamplerAcceptError::GrammarStateCorrupted { message })
27 }
28 other => unreachable!("llama_rs_sampler_accept returned unrecognized status {other}"),
29 }
30}
31
32fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
33 i32::try_from(value).map_err(|convert_error| {
34 GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
35 })
36}
37
38fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
39 i32::try_from(value).map_err(|convert_error| {
40 SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
41 })
42}
43
44pub struct LlamaSampler {
46 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
48}
49
50impl Debug for LlamaSampler {
51 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("LlamaSamplerChain").finish()
53 }
54}
55
56impl LlamaSampler {
57 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
63 let mut token: i32 = -1;
64 let mut error_ptr: *mut c_char = std::ptr::null_mut();
65
66 let status = unsafe {
67 llama_cpp_bindings_sys::llama_rs_sampler_sample(
68 self.sampler,
69 ctx.context.as_ptr(),
70 idx,
71 &raw mut token,
72 &raw mut error_ptr,
73 )
74 };
75
76 match status {
77 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_OK => Ok(LlamaToken(token)),
78 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_ERROR_STRING_ALLOCATION_FAILED => {
79 Err(SampleError::NotEnoughMemory)
80 }
81 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_VENDORED_THREW_CXX_EXCEPTION => {
82 let message = unsafe { read_and_free_cpp_error(error_ptr) };
83 Err(SampleError::Reported { message })
84 }
85 other => unreachable!("llama_rs_sampler_sample returned unrecognized status {other}"),
86 }
87 }
88
89 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
91 data_array.apply_sampler(self);
92 }
93
94 pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
100 self.try_accept(token)
101 }
102
103 pub fn accept_many(
109 &mut self,
110 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
111 ) -> Result<(), SamplerAcceptError> {
112 for token in tokens {
113 self.try_accept(*token.borrow())?;
114 }
115
116 Ok(())
117 }
118
119 pub fn with_tokens(
125 mut self,
126 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
127 ) -> Result<Self, SamplerAcceptError> {
128 self.accept_many(tokens)?;
129
130 Ok(self)
131 }
132
133 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
138 let mut error_ptr: *mut c_char = std::ptr::null_mut();
139
140 let status = unsafe {
141 llama_cpp_bindings_sys::llama_rs_sampler_accept(
142 self.sampler,
143 token.0,
144 &raw mut error_ptr,
145 )
146 };
147
148 check_sampler_accept_status(status, error_ptr)
149 }
150
151 pub fn reset(&mut self) {
155 unsafe {
156 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
157 }
158 }
159
160 #[must_use]
167 pub fn get_seed(&self) -> u32 {
168 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
169 }
170
171 #[must_use]
178 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
179 unsafe {
180 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
181 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
182 );
183
184 for sampler in samplers {
185 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
186 std::mem::forget(sampler);
187 }
188
189 Self { sampler: chain }
190 }
191 }
192
193 #[must_use]
225 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
226 Self::chain(samplers, false)
227 }
228
229 #[must_use]
254 pub fn temp(t: f32) -> Self {
255 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
256 Self { sampler }
257 }
258
259 #[must_use]
262 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
263 let sampler =
264 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
265 Self { sampler }
266 }
267
268 #[must_use]
294 pub fn top_k(k: i32) -> Self {
295 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
296 Self { sampler }
297 }
298
299 #[must_use]
325 pub fn top_n_sigma(n: f32) -> Self {
326 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
327 Self { sampler }
328 }
329
330 #[must_use]
332 pub fn typical(p: f32, min_keep: usize) -> Self {
333 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
334 Self { sampler }
335 }
336
337 #[must_use]
340 pub fn top_p(p: f32, min_keep: usize) -> Self {
341 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
342 Self { sampler }
343 }
344
345 #[must_use]
347 pub fn min_p(p: f32, min_keep: usize) -> Self {
348 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
349 Self { sampler }
350 }
351
352 #[must_use]
354 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
355 let sampler =
356 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
357 Self { sampler }
358 }
359
360 pub fn grammar(
365 model: &LlamaModel,
366 grammar_str: &str,
367 grammar_root: &str,
368 ) -> Result<Self, GrammarError> {
369 let (grammar_str, grammar_root) =
370 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
371 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
372 let mut error_ptr: *mut c_char = std::ptr::null_mut();
373
374 let status = unsafe {
375 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
376 model.vocab_ptr(),
377 grammar_str.as_ptr(),
378 grammar_root.as_ptr(),
379 &raw mut sampler,
380 &raw mut error_ptr,
381 )
382 };
383
384 match status {
385 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_OK => {
386 Ok(Self { sampler })
387 }
388 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_RETURNED_NULL => {
389 Err(GrammarError::GrammarMalformed)
390 }
391 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_ERROR_STRING_ALLOCATION_FAILED => {
392 Err(GrammarError::NotEnoughMemory)
393 }
394 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_THREW_CXX_EXCEPTION => {
395 let message = unsafe { read_and_free_cpp_error(error_ptr) };
396 Err(GrammarError::Reported { message })
397 }
398 other => unreachable!(
399 "llama_rs_sampler_init_grammar returned unrecognized status {other}"
400 ),
401 }
402 }
403
404 pub fn grammar_lazy(
411 model: &LlamaModel,
412 grammar_str: &str,
413 grammar_root: &str,
414 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
415 trigger_tokens: &[LlamaToken],
416 ) -> Result<Self, GrammarError> {
417 let (grammar_str, grammar_root) =
418 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
419 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
420 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
421 let mut error_ptr: *mut c_char = std::ptr::null_mut();
422
423 let mut trigger_word_ptrs: Vec<*const c_char> =
424 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
425
426 let status = unsafe {
427 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
428 model.vocab_ptr(),
429 grammar_str.as_ptr(),
430 grammar_root.as_ptr(),
431 trigger_word_ptrs.as_mut_ptr(),
432 trigger_word_ptrs.len(),
433 trigger_tokens.as_ptr().cast(),
434 trigger_tokens.len(),
435 &raw mut sampler,
436 &raw mut error_ptr,
437 )
438 };
439
440 match status {
441 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_OK => {
442 Ok(Self { sampler })
443 }
444 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_RETURNED_NULL => {
445 Err(GrammarError::LazyGrammarMalformed)
446 }
447 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_ERROR_STRING_ALLOCATION_FAILED => {
448 Err(GrammarError::NotEnoughMemory)
449 }
450 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_THREW_CXX_EXCEPTION => {
451 let message = unsafe { read_and_free_cpp_error(error_ptr) };
452 Err(GrammarError::Reported { message })
453 }
454 other => unreachable!(
455 "llama_rs_sampler_init_grammar_lazy returned unrecognized status {other}"
456 ),
457 }
458 }
459
460 pub fn grammar_lazy_patterns(
469 model: &LlamaModel,
470 grammar_str: &str,
471 grammar_root: &str,
472 trigger_patterns: &[String],
473 trigger_tokens: &[LlamaToken],
474 ) -> Result<Self, GrammarError> {
475 let (grammar_str, grammar_root) =
476 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
477 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
478 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
479 let mut error_ptr: *mut c_char = std::ptr::null_mut();
480
481 let mut trigger_pattern_ptrs: Vec<*const c_char> =
482 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
483
484 let status = unsafe {
485 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
486 model.vocab_ptr(),
487 grammar_str.as_ptr(),
488 grammar_root.as_ptr(),
489 trigger_pattern_ptrs.as_mut_ptr(),
490 trigger_pattern_ptrs.len(),
491 trigger_tokens.as_ptr().cast(),
492 trigger_tokens.len(),
493 &raw mut sampler,
494 &raw mut error_ptr,
495 )
496 };
497
498 match status {
499 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_OK => {
500 Ok(Self { sampler })
501 }
502 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_RETURNED_NULL => {
503 Err(GrammarError::LazyPatternsGrammarMalformed)
504 }
505 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_ERROR_STRING_ALLOCATION_FAILED => {
506 Err(GrammarError::NotEnoughMemory)
507 }
508 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_INVALID_TRIGGER_PATTERN => {
509 let message = unsafe { read_and_free_cpp_error(error_ptr) };
510 Err(GrammarError::InvalidTriggerPattern { message })
511 }
512 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_THREW_CXX_EXCEPTION => {
513 let message = unsafe { read_and_free_cpp_error(error_ptr) };
514 Err(GrammarError::Reported { message })
515 }
516 other => unreachable!(
517 "llama_rs_sampler_init_grammar_lazy_patterns returned unrecognized status {other}"
518 ),
519 }
520 }
521
522 pub fn llguidance(
531 model: &LlamaModel,
532 grammar_kind: &str,
533 grammar_data: &str,
534 ) -> Result<Self, GrammarError> {
535 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
536 }
537
538 fn sanitize_grammar_strings(
539 grammar_str: &str,
540 grammar_root: &str,
541 ) -> Result<(CString, CString), GrammarError> {
542 if !grammar_str.contains(grammar_root) {
543 return Err(GrammarError::RootNotFound);
544 }
545
546 let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
547 let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
548
549 Ok((grammar, root))
550 }
551
552 fn sanitize_trigger_words(
553 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
554 ) -> Result<Vec<CString>, GrammarError> {
555 trigger_words
556 .into_iter()
557 .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
558 .collect()
559 }
560
561 fn sanitize_trigger_patterns(
562 trigger_patterns: &[String],
563 ) -> Result<Vec<CString>, GrammarError> {
564 trigger_patterns
565 .iter()
566 .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
567 .collect()
568 }
569
570 pub fn dry(
577 model: &LlamaModel,
578 multiplier: f32,
579 base: f32,
580 allowed_length: i32,
581 penalty_last_n: i32,
582 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
583 ) -> Result<Self, GrammarError> {
584 let seq_breakers: Vec<CString> = seq_breakers
585 .into_iter()
586 .map(|seq_breaker| CString::new(seq_breaker.as_ref()))
587 .collect::<Result<Vec<_>, _>>()?;
588 let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers
589 .iter()
590 .map(|seq_breaker| seq_breaker.as_ptr())
591 .collect();
592
593 let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
594 GrammarError::IntegerOverflow(format!(
595 "n_ctx_train does not fit into u32: {convert_error}"
596 ))
597 })?;
598 let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
599 let sampler = unsafe {
600 llama_cpp_bindings_sys::llama_sampler_init_dry(
601 model.vocab_ptr(),
602 n_ctx_train,
603 multiplier,
604 base,
605 allowed_length,
606 penalty_last_n,
607 seq_breaker_pointers.as_mut_ptr(),
608 seq_breaker_pointers.len(),
609 )
610 };
611
612 Ok(Self { sampler })
613 }
614
615 #[must_use]
623 pub fn penalties(
624 penalty_last_n: i32,
625 penalty_repeat: f32,
626 penalty_freq: f32,
627 penalty_present: f32,
628 ) -> Self {
629 let sampler = unsafe {
630 llama_cpp_bindings_sys::llama_sampler_init_penalties(
631 penalty_last_n,
632 penalty_repeat,
633 penalty_freq,
634 penalty_present,
635 )
636 };
637 Self { sampler }
638 }
639
640 #[must_use]
656 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
657 let sampler = unsafe {
658 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
659 };
660 Self { sampler }
661 }
662
663 #[must_use]
674 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
675 let sampler =
676 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
677 Self { sampler }
678 }
679
680 #[must_use]
682 pub fn dist(seed: u32) -> Self {
683 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
684 Self { sampler }
685 }
686
687 #[must_use]
709 pub fn greedy() -> Self {
710 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
711 Self { sampler }
712 }
713
714 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
737 let bias_count = checked_usize_as_i32_sampling(biases.len())?;
738 let data = biases
739 .as_ptr()
740 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
741
742 let sampler = unsafe {
743 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
744 };
745
746 Ok(Self { sampler })
747 }
748}
749
750impl Drop for LlamaSampler {
751 fn drop(&mut self) {
752 unsafe {
753 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
754 }
755 }
756}
757
758#[cfg(test)]
759mod tests {
760 use super::LlamaSampler;
761 use crate::GrammarError;
762
763 #[test]
764 fn sanitize_grammar_strings_valid() {
765 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
766
767 assert!(result.is_ok());
768 }
769
770 #[test]
771 fn sanitize_grammar_strings_root_not_found() {
772 let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
773
774 assert!(matches!(result.err(), Some(GrammarError::RootNotFound)));
775 }
776
777 #[test]
778 fn sanitize_grammar_strings_null_byte_in_grammar() {
779 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
780
781 assert!(matches!(
782 result.err(),
783 Some(GrammarError::GrammarNullBytes(_))
784 ));
785 }
786
787 #[test]
788 fn sanitize_grammar_strings_null_byte_in_root() {
789 let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
790
791 assert!(matches!(
792 result.err(),
793 Some(GrammarError::GrammarNullBytes(_))
794 ));
795 }
796
797 #[test]
798 fn sanitize_trigger_words_valid() {
799 let words: Vec<&[u8]> = vec![b"hello", b"world"];
800 let result = LlamaSampler::sanitize_trigger_words(words);
801
802 assert!(result.is_ok());
803 assert_eq!(result.expect("valid trigger words").len(), 2);
804 }
805
806 #[test]
807 fn sanitize_trigger_words_empty_list() {
808 let words: Vec<&[u8]> = vec![];
809 let result = LlamaSampler::sanitize_trigger_words(words);
810
811 assert!(result.is_ok());
812 assert!(result.expect("valid trigger words").is_empty());
813 }
814
815 #[test]
816 fn sanitize_trigger_words_null_byte() {
817 let words: Vec<&[u8]> = vec![b"hel\0lo"];
818 let result = LlamaSampler::sanitize_trigger_words(words);
819
820 assert!(matches!(
821 result.err(),
822 Some(GrammarError::TriggerWordNullBytes(_))
823 ));
824 }
825
826 #[test]
827 fn sanitize_trigger_patterns_valid() {
828 let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
829 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
830
831 assert!(result.is_ok());
832 assert_eq!(result.expect("valid trigger patterns").len(), 2);
833 }
834
835 #[test]
836 fn sanitize_trigger_patterns_empty_list() {
837 let patterns: Vec<String> = vec![];
838 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
839
840 assert!(result.is_ok());
841 assert!(result.expect("valid trigger patterns").is_empty());
842 }
843
844 #[test]
845 fn sanitize_trigger_patterns_null_byte() {
846 let patterns = vec!["hel\0lo".to_string()];
847 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
848
849 assert!(matches!(
850 result.err(),
851 Some(GrammarError::GrammarNullBytes(_))
852 ));
853 }
854
855 #[test]
856 fn apply_modifies_data_array() {
857 use crate::token::LlamaToken;
858 use crate::token::data::LlamaTokenData;
859 use crate::token::data_array::LlamaTokenDataArray;
860
861 let sampler = LlamaSampler::greedy();
862 let mut data_array = LlamaTokenDataArray::new(
863 vec![
864 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
865 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
866 ],
867 false,
868 );
869
870 sampler.apply(&mut data_array);
871
872 assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
873 }
874
875 #[test]
876 fn accept_succeeds() {
877 let mut sampler = LlamaSampler::chain_simple([
878 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
879 LlamaSampler::greedy(),
880 ]);
881
882 sampler
883 .accept(crate::token::LlamaToken::new(1))
884 .expect("test: accept should succeed");
885 }
886
887 #[test]
888 fn try_accept_succeeds_on_penalties_sampler() {
889 let mut sampler = LlamaSampler::chain_simple([
890 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
891 LlamaSampler::greedy(),
892 ]);
893
894 let result = sampler.try_accept(crate::token::LlamaToken::new(42));
895
896 assert!(result.is_ok());
897 }
898
899 #[test]
900 fn accept_many_multiple_tokens() {
901 use crate::token::LlamaToken;
902
903 let mut sampler = LlamaSampler::chain_simple([
904 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
905 LlamaSampler::greedy(),
906 ]);
907
908 sampler
909 .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
910 .expect("test: accept_many should succeed");
911 }
912
913 #[test]
914 fn with_tokens_builder_pattern() {
915 use crate::token::LlamaToken;
916
917 let _sampler = LlamaSampler::chain_simple([
918 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
919 LlamaSampler::greedy(),
920 ])
921 .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
922 .expect("test: with_tokens should succeed");
923 }
924
925 #[test]
926 fn all_sampler_constructors() {
927 use crate::token::LlamaToken;
928 use crate::token::logit_bias::LlamaLogitBias;
929
930 let _temp = LlamaSampler::temp(0.8);
931 let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
932 let _top_k = LlamaSampler::top_k(40);
933 let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
934 let _top_p = LlamaSampler::top_p(0.9, 1);
935 let _min_p = LlamaSampler::min_p(0.05, 1);
936 let _typical = LlamaSampler::typical(0.9, 1);
937 let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
938 let _dist = LlamaSampler::dist(42);
939 let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
940 let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
941 let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
942 let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
943 let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
944 }
945
946 #[test]
947 fn reset_and_get_seed() {
948 let mut sampler = LlamaSampler::dist(42);
949 sampler.reset();
950 let _seed = sampler.get_seed();
951 }
952
953 #[test]
954 fn debug_formatting() {
955 let sampler = LlamaSampler::greedy();
956 let debug_output = format!("{sampler:?}");
957 assert!(debug_output.contains("LlamaSampler"));
958 }
959
960 #[test]
961 fn checked_u32_as_i32_overflow() {
962 let result = super::checked_u32_as_i32(u32::MAX);
963 assert!(result.is_err());
964 }
965
966 #[test]
967 fn checked_usize_as_i32_sampling_overflow() {
968 let result = super::checked_usize_as_i32_sampling(usize::MAX);
969 assert!(result.is_err());
970 }
971
972 #[test]
973 fn check_sampler_accept_status_ok() {
974 let result = super::check_sampler_accept_status(
975 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK,
976 std::ptr::null_mut(),
977 );
978
979 assert!(result.is_ok());
980 }
981
982 #[test]
983 fn check_sampler_accept_status_exception_maps_to_typed_variant() {
984 let result = super::check_sampler_accept_status(
985 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION,
986 std::ptr::null_mut(),
987 );
988
989 assert!(matches!(
990 result,
991 Err(crate::SamplerAcceptError::GrammarStateCorrupted { .. })
992 ));
993 }
994}