Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::ffi_error_reader::read_and_free_cpp_error;
9use crate::model::LlamaModel;
10use crate::token::LlamaToken;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
14
15fn check_sampler_accept_status(
16    status: llama_cpp_bindings_sys::llama_rs_status,
17    error_ptr: *mut c_char,
18) -> Result<(), SamplerAcceptError> {
19    match status {
20        llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(()),
21        llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
22            Err(SamplerAcceptError::InvalidArgument)
23        }
24        _ => Err(SamplerAcceptError::CppException(unsafe {
25            read_and_free_cpp_error(error_ptr)
26        })),
27    }
28}
29
30fn check_sampler_not_null(
31    sampler: *mut llama_cpp_bindings_sys::llama_sampler,
32    error_ptr: *mut c_char,
33) -> Result<LlamaSampler, GrammarError> {
34    if sampler.is_null() {
35        Err(GrammarError::NullGrammar(unsafe {
36            read_and_free_cpp_error(error_ptr)
37        }))
38    } else {
39        Ok(LlamaSampler { sampler })
40    }
41}
42
43fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
44    i32::try_from(value).map_err(|convert_error| {
45        GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
46    })
47}
48
49fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
50    i32::try_from(value).map_err(|convert_error| {
51        SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
52    })
53}
54
55/// A safe wrapper around `llama_sampler`.
56pub struct LlamaSampler {
57    /// Raw pointer to the underlying `llama_sampler`.
58    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
59}
60
61impl Debug for LlamaSampler {
62    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("LlamaSamplerChain").finish()
64    }
65}
66
67impl LlamaSampler {
68    /// Sample and accept a token from the idx-th output of the last evaluation.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid.
73    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
74        let mut token: i32 = -1;
75        let mut error_ptr: *mut c_char = std::ptr::null_mut();
76
77        let status = unsafe {
78            llama_cpp_bindings_sys::llama_rs_sampler_sample(
79                self.sampler,
80                ctx.context.as_ptr(),
81                idx,
82                &raw mut token,
83                &raw mut error_ptr,
84            )
85        };
86
87        match status {
88            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(LlamaToken(token)),
89            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => {
90                Err(SampleError::InvalidArgument)
91            }
92            _ => Err(SampleError::CppException(unsafe {
93                read_and_free_cpp_error(error_ptr)
94            })),
95        }
96    }
97
98    /// Applies this sampler to a [`LlamaTokenDataArray`].
99    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
100        data_array.apply_sampler(self);
101    }
102
103    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
104    /// (e.g. grammar, repetition, etc.)
105    ///
106    /// # Errors
107    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
108    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
109        self.try_accept(token)
110    }
111
112    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
113    /// certain samplers (e.g. grammar, repetition, etc.)
114    ///
115    /// # Errors
116    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
117    pub fn accept_many(
118        &mut self,
119        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
120    ) -> Result<(), SamplerAcceptError> {
121        for token in tokens {
122            self.try_accept(*token.borrow())?;
123        }
124
125        Ok(())
126    }
127
128    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
129    /// certain samplers (e.g. grammar, repetition, etc.)
130    ///
131    /// # Errors
132    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
133    pub fn with_tokens(
134        mut self,
135        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
136    ) -> Result<Self, SamplerAcceptError> {
137        self.accept_many(tokens)?;
138
139        Ok(self)
140    }
141
142    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
143    ///
144    /// # Errors
145    /// Returns an error if the underlying sampler rejects the token.
146    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
147        let mut error_ptr: *mut c_char = std::ptr::null_mut();
148
149        let status = unsafe {
150            llama_cpp_bindings_sys::llama_rs_sampler_accept(
151                self.sampler,
152                token.0,
153                &raw mut error_ptr,
154            )
155        };
156
157        check_sampler_accept_status(status, error_ptr)
158    }
159
160    /// Resets the internal state of the sampler.
161    ///
162    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
163    pub fn reset(&mut self) {
164        unsafe {
165            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
166        }
167    }
168
169    /// Gets the random seed used by this sampler.
170    ///
171    /// Returns:
172    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
173    /// - For sampler chains: returns the first non-default seed found in reverse order
174    /// - For all other samplers: returns 0xFFFFFFFF
175    #[must_use]
176    pub fn get_seed(&self) -> u32 {
177        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
178    }
179
180    /// Combines a list of samplers into a single sampler that applies each component sampler one
181    /// after another.
182    ///
183    /// If you are using a chain to select a token, the chain should always end with one of
184    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
185    /// [`LlamaSampler::mirostat_v2`].
186    #[must_use]
187    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
188        unsafe {
189            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
190                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
191            );
192
193            for sampler in samplers {
194                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
195
196                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
197                // owned by the chain
198                std::mem::forget(sampler);
199            }
200
201            Self { sampler: chain }
202        }
203    }
204
205    /// Same as [`Self::chain`] with `no_perf = false`.
206    ///
207    /// # Example
208    /// ```rust
209    /// use llama_cpp_bindings::token::{
210    ///    LlamaToken,
211    ///    data::LlamaTokenData,
212    ///    data_array::LlamaTokenDataArray
213    /// };
214    /// use llama_cpp_bindings::sampling::LlamaSampler;
215    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
216    /// let backend = LlamaBackend::init().unwrap();
217    ///
218    /// let mut data_array = LlamaTokenDataArray::new(vec![
219    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
220    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
221    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
222    /// ], false);
223    ///
224    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
225    ///     LlamaSampler::temp(0.5),
226    ///     LlamaSampler::greedy(),
227    /// ]));
228    ///
229    /// assert_eq!(data_array.data[0].logit(), 0.);
230    /// assert_eq!(data_array.data[1].logit(), 2.);
231    /// assert_eq!(data_array.data[2].logit(), 4.);
232    ///
233    /// assert_eq!(data_array.data.len(), 3);
234    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
235    /// ```
236    #[must_use]
237    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
238        Self::chain(samplers, false)
239    }
240
241    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
242    /// value, the rest are set to -inf
243    ///
244    /// # Example:
245    /// ```rust
246    /// use llama_cpp_bindings::token::{
247    ///    LlamaToken,
248    ///    data::LlamaTokenData,
249    ///    data_array::LlamaTokenDataArray
250    /// };
251    /// use llama_cpp_bindings::sampling::LlamaSampler;
252    ///
253    /// let mut data_array = LlamaTokenDataArray::new(vec![
254    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
255    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
256    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
257    /// ], false);
258    ///
259    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
260    ///
261    /// assert_eq!(data_array.data[0].logit(), 0.);
262    /// assert_eq!(data_array.data[1].logit(), 2.);
263    /// assert_eq!(data_array.data[2].logit(), 4.);
264    /// ```
265    #[must_use]
266    pub fn temp(t: f32) -> Self {
267        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
268        Self { sampler }
269    }
270
271    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
272    /// <https://arxiv.org/abs/2309.02772>.
273    #[must_use]
274    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
275        let sampler =
276            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
277        Self { sampler }
278    }
279
280    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
281    /// <https://arxiv.org/abs/1904.09751>
282    ///
283    /// # Example:
284    /// ```rust
285    /// use llama_cpp_bindings::token::{
286    ///    LlamaToken,
287    ///    data::LlamaTokenData,
288    ///    data_array::LlamaTokenDataArray
289    /// };
290    /// use llama_cpp_bindings::sampling::LlamaSampler;
291    ///
292    /// let mut data_array = LlamaTokenDataArray::new(vec![
293    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
294    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
295    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
296    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
297    /// ], false);
298    ///
299    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
300    ///
301    /// assert_eq!(data_array.data.len(), 2);
302    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
303    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
304    /// ```
305    #[must_use]
306    pub fn top_k(k: i32) -> Self {
307        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
308        Self { sampler }
309    }
310
311    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
312    /// <https://arxiv.org/pdf/2411.07641>
313    ///
314    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
315    ///
316    /// # Parameters
317    /// - `n`: Number of standard deviations from the mean to include in sampling
318    ///
319    /// # Example
320    /// ```rust
321    /// use llama_cpp_bindings::sampling::LlamaSampler;
322    /// use llama_cpp_bindings::token::{
323    ///     LlamaToken,
324    ///     data::LlamaTokenData,
325    ///     data_array::LlamaTokenDataArray
326    /// };
327    ///
328    /// let mut data_array = LlamaTokenDataArray::new(vec![
329    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
330    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
331    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
332    /// ], false);
333    ///
334    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
335    /// ```
336    #[must_use]
337    pub fn top_n_sigma(n: f32) -> Self {
338        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
339        Self { sampler }
340    }
341
342    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
343    #[must_use]
344    pub fn typical(p: f32, min_keep: usize) -> Self {
345        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
346        Self { sampler }
347    }
348
349    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
350    /// <https://arxiv.org/abs/1904.09751>
351    #[must_use]
352    pub fn top_p(p: f32, min_keep: usize) -> Self {
353        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
354        Self { sampler }
355    }
356
357    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
358    #[must_use]
359    pub fn min_p(p: f32, min_keep: usize) -> Self {
360        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
361        Self { sampler }
362    }
363
364    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
365    #[must_use]
366    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
367        let sampler =
368            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
369        Self { sampler }
370    }
371
372    /// Grammar sampler
373    ///
374    /// # Errors
375    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
376    pub fn grammar(
377        model: &LlamaModel,
378        grammar_str: &str,
379        grammar_root: &str,
380    ) -> Result<Self, GrammarError> {
381        let (grammar_str, grammar_root) =
382            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
383        let mut error_ptr: *mut c_char = std::ptr::null_mut();
384
385        let sampler = unsafe {
386            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
387                model.vocab_ptr(),
388                grammar_str.as_ptr(),
389                grammar_root.as_ptr(),
390                &raw mut error_ptr,
391            )
392        };
393
394        check_sampler_not_null(sampler, error_ptr)
395    }
396
397    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
398    ///
399    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
400    ///
401    /// # Errors
402    /// Returns an error if the grammar or trigger words are invalid.
403    pub fn grammar_lazy(
404        model: &LlamaModel,
405        grammar_str: &str,
406        grammar_root: &str,
407        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
408        trigger_tokens: &[LlamaToken],
409    ) -> Result<Self, GrammarError> {
410        let (grammar_str, grammar_root) =
411            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
412        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
413        let mut error_ptr: *mut c_char = std::ptr::null_mut();
414
415        let mut trigger_word_ptrs: Vec<*const c_char> =
416            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
417
418        let sampler = unsafe {
419            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
420                model.vocab_ptr(),
421                grammar_str.as_ptr(),
422                grammar_root.as_ptr(),
423                trigger_word_ptrs.as_mut_ptr(),
424                trigger_word_ptrs.len(),
425                trigger_tokens.as_ptr().cast(),
426                trigger_tokens.len(),
427                &raw mut error_ptr,
428            )
429        };
430
431        check_sampler_not_null(sampler, error_ptr)
432    }
433
434    /// Lazy grammar sampler using regex trigger patterns.
435    ///
436    /// Trigger patterns are regular expressions matched from the start of the
437    /// generation output. The grammar sampler will be fed content starting from
438    /// the first match group.
439    ///
440    /// # Errors
441    /// Returns an error if the grammar or trigger patterns are invalid.
442    pub fn grammar_lazy_patterns(
443        model: &LlamaModel,
444        grammar_str: &str,
445        grammar_root: &str,
446        trigger_patterns: &[String],
447        trigger_tokens: &[LlamaToken],
448    ) -> Result<Self, GrammarError> {
449        let (grammar_str, grammar_root) =
450            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
451        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
452        let mut error_ptr: *mut c_char = std::ptr::null_mut();
453
454        let mut trigger_pattern_ptrs: Vec<*const c_char> =
455            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
456
457        let sampler = unsafe {
458            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
459                model.vocab_ptr(),
460                grammar_str.as_ptr(),
461                grammar_root.as_ptr(),
462                trigger_pattern_ptrs.as_mut_ptr(),
463                trigger_pattern_ptrs.len(),
464                trigger_tokens.as_ptr().cast(),
465                trigger_tokens.len(),
466                &raw mut error_ptr,
467            )
468        };
469
470        check_sampler_not_null(sampler, error_ptr)
471    }
472
473    /// `LLGuidance` sampler for constrained decoding.
474    ///
475    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
476    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
477    ///
478    /// # Errors
479    ///
480    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
481    #[cfg(feature = "llguidance")]
482    pub fn llguidance(
483        model: &LlamaModel,
484        grammar_kind: &str,
485        grammar_data: &str,
486    ) -> Result<Self, GrammarError> {
487        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
488    }
489
490    fn sanitize_grammar_strings(
491        grammar_str: &str,
492        grammar_root: &str,
493    ) -> Result<(CString, CString), GrammarError> {
494        if !grammar_str.contains(grammar_root) {
495            return Err(GrammarError::RootNotFound);
496        }
497
498        let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
499        let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
500
501        Ok((grammar, root))
502    }
503
504    fn sanitize_trigger_words(
505        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
506    ) -> Result<Vec<CString>, GrammarError> {
507        trigger_words
508            .into_iter()
509            .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
510            .collect()
511    }
512
513    fn sanitize_trigger_patterns(
514        trigger_patterns: &[String],
515    ) -> Result<Vec<CString>, GrammarError> {
516        trigger_patterns
517            .iter()
518            .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
519            .collect()
520    }
521
522    /// DRY sampler, designed by p-e-w, as described in:
523    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
524    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
525    ///
526    /// # Errors
527    /// Returns an error if any string in `seq_breakers` contains null bytes.
528    #[allow(missing_docs)]
529    pub fn dry(
530        model: &LlamaModel,
531        multiplier: f32,
532        base: f32,
533        allowed_length: i32,
534        penalty_last_n: i32,
535        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
536    ) -> Result<Self, GrammarError> {
537        let seq_breakers: Vec<CString> = seq_breakers
538            .into_iter()
539            .map(|s| CString::new(s.as_ref()))
540            .collect::<Result<Vec<_>, _>>()?;
541        let mut seq_breaker_pointers: Vec<*const c_char> =
542            seq_breakers.iter().map(|s| s.as_ptr()).collect();
543
544        let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
545            GrammarError::IntegerOverflow(format!(
546                "n_ctx_train does not fit into u32: {convert_error}"
547            ))
548        })?;
549        let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
550        let sampler = unsafe {
551            llama_cpp_bindings_sys::llama_sampler_init_dry(
552                model.vocab_ptr(),
553                n_ctx_train,
554                multiplier,
555                base,
556                allowed_length,
557                penalty_last_n,
558                seq_breaker_pointers.as_mut_ptr(),
559                seq_breaker_pointers.len(),
560            )
561        };
562
563        Ok(Self { sampler })
564    }
565
566    /// Penalizes tokens for being present in the context.
567    ///
568    /// Parameters:
569    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
570    /// - ``penalty_repeat``: 1.0 = disabled
571    /// - ``penalty_freq``: 0.0 = disabled
572    /// - ``penalty_present``: 0.0 = disabled
573    #[must_use]
574    pub fn penalties(
575        penalty_last_n: i32,
576        penalty_repeat: f32,
577        penalty_freq: f32,
578        penalty_present: f32,
579    ) -> Self {
580        let sampler = unsafe {
581            llama_cpp_bindings_sys::llama_sampler_init_penalties(
582                penalty_last_n,
583                penalty_repeat,
584                penalty_freq,
585                penalty_present,
586            )
587        };
588        Self { sampler }
589    }
590
591    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
592    ///
593    /// # Parameters:
594    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
595    /// - ``seed``: Seed to initialize random generation with.
596    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
597    ///   generated text. A higher value corresponds to more surprising or less predictable text,
598    ///   while a lower value corresponds to less surprising or more predictable text.
599    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
600    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
601    ///   updated more quickly, while a smaller learning rate will result in slower updates.
602    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
603    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
604    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
605    ///   it affects the performance of the algorithm.
606    #[must_use]
607    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
608        let sampler = unsafe {
609            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
610        };
611        Self { sampler }
612    }
613
614    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
615    ///
616    /// # Parameters:
617    /// - ``seed``: Seed to initialize random generation with.
618    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
619    ///   generated text. A higher value corresponds to more surprising or less predictable text,
620    ///   while a lower value corresponds to less surprising or more predictable text.
621    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
622    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
623    ///   updated more quickly, while a smaller learning rate will result in slower updates.
624    #[must_use]
625    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
626        let sampler =
627            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
628        Self { sampler }
629    }
630
631    /// Selects a token at random based on each token's probabilities
632    #[must_use]
633    pub fn dist(seed: u32) -> Self {
634        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
635        Self { sampler }
636    }
637
638    /// Selects the most likely token
639    ///
640    /// # Example:
641    /// ```rust
642    /// use llama_cpp_bindings::token::{
643    ///    LlamaToken,
644    ///    data::LlamaTokenData,
645    ///    data_array::LlamaTokenDataArray
646    /// };
647    /// use llama_cpp_bindings::sampling::LlamaSampler;
648    ///
649    /// let mut data_array = LlamaTokenDataArray::new(vec![
650    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
651    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
652    /// ], false);
653    ///
654    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
655    ///
656    /// assert_eq!(data_array.data.len(), 2);
657    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
658    /// ```
659    #[must_use]
660    pub fn greedy() -> Self {
661        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
662        Self { sampler }
663    }
664
665    /// Creates a sampler that applies bias values to specific tokens during sampling.
666    ///
667    /// # Parameters
668    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
669    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
670    ///
671    /// # Errors
672    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
673    ///
674    /// # Example
675    /// ```rust
676    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
677    /// use llama_cpp_bindings::sampling::LlamaSampler;
678    ///
679    /// let biases = vec![
680    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
681    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
682    /// ];
683    ///
684    /// // Assuming vocab_size of 32000
685    /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
686    /// ```
687    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
688        let bias_count = checked_usize_as_i32_sampling(biases.len())?;
689        let data = biases
690            .as_ptr()
691            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
692
693        let sampler = unsafe {
694            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
695        };
696
697        Ok(Self { sampler })
698    }
699}
700
701impl Drop for LlamaSampler {
702    fn drop(&mut self) {
703        unsafe {
704            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
705        }
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::LlamaSampler;
712    use crate::GrammarError;
713
714    #[test]
715    fn sanitize_grammar_strings_valid() {
716        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
717
718        assert!(result.is_ok());
719    }
720
721    #[test]
722    fn sanitize_grammar_strings_root_not_found() {
723        let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
724
725        assert_eq!(result.err(), Some(GrammarError::RootNotFound));
726    }
727
728    #[test]
729    fn sanitize_grammar_strings_null_byte_in_grammar() {
730        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
731
732        assert!(matches!(
733            result.err(),
734            Some(GrammarError::GrammarNullBytes(_))
735        ));
736    }
737
738    #[test]
739    fn sanitize_grammar_strings_null_byte_in_root() {
740        let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
741
742        assert!(matches!(
743            result.err(),
744            Some(GrammarError::GrammarNullBytes(_))
745        ));
746    }
747
748    #[test]
749    fn sanitize_trigger_words_valid() {
750        let words: Vec<&[u8]> = vec![b"hello", b"world"];
751        let result = LlamaSampler::sanitize_trigger_words(words);
752
753        assert!(result.is_ok());
754        assert_eq!(result.expect("valid trigger words").len(), 2);
755    }
756
757    #[test]
758    fn sanitize_trigger_words_empty_list() {
759        let words: Vec<&[u8]> = vec![];
760        let result = LlamaSampler::sanitize_trigger_words(words);
761
762        assert!(result.is_ok());
763        assert!(result.expect("valid trigger words").is_empty());
764    }
765
766    #[test]
767    fn sanitize_trigger_words_null_byte() {
768        let words: Vec<&[u8]> = vec![b"hel\0lo"];
769        let result = LlamaSampler::sanitize_trigger_words(words);
770
771        assert!(matches!(
772            result.err(),
773            Some(GrammarError::TriggerWordNullBytes(_))
774        ));
775    }
776
777    #[test]
778    fn sanitize_trigger_patterns_valid() {
779        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
780        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
781
782        assert!(result.is_ok());
783        assert_eq!(result.expect("valid trigger patterns").len(), 2);
784    }
785
786    #[test]
787    fn sanitize_trigger_patterns_empty_list() {
788        let patterns: Vec<String> = vec![];
789        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
790
791        assert!(result.is_ok());
792        assert!(result.expect("valid trigger patterns").is_empty());
793    }
794
795    #[test]
796    fn sanitize_trigger_patterns_null_byte() {
797        let patterns = vec!["hel\0lo".to_string()];
798        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
799
800        assert!(matches!(
801            result.err(),
802            Some(GrammarError::GrammarNullBytes(_))
803        ));
804    }
805
806    #[test]
807    fn apply_modifies_data_array() {
808        use crate::token::LlamaToken;
809        use crate::token::data::LlamaTokenData;
810        use crate::token::data_array::LlamaTokenDataArray;
811
812        let sampler = LlamaSampler::greedy();
813        let mut data_array = LlamaTokenDataArray::new(
814            vec![
815                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
816                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
817            ],
818            false,
819        );
820
821        sampler.apply(&mut data_array);
822
823        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
824    }
825
826    #[test]
827    fn accept_succeeds() {
828        let mut sampler = LlamaSampler::chain_simple([
829            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
830            LlamaSampler::greedy(),
831        ]);
832
833        sampler
834            .accept(crate::token::LlamaToken::new(1))
835            .expect("test: accept should succeed");
836    }
837
838    #[test]
839    fn try_accept_succeeds_on_penalties_sampler() {
840        let mut sampler = LlamaSampler::chain_simple([
841            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
842            LlamaSampler::greedy(),
843        ]);
844
845        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
846
847        assert!(result.is_ok());
848    }
849
850    #[test]
851    fn accept_many_multiple_tokens() {
852        use crate::token::LlamaToken;
853
854        let mut sampler = LlamaSampler::chain_simple([
855            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
856            LlamaSampler::greedy(),
857        ]);
858
859        sampler
860            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
861            .expect("test: accept_many should succeed");
862    }
863
864    #[test]
865    fn with_tokens_builder_pattern() {
866        use crate::token::LlamaToken;
867
868        let _sampler = LlamaSampler::chain_simple([
869            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
870            LlamaSampler::greedy(),
871        ])
872        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
873        .expect("test: with_tokens should succeed");
874    }
875
876    #[test]
877    fn all_sampler_constructors() {
878        use crate::token::LlamaToken;
879        use crate::token::logit_bias::LlamaLogitBias;
880
881        let _temp = LlamaSampler::temp(0.8);
882        let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
883        let _top_k = LlamaSampler::top_k(40);
884        let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
885        let _top_p = LlamaSampler::top_p(0.9, 1);
886        let _min_p = LlamaSampler::min_p(0.05, 1);
887        let _typical = LlamaSampler::typical(0.9, 1);
888        let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
889        let _dist = LlamaSampler::dist(42);
890        let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
891        let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
892        let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
893        let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
894        let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
895    }
896
897    #[test]
898    fn reset_and_get_seed() {
899        let mut sampler = LlamaSampler::dist(42);
900        sampler.reset();
901        let _seed = sampler.get_seed();
902    }
903
904    #[test]
905    fn debug_formatting() {
906        let sampler = LlamaSampler::greedy();
907        let debug_output = format!("{sampler:?}");
908        assert!(debug_output.contains("LlamaSampler"));
909    }
910
911    #[cfg(feature = "tests_that_use_llms")]
912    #[test]
913    #[serial_test::serial]
914    fn dry_sampler_with_model() {
915        let (_backend, model) = crate::test_model::load_default_model().unwrap();
916        let breakers: Vec<&[u8]> = vec![b"\n", b"\t"];
917        let _sampler = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, &breakers);
918    }
919
920    #[cfg(feature = "tests_that_use_llms")]
921    #[test]
922    #[serial_test::serial]
923    fn dry_sampler_with_null_byte_in_seq_breakers_returns_error() {
924        let (_backend, model) = crate::test_model::load_default_model().unwrap();
925        let breakers: Vec<&[u8]> = vec![b"hello\0world"];
926        let result = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, breakers);
927
928        assert!(result.is_err());
929    }
930
931    #[cfg(feature = "tests_that_use_llms")]
932    #[test]
933    #[serial_test::serial]
934    fn grammar_returns_sampler_for_valid_grammar() {
935        let (_backend, model) = crate::test_model::load_default_model().unwrap();
936        let sampler = LlamaSampler::grammar(&model, "root ::= \"hello\"", "root");
937
938        assert!(sampler.is_ok());
939    }
940
941    #[cfg(feature = "tests_that_use_llms")]
942    #[test]
943    #[serial_test::serial]
944    fn grammar_lazy_returns_sampler_for_valid_grammar_with_triggers() {
945        let (_backend, model) = crate::test_model::load_default_model().unwrap();
946        let trigger_words: Vec<&[u8]> = vec![b"function"];
947        let sampler =
948            LlamaSampler::grammar_lazy(&model, "root ::= \"hello\"", "root", trigger_words, &[]);
949
950        assert!(sampler.is_ok());
951    }
952
953    #[cfg(feature = "tests_that_use_llms")]
954    #[test]
955    #[serial_test::serial]
956    fn grammar_lazy_patterns_returns_sampler_for_valid_grammar_with_patterns() {
957        let (_backend, model) = crate::test_model::load_default_model().unwrap();
958        let patterns = vec!["\\{.*".to_string()];
959        let sampler = LlamaSampler::grammar_lazy_patterns(
960            &model,
961            "root ::= \"hello\"",
962            "root",
963            &patterns,
964            &[],
965        );
966
967        assert!(sampler.is_ok());
968    }
969
970    #[cfg(feature = "tests_that_use_llms")]
971    #[test]
972    #[serial_test::serial]
973    fn sample_returns_token_after_decode() {
974        use crate::context::params::LlamaContextParams;
975        use crate::llama_batch::LlamaBatch;
976        use crate::model::AddBos;
977        use crate::token::LlamaToken;
978
979        let (backend, model) = crate::test_model::load_default_model().unwrap();
980        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
981        let mut context = model.new_context(&backend, ctx_params).unwrap();
982        let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
983        let mut batch = LlamaBatch::new(512, 1).unwrap();
984        batch.add_sequence(&tokens, 0, false).unwrap();
985        context.decode(&mut batch).unwrap();
986        let mut sampler =
987            LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
988        let result = sampler.sample(&context, batch.n_tokens() - 1);
989
990        assert!(result.is_ok());
991    }
992
993    #[test]
994    fn checked_u32_as_i32_overflow() {
995        let result = super::checked_u32_as_i32(u32::MAX);
996        assert!(result.is_err());
997    }
998
999    #[test]
1000    fn checked_usize_as_i32_sampling_overflow() {
1001        let result = super::checked_usize_as_i32_sampling(usize::MAX);
1002        assert!(result.is_err());
1003    }
1004
1005    #[test]
1006    fn check_sampler_accept_status_ok() {
1007        let result = super::check_sampler_accept_status(
1008            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1009            std::ptr::null_mut(),
1010        );
1011
1012        assert!(result.is_ok());
1013    }
1014
1015    #[test]
1016    fn check_sampler_accept_status_invalid_argument() {
1017        let result = super::check_sampler_accept_status(
1018            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
1019            std::ptr::null_mut(),
1020        );
1021
1022        assert!(matches!(
1023            result,
1024            Err(crate::SamplerAcceptError::InvalidArgument)
1025        ));
1026    }
1027
1028    #[test]
1029    fn check_sampler_accept_status_exception() {
1030        let result = super::check_sampler_accept_status(
1031            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
1032            std::ptr::null_mut(),
1033        );
1034
1035        assert!(matches!(
1036            result,
1037            Err(crate::SamplerAcceptError::CppException(_))
1038        ));
1039    }
1040
1041    #[test]
1042    fn check_sampler_not_null_returns_error() {
1043        let result = super::check_sampler_not_null(std::ptr::null_mut(), std::ptr::null_mut());
1044
1045        assert!(result.is_err());
1046    }
1047}