Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32};
13
14const fn check_sampler_accept_status(
15    status: llama_cpp_bindings_sys::llama_rs_status,
16) -> Result<(), SamplerAcceptError> {
17    if status_is_ok(status) {
18        Ok(())
19    } else {
20        Err(SamplerAcceptError::FfiError(status_to_i32(status)))
21    }
22}
23
24const fn check_sampler_not_null(
25    sampler: *mut llama_cpp_bindings_sys::llama_sampler,
26) -> Result<LlamaSampler, GrammarError> {
27    if sampler.is_null() {
28        Err(GrammarError::NullGrammar)
29    } else {
30        Ok(LlamaSampler { sampler })
31    }
32}
33
34fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
35    i32::try_from(value).map_err(|convert_error| {
36        GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
37    })
38}
39
40fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
41    i32::try_from(value).map_err(|convert_error| {
42        SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
43    })
44}
45
46/// A safe wrapper around `llama_sampler`.
47pub struct LlamaSampler {
48    /// Raw pointer to the underlying `llama_sampler`.
49    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
50}
51
52impl Debug for LlamaSampler {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("LlamaSamplerChain").finish()
55    }
56}
57
58impl LlamaSampler {
59    /// Sample and accept a token from the idx-th output of the last evaluation
60    #[must_use]
61    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
62        let token = unsafe {
63            llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
64        };
65
66        LlamaToken(token)
67    }
68
69    /// Applies this sampler to a [`LlamaTokenDataArray`].
70    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
71        data_array.apply_sampler(self);
72    }
73
74    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
75    /// (e.g. grammar, repetition, etc.)
76    ///
77    /// # Errors
78    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
79    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
80        self.try_accept(token)
81    }
82
83    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
84    /// certain samplers (e.g. grammar, repetition, etc.)
85    ///
86    /// # Errors
87    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
88    pub fn accept_many(
89        &mut self,
90        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
91    ) -> Result<(), SamplerAcceptError> {
92        for token in tokens {
93            self.try_accept(*token.borrow())?;
94        }
95
96        Ok(())
97    }
98
99    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
100    /// certain samplers (e.g. grammar, repetition, etc.)
101    ///
102    /// # Errors
103    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
104    pub fn with_tokens(
105        mut self,
106        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
107    ) -> Result<Self, SamplerAcceptError> {
108        self.accept_many(tokens)?;
109
110        Ok(self)
111    }
112
113    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
114    ///
115    /// # Errors
116    /// Returns an error if the underlying sampler rejects the token.
117    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
118        let sampler_result =
119            unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
120
121        check_sampler_accept_status(sampler_result)
122    }
123
124    /// Resets the internal state of the sampler.
125    ///
126    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
127    pub fn reset(&mut self) {
128        unsafe {
129            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
130        }
131    }
132
133    /// Gets the random seed used by this sampler.
134    ///
135    /// Returns:
136    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
137    /// - For sampler chains: returns the first non-default seed found in reverse order
138    /// - For all other samplers: returns 0xFFFFFFFF
139    #[must_use]
140    pub fn get_seed(&self) -> u32 {
141        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
142    }
143
144    /// Combines a list of samplers into a single sampler that applies each component sampler one
145    /// after another.
146    ///
147    /// If you are using a chain to select a token, the chain should always end with one of
148    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
149    /// [`LlamaSampler::mirostat_v2`].
150    #[must_use]
151    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
152        unsafe {
153            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
154                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
155            );
156
157            for sampler in samplers {
158                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
159
160                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
161                // owned by the chain
162                std::mem::forget(sampler);
163            }
164
165            Self { sampler: chain }
166        }
167    }
168
169    /// Same as [`Self::chain`] with `no_perf = false`.
170    ///
171    /// # Example
172    /// ```rust
173    /// use llama_cpp_bindings::token::{
174    ///    LlamaToken,
175    ///    data::LlamaTokenData,
176    ///    data_array::LlamaTokenDataArray
177    /// };
178    /// use llama_cpp_bindings::sampling::LlamaSampler;
179    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
180    /// let backend = LlamaBackend::init().unwrap();
181    ///
182    /// let mut data_array = LlamaTokenDataArray::new(vec![
183    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
184    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
185    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
186    /// ], false);
187    ///
188    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
189    ///     LlamaSampler::temp(0.5),
190    ///     LlamaSampler::greedy(),
191    /// ]));
192    ///
193    /// assert_eq!(data_array.data[0].logit(), 0.);
194    /// assert_eq!(data_array.data[1].logit(), 2.);
195    /// assert_eq!(data_array.data[2].logit(), 4.);
196    ///
197    /// assert_eq!(data_array.data.len(), 3);
198    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
199    /// ```
200    #[must_use]
201    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
202        Self::chain(samplers, false)
203    }
204
205    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
206    /// value, the rest are set to -inf
207    ///
208    /// # Example:
209    /// ```rust
210    /// use llama_cpp_bindings::token::{
211    ///    LlamaToken,
212    ///    data::LlamaTokenData,
213    ///    data_array::LlamaTokenDataArray
214    /// };
215    /// use llama_cpp_bindings::sampling::LlamaSampler;
216    ///
217    /// let mut data_array = LlamaTokenDataArray::new(vec![
218    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
219    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
220    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
221    /// ], false);
222    ///
223    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
224    ///
225    /// assert_eq!(data_array.data[0].logit(), 0.);
226    /// assert_eq!(data_array.data[1].logit(), 2.);
227    /// assert_eq!(data_array.data[2].logit(), 4.);
228    /// ```
229    #[must_use]
230    pub fn temp(t: f32) -> Self {
231        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
232        Self { sampler }
233    }
234
235    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
236    /// <https://arxiv.org/abs/2309.02772>.
237    #[must_use]
238    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
239        let sampler =
240            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
241        Self { sampler }
242    }
243
244    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
245    /// <https://arxiv.org/abs/1904.09751>
246    ///
247    /// # Example:
248    /// ```rust
249    /// use llama_cpp_bindings::token::{
250    ///    LlamaToken,
251    ///    data::LlamaTokenData,
252    ///    data_array::LlamaTokenDataArray
253    /// };
254    /// use llama_cpp_bindings::sampling::LlamaSampler;
255    ///
256    /// let mut data_array = LlamaTokenDataArray::new(vec![
257    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
258    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
259    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
260    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
261    /// ], false);
262    ///
263    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
264    ///
265    /// assert_eq!(data_array.data.len(), 2);
266    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
267    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
268    /// ```
269    #[must_use]
270    pub fn top_k(k: i32) -> Self {
271        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
272        Self { sampler }
273    }
274
275    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
276    /// <https://arxiv.org/pdf/2411.07641>
277    ///
278    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
279    ///
280    /// # Parameters
281    /// - `n`: Number of standard deviations from the mean to include in sampling
282    ///
283    /// # Example
284    /// ```rust
285    /// use llama_cpp_bindings::sampling::LlamaSampler;
286    /// use llama_cpp_bindings::token::{
287    ///     LlamaToken,
288    ///     data::LlamaTokenData,
289    ///     data_array::LlamaTokenDataArray
290    /// };
291    ///
292    /// let mut data_array = LlamaTokenDataArray::new(vec![
293    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
294    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
295    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
296    /// ], false);
297    ///
298    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
299    /// ```
300    #[must_use]
301    pub fn top_n_sigma(n: f32) -> Self {
302        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
303        Self { sampler }
304    }
305
306    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
307    #[must_use]
308    pub fn typical(p: f32, min_keep: usize) -> Self {
309        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
310        Self { sampler }
311    }
312
313    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
314    /// <https://arxiv.org/abs/1904.09751>
315    #[must_use]
316    pub fn top_p(p: f32, min_keep: usize) -> Self {
317        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
318        Self { sampler }
319    }
320
321    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
322    #[must_use]
323    pub fn min_p(p: f32, min_keep: usize) -> Self {
324        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
325        Self { sampler }
326    }
327
328    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
329    #[must_use]
330    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
331        let sampler =
332            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
333        Self { sampler }
334    }
335
336    /// Grammar sampler
337    ///
338    /// # Errors
339    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
340    pub fn grammar(
341        model: &LlamaModel,
342        grammar_str: &str,
343        grammar_root: &str,
344    ) -> Result<Self, GrammarError> {
345        let (grammar_str, grammar_root) =
346            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
347
348        let sampler = unsafe {
349            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
350                model.vocab_ptr(),
351                grammar_str.as_ptr(),
352                grammar_root.as_ptr(),
353            )
354        };
355
356        check_sampler_not_null(sampler)
357    }
358
359    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
360    ///
361    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
362    ///
363    /// # Errors
364    /// Returns an error if the grammar or trigger words are invalid.
365    pub fn grammar_lazy(
366        model: &LlamaModel,
367        grammar_str: &str,
368        grammar_root: &str,
369        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
370        trigger_tokens: &[LlamaToken],
371    ) -> Result<Self, GrammarError> {
372        let (grammar_str, grammar_root) =
373            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
374        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
375
376        let mut trigger_word_ptrs: Vec<*const c_char> =
377            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
378
379        let sampler = unsafe {
380            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
381                model.vocab_ptr(),
382                grammar_str.as_ptr(),
383                grammar_root.as_ptr(),
384                trigger_word_ptrs.as_mut_ptr(),
385                trigger_word_ptrs.len(),
386                trigger_tokens.as_ptr().cast(),
387                trigger_tokens.len(),
388            )
389        };
390
391        check_sampler_not_null(sampler)
392    }
393
394    /// Lazy grammar sampler using regex trigger patterns.
395    ///
396    /// Trigger patterns are regular expressions matched from the start of the
397    /// generation output. The grammar sampler will be fed content starting from
398    /// the first match group.
399    ///
400    /// # Errors
401    /// Returns an error if the grammar or trigger patterns are invalid.
402    pub fn grammar_lazy_patterns(
403        model: &LlamaModel,
404        grammar_str: &str,
405        grammar_root: &str,
406        trigger_patterns: &[String],
407        trigger_tokens: &[LlamaToken],
408    ) -> Result<Self, GrammarError> {
409        let (grammar_str, grammar_root) =
410            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
411        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
412
413        let mut trigger_pattern_ptrs: Vec<*const c_char> =
414            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
415
416        let sampler = unsafe {
417            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
418                model.vocab_ptr(),
419                grammar_str.as_ptr(),
420                grammar_root.as_ptr(),
421                trigger_pattern_ptrs.as_mut_ptr(),
422                trigger_pattern_ptrs.len(),
423                trigger_tokens.as_ptr().cast(),
424                trigger_tokens.len(),
425            )
426        };
427
428        check_sampler_not_null(sampler)
429    }
430
431    /// `LLGuidance` sampler for constrained decoding.
432    ///
433    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
434    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
435    ///
436    /// # Errors
437    ///
438    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
439    #[cfg(feature = "llguidance")]
440    pub fn llguidance(
441        model: &LlamaModel,
442        grammar_kind: &str,
443        grammar_data: &str,
444    ) -> Result<Self, GrammarError> {
445        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
446    }
447
448    fn sanitize_grammar_strings(
449        grammar_str: &str,
450        grammar_root: &str,
451    ) -> Result<(CString, CString), GrammarError> {
452        if !grammar_str.contains(grammar_root) {
453            return Err(GrammarError::RootNotFound);
454        }
455
456        let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
457        let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
458
459        Ok((grammar, root))
460    }
461
462    fn sanitize_trigger_words(
463        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
464    ) -> Result<Vec<CString>, GrammarError> {
465        trigger_words
466            .into_iter()
467            .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
468            .collect()
469    }
470
471    fn sanitize_trigger_patterns(
472        trigger_patterns: &[String],
473    ) -> Result<Vec<CString>, GrammarError> {
474        trigger_patterns
475            .iter()
476            .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
477            .collect()
478    }
479
480    /// DRY sampler, designed by p-e-w, as described in:
481    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
482    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
483    ///
484    /// # Errors
485    /// Returns an error if any string in `seq_breakers` contains null bytes.
486    #[allow(missing_docs)]
487    pub fn dry(
488        model: &LlamaModel,
489        multiplier: f32,
490        base: f32,
491        allowed_length: i32,
492        penalty_last_n: i32,
493        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
494    ) -> Result<Self, GrammarError> {
495        let seq_breakers: Vec<CString> = seq_breakers
496            .into_iter()
497            .map(|s| CString::new(s.as_ref()))
498            .collect::<Result<Vec<_>, _>>()?;
499        let mut seq_breaker_pointers: Vec<*const c_char> =
500            seq_breakers.iter().map(|s| s.as_ptr()).collect();
501
502        let n_ctx_train = checked_u32_as_i32(model.n_ctx_train())?;
503        let sampler = unsafe {
504            llama_cpp_bindings_sys::llama_sampler_init_dry(
505                model.vocab_ptr(),
506                n_ctx_train,
507                multiplier,
508                base,
509                allowed_length,
510                penalty_last_n,
511                seq_breaker_pointers.as_mut_ptr(),
512                seq_breaker_pointers.len(),
513            )
514        };
515
516        Ok(Self { sampler })
517    }
518
519    /// Penalizes tokens for being present in the context.
520    ///
521    /// Parameters:
522    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
523    /// - ``penalty_repeat``: 1.0 = disabled
524    /// - ``penalty_freq``: 0.0 = disabled
525    /// - ``penalty_present``: 0.0 = disabled
526    #[must_use]
527    pub fn penalties(
528        penalty_last_n: i32,
529        penalty_repeat: f32,
530        penalty_freq: f32,
531        penalty_present: f32,
532    ) -> Self {
533        let sampler = unsafe {
534            llama_cpp_bindings_sys::llama_sampler_init_penalties(
535                penalty_last_n,
536                penalty_repeat,
537                penalty_freq,
538                penalty_present,
539            )
540        };
541        Self { sampler }
542    }
543
544    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
545    ///
546    /// # Parameters:
547    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
548    /// - ``seed``: Seed to initialize random generation with.
549    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
550    ///   generated text. A higher value corresponds to more surprising or less predictable text,
551    ///   while a lower value corresponds to less surprising or more predictable text.
552    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
553    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
554    ///   updated more quickly, while a smaller learning rate will result in slower updates.
555    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
556    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
557    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
558    ///   it affects the performance of the algorithm.
559    #[must_use]
560    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
561        let sampler = unsafe {
562            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
563        };
564        Self { sampler }
565    }
566
567    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
568    ///
569    /// # Parameters:
570    /// - ``seed``: Seed to initialize random generation with.
571    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
572    ///   generated text. A higher value corresponds to more surprising or less predictable text,
573    ///   while a lower value corresponds to less surprising or more predictable text.
574    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
575    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
576    ///   updated more quickly, while a smaller learning rate will result in slower updates.
577    #[must_use]
578    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
579        let sampler =
580            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
581        Self { sampler }
582    }
583
584    /// Selects a token at random based on each token's probabilities
585    #[must_use]
586    pub fn dist(seed: u32) -> Self {
587        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
588        Self { sampler }
589    }
590
591    /// Selects the most likely token
592    ///
593    /// # Example:
594    /// ```rust
595    /// use llama_cpp_bindings::token::{
596    ///    LlamaToken,
597    ///    data::LlamaTokenData,
598    ///    data_array::LlamaTokenDataArray
599    /// };
600    /// use llama_cpp_bindings::sampling::LlamaSampler;
601    ///
602    /// let mut data_array = LlamaTokenDataArray::new(vec![
603    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
604    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
605    /// ], false);
606    ///
607    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
608    ///
609    /// assert_eq!(data_array.data.len(), 2);
610    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
611    /// ```
612    #[must_use]
613    pub fn greedy() -> Self {
614        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
615        Self { sampler }
616    }
617
618    /// Creates a sampler that applies bias values to specific tokens during sampling.
619    ///
620    /// # Parameters
621    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
622    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
623    ///
624    /// # Errors
625    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
626    ///
627    /// # Example
628    /// ```rust
629    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
630    /// use llama_cpp_bindings::sampling::LlamaSampler;
631    ///
632    /// let biases = vec![
633    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
634    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
635    /// ];
636    ///
637    /// // Assuming vocab_size of 32000
638    /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
639    /// ```
640    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
641        let bias_count = checked_usize_as_i32_sampling(biases.len())?;
642        let data = biases
643            .as_ptr()
644            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
645
646        let sampler = unsafe {
647            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
648        };
649
650        Ok(Self { sampler })
651    }
652}
653
654impl Drop for LlamaSampler {
655    fn drop(&mut self) {
656        unsafe {
657            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::LlamaSampler;
665    use crate::GrammarError;
666
667    #[test]
668    fn sanitize_grammar_strings_valid() {
669        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
670
671        assert!(result.is_ok());
672    }
673
674    #[test]
675    fn sanitize_grammar_strings_root_not_found() {
676        let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
677
678        assert_eq!(result.err(), Some(GrammarError::RootNotFound));
679    }
680
681    #[test]
682    fn sanitize_grammar_strings_null_byte_in_grammar() {
683        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
684
685        assert!(matches!(
686            result.err(),
687            Some(GrammarError::GrammarNullBytes(_))
688        ));
689    }
690
691    #[test]
692    fn sanitize_grammar_strings_null_byte_in_root() {
693        let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
694
695        assert!(matches!(
696            result.err(),
697            Some(GrammarError::GrammarNullBytes(_))
698        ));
699    }
700
701    #[test]
702    fn sanitize_trigger_words_valid() {
703        let words: Vec<&[u8]> = vec![b"hello", b"world"];
704        let result = LlamaSampler::sanitize_trigger_words(words);
705
706        assert!(result.is_ok());
707        assert_eq!(result.expect("valid trigger words").len(), 2);
708    }
709
710    #[test]
711    fn sanitize_trigger_words_empty_list() {
712        let words: Vec<&[u8]> = vec![];
713        let result = LlamaSampler::sanitize_trigger_words(words);
714
715        assert!(result.is_ok());
716        assert!(result.expect("valid trigger words").is_empty());
717    }
718
719    #[test]
720    fn sanitize_trigger_words_null_byte() {
721        let words: Vec<&[u8]> = vec![b"hel\0lo"];
722        let result = LlamaSampler::sanitize_trigger_words(words);
723
724        assert!(matches!(
725            result.err(),
726            Some(GrammarError::TriggerWordNullBytes(_))
727        ));
728    }
729
730    #[test]
731    fn sanitize_trigger_patterns_valid() {
732        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
733        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
734
735        assert!(result.is_ok());
736        assert_eq!(result.expect("valid trigger patterns").len(), 2);
737    }
738
739    #[test]
740    fn sanitize_trigger_patterns_empty_list() {
741        let patterns: Vec<String> = vec![];
742        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
743
744        assert!(result.is_ok());
745        assert!(result.expect("valid trigger patterns").is_empty());
746    }
747
748    #[test]
749    fn sanitize_trigger_patterns_null_byte() {
750        let patterns = vec!["hel\0lo".to_string()];
751        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
752
753        assert!(matches!(
754            result.err(),
755            Some(GrammarError::GrammarNullBytes(_))
756        ));
757    }
758
759    #[test]
760    fn apply_modifies_data_array() {
761        use crate::token::LlamaToken;
762        use crate::token::data::LlamaTokenData;
763        use crate::token::data_array::LlamaTokenDataArray;
764
765        let sampler = LlamaSampler::greedy();
766        let mut data_array = LlamaTokenDataArray::new(
767            vec![
768                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
769                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
770            ],
771            false,
772        );
773
774        sampler.apply(&mut data_array);
775
776        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
777    }
778
779    #[test]
780    fn accept_succeeds() {
781        let mut sampler = LlamaSampler::chain_simple([
782            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
783            LlamaSampler::greedy(),
784        ]);
785
786        sampler
787            .accept(crate::token::LlamaToken::new(1))
788            .expect("test: accept should succeed");
789    }
790
791    #[test]
792    fn try_accept_succeeds_on_penalties_sampler() {
793        let mut sampler = LlamaSampler::chain_simple([
794            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
795            LlamaSampler::greedy(),
796        ]);
797
798        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
799
800        assert!(result.is_ok());
801    }
802
803    #[test]
804    fn accept_many_multiple_tokens() {
805        use crate::token::LlamaToken;
806
807        let mut sampler = LlamaSampler::chain_simple([
808            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
809            LlamaSampler::greedy(),
810        ]);
811
812        sampler
813            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
814            .expect("test: accept_many should succeed");
815    }
816
817    #[test]
818    fn with_tokens_builder_pattern() {
819        use crate::token::LlamaToken;
820
821        let _sampler = LlamaSampler::chain_simple([
822            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
823            LlamaSampler::greedy(),
824        ])
825        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
826        .expect("test: with_tokens should succeed");
827    }
828
829    #[test]
830    fn all_sampler_constructors() {
831        use crate::token::LlamaToken;
832        use crate::token::logit_bias::LlamaLogitBias;
833
834        let _temp = LlamaSampler::temp(0.8);
835        let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
836        let _top_k = LlamaSampler::top_k(40);
837        let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
838        let _top_p = LlamaSampler::top_p(0.9, 1);
839        let _min_p = LlamaSampler::min_p(0.05, 1);
840        let _typical = LlamaSampler::typical(0.9, 1);
841        let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
842        let _dist = LlamaSampler::dist(42);
843        let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
844        let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
845        let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
846        let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
847        let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
848    }
849
850    #[test]
851    fn reset_and_get_seed() {
852        let mut sampler = LlamaSampler::dist(42);
853        sampler.reset();
854        let _seed = sampler.get_seed();
855    }
856
857    #[test]
858    fn debug_formatting() {
859        let sampler = LlamaSampler::greedy();
860        let debug_output = format!("{sampler:?}");
861        assert!(debug_output.contains("LlamaSampler"));
862    }
863
864    #[cfg(feature = "tests_that_use_llms")]
865    #[test]
866    #[serial_test::serial]
867    fn dry_sampler_with_model() {
868        let (_backend, model) = crate::test_model::load_default_model().unwrap();
869        let breakers: Vec<&[u8]> = vec![b"\n", b"\t"];
870        let _sampler = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, &breakers);
871    }
872
873    #[cfg(feature = "tests_that_use_llms")]
874    #[test]
875    #[serial_test::serial]
876    fn dry_sampler_with_null_byte_in_seq_breakers_returns_error() {
877        let (_backend, model) = crate::test_model::load_default_model().unwrap();
878        let breakers: Vec<&[u8]> = vec![b"hello\0world"];
879        let result = LlamaSampler::dry(&model, 1.5, 2.0, 128, 2, breakers);
880
881        assert!(result.is_err());
882    }
883
884    #[cfg(feature = "tests_that_use_llms")]
885    #[test]
886    #[serial_test::serial]
887    fn grammar_returns_sampler_for_valid_grammar() {
888        let (_backend, model) = crate::test_model::load_default_model().unwrap();
889        let sampler = LlamaSampler::grammar(&model, "root ::= \"hello\"", "root");
890
891        assert!(sampler.is_ok());
892    }
893
894    #[cfg(feature = "tests_that_use_llms")]
895    #[test]
896    #[serial_test::serial]
897    fn grammar_lazy_returns_sampler_for_valid_grammar_with_triggers() {
898        let (_backend, model) = crate::test_model::load_default_model().unwrap();
899        let trigger_words: Vec<&[u8]> = vec![b"function"];
900        let sampler =
901            LlamaSampler::grammar_lazy(&model, "root ::= \"hello\"", "root", trigger_words, &[]);
902
903        assert!(sampler.is_ok());
904    }
905
906    #[cfg(feature = "tests_that_use_llms")]
907    #[test]
908    #[serial_test::serial]
909    fn grammar_lazy_patterns_returns_sampler_for_valid_grammar_with_patterns() {
910        let (_backend, model) = crate::test_model::load_default_model().unwrap();
911        let patterns = vec!["\\{.*".to_string()];
912        let sampler = LlamaSampler::grammar_lazy_patterns(
913            &model,
914            "root ::= \"hello\"",
915            "root",
916            &patterns,
917            &[],
918        );
919
920        assert!(sampler.is_ok());
921    }
922
923    #[cfg(feature = "tests_that_use_llms")]
924    #[test]
925    #[serial_test::serial]
926    fn sample_returns_token_after_decode() {
927        use crate::context::params::LlamaContextParams;
928        use crate::llama_batch::LlamaBatch;
929        use crate::model::AddBos;
930        use crate::token::LlamaToken;
931
932        let (backend, model) = crate::test_model::load_default_model().unwrap();
933        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
934        let mut context = model.new_context(&backend, ctx_params).unwrap();
935        let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
936        let mut batch = LlamaBatch::new(512, 1).unwrap();
937        batch.add_sequence(&tokens, 0, false).unwrap();
938        context.decode(&mut batch).unwrap();
939        let mut sampler =
940            LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]);
941        let token = sampler.sample(&context, batch.n_tokens() - 1);
942
943        assert_ne!(token, LlamaToken::new(-1));
944    }
945
946    #[test]
947    fn checked_u32_as_i32_overflow() {
948        let result = super::checked_u32_as_i32(u32::MAX);
949        assert!(result.is_err());
950    }
951
952    #[test]
953    fn checked_usize_as_i32_sampling_overflow() {
954        let result = super::checked_usize_as_i32_sampling(usize::MAX);
955        assert!(result.is_err());
956    }
957
958    #[test]
959    fn check_sampler_accept_status_error() {
960        let result =
961            super::check_sampler_accept_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION);
962        assert!(result.is_err());
963    }
964
965    #[test]
966    fn check_sampler_not_null_returns_error() {
967        let result = super::check_sampler_not_null(std::ptr::null_mut());
968        assert!(result.is_err());
969    }
970}