Skip to main content

llama_cpp_bindings/
sampling.rs

1use std::borrow::Borrow;
2use std::ffi::{CString, c_char};
3use std::fmt::{Debug, Formatter};
4
5use crate::context::LlamaContext;
6use crate::ffi_error_reader::read_and_free_cpp_error;
7use crate::model::LlamaModel;
8use crate::token::LlamaToken;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
12
13fn check_sampler_accept_status(
14    status: llama_cpp_bindings_sys::llama_rs_sampler_accept_status,
15    error_ptr: *mut c_char,
16) -> Result<(), SamplerAcceptError> {
17    match status {
18        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK => Ok(()),
19        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_ERROR_STRING_ALLOCATION_FAILED => {
20            Err(SamplerAcceptError::NotEnoughMemory)
21        }
22        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION => {
23            let message = unsafe { read_and_free_cpp_error(error_ptr) };
24            Err(SamplerAcceptError::GrammarStateCorrupted { message })
25        }
26        other => unreachable!("llama_rs_sampler_accept returned unrecognized status {other}"),
27    }
28}
29
30fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
31    i32::try_from(value).map_err(|convert_error| {
32        GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
33    })
34}
35
36fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
37    i32::try_from(value).map_err(|convert_error| {
38        SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
39    })
40}
41
42pub struct LlamaSampler {
43    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
44}
45
46impl Debug for LlamaSampler {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("LlamaSamplerChain").finish()
49    }
50}
51
52impl LlamaSampler {
53    /// # Errors
54    ///
55    /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid.
56    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
57        let mut token: i32 = -1;
58        let mut error_ptr: *mut c_char = std::ptr::null_mut();
59
60        let status = unsafe {
61            llama_cpp_bindings_sys::llama_rs_sampler_sample(
62                self.sampler,
63                ctx.context.as_ptr(),
64                idx,
65                &raw mut token,
66                &raw mut error_ptr,
67            )
68        };
69
70        match status {
71            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_OK => Ok(LlamaToken(token)),
72            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_ERROR_STRING_ALLOCATION_FAILED => {
73                Err(SampleError::NotEnoughMemory)
74            }
75            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_VENDORED_THREW_CXX_EXCEPTION => {
76                let message = unsafe { read_and_free_cpp_error(error_ptr) };
77                Err(SampleError::Reported { message })
78            }
79            other => unreachable!("llama_rs_sampler_sample returned unrecognized status {other}"),
80        }
81    }
82
83    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
84        data_array.apply_sampler(self);
85    }
86
87    /// # Errors
88    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
89    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
90        self.try_accept(token)
91    }
92
93    /// # Errors
94    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
95    pub fn accept_many(
96        &mut self,
97        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
98    ) -> Result<(), SamplerAcceptError> {
99        for token in tokens {
100            self.try_accept(*token.borrow())?;
101        }
102
103        Ok(())
104    }
105
106    /// # Errors
107    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
108    pub fn with_tokens(
109        mut self,
110        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
111    ) -> Result<Self, SamplerAcceptError> {
112        self.accept_many(tokens)?;
113
114        Ok(self)
115    }
116
117    /// # Errors
118    /// Returns an error if the underlying sampler rejects the token.
119    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
120        let mut error_ptr: *mut c_char = std::ptr::null_mut();
121
122        let status = unsafe {
123            llama_cpp_bindings_sys::llama_rs_sampler_accept(
124                self.sampler,
125                token.0,
126                &raw mut error_ptr,
127            )
128        };
129
130        check_sampler_accept_status(status, error_ptr)
131    }
132
133    pub fn reset(&mut self) {
134        unsafe {
135            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
136        }
137    }
138
139    #[must_use]
140    pub fn get_seed(&self) -> u32 {
141        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
142    }
143
144    #[must_use]
145    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
146        unsafe {
147            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
148                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
149            );
150
151            for sampler in samplers {
152                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
153                std::mem::forget(sampler);
154            }
155
156            Self { sampler: chain }
157        }
158    }
159
160    #[must_use]
161    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
162        Self::chain(samplers, false)
163    }
164
165    #[must_use]
166    pub fn temp(t: f32) -> Self {
167        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
168        Self { sampler }
169    }
170
171    #[must_use]
172    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
173        let sampler =
174            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
175        Self { sampler }
176    }
177
178    #[must_use]
179    pub fn top_k(k: i32) -> Self {
180        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
181        Self { sampler }
182    }
183
184    #[must_use]
185    pub fn top_n_sigma(n: f32) -> Self {
186        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
187        Self { sampler }
188    }
189
190    #[must_use]
191    pub fn typical(p: f32, min_keep: usize) -> Self {
192        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
193        Self { sampler }
194    }
195
196    #[must_use]
197    pub fn top_p(p: f32, min_keep: usize) -> Self {
198        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
199        Self { sampler }
200    }
201
202    #[must_use]
203    pub fn min_p(p: f32, min_keep: usize) -> Self {
204        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
205        Self { sampler }
206    }
207
208    #[must_use]
209    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
210        let sampler =
211            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
212        Self { sampler }
213    }
214
215    /// # Errors
216    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
217    pub fn grammar(
218        model: &LlamaModel,
219        grammar_str: &str,
220        grammar_root: &str,
221    ) -> Result<Self, GrammarError> {
222        let (grammar_str, grammar_root) =
223            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
224        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
225        let mut error_ptr: *mut c_char = std::ptr::null_mut();
226
227        let status = unsafe {
228            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
229                model.vocab_ptr(),
230                grammar_str.as_ptr(),
231                grammar_root.as_ptr(),
232                &raw mut sampler,
233                &raw mut error_ptr,
234            )
235        };
236
237        match status {
238            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_OK => {
239                Ok(Self { sampler })
240            }
241            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_RETURNED_NULL => {
242                Err(GrammarError::GrammarMalformed)
243            }
244            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_ERROR_STRING_ALLOCATION_FAILED => {
245                Err(GrammarError::NotEnoughMemory)
246            }
247            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_THREW_CXX_EXCEPTION => {
248                let message = unsafe { read_and_free_cpp_error(error_ptr) };
249                Err(GrammarError::Reported { message })
250            }
251            other => unreachable!(
252                "llama_rs_sampler_init_grammar returned unrecognized status {other}"
253            ),
254        }
255    }
256
257    /// # Errors
258    /// Returns an error if the grammar or trigger words are invalid.
259    pub fn grammar_lazy(
260        model: &LlamaModel,
261        grammar_str: &str,
262        grammar_root: &str,
263        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
264        trigger_tokens: &[LlamaToken],
265    ) -> Result<Self, GrammarError> {
266        let (grammar_str, grammar_root) =
267            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
268        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
269        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
270        let mut error_ptr: *mut c_char = std::ptr::null_mut();
271
272        let mut trigger_word_ptrs: Vec<*const c_char> =
273            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
274
275        let status = unsafe {
276            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
277                model.vocab_ptr(),
278                grammar_str.as_ptr(),
279                grammar_root.as_ptr(),
280                trigger_word_ptrs.as_mut_ptr(),
281                trigger_word_ptrs.len(),
282                trigger_tokens.as_ptr().cast(),
283                trigger_tokens.len(),
284                &raw mut sampler,
285                &raw mut error_ptr,
286            )
287        };
288
289        match status {
290            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_OK => {
291                Ok(Self { sampler })
292            }
293            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_RETURNED_NULL => {
294                Err(GrammarError::LazyGrammarMalformed)
295            }
296            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_ERROR_STRING_ALLOCATION_FAILED => {
297                Err(GrammarError::NotEnoughMemory)
298            }
299            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_THREW_CXX_EXCEPTION => {
300                let message = unsafe { read_and_free_cpp_error(error_ptr) };
301                Err(GrammarError::Reported { message })
302            }
303            other => unreachable!(
304                "llama_rs_sampler_init_grammar_lazy returned unrecognized status {other}"
305            ),
306        }
307    }
308
309    /// # Errors
310    /// Returns an error if the grammar or trigger patterns are invalid.
311    pub fn grammar_lazy_patterns(
312        model: &LlamaModel,
313        grammar_str: &str,
314        grammar_root: &str,
315        trigger_patterns: &[String],
316        trigger_tokens: &[LlamaToken],
317    ) -> Result<Self, GrammarError> {
318        let (grammar_str, grammar_root) =
319            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
320        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
321        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
322        let mut error_ptr: *mut c_char = std::ptr::null_mut();
323
324        let mut trigger_pattern_ptrs: Vec<*const c_char> =
325            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
326
327        let status = unsafe {
328            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
329                model.vocab_ptr(),
330                grammar_str.as_ptr(),
331                grammar_root.as_ptr(),
332                trigger_pattern_ptrs.as_mut_ptr(),
333                trigger_pattern_ptrs.len(),
334                trigger_tokens.as_ptr().cast(),
335                trigger_tokens.len(),
336                &raw mut sampler,
337                &raw mut error_ptr,
338            )
339        };
340
341        match status {
342            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_OK => {
343                Ok(Self { sampler })
344            }
345            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_RETURNED_NULL => {
346                Err(GrammarError::LazyPatternsGrammarMalformed)
347            }
348            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_ERROR_STRING_ALLOCATION_FAILED => {
349                Err(GrammarError::NotEnoughMemory)
350            }
351            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_INVALID_TRIGGER_PATTERN => {
352                let message = unsafe { read_and_free_cpp_error(error_ptr) };
353                Err(GrammarError::InvalidTriggerPattern { message })
354            }
355            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_THREW_CXX_EXCEPTION => {
356                let message = unsafe { read_and_free_cpp_error(error_ptr) };
357                Err(GrammarError::Reported { message })
358            }
359            other => unreachable!(
360                "llama_rs_sampler_init_grammar_lazy_patterns returned unrecognized status {other}"
361            ),
362        }
363    }
364
365    /// # Errors
366    ///
367    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
368    pub fn llguidance(
369        model: &LlamaModel,
370        grammar_kind: &str,
371        grammar_data: &str,
372    ) -> Result<Self, GrammarError> {
373        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
374    }
375
376    fn sanitize_grammar_strings(
377        grammar_str: &str,
378        grammar_root: &str,
379    ) -> Result<(CString, CString), GrammarError> {
380        if !grammar_str.contains(grammar_root) {
381            return Err(GrammarError::RootNotFound);
382        }
383
384        let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
385        let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
386
387        Ok((grammar, root))
388    }
389
390    fn sanitize_trigger_words(
391        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
392    ) -> Result<Vec<CString>, GrammarError> {
393        trigger_words
394            .into_iter()
395            .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
396            .collect()
397    }
398
399    fn sanitize_trigger_patterns(
400        trigger_patterns: &[String],
401    ) -> Result<Vec<CString>, GrammarError> {
402        trigger_patterns
403            .iter()
404            .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
405            .collect()
406    }
407
408    /// # Errors
409    /// Returns an error if any string in `seq_breakers` contains null bytes.
410    pub fn dry(
411        model: &LlamaModel,
412        multiplier: f32,
413        base: f32,
414        allowed_length: i32,
415        penalty_last_n: i32,
416        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
417    ) -> Result<Self, GrammarError> {
418        let seq_breakers: Vec<CString> = seq_breakers
419            .into_iter()
420            .map(|seq_breaker| CString::new(seq_breaker.as_ref()))
421            .collect::<Result<Vec<_>, _>>()?;
422        let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers
423            .iter()
424            .map(|seq_breaker| seq_breaker.as_ptr())
425            .collect();
426
427        let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
428            GrammarError::IntegerOverflow(format!(
429                "n_ctx_train does not fit into u32: {convert_error}"
430            ))
431        })?;
432        let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
433        let sampler = unsafe {
434            llama_cpp_bindings_sys::llama_sampler_init_dry(
435                model.vocab_ptr(),
436                n_ctx_train,
437                multiplier,
438                base,
439                allowed_length,
440                penalty_last_n,
441                seq_breaker_pointers.as_mut_ptr(),
442                seq_breaker_pointers.len(),
443            )
444        };
445
446        Ok(Self { sampler })
447    }
448
449    #[must_use]
450    pub fn penalties(
451        penalty_last_n: i32,
452        penalty_repeat: f32,
453        penalty_freq: f32,
454        penalty_present: f32,
455    ) -> Self {
456        let sampler = unsafe {
457            llama_cpp_bindings_sys::llama_sampler_init_penalties(
458                penalty_last_n,
459                penalty_repeat,
460                penalty_freq,
461                penalty_present,
462            )
463        };
464        Self { sampler }
465    }
466
467    #[must_use]
468    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
469        let sampler = unsafe {
470            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
471        };
472        Self { sampler }
473    }
474
475    #[must_use]
476    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
477        let sampler =
478            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
479        Self { sampler }
480    }
481
482    #[must_use]
483    pub fn dist(seed: u32) -> Self {
484        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
485        Self { sampler }
486    }
487
488    #[must_use]
489    pub fn greedy() -> Self {
490        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
491        Self { sampler }
492    }
493
494    /// # Errors
495    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
496    ///
497    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
498        let bias_count = checked_usize_as_i32_sampling(biases.len())?;
499        let data = biases
500            .as_ptr()
501            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
502
503        let sampler = unsafe {
504            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
505        };
506
507        Ok(Self { sampler })
508    }
509}
510
511impl Drop for LlamaSampler {
512    fn drop(&mut self) {
513        unsafe {
514            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
515        }
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use std::ffi::CString;
522    use std::mem::Discriminant;
523
524    use super::LlamaSampler;
525    use crate::GrammarError;
526
527    fn nul_error() -> std::ffi::NulError {
528        CString::new(b"a\0b".to_vec()).unwrap_err()
529    }
530
531    fn root_not_found_disc() -> Discriminant<GrammarError> {
532        std::mem::discriminant(&GrammarError::RootNotFound)
533    }
534
535    fn grammar_null_bytes_disc() -> Discriminant<GrammarError> {
536        std::mem::discriminant(&GrammarError::GrammarNullBytes(nul_error()))
537    }
538
539    fn trigger_word_null_bytes_disc() -> Discriminant<GrammarError> {
540        std::mem::discriminant(&GrammarError::TriggerWordNullBytes(nul_error()))
541    }
542
543    #[test]
544    fn sanitize_grammar_strings_valid() {
545        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
546
547        assert!(result.is_ok());
548    }
549
550    #[test]
551    fn sanitize_grammar_strings_root_not_found() {
552        let err = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root").unwrap_err();
553
554        assert_eq!(std::mem::discriminant(&err), root_not_found_disc());
555    }
556
557    #[test]
558    fn sanitize_grammar_strings_null_byte_in_grammar() {
559        let err = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root").unwrap_err();
560
561        assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
562    }
563
564    #[test]
565    fn sanitize_grammar_strings_null_byte_in_root() {
566        let err =
567            LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot").unwrap_err();
568
569        assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
570    }
571
572    #[test]
573    fn sanitize_trigger_words_valid() {
574        let words: Vec<&[u8]> = vec![b"hello", b"world"];
575        let result = LlamaSampler::sanitize_trigger_words(words);
576
577        assert!(result.is_ok());
578        assert_eq!(result.expect("valid trigger words").len(), 2);
579    }
580
581    #[test]
582    fn sanitize_trigger_words_empty_list() {
583        let words: Vec<&[u8]> = vec![];
584        let result = LlamaSampler::sanitize_trigger_words(words);
585
586        assert!(result.is_ok());
587        assert!(result.expect("valid trigger words").is_empty());
588    }
589
590    #[test]
591    fn sanitize_trigger_words_null_byte() {
592        let words: Vec<&[u8]> = vec![b"hel\0lo"];
593        let err = LlamaSampler::sanitize_trigger_words(words).unwrap_err();
594
595        assert_eq!(std::mem::discriminant(&err), trigger_word_null_bytes_disc());
596    }
597
598    #[test]
599    fn sanitize_trigger_patterns_valid() {
600        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
601        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
602
603        assert!(result.is_ok());
604        assert_eq!(result.expect("valid trigger patterns").len(), 2);
605    }
606
607    #[test]
608    fn sanitize_trigger_patterns_empty_list() {
609        let patterns: Vec<String> = vec![];
610        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
611
612        assert!(result.is_ok());
613        assert!(result.expect("valid trigger patterns").is_empty());
614    }
615
616    #[test]
617    fn sanitize_trigger_patterns_null_byte() {
618        let patterns = vec!["hel\0lo".to_string()];
619        let err = LlamaSampler::sanitize_trigger_patterns(&patterns).unwrap_err();
620
621        assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
622    }
623
624    #[test]
625    fn apply_modifies_data_array() {
626        use crate::token::LlamaToken;
627        use crate::token::data::LlamaTokenData;
628        use crate::token::data_array::LlamaTokenDataArray;
629
630        let sampler = LlamaSampler::greedy();
631        let mut data_array = LlamaTokenDataArray::new(
632            vec![
633                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
634                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
635            ],
636            false,
637        );
638
639        sampler.apply(&mut data_array);
640
641        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
642    }
643
644    #[test]
645    fn accept_succeeds() {
646        let mut sampler = LlamaSampler::chain_simple([
647            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
648            LlamaSampler::greedy(),
649        ]);
650
651        sampler
652            .accept(crate::token::LlamaToken::new(1))
653            .expect("test: accept should succeed");
654    }
655
656    #[test]
657    fn try_accept_succeeds_on_penalties_sampler() {
658        let mut sampler = LlamaSampler::chain_simple([
659            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
660            LlamaSampler::greedy(),
661        ]);
662
663        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
664
665        assert!(result.is_ok());
666    }
667
668    #[test]
669    fn accept_many_multiple_tokens() {
670        use crate::token::LlamaToken;
671
672        let mut sampler = LlamaSampler::chain_simple([
673            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
674            LlamaSampler::greedy(),
675        ]);
676
677        sampler
678            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
679            .expect("test: accept_many should succeed");
680    }
681
682    #[test]
683    fn with_tokens_builder_pattern() {
684        use crate::token::LlamaToken;
685
686        let _sampler = LlamaSampler::chain_simple([
687            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
688            LlamaSampler::greedy(),
689        ])
690        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
691        .expect("test: with_tokens should succeed");
692    }
693
694    #[test]
695    fn all_sampler_constructors() {
696        use crate::token::LlamaToken;
697        use crate::token::logit_bias::LlamaLogitBias;
698
699        let _temp = LlamaSampler::temp(0.8);
700        let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
701        let _top_k = LlamaSampler::top_k(40);
702        let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
703        let _top_p = LlamaSampler::top_p(0.9, 1);
704        let _min_p = LlamaSampler::min_p(0.05, 1);
705        let _typical = LlamaSampler::typical(0.9, 1);
706        let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
707        let _dist = LlamaSampler::dist(42);
708        let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
709        let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
710        let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
711        let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
712        let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
713    }
714
715    #[test]
716    fn reset_and_get_seed() {
717        let mut sampler = LlamaSampler::dist(42);
718        sampler.reset();
719        let _seed = sampler.get_seed();
720    }
721
722    #[test]
723    fn debug_formatting() {
724        let sampler = LlamaSampler::greedy();
725        let debug_output = format!("{sampler:?}");
726        assert!(debug_output.contains("LlamaSampler"));
727    }
728
729    #[test]
730    fn checked_u32_as_i32_overflow() {
731        let result = super::checked_u32_as_i32(u32::MAX);
732        assert!(result.is_err());
733    }
734
735    #[test]
736    fn checked_usize_as_i32_sampling_overflow() {
737        let result = super::checked_usize_as_i32_sampling(usize::MAX);
738        assert!(result.is_err());
739    }
740
741    #[test]
742    fn check_sampler_accept_status_ok() {
743        let result = super::check_sampler_accept_status(
744            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK,
745            std::ptr::null_mut(),
746        );
747
748        assert!(result.is_ok());
749    }
750
751    #[test]
752    fn check_sampler_accept_status_exception_maps_to_typed_variant() {
753        let err = super::check_sampler_accept_status(
754            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION,
755            std::ptr::null_mut(),
756        )
757        .unwrap_err();
758        let grammar_state_corrupted_disc =
759            std::mem::discriminant(&crate::SamplerAcceptError::GrammarStateCorrupted {
760                message: String::new(),
761            });
762
763        assert_eq!(std::mem::discriminant(&err), grammar_state_corrupted_disc);
764    }
765}