Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, status_is_ok, status_to_i32};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16    /// Raw pointer to the underlying `llama_sampler`.
17    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
18}
19
20impl Debug for LlamaSampler {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("LlamaSamplerChain").finish()
23    }
24}
25
26impl LlamaSampler {
27    /// Sample and accept a token from the idx-th output of the last evaluation
28    #[must_use]
29    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
30        let token = unsafe {
31            llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
32        };
33
34        LlamaToken(token)
35    }
36
37    /// Applies this sampler to a [`LlamaTokenDataArray`].
38    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
39        data_array.apply_sampler(self);
40    }
41
42    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
43    /// (e.g. grammar, repetition, etc.)
44    pub fn accept(&mut self, token: LlamaToken) {
45        let _ = self.try_accept(token);
46    }
47
48    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
49    /// certain samplers (e.g. grammar, repetition, etc.)
50    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
51        for token in tokens {
52            let _ = self.try_accept(*token.borrow());
53        }
54    }
55
56    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
57    /// certain samplers (e.g. grammar, repetition, etc.)
58    #[must_use]
59    pub fn with_tokens(
60        mut self,
61        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
62    ) -> Self {
63        self.accept_many(tokens);
64        self
65    }
66
67    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
68    ///
69    /// # Errors
70    /// Returns an error if the underlying sampler rejects the token.
71    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
72        let sampler_result =
73            unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
74        if status_is_ok(sampler_result) {
75            Ok(())
76        } else {
77            Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
78        }
79    }
80
81    /// Resets the internal state of the sampler.
82    ///
83    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
84    pub fn reset(&mut self) {
85        unsafe {
86            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
87        }
88    }
89
90    /// Gets the random seed used by this sampler.
91    ///
92    /// Returns:
93    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
94    /// - For sampler chains: returns the first non-default seed found in reverse order
95    /// - For all other samplers: returns 0xFFFFFFFF
96    #[must_use]
97    pub fn get_seed(&self) -> u32 {
98        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
99    }
100
101    /// Combines a list of samplers into a single sampler that applies each component sampler one
102    /// after another.
103    ///
104    /// If you are using a chain to select a token, the chain should always end with one of
105    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
106    /// [`LlamaSampler::mirostat_v2`].
107    #[must_use]
108    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
109        unsafe {
110            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
111                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
112            );
113
114            for sampler in samplers {
115                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
116
117                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
118                // owned by the chain
119                std::mem::forget(sampler);
120            }
121
122            Self { sampler: chain }
123        }
124    }
125
126    /// Same as [`Self::chain`] with `no_perf = false`.
127    ///
128    /// # Example
129    /// ```rust
130    /// use llama_cpp_bindings::token::{
131    ///    LlamaToken,
132    ///    data::LlamaTokenData,
133    ///    data_array::LlamaTokenDataArray
134    /// };
135    /// use llama_cpp_bindings::sampling::LlamaSampler;
136    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
137    /// let backend = LlamaBackend::init().unwrap();
138    ///
139    /// let mut data_array = LlamaTokenDataArray::new(vec![
140    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
141    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
142    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
143    /// ], false);
144    ///
145    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
146    ///     LlamaSampler::temp(0.5),
147    ///     LlamaSampler::greedy(),
148    /// ]));
149    ///
150    /// assert_eq!(data_array.data[0].logit(), 0.);
151    /// assert_eq!(data_array.data[1].logit(), 2.);
152    /// assert_eq!(data_array.data[2].logit(), 4.);
153    ///
154    /// assert_eq!(data_array.data.len(), 3);
155    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
156    /// ```
157    #[must_use]
158    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
159        Self::chain(samplers, false)
160    }
161
162    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
163    /// value, the rest are set to -inf
164    ///
165    /// # Example:
166    /// ```rust
167    /// use llama_cpp_bindings::token::{
168    ///    LlamaToken,
169    ///    data::LlamaTokenData,
170    ///    data_array::LlamaTokenDataArray
171    /// };
172    /// use llama_cpp_bindings::sampling::LlamaSampler;
173    ///
174    /// let mut data_array = LlamaTokenDataArray::new(vec![
175    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
176    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
177    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
178    /// ], false);
179    ///
180    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
181    ///
182    /// assert_eq!(data_array.data[0].logit(), 0.);
183    /// assert_eq!(data_array.data[1].logit(), 2.);
184    /// assert_eq!(data_array.data[2].logit(), 4.);
185    /// ```
186    #[must_use]
187    pub fn temp(t: f32) -> Self {
188        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
189        Self { sampler }
190    }
191
192    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
193    /// <https://arxiv.org/abs/2309.02772>.
194    #[must_use]
195    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
196        let sampler =
197            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
198        Self { sampler }
199    }
200
201    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
202    /// <https://arxiv.org/abs/1904.09751>
203    ///
204    /// # Example:
205    /// ```rust
206    /// use llama_cpp_bindings::token::{
207    ///    LlamaToken,
208    ///    data::LlamaTokenData,
209    ///    data_array::LlamaTokenDataArray
210    /// };
211    /// use llama_cpp_bindings::sampling::LlamaSampler;
212    ///
213    /// let mut data_array = LlamaTokenDataArray::new(vec![
214    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
215    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
216    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
217    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
218    /// ], false);
219    ///
220    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
221    ///
222    /// assert_eq!(data_array.data.len(), 2);
223    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
224    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
225    /// ```
226    #[must_use]
227    pub fn top_k(k: i32) -> Self {
228        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
229        Self { sampler }
230    }
231
232    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
233    /// <https://arxiv.org/pdf/2411.07641>
234    ///
235    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
236    ///
237    /// # Parameters
238    /// - `n`: Number of standard deviations from the mean to include in sampling
239    ///
240    /// # Example
241    /// ```rust
242    /// use llama_cpp_bindings::sampling::LlamaSampler;
243    /// use llama_cpp_bindings::token::{
244    ///     LlamaToken,
245    ///     data::LlamaTokenData,
246    ///     data_array::LlamaTokenDataArray
247    /// };
248    ///
249    /// let mut data_array = LlamaTokenDataArray::new(vec![
250    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
251    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
252    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
253    /// ], false);
254    ///
255    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
256    /// ```
257    #[must_use]
258    pub fn top_n_sigma(n: f32) -> Self {
259        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
260        Self { sampler }
261    }
262
263    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
264    #[must_use]
265    pub fn typical(p: f32, min_keep: usize) -> Self {
266        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
267        Self { sampler }
268    }
269
270    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
271    /// <https://arxiv.org/abs/1904.09751>
272    #[must_use]
273    pub fn top_p(p: f32, min_keep: usize) -> Self {
274        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
275        Self { sampler }
276    }
277
278    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
279    #[must_use]
280    pub fn min_p(p: f32, min_keep: usize) -> Self {
281        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
282        Self { sampler }
283    }
284
285    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
286    #[must_use]
287    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
288        let sampler =
289            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
290        Self { sampler }
291    }
292
293    /// Grammar sampler
294    ///
295    /// # Errors
296    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
297    pub fn grammar(
298        model: &LlamaModel,
299        grammar_str: &str,
300        grammar_root: &str,
301    ) -> Result<Self, GrammarError> {
302        let (grammar_str, grammar_root) =
303            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
304
305        let sampler = unsafe {
306            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
307                model.vocab_ptr(),
308                grammar_str.as_ptr(),
309                grammar_root.as_ptr(),
310            )
311        };
312
313        if sampler.is_null() {
314            Err(GrammarError::NullGrammar)
315        } else {
316            Ok(Self { sampler })
317        }
318    }
319
320    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
321    ///
322    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
323    ///
324    /// # Errors
325    /// Returns an error if the grammar or trigger words are invalid.
326    pub fn grammar_lazy(
327        model: &LlamaModel,
328        grammar_str: &str,
329        grammar_root: &str,
330        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
331        trigger_tokens: &[LlamaToken],
332    ) -> Result<Self, GrammarError> {
333        let (grammar_str, grammar_root) =
334            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
335        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
336
337        let mut trigger_word_ptrs: Vec<*const c_char> =
338            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
339
340        let sampler = unsafe {
341            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
342                model.vocab_ptr(),
343                grammar_str.as_ptr(),
344                grammar_root.as_ptr(),
345                trigger_word_ptrs.as_mut_ptr(),
346                trigger_word_ptrs.len(),
347                trigger_tokens.as_ptr().cast(),
348                trigger_tokens.len(),
349            )
350        };
351
352        if sampler.is_null() {
353            Err(GrammarError::NullGrammar)
354        } else {
355            Ok(Self { sampler })
356        }
357    }
358
359    /// Lazy grammar sampler using regex trigger patterns.
360    ///
361    /// Trigger patterns are regular expressions matched from the start of the
362    /// generation output. The grammar sampler will be fed content starting from
363    /// the first match group.
364    ///
365    /// # Errors
366    /// Returns an error if the grammar or trigger patterns are invalid.
367    pub fn grammar_lazy_patterns(
368        model: &LlamaModel,
369        grammar_str: &str,
370        grammar_root: &str,
371        trigger_patterns: &[String],
372        trigger_tokens: &[LlamaToken],
373    ) -> Result<Self, GrammarError> {
374        let (grammar_str, grammar_root) =
375            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
376        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
377
378        let mut trigger_pattern_ptrs: Vec<*const c_char> =
379            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
380
381        let sampler = unsafe {
382            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
383                model.vocab_ptr(),
384                grammar_str.as_ptr(),
385                grammar_root.as_ptr(),
386                trigger_pattern_ptrs.as_mut_ptr(),
387                trigger_pattern_ptrs.len(),
388                trigger_tokens.as_ptr().cast(),
389                trigger_tokens.len(),
390            )
391        };
392
393        if sampler.is_null() {
394            Err(GrammarError::NullGrammar)
395        } else {
396            Ok(Self { sampler })
397        }
398    }
399
400    /// `LLGuidance` sampler for constrained decoding.
401    ///
402    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
403    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
404    ///
405    /// # Errors
406    ///
407    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
408    #[cfg(feature = "llguidance")]
409    pub fn llguidance(
410        model: &LlamaModel,
411        grammar_kind: &str,
412        grammar_data: &str,
413    ) -> Result<Self, GrammarError> {
414        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
415    }
416
417    fn sanitize_grammar_strings(
418        grammar_str: &str,
419        grammar_root: &str,
420    ) -> Result<(CString, CString), GrammarError> {
421        if !grammar_str.contains(grammar_root) {
422            return Err(GrammarError::RootNotFound);
423        }
424
425        if grammar_str.contains('\0') || grammar_root.contains('\0') {
426            return Err(GrammarError::GrammarNullBytes);
427        }
428
429        Ok((
430            CString::new(grammar_str).unwrap(),
431            CString::new(grammar_root).unwrap(),
432        ))
433    }
434
435    fn sanitize_trigger_words(
436        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
437    ) -> Result<Vec<CString>, GrammarError> {
438        let trigger_words: Vec<_> = trigger_words.into_iter().collect();
439        if trigger_words
440            .iter()
441            .any(|word| word.as_ref().contains(&b'\0'))
442        {
443            return Err(GrammarError::TriggerWordNullBytes);
444        }
445        Ok(trigger_words
446            .into_iter()
447            .map(|word| CString::new(word.as_ref()).unwrap())
448            .collect())
449    }
450
451    fn sanitize_trigger_patterns(
452        trigger_patterns: &[String],
453    ) -> Result<Vec<CString>, GrammarError> {
454        let mut patterns = Vec::with_capacity(trigger_patterns.len());
455        for pattern in trigger_patterns {
456            if pattern.contains('\0') {
457                return Err(GrammarError::GrammarNullBytes);
458            }
459            patterns.push(CString::new(pattern.as_str()).unwrap());
460        }
461        Ok(patterns)
462    }
463
464    /// DRY sampler, designed by p-e-w, as described in:
465    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
466    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
467    ///
468    /// # Panics
469    /// If any string in ``seq_breakers`` contains null bytes.
470    #[allow(missing_docs)]
471    #[must_use]
472    pub fn dry(
473        model: &LlamaModel,
474        multiplier: f32,
475        base: f32,
476        allowed_length: i32,
477        penalty_last_n: i32,
478        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
479    ) -> Self {
480        let seq_breakers: Vec<CString> = seq_breakers
481            .into_iter()
482            .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
483            .collect();
484        let mut seq_breaker_pointers: Vec<*const c_char> =
485            seq_breakers.iter().map(|s| s.as_ptr()).collect();
486
487        let sampler = unsafe {
488            llama_cpp_bindings_sys::llama_sampler_init_dry(
489                model.vocab_ptr(),
490                model
491                    .n_ctx_train()
492                    .try_into()
493                    .expect("n_ctx_train exceeds i32::MAX"),
494                multiplier,
495                base,
496                allowed_length,
497                penalty_last_n,
498                seq_breaker_pointers.as_mut_ptr(),
499                seq_breaker_pointers.len(),
500            )
501        };
502        Self { sampler }
503    }
504
505    /// Penalizes tokens for being present in the context.
506    ///
507    /// Parameters:
508    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
509    /// - ``penalty_repeat``: 1.0 = disabled
510    /// - ``penalty_freq``: 0.0 = disabled
511    /// - ``penalty_present``: 0.0 = disabled
512    #[must_use]
513    pub fn penalties(
514        penalty_last_n: i32,
515        penalty_repeat: f32,
516        penalty_freq: f32,
517        penalty_present: f32,
518    ) -> Self {
519        let sampler = unsafe {
520            llama_cpp_bindings_sys::llama_sampler_init_penalties(
521                penalty_last_n,
522                penalty_repeat,
523                penalty_freq,
524                penalty_present,
525            )
526        };
527        Self { sampler }
528    }
529
530    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
531    ///
532    /// # Parameters:
533    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
534    /// - ``seed``: Seed to initialize random generation with.
535    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
536    ///   generated text. A higher value corresponds to more surprising or less predictable text,
537    ///   while a lower value corresponds to less surprising or more predictable text.
538    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
539    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
540    ///   updated more quickly, while a smaller learning rate will result in slower updates.
541    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
542    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
543    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
544    ///   it affects the performance of the algorithm.
545    #[must_use]
546    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
547        let sampler = unsafe {
548            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
549        };
550        Self { sampler }
551    }
552
553    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
554    ///
555    /// # Parameters:
556    /// - ``seed``: Seed to initialize random generation with.
557    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
558    ///   generated text. A higher value corresponds to more surprising or less predictable text,
559    ///   while a lower value corresponds to less surprising or more predictable text.
560    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
561    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
562    ///   updated more quickly, while a smaller learning rate will result in slower updates.
563    #[must_use]
564    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
565        let sampler =
566            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
567        Self { sampler }
568    }
569
570    /// Selects a token at random based on each token's probabilities
571    #[must_use]
572    pub fn dist(seed: u32) -> Self {
573        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
574        Self { sampler }
575    }
576
577    /// Selects the most likely token
578    ///
579    /// # Example:
580    /// ```rust
581    /// use llama_cpp_bindings::token::{
582    ///    LlamaToken,
583    ///    data::LlamaTokenData,
584    ///    data_array::LlamaTokenDataArray
585    /// };
586    /// use llama_cpp_bindings::sampling::LlamaSampler;
587    ///
588    /// let mut data_array = LlamaTokenDataArray::new(vec![
589    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
590    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
591    /// ], false);
592    ///
593    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
594    ///
595    /// assert_eq!(data_array.data.len(), 2);
596    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
597    /// ```
598    #[must_use]
599    pub fn greedy() -> Self {
600        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
601        Self { sampler }
602    }
603
604    /// Creates a sampler that applies bias values to specific tokens during sampling.
605    ///
606    /// # Parameters
607    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
608    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
609    ///
610    /// # Panics
611    ///
612    /// Panics if `biases.len()` exceeds `i32::MAX`.
613    ///
614    /// # Example
615    /// ```rust
616    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
617    /// use llama_cpp_bindings::sampling::LlamaSampler;
618    ///
619    /// let biases = vec![
620    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
621    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
622    /// ];
623    ///
624    /// // Assuming vocab_size of 32000
625    /// let sampler = LlamaSampler::logit_bias(32000, &biases);
626    /// ```
627    #[must_use]
628    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
629        let data = biases
630            .as_ptr()
631            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
632
633        let sampler = unsafe {
634            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(
635                n_vocab,
636                i32::try_from(biases.len()).expect("bias count exceeds i32::MAX"),
637                data,
638            )
639        };
640
641        Self { sampler }
642    }
643}
644
645impl Drop for LlamaSampler {
646    fn drop(&mut self) {
647        unsafe {
648            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
649        }
650    }
651}