Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::ffi_error_reader::read_and_free_cpp_error;
9use crate::model::LlamaModel;
10use crate::token::LlamaToken;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
14
15fn check_sampler_accept_status(
16    status: llama_cpp_bindings_sys::llama_rs_status,
17    error_ptr: *mut c_char,
18) -> Result<(), SamplerAcceptError> {
19    match status {
20        llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(()),
21        llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
22            Err(SamplerAcceptError::InvalidArgument)
23        }
24        _ => Err(SamplerAcceptError::CppException(unsafe {
25            read_and_free_cpp_error(error_ptr)
26        })),
27    }
28}
29
30fn check_sampler_not_null(
31    sampler: *mut llama_cpp_bindings_sys::llama_sampler,
32    error_ptr: *mut c_char,
33) -> Result<LlamaSampler, GrammarError> {
34    if sampler.is_null() {
35        Err(GrammarError::NullGrammar(unsafe {
36            read_and_free_cpp_error(error_ptr)
37        }))
38    } else {
39        Ok(LlamaSampler { sampler })
40    }
41}
42
43fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
44    i32::try_from(value).map_err(|convert_error| {
45        GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
46    })
47}
48
49fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
50    i32::try_from(value).map_err(|convert_error| {
51        SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
52    })
53}
54
55/// A safe wrapper around `llama_sampler`.
56pub struct LlamaSampler {
57    /// Raw pointer to the underlying `llama_sampler`.
58    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
59}
60
61impl Debug for LlamaSampler {
62    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("LlamaSamplerChain").finish()
64    }
65}
66
67impl LlamaSampler {
68    /// Sample and accept a token from the idx-th output of the last evaluation.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid.
73    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
74        let mut token: i32 = -1;
75        let mut error_ptr: *mut c_char = std::ptr::null_mut();
76
77        let status = unsafe {
78            llama_cpp_bindings_sys::llama_rs_sampler_sample(
79                self.sampler,
80                ctx.context.as_ptr(),
81                idx,
82                &raw mut token,
83                &raw mut error_ptr,
84            )
85        };
86
87        match status {
88            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(LlamaToken(token)),
89            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
90                Err(SampleError::InvalidArgument)
91            }
92            _ => Err(SampleError::CppException(unsafe {
93                read_and_free_cpp_error(error_ptr)
94            })),
95        }
96    }
97
98    /// Applies this sampler to a [`LlamaTokenDataArray`].
99    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
100        data_array.apply_sampler(self);
101    }
102
103    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
104    /// (e.g. grammar, repetition, etc.)
105    ///
106    /// # Errors
107    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
108    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
109        self.try_accept(token)
110    }
111
112    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
113    /// certain samplers (e.g. grammar, repetition, etc.)
114    ///
115    /// # Errors
116    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
117    pub fn accept_many(
118        &mut self,
119        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
120    ) -> Result<(), SamplerAcceptError> {
121        for token in tokens {
122            self.try_accept(*token.borrow())?;
123        }
124
125        Ok(())
126    }
127
128    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
129    /// certain samplers (e.g. grammar, repetition, etc.)
130    ///
131    /// # Errors
132    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
133    pub fn with_tokens(
134        mut self,
135        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
136    ) -> Result<Self, SamplerAcceptError> {
137        self.accept_many(tokens)?;
138
139        Ok(self)
140    }
141
142    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
143    ///
144    /// # Errors
145    /// Returns an error if the underlying sampler rejects the token.
146    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
147        let mut error_ptr: *mut c_char = std::ptr::null_mut();
148
149        let status = unsafe {
150            llama_cpp_bindings_sys::llama_rs_sampler_accept(
151                self.sampler,
152                token.0,
153                &raw mut error_ptr,
154            )
155        };
156
157        check_sampler_accept_status(status, error_ptr)
158    }
159
160    /// Resets the internal state of the sampler.
161    ///
162    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
163    pub fn reset(&mut self) {
164        unsafe {
165            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
166        }
167    }
168
169    /// Gets the random seed used by this sampler.
170    ///
171    /// Returns:
172    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
173    /// - For sampler chains: returns the first non-default seed found in reverse order
174    /// - For all other samplers: returns 0xFFFFFFFF
175    #[must_use]
176    pub fn get_seed(&self) -> u32 {
177        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
178    }
179
180    /// Combines a list of samplers into a single sampler that applies each component sampler one
181    /// after another.
182    ///
183    /// If you are using a chain to select a token, the chain should always end with one of
184    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
185    /// [`LlamaSampler::mirostat_v2`].
186    #[must_use]
187    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
188        unsafe {
189            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
190                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
191            );
192
193            for sampler in samplers {
194                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
195                std::mem::forget(sampler);
196            }
197
198            Self { sampler: chain }
199        }
200    }
201
202    /// Same as [`Self::chain`] with `no_perf = false`.
203    ///
204    /// # Example
205    /// ```rust
206    /// use llama_cpp_bindings::token::{
207    ///    LlamaToken,
208    ///    data::LlamaTokenData,
209    ///    data_array::LlamaTokenDataArray
210    /// };
211    /// use llama_cpp_bindings::sampling::LlamaSampler;
212    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
213    /// let backend = LlamaBackend::init().unwrap();
214    ///
215    /// let mut data_array = LlamaTokenDataArray::new(vec![
216    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
217    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
218    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
219    /// ], false);
220    ///
221    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
222    ///     LlamaSampler::temp(0.5),
223    ///     LlamaSampler::greedy(),
224    /// ]));
225    ///
226    /// assert_eq!(data_array.data[0].logit(), 0.);
227    /// assert_eq!(data_array.data[1].logit(), 2.);
228    /// assert_eq!(data_array.data[2].logit(), 4.);
229    ///
230    /// assert_eq!(data_array.data.len(), 3);
231    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
232    /// ```
233    #[must_use]
234    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
235        Self::chain(samplers, false)
236    }
237
238    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
239    /// value, the rest are set to -inf
240    ///
241    /// # Example:
242    /// ```rust
243    /// use llama_cpp_bindings::token::{
244    ///    LlamaToken,
245    ///    data::LlamaTokenData,
246    ///    data_array::LlamaTokenDataArray
247    /// };
248    /// use llama_cpp_bindings::sampling::LlamaSampler;
249    ///
250    /// let mut data_array = LlamaTokenDataArray::new(vec![
251    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
252    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
253    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
254    /// ], false);
255    ///
256    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
257    ///
258    /// assert_eq!(data_array.data[0].logit(), 0.);
259    /// assert_eq!(data_array.data[1].logit(), 2.);
260    /// assert_eq!(data_array.data[2].logit(), 4.);
261    /// ```
262    #[must_use]
263    pub fn temp(t: f32) -> Self {
264        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
265        Self { sampler }
266    }
267
268    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
269    /// <https://arxiv.org/abs/2309.02772>.
270    #[must_use]
271    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
272        let sampler =
273            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
274        Self { sampler }
275    }
276
277    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
278    /// <https://arxiv.org/abs/1904.09751>
279    ///
280    /// # Example:
281    /// ```rust
282    /// use llama_cpp_bindings::token::{
283    ///    LlamaToken,
284    ///    data::LlamaTokenData,
285    ///    data_array::LlamaTokenDataArray
286    /// };
287    /// use llama_cpp_bindings::sampling::LlamaSampler;
288    ///
289    /// let mut data_array = LlamaTokenDataArray::new(vec![
290    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
291    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
292    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
293    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
294    /// ], false);
295    ///
296    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
297    ///
298    /// assert_eq!(data_array.data.len(), 2);
299    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
300    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
301    /// ```
302    #[must_use]
303    pub fn top_k(k: i32) -> Self {
304        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
305        Self { sampler }
306    }
307
308    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
309    /// <https://arxiv.org/pdf/2411.07641>
310    ///
311    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
312    ///
313    /// # Parameters
314    /// - `n`: Number of standard deviations from the mean to include in sampling
315    ///
316    /// # Example
317    /// ```rust
318    /// use llama_cpp_bindings::sampling::LlamaSampler;
319    /// use llama_cpp_bindings::token::{
320    ///     LlamaToken,
321    ///     data::LlamaTokenData,
322    ///     data_array::LlamaTokenDataArray
323    /// };
324    ///
325    /// let mut data_array = LlamaTokenDataArray::new(vec![
326    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
327    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
328    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
329    /// ], false);
330    ///
331    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
332    /// ```
333    #[must_use]
334    pub fn top_n_sigma(n: f32) -> Self {
335        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
336        Self { sampler }
337    }
338
339    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
340    #[must_use]
341    pub fn typical(p: f32, min_keep: usize) -> Self {
342        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
343        Self { sampler }
344    }
345
346    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
347    /// <https://arxiv.org/abs/1904.09751>
348    #[must_use]
349    pub fn top_p(p: f32, min_keep: usize) -> Self {
350        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
351        Self { sampler }
352    }
353
354    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
355    #[must_use]
356    pub fn min_p(p: f32, min_keep: usize) -> Self {
357        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
358        Self { sampler }
359    }
360
361    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
362    #[must_use]
363    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
364        let sampler =
365            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
366        Self { sampler }
367    }
368
369    /// Grammar sampler
370    ///
371    /// # Errors
372    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
373    pub fn grammar(
374        model: &LlamaModel,
375        grammar_str: &str,
376        grammar_root: &str,
377    ) -> Result<Self, GrammarError> {
378        let (grammar_str, grammar_root) =
379            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
380        let mut error_ptr: *mut c_char = std::ptr::null_mut();
381
382        let sampler = unsafe {
383            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
384                model.vocab_ptr(),
385                grammar_str.as_ptr(),
386                grammar_root.as_ptr(),
387                &raw mut error_ptr,
388            )
389        };
390
391        check_sampler_not_null(sampler, error_ptr)
392    }
393
394    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
395    ///
396    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
397    ///
398    /// # Errors
399    /// Returns an error if the grammar or trigger words are invalid.
400    pub fn grammar_lazy(
401        model: &LlamaModel,
402        grammar_str: &str,
403        grammar_root: &str,
404        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
405        trigger_tokens: &[LlamaToken],
406    ) -> Result<Self, GrammarError> {
407        let (grammar_str, grammar_root) =
408            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
409        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
410        let mut error_ptr: *mut c_char = std::ptr::null_mut();
411
412        let mut trigger_word_ptrs: Vec<*const c_char> =
413            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
414
415        let sampler = unsafe {
416            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
417                model.vocab_ptr(),
418                grammar_str.as_ptr(),
419                grammar_root.as_ptr(),
420                trigger_word_ptrs.as_mut_ptr(),
421                trigger_word_ptrs.len(),
422                trigger_tokens.as_ptr().cast(),
423                trigger_tokens.len(),
424                &raw mut error_ptr,
425            )
426        };
427
428        check_sampler_not_null(sampler, error_ptr)
429    }
430
431    /// Lazy grammar sampler using regex trigger patterns.
432    ///
433    /// Trigger patterns are regular expressions matched from the start of the
434    /// generation output. The grammar sampler will be fed content starting from
435    /// the first match group.
436    ///
437    /// # Errors
438    /// Returns an error if the grammar or trigger patterns are invalid.
439    pub fn grammar_lazy_patterns(
440        model: &LlamaModel,
441        grammar_str: &str,
442        grammar_root: &str,
443        trigger_patterns: &[String],
444        trigger_tokens: &[LlamaToken],
445    ) -> Result<Self, GrammarError> {
446        let (grammar_str, grammar_root) =
447            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
448        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
449        let mut error_ptr: *mut c_char = std::ptr::null_mut();
450
451        let mut trigger_pattern_ptrs: Vec<*const c_char> =
452            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
453
454        let sampler = unsafe {
455            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
456                model.vocab_ptr(),
457                grammar_str.as_ptr(),
458                grammar_root.as_ptr(),
459                trigger_pattern_ptrs.as_mut_ptr(),
460                trigger_pattern_ptrs.len(),
461                trigger_tokens.as_ptr().cast(),
462                trigger_tokens.len(),
463                &raw mut error_ptr,
464            )
465        };
466
467        check_sampler_not_null(sampler, error_ptr)
468    }
469
470    /// `LLGuidance` sampler for constrained decoding.
471    ///
472    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
473    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
474    ///
475    /// # Errors
476    ///
477    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
478    pub fn llguidance(
479        model: &LlamaModel,
480        grammar_kind: &str,
481        grammar_data: &str,
482    ) -> Result<Self, GrammarError> {
483        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
484    }
485
486    fn sanitize_grammar_strings(
487        grammar_str: &str,
488        grammar_root: &str,
489    ) -> Result<(CString, CString), GrammarError> {
490        if !grammar_str.contains(grammar_root) {
491            return Err(GrammarError::RootNotFound);
492        }
493
494        let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
495        let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
496
497        Ok((grammar, root))
498    }
499
500    fn sanitize_trigger_words(
501        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
502    ) -> Result<Vec<CString>, GrammarError> {
503        trigger_words
504            .into_iter()
505            .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
506            .collect()
507    }
508
509    fn sanitize_trigger_patterns(
510        trigger_patterns: &[String],
511    ) -> Result<Vec<CString>, GrammarError> {
512        trigger_patterns
513            .iter()
514            .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
515            .collect()
516    }
517
518    /// DRY sampler, designed by p-e-w, as described in:
519    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
520    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
521    ///
522    /// # Errors
523    /// Returns an error if any string in `seq_breakers` contains null bytes.
524    pub fn dry(
525        model: &LlamaModel,
526        multiplier: f32,
527        base: f32,
528        allowed_length: i32,
529        penalty_last_n: i32,
530        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
531    ) -> Result<Self, GrammarError> {
532        let seq_breakers: Vec<CString> = seq_breakers
533            .into_iter()
534            .map(|seq_breaker| CString::new(seq_breaker.as_ref()))
535            .collect::<Result<Vec<_>, _>>()?;
536        let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers
537            .iter()
538            .map(|seq_breaker| seq_breaker.as_ptr())
539            .collect();
540
541        let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
542            GrammarError::IntegerOverflow(format!(
543                "n_ctx_train does not fit into u32: {convert_error}"
544            ))
545        })?;
546        let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
547        let sampler = unsafe {
548            llama_cpp_bindings_sys::llama_sampler_init_dry(
549                model.vocab_ptr(),
550                n_ctx_train,
551                multiplier,
552                base,
553                allowed_length,
554                penalty_last_n,
555                seq_breaker_pointers.as_mut_ptr(),
556                seq_breaker_pointers.len(),
557            )
558        };
559
560        Ok(Self { sampler })
561    }
562
563    /// Penalizes tokens for being present in the context.
564    ///
565    /// Parameters:
566    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
567    /// - ``penalty_repeat``: 1.0 = disabled
568    /// - ``penalty_freq``: 0.0 = disabled
569    /// - ``penalty_present``: 0.0 = disabled
570    #[must_use]
571    pub fn penalties(
572        penalty_last_n: i32,
573        penalty_repeat: f32,
574        penalty_freq: f32,
575        penalty_present: f32,
576    ) -> Self {
577        let sampler = unsafe {
578            llama_cpp_bindings_sys::llama_sampler_init_penalties(
579                penalty_last_n,
580                penalty_repeat,
581                penalty_freq,
582                penalty_present,
583            )
584        };
585        Self { sampler }
586    }
587
588    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
589    ///
590    /// # Parameters:
591    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
592    /// - ``seed``: Seed to initialize random generation with.
593    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
594    ///   generated text. A higher value corresponds to more surprising or less predictable text,
595    ///   while a lower value corresponds to less surprising or more predictable text.
596    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
597    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
598    ///   updated more quickly, while a smaller learning rate will result in slower updates.
599    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
600    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
601    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
602    ///   it affects the performance of the algorithm.
603    #[must_use]
604    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
605        let sampler = unsafe {
606            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
607        };
608        Self { sampler }
609    }
610
611    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
612    ///
613    /// # Parameters:
614    /// - ``seed``: Seed to initialize random generation with.
615    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
616    ///   generated text. A higher value corresponds to more surprising or less predictable text,
617    ///   while a lower value corresponds to less surprising or more predictable text.
618    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
619    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
620    ///   updated more quickly, while a smaller learning rate will result in slower updates.
621    #[must_use]
622    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
623        let sampler =
624            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
625        Self { sampler }
626    }
627
628    /// Selects a token at random based on each token's probabilities
629    #[must_use]
630    pub fn dist(seed: u32) -> Self {
631        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
632        Self { sampler }
633    }
634
635    /// Selects the most likely token
636    ///
637    /// # Example:
638    /// ```rust
639    /// use llama_cpp_bindings::token::{
640    ///    LlamaToken,
641    ///    data::LlamaTokenData,
642    ///    data_array::LlamaTokenDataArray
643    /// };
644    /// use llama_cpp_bindings::sampling::LlamaSampler;
645    ///
646    /// let mut data_array = LlamaTokenDataArray::new(vec![
647    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
648    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
649    /// ], false);
650    ///
651    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
652    ///
653    /// assert_eq!(data_array.data.len(), 2);
654    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
655    /// ```
656    #[must_use]
657    pub fn greedy() -> Self {
658        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
659        Self { sampler }
660    }
661
662    /// Creates a sampler that applies bias values to specific tokens during sampling.
663    ///
664    /// # Parameters
665    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
666    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
667    ///
668    /// # Errors
669    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
670    ///
671    /// # Example
672    /// ```rust
673    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
674    /// use llama_cpp_bindings::sampling::LlamaSampler;
675    ///
676    /// let biases = vec![
677    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
678    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
679    /// ];
680    ///
681    /// // Assuming vocab_size of 32000
682    /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
683    /// ```
684    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
685        let bias_count = checked_usize_as_i32_sampling(biases.len())?;
686        let data = biases
687            .as_ptr()
688            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
689
690        let sampler = unsafe {
691            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
692        };
693
694        Ok(Self { sampler })
695    }
696}
697
698impl Drop for LlamaSampler {
699    fn drop(&mut self) {
700        unsafe {
701            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
702        }
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    use super::LlamaSampler;
709    use crate::GrammarError;
710
711    #[test]
712    fn sanitize_grammar_strings_valid() {
713        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
714
715        assert!(result.is_ok());
716    }
717
718    #[test]
719    fn sanitize_grammar_strings_root_not_found() {
720        let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
721
722        assert_eq!(result.err(), Some(GrammarError::RootNotFound));
723    }
724
725    #[test]
726    fn sanitize_grammar_strings_null_byte_in_grammar() {
727        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
728
729        assert!(matches!(
730            result.err(),
731            Some(GrammarError::GrammarNullBytes(_))
732        ));
733    }
734
735    #[test]
736    fn sanitize_grammar_strings_null_byte_in_root() {
737        let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
738
739        assert!(matches!(
740            result.err(),
741            Some(GrammarError::GrammarNullBytes(_))
742        ));
743    }
744
745    #[test]
746    fn sanitize_trigger_words_valid() {
747        let words: Vec<&[u8]> = vec![b"hello", b"world"];
748        let result = LlamaSampler::sanitize_trigger_words(words);
749
750        assert!(result.is_ok());
751        assert_eq!(result.expect("valid trigger words").len(), 2);
752    }
753
754    #[test]
755    fn sanitize_trigger_words_empty_list() {
756        let words: Vec<&[u8]> = vec![];
757        let result = LlamaSampler::sanitize_trigger_words(words);
758
759        assert!(result.is_ok());
760        assert!(result.expect("valid trigger words").is_empty());
761    }
762
763    #[test]
764    fn sanitize_trigger_words_null_byte() {
765        let words: Vec<&[u8]> = vec![b"hel\0lo"];
766        let result = LlamaSampler::sanitize_trigger_words(words);
767
768        assert!(matches!(
769            result.err(),
770            Some(GrammarError::TriggerWordNullBytes(_))
771        ));
772    }
773
774    #[test]
775    fn sanitize_trigger_patterns_valid() {
776        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
777        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
778
779        assert!(result.is_ok());
780        assert_eq!(result.expect("valid trigger patterns").len(), 2);
781    }
782
783    #[test]
784    fn sanitize_trigger_patterns_empty_list() {
785        let patterns: Vec<String> = vec![];
786        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
787
788        assert!(result.is_ok());
789        assert!(result.expect("valid trigger patterns").is_empty());
790    }
791
792    #[test]
793    fn sanitize_trigger_patterns_null_byte() {
794        let patterns = vec!["hel\0lo".to_string()];
795        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
796
797        assert!(matches!(
798            result.err(),
799            Some(GrammarError::GrammarNullBytes(_))
800        ));
801    }
802
803    #[test]
804    fn apply_modifies_data_array() {
805        use crate::token::LlamaToken;
806        use crate::token::data::LlamaTokenData;
807        use crate::token::data_array::LlamaTokenDataArray;
808
809        let sampler = LlamaSampler::greedy();
810        let mut data_array = LlamaTokenDataArray::new(
811            vec![
812                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
813                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
814            ],
815            false,
816        );
817
818        sampler.apply(&mut data_array);
819
820        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
821    }
822
823    #[test]
824    fn accept_succeeds() {
825        let mut sampler = LlamaSampler::chain_simple([
826            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
827            LlamaSampler::greedy(),
828        ]);
829
830        sampler
831            .accept(crate::token::LlamaToken::new(1))
832            .expect("test: accept should succeed");
833    }
834
835    #[test]
836    fn try_accept_succeeds_on_penalties_sampler() {
837        let mut sampler = LlamaSampler::chain_simple([
838            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
839            LlamaSampler::greedy(),
840        ]);
841
842        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
843
844        assert!(result.is_ok());
845    }
846
847    #[test]
848    fn accept_many_multiple_tokens() {
849        use crate::token::LlamaToken;
850
851        let mut sampler = LlamaSampler::chain_simple([
852            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
853            LlamaSampler::greedy(),
854        ]);
855
856        sampler
857            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
858            .expect("test: accept_many should succeed");
859    }
860
861    #[test]
862    fn with_tokens_builder_pattern() {
863        use crate::token::LlamaToken;
864
865        let _sampler = LlamaSampler::chain_simple([
866            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
867            LlamaSampler::greedy(),
868        ])
869        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
870        .expect("test: with_tokens should succeed");
871    }
872
873    #[test]
874    fn all_sampler_constructors() {
875        use crate::token::LlamaToken;
876        use crate::token::logit_bias::LlamaLogitBias;
877
878        let _temp = LlamaSampler::temp(0.8);
879        let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
880        let _top_k = LlamaSampler::top_k(40);
881        let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
882        let _top_p = LlamaSampler::top_p(0.9, 1);
883        let _min_p = LlamaSampler::min_p(0.05, 1);
884        let _typical = LlamaSampler::typical(0.9, 1);
885        let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
886        let _dist = LlamaSampler::dist(42);
887        let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
888        let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
889        let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
890        let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
891        let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
892    }
893
894    #[test]
895    fn reset_and_get_seed() {
896        let mut sampler = LlamaSampler::dist(42);
897        sampler.reset();
898        let _seed = sampler.get_seed();
899    }
900
901    #[test]
902    fn debug_formatting() {
903        let sampler = LlamaSampler::greedy();
904        let debug_output = format!("{sampler:?}");
905        assert!(debug_output.contains("LlamaSampler"));
906    }
907
908    #[test]
909    fn checked_u32_as_i32_overflow() {
910        let result = super::checked_u32_as_i32(u32::MAX);
911        assert!(result.is_err());
912    }
913
914    #[test]
915    fn checked_usize_as_i32_sampling_overflow() {
916        let result = super::checked_usize_as_i32_sampling(usize::MAX);
917        assert!(result.is_err());
918    }
919
920    #[test]
921    fn check_sampler_accept_status_ok() {
922        let result = super::check_sampler_accept_status(
923            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
924            std::ptr::null_mut(),
925        );
926
927        assert!(result.is_ok());
928    }
929
930    #[test]
931    fn check_sampler_accept_status_invalid_argument() {
932        let result = super::check_sampler_accept_status(
933            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
934            std::ptr::null_mut(),
935        );
936
937        assert!(matches!(
938            result,
939            Err(crate::SamplerAcceptError::InvalidArgument)
940        ));
941    }
942
943    #[test]
944    fn check_sampler_accept_status_exception() {
945        let result = super::check_sampler_accept_status(
946            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
947            std::ptr::null_mut(),
948        );
949
950        assert!(matches!(
951            result,
952            Err(crate::SamplerAcceptError::CppException(_))
953        ));
954    }
955
956    #[test]
957    fn check_sampler_not_null_returns_error() {
958        let result = super::check_sampler_not_null(std::ptr::null_mut(), std::ptr::null_mut());
959
960        assert!(result.is_err());
961    }
962}