Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16    /// Raw pointer to the underlying `llama_sampler`.
17    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
18}
19
20impl Debug for LlamaSampler {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("LlamaSamplerChain").finish()
23    }
24}
25
26impl LlamaSampler {
27    /// Sample and accept a token from the idx-th output of the last evaluation
28    #[must_use]
29    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
30        let token = unsafe {
31            llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
32        };
33
34        LlamaToken(token)
35    }
36
37    /// Applies this sampler to a [`LlamaTokenDataArray`].
38    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
39        data_array.apply_sampler(self);
40    }
41
42    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
43    /// (e.g. grammar, repetition, etc.)
44    ///
45    /// # Errors
46    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
47    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
48        self.try_accept(token)
49    }
50
51    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
52    /// certain samplers (e.g. grammar, repetition, etc.)
53    ///
54    /// # Errors
55    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
56    pub fn accept_many(
57        &mut self,
58        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
59    ) -> Result<(), SamplerAcceptError> {
60        for token in tokens {
61            self.try_accept(*token.borrow())?;
62        }
63
64        Ok(())
65    }
66
67    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
68    /// certain samplers (e.g. grammar, repetition, etc.)
69    ///
70    /// # Errors
71    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
72    pub fn with_tokens(
73        mut self,
74        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
75    ) -> Result<Self, SamplerAcceptError> {
76        self.accept_many(tokens)?;
77
78        Ok(self)
79    }
80
81    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
82    ///
83    /// # Errors
84    /// Returns an error if the underlying sampler rejects the token.
85    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
86        let sampler_result =
87            unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
88        if status_is_ok(sampler_result) {
89            Ok(())
90        } else {
91            Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
92        }
93    }
94
95    /// Resets the internal state of the sampler.
96    ///
97    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
98    pub fn reset(&mut self) {
99        unsafe {
100            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
101        }
102    }
103
104    /// Gets the random seed used by this sampler.
105    ///
106    /// Returns:
107    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
108    /// - For sampler chains: returns the first non-default seed found in reverse order
109    /// - For all other samplers: returns 0xFFFFFFFF
110    #[must_use]
111    pub fn get_seed(&self) -> u32 {
112        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
113    }
114
115    /// Combines a list of samplers into a single sampler that applies each component sampler one
116    /// after another.
117    ///
118    /// If you are using a chain to select a token, the chain should always end with one of
119    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
120    /// [`LlamaSampler::mirostat_v2`].
121    #[must_use]
122    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
123        unsafe {
124            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
125                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
126            );
127
128            for sampler in samplers {
129                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
130
131                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
132                // owned by the chain
133                std::mem::forget(sampler);
134            }
135
136            Self { sampler: chain }
137        }
138    }
139
140    /// Same as [`Self::chain`] with `no_perf = false`.
141    ///
142    /// # Example
143    /// ```rust
144    /// use llama_cpp_bindings::token::{
145    ///    LlamaToken,
146    ///    data::LlamaTokenData,
147    ///    data_array::LlamaTokenDataArray
148    /// };
149    /// use llama_cpp_bindings::sampling::LlamaSampler;
150    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
151    /// let backend = LlamaBackend::init().unwrap();
152    ///
153    /// let mut data_array = LlamaTokenDataArray::new(vec![
154    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
155    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
156    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
157    /// ], false);
158    ///
159    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
160    ///     LlamaSampler::temp(0.5),
161    ///     LlamaSampler::greedy(),
162    /// ]));
163    ///
164    /// assert_eq!(data_array.data[0].logit(), 0.);
165    /// assert_eq!(data_array.data[1].logit(), 2.);
166    /// assert_eq!(data_array.data[2].logit(), 4.);
167    ///
168    /// assert_eq!(data_array.data.len(), 3);
169    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
170    /// ```
171    #[must_use]
172    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
173        Self::chain(samplers, false)
174    }
175
176    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
177    /// value, the rest are set to -inf
178    ///
179    /// # Example:
180    /// ```rust
181    /// use llama_cpp_bindings::token::{
182    ///    LlamaToken,
183    ///    data::LlamaTokenData,
184    ///    data_array::LlamaTokenDataArray
185    /// };
186    /// use llama_cpp_bindings::sampling::LlamaSampler;
187    ///
188    /// let mut data_array = LlamaTokenDataArray::new(vec![
189    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
190    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
191    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
192    /// ], false);
193    ///
194    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
195    ///
196    /// assert_eq!(data_array.data[0].logit(), 0.);
197    /// assert_eq!(data_array.data[1].logit(), 2.);
198    /// assert_eq!(data_array.data[2].logit(), 4.);
199    /// ```
200    #[must_use]
201    pub fn temp(t: f32) -> Self {
202        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
203        Self { sampler }
204    }
205
206    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
207    /// <https://arxiv.org/abs/2309.02772>.
208    #[must_use]
209    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
210        let sampler =
211            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
212        Self { sampler }
213    }
214
215    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
216    /// <https://arxiv.org/abs/1904.09751>
217    ///
218    /// # Example:
219    /// ```rust
220    /// use llama_cpp_bindings::token::{
221    ///    LlamaToken,
222    ///    data::LlamaTokenData,
223    ///    data_array::LlamaTokenDataArray
224    /// };
225    /// use llama_cpp_bindings::sampling::LlamaSampler;
226    ///
227    /// let mut data_array = LlamaTokenDataArray::new(vec![
228    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
229    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
230    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
231    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
232    /// ], false);
233    ///
234    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
235    ///
236    /// assert_eq!(data_array.data.len(), 2);
237    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
238    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
239    /// ```
240    #[must_use]
241    pub fn top_k(k: i32) -> Self {
242        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
243        Self { sampler }
244    }
245
246    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
247    /// <https://arxiv.org/pdf/2411.07641>
248    ///
249    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
250    ///
251    /// # Parameters
252    /// - `n`: Number of standard deviations from the mean to include in sampling
253    ///
254    /// # Example
255    /// ```rust
256    /// use llama_cpp_bindings::sampling::LlamaSampler;
257    /// use llama_cpp_bindings::token::{
258    ///     LlamaToken,
259    ///     data::LlamaTokenData,
260    ///     data_array::LlamaTokenDataArray
261    /// };
262    ///
263    /// let mut data_array = LlamaTokenDataArray::new(vec![
264    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
265    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
266    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
267    /// ], false);
268    ///
269    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
270    /// ```
271    #[must_use]
272    pub fn top_n_sigma(n: f32) -> Self {
273        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
274        Self { sampler }
275    }
276
277    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
278    #[must_use]
279    pub fn typical(p: f32, min_keep: usize) -> Self {
280        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
281        Self { sampler }
282    }
283
284    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
285    /// <https://arxiv.org/abs/1904.09751>
286    #[must_use]
287    pub fn top_p(p: f32, min_keep: usize) -> Self {
288        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
289        Self { sampler }
290    }
291
292    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
293    #[must_use]
294    pub fn min_p(p: f32, min_keep: usize) -> Self {
295        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
296        Self { sampler }
297    }
298
299    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
300    #[must_use]
301    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
302        let sampler =
303            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
304        Self { sampler }
305    }
306
307    /// Grammar sampler
308    ///
309    /// # Errors
310    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
311    pub fn grammar(
312        model: &LlamaModel,
313        grammar_str: &str,
314        grammar_root: &str,
315    ) -> Result<Self, GrammarError> {
316        let (grammar_str, grammar_root) =
317            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
318
319        let sampler = unsafe {
320            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
321                model.vocab_ptr(),
322                grammar_str.as_ptr(),
323                grammar_root.as_ptr(),
324            )
325        };
326
327        if sampler.is_null() {
328            Err(GrammarError::NullGrammar)
329        } else {
330            Ok(Self { sampler })
331        }
332    }
333
334    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
335    ///
336    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
337    ///
338    /// # Errors
339    /// Returns an error if the grammar or trigger words are invalid.
340    pub fn grammar_lazy(
341        model: &LlamaModel,
342        grammar_str: &str,
343        grammar_root: &str,
344        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
345        trigger_tokens: &[LlamaToken],
346    ) -> Result<Self, GrammarError> {
347        let (grammar_str, grammar_root) =
348            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
349        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
350
351        let mut trigger_word_ptrs: Vec<*const c_char> =
352            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
353
354        let sampler = unsafe {
355            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
356                model.vocab_ptr(),
357                grammar_str.as_ptr(),
358                grammar_root.as_ptr(),
359                trigger_word_ptrs.as_mut_ptr(),
360                trigger_word_ptrs.len(),
361                trigger_tokens.as_ptr().cast(),
362                trigger_tokens.len(),
363            )
364        };
365
366        if sampler.is_null() {
367            Err(GrammarError::NullGrammar)
368        } else {
369            Ok(Self { sampler })
370        }
371    }
372
373    /// Lazy grammar sampler using regex trigger patterns.
374    ///
375    /// Trigger patterns are regular expressions matched from the start of the
376    /// generation output. The grammar sampler will be fed content starting from
377    /// the first match group.
378    ///
379    /// # Errors
380    /// Returns an error if the grammar or trigger patterns are invalid.
381    pub fn grammar_lazy_patterns(
382        model: &LlamaModel,
383        grammar_str: &str,
384        grammar_root: &str,
385        trigger_patterns: &[String],
386        trigger_tokens: &[LlamaToken],
387    ) -> Result<Self, GrammarError> {
388        let (grammar_str, grammar_root) =
389            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
390        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
391
392        let mut trigger_pattern_ptrs: Vec<*const c_char> =
393            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
394
395        let sampler = unsafe {
396            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
397                model.vocab_ptr(),
398                grammar_str.as_ptr(),
399                grammar_root.as_ptr(),
400                trigger_pattern_ptrs.as_mut_ptr(),
401                trigger_pattern_ptrs.len(),
402                trigger_tokens.as_ptr().cast(),
403                trigger_tokens.len(),
404            )
405        };
406
407        if sampler.is_null() {
408            Err(GrammarError::NullGrammar)
409        } else {
410            Ok(Self { sampler })
411        }
412    }
413
414    /// `LLGuidance` sampler for constrained decoding.
415    ///
416    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
417    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
418    ///
419    /// # Errors
420    ///
421    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
422    #[cfg(feature = "llguidance")]
423    pub fn llguidance(
424        model: &LlamaModel,
425        grammar_kind: &str,
426        grammar_data: &str,
427    ) -> Result<Self, GrammarError> {
428        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
429    }
430
431    fn sanitize_grammar_strings(
432        grammar_str: &str,
433        grammar_root: &str,
434    ) -> Result<(CString, CString), GrammarError> {
435        if !grammar_str.contains(grammar_root) {
436            return Err(GrammarError::RootNotFound);
437        }
438
439        if grammar_str.contains('\0') || grammar_root.contains('\0') {
440            return Err(GrammarError::GrammarNullBytes);
441        }
442
443        Ok((CString::new(grammar_str)?, CString::new(grammar_root)?))
444    }
445
446    fn sanitize_trigger_words(
447        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
448    ) -> Result<Vec<CString>, GrammarError> {
449        let trigger_words: Vec<_> = trigger_words.into_iter().collect();
450        if trigger_words
451            .iter()
452            .any(|word| word.as_ref().contains(&b'\0'))
453        {
454            return Err(GrammarError::TriggerWordNullBytes);
455        }
456        trigger_words
457            .into_iter()
458            .map(|word| Ok(CString::new(word.as_ref())?))
459            .collect()
460    }
461
462    fn sanitize_trigger_patterns(
463        trigger_patterns: &[String],
464    ) -> Result<Vec<CString>, GrammarError> {
465        let mut patterns = Vec::with_capacity(trigger_patterns.len());
466        for pattern in trigger_patterns {
467            if pattern.contains('\0') {
468                return Err(GrammarError::GrammarNullBytes);
469            }
470            patterns.push(CString::new(pattern.as_str())?);
471        }
472        Ok(patterns)
473    }
474
475    /// DRY sampler, designed by p-e-w, as described in:
476    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
477    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
478    ///
479    /// # Errors
480    /// Returns an error if any string in `seq_breakers` contains null bytes.
481    #[allow(missing_docs)]
482    pub fn dry(
483        model: &LlamaModel,
484        multiplier: f32,
485        base: f32,
486        allowed_length: i32,
487        penalty_last_n: i32,
488        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
489    ) -> Result<Self, GrammarError> {
490        let seq_breakers: Vec<CString> = seq_breakers
491            .into_iter()
492            .map(|s| CString::new(s.as_ref()))
493            .collect::<Result<Vec<_>, _>>()?;
494        let mut seq_breaker_pointers: Vec<*const c_char> =
495            seq_breakers.iter().map(|s| s.as_ptr()).collect();
496
497        let n_ctx_train = model.n_ctx_train().try_into().map_err(|convert_error| {
498            GrammarError::IntegerOverflow(format!("n_ctx_train exceeds i32::MAX: {convert_error}"))
499        })?;
500        let sampler = unsafe {
501            llama_cpp_bindings_sys::llama_sampler_init_dry(
502                model.vocab_ptr(),
503                n_ctx_train,
504                multiplier,
505                base,
506                allowed_length,
507                penalty_last_n,
508                seq_breaker_pointers.as_mut_ptr(),
509                seq_breaker_pointers.len(),
510            )
511        };
512
513        Ok(Self { sampler })
514    }
515
516    /// Penalizes tokens for being present in the context.
517    ///
518    /// Parameters:
519    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
520    /// - ``penalty_repeat``: 1.0 = disabled
521    /// - ``penalty_freq``: 0.0 = disabled
522    /// - ``penalty_present``: 0.0 = disabled
523    #[must_use]
524    pub fn penalties(
525        penalty_last_n: i32,
526        penalty_repeat: f32,
527        penalty_freq: f32,
528        penalty_present: f32,
529    ) -> Self {
530        let sampler = unsafe {
531            llama_cpp_bindings_sys::llama_sampler_init_penalties(
532                penalty_last_n,
533                penalty_repeat,
534                penalty_freq,
535                penalty_present,
536            )
537        };
538        Self { sampler }
539    }
540
541    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
542    ///
543    /// # Parameters:
544    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
545    /// - ``seed``: Seed to initialize random generation with.
546    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
547    ///   generated text. A higher value corresponds to more surprising or less predictable text,
548    ///   while a lower value corresponds to less surprising or more predictable text.
549    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
550    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
551    ///   updated more quickly, while a smaller learning rate will result in slower updates.
552    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
553    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
554    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
555    ///   it affects the performance of the algorithm.
556    #[must_use]
557    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
558        let sampler = unsafe {
559            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
560        };
561        Self { sampler }
562    }
563
564    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
565    ///
566    /// # Parameters:
567    /// - ``seed``: Seed to initialize random generation with.
568    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
569    ///   generated text. A higher value corresponds to more surprising or less predictable text,
570    ///   while a lower value corresponds to less surprising or more predictable text.
571    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
572    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
573    ///   updated more quickly, while a smaller learning rate will result in slower updates.
574    #[must_use]
575    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
576        let sampler =
577            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
578        Self { sampler }
579    }
580
581    /// Selects a token at random based on each token's probabilities
582    #[must_use]
583    pub fn dist(seed: u32) -> Self {
584        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
585        Self { sampler }
586    }
587
588    /// Selects the most likely token
589    ///
590    /// # Example:
591    /// ```rust
592    /// use llama_cpp_bindings::token::{
593    ///    LlamaToken,
594    ///    data::LlamaTokenData,
595    ///    data_array::LlamaTokenDataArray
596    /// };
597    /// use llama_cpp_bindings::sampling::LlamaSampler;
598    ///
599    /// let mut data_array = LlamaTokenDataArray::new(vec![
600    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
601    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
602    /// ], false);
603    ///
604    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
605    ///
606    /// assert_eq!(data_array.data.len(), 2);
607    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
608    /// ```
609    #[must_use]
610    pub fn greedy() -> Self {
611        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
612        Self { sampler }
613    }
614
615    /// Creates a sampler that applies bias values to specific tokens during sampling.
616    ///
617    /// # Parameters
618    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
619    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
620    ///
621    /// # Errors
622    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
623    ///
624    /// # Example
625    /// ```rust
626    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
627    /// use llama_cpp_bindings::sampling::LlamaSampler;
628    ///
629    /// let biases = vec![
630    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
631    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
632    /// ];
633    ///
634    /// // Assuming vocab_size of 32000
635    /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
636    /// ```
637    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
638        let bias_count = i32::try_from(biases.len()).map_err(|convert_error| {
639            SamplingError::IntegerOverflow(format!("bias count exceeds i32::MAX: {convert_error}"))
640        })?;
641        let data = biases
642            .as_ptr()
643            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
644
645        let sampler = unsafe {
646            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
647        };
648
649        Ok(Self { sampler })
650    }
651}
652
653impl Drop for LlamaSampler {
654    fn drop(&mut self) {
655        unsafe {
656            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
657        }
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::LlamaSampler;
664    use crate::GrammarError;
665
666    #[test]
667    fn sanitize_grammar_strings_valid() {
668        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
669
670        assert!(result.is_ok());
671    }
672
673    #[test]
674    fn sanitize_grammar_strings_root_not_found() {
675        let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
676
677        assert_eq!(result.err(), Some(GrammarError::RootNotFound));
678    }
679
680    #[test]
681    fn sanitize_grammar_strings_null_byte_in_grammar() {
682        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
683
684        assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
685    }
686
687    #[test]
688    fn sanitize_grammar_strings_null_byte_in_root() {
689        let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
690
691        assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
692    }
693
694    #[test]
695    fn sanitize_trigger_words_valid() {
696        let words: Vec<&[u8]> = vec![b"hello", b"world"];
697        let result = LlamaSampler::sanitize_trigger_words(words);
698
699        assert!(result.is_ok());
700        assert_eq!(result.expect("valid trigger words").len(), 2);
701    }
702
703    #[test]
704    fn sanitize_trigger_words_empty_list() {
705        let words: Vec<&[u8]> = vec![];
706        let result = LlamaSampler::sanitize_trigger_words(words);
707
708        assert!(result.is_ok());
709        assert!(result.expect("valid trigger words").is_empty());
710    }
711
712    #[test]
713    fn sanitize_trigger_words_null_byte() {
714        let words: Vec<&[u8]> = vec![b"hel\0lo"];
715        let result = LlamaSampler::sanitize_trigger_words(words);
716
717        assert_eq!(result.err(), Some(GrammarError::TriggerWordNullBytes));
718    }
719
720    #[test]
721    fn sanitize_trigger_patterns_valid() {
722        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
723        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
724
725        assert!(result.is_ok());
726        assert_eq!(result.expect("valid trigger patterns").len(), 2);
727    }
728
729    #[test]
730    fn sanitize_trigger_patterns_empty_list() {
731        let patterns: Vec<String> = vec![];
732        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
733
734        assert!(result.is_ok());
735        assert!(result.expect("valid trigger patterns").is_empty());
736    }
737
738    #[test]
739    fn sanitize_trigger_patterns_null_byte() {
740        let patterns = vec!["hel\0lo".to_string()];
741        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
742
743        assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
744    }
745
746    #[test]
747    fn apply_modifies_data_array() {
748        use crate::token::LlamaToken;
749        use crate::token::data::LlamaTokenData;
750        use crate::token::data_array::LlamaTokenDataArray;
751
752        let sampler = LlamaSampler::greedy();
753        let mut data_array = LlamaTokenDataArray::new(
754            vec![
755                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
756                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
757            ],
758            false,
759        );
760
761        sampler.apply(&mut data_array);
762
763        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
764    }
765
766    #[test]
767    fn accept_succeeds() {
768        let mut sampler = LlamaSampler::chain_simple([
769            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
770            LlamaSampler::greedy(),
771        ]);
772
773        sampler
774            .accept(crate::token::LlamaToken::new(1))
775            .expect("test: accept should succeed");
776    }
777
778    #[test]
779    fn try_accept_succeeds_on_penalties_sampler() {
780        let mut sampler = LlamaSampler::chain_simple([
781            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
782            LlamaSampler::greedy(),
783        ]);
784
785        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
786
787        assert!(result.is_ok());
788    }
789
790    #[test]
791    fn accept_many_multiple_tokens() {
792        use crate::token::LlamaToken;
793
794        let mut sampler = LlamaSampler::chain_simple([
795            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
796            LlamaSampler::greedy(),
797        ]);
798
799        sampler
800            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
801            .expect("test: accept_many should succeed");
802    }
803
804    #[test]
805    fn with_tokens_builder_pattern() {
806        use crate::token::LlamaToken;
807
808        let _sampler = LlamaSampler::chain_simple([
809            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
810            LlamaSampler::greedy(),
811        ])
812        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
813        .expect("test: with_tokens should succeed");
814    }
815}