Skip to main content

llama_cpp_2/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12use crate::{status_is_ok, status_to_i32, GrammarError, SamplerAcceptError};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16    pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
17}
18
19impl Debug for LlamaSampler {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("LlamaSamplerChain").finish()
22    }
23}
24
25impl LlamaSampler {
26    /// Sample and accept a token from the idx-th output of the last evaluation
27    #[must_use]
28    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
29        let token = unsafe {
30            llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
31        };
32
33        LlamaToken(token)
34    }
35
36    /// Applies this sampler to a [`LlamaTokenDataArray`].
37    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
38        data_array.apply_sampler(self);
39    }
40
41    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
42    /// (e.g. grammar, repetition, etc.)
43    pub fn accept(&mut self, token: LlamaToken) {
44        let _ = self.try_accept(token);
45    }
46
47    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
48    /// certain samplers (e.g. grammar, repetition, etc.)
49    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
50        for token in tokens {
51            let _ = self.try_accept(*token.borrow());
52        }
53    }
54
55    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
56    /// certain samplers (e.g. grammar, repetition, etc.)
57    #[must_use]
58    pub fn with_tokens(
59        mut self,
60        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
61    ) -> Self {
62        self.accept_many(tokens);
63        self
64    }
65
66    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
67    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
68        let sampler_result =
69            unsafe { llama_cpp_sys_2::llama_rs_sampler_accept(self.sampler, token.0) };
70        if status_is_ok(sampler_result) {
71            Ok(())
72        } else {
73            Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
74        }
75    }
76
77    /// Resets the internal state of the sampler.
78    ///
79    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
80    pub fn reset(&mut self) {
81        unsafe {
82            llama_cpp_sys_2::llama_sampler_reset(self.sampler);
83        }
84    }
85
86    /// Gets the random seed used by this sampler.
87    ///
88    /// Returns:
89    /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
90    /// - For sampler chains: returns the first non-default seed found in reverse order
91    /// - For all other samplers: returns 0xFFFFFFFF
92    #[must_use]
93    pub fn get_seed(&self) -> u32 {
94        unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
95    }
96
97    /// Combines a list of samplers into a single sampler that applies each component sampler one
98    /// after another.
99    ///
100    /// If you are using a chain to select a token, the chain should always end with one of
101    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
102    /// [`LlamaSampler::mirostat_v2`].
103    #[must_use]
104    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
105        unsafe {
106            let chain = llama_cpp_sys_2::llama_sampler_chain_init(
107                llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
108            );
109
110            for sampler in samplers {
111                llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
112
113                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
114                // owned by the chain
115                std::mem::forget(sampler);
116            }
117
118            Self { sampler: chain }
119        }
120    }
121
122    /// Same as [`Self::chain`] with `no_perf = false`.
123    ///
124    /// # Example
125    /// ```rust
126    /// use llama_cpp_2::token::{
127    ///    LlamaToken,
128    ///    data::LlamaTokenData,
129    ///    data_array::LlamaTokenDataArray
130    /// };
131    /// use llama_cpp_2::sampling::LlamaSampler;
132    /// use llama_cpp_2::llama_backend::LlamaBackend;
133    /// let backend = LlamaBackend::init().unwrap();
134    ///
135    /// let mut data_array = LlamaTokenDataArray::new(vec![
136    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
137    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
138    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
139    /// ], false);
140    ///
141    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
142    ///     LlamaSampler::temp(0.5),
143    ///     LlamaSampler::greedy(),
144    /// ]));
145    ///
146    /// assert_eq!(data_array.data[0].logit(), 0.);
147    /// assert_eq!(data_array.data[1].logit(), 2.);
148    /// assert_eq!(data_array.data[2].logit(), 4.);
149    ///
150    /// assert_eq!(data_array.data.len(), 3);
151    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
152    /// ```
153    #[must_use]
154    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
155        Self::chain(samplers, false)
156    }
157
158    #[allow(clippy::doc_markdown)]
159    /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
160    /// value, the rest are set to -inf
161    ///
162    /// # Example:
163    /// ```rust
164    /// use llama_cpp_2::token::{
165    ///    LlamaToken,
166    ///    data::LlamaTokenData,
167    ///    data_array::LlamaTokenDataArray
168    /// };
169    /// use llama_cpp_2::sampling::LlamaSampler;
170    ///
171    /// let mut data_array = LlamaTokenDataArray::new(vec![
172    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
173    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
174    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
175    /// ], false);
176    ///
177    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
178    ///
179    /// assert_eq!(data_array.data[0].logit(), 0.);
180    /// assert_eq!(data_array.data[1].logit(), 2.);
181    /// assert_eq!(data_array.data[2].logit(), 4.);
182    /// ```
183    #[must_use]
184    pub fn temp(t: f32) -> Self {
185        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
186        Self { sampler }
187    }
188
189    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
190    /// <https://arxiv.org/abs/2309.02772>.
191    #[must_use]
192    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
193        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
194        Self { sampler }
195    }
196
197    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
198    /// <https://arxiv.org/abs/1904.09751>
199    ///
200    /// # Example:
201    /// ```rust
202    /// use llama_cpp_2::token::{
203    ///    LlamaToken,
204    ///    data::LlamaTokenData,
205    ///    data_array::LlamaTokenDataArray
206    /// };
207    /// use llama_cpp_2::sampling::LlamaSampler;
208    ///
209    /// let mut data_array = LlamaTokenDataArray::new(vec![
210    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
211    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
212    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
213    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
214    /// ], false);
215    ///
216    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
217    ///
218    /// assert_eq!(data_array.data.len(), 2);
219    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
220    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
221    /// ```
222    #[must_use]
223    pub fn top_k(k: i32) -> Self {
224        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
225        Self { sampler }
226    }
227
228    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
229    /// <https://arxiv.org/pdf/2411.07641>
230    ///
231    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
232    ///
233    /// # Parameters
234    /// - `n`: Number of standard deviations from the mean to include in sampling
235    ///
236    /// # Example
237    /// ```rust
238    /// use llama_cpp_2::sampling::LlamaSampler;
239    /// use llama_cpp_2::token::{
240    ///     LlamaToken,
241    ///     data::LlamaTokenData,
242    ///     data_array::LlamaTokenDataArray
243    /// };
244    ///
245    /// let mut data_array = LlamaTokenDataArray::new(vec![
246    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
247    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
248    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
249    /// ], false);
250    ///
251    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
252    /// ```
253    #[must_use]
254    pub fn top_n_sigma(n: f32) -> Self {
255        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
256        Self { sampler }
257    }
258
259    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
260    #[must_use]
261    pub fn typical(p: f32, min_keep: usize) -> Self {
262        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
263        Self { sampler }
264    }
265
266    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
267    /// <https://arxiv.org/abs/1904.09751>
268    #[must_use]
269    pub fn top_p(p: f32, min_keep: usize) -> Self {
270        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
271        Self { sampler }
272    }
273
274    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
275    #[must_use]
276    pub fn min_p(p: f32, min_keep: usize) -> Self {
277        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
278        Self { sampler }
279    }
280
281    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
282    #[must_use]
283    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
284        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
285        Self { sampler }
286    }
287
288    /// Grammar sampler
289    #[must_use]
290    pub fn grammar(
291        model: &LlamaModel,
292        grammar_str: &str,
293        grammar_root: &str,
294    ) -> Result<Self, GrammarError> {
295        let (grammar_str, grammar_root) =
296            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
297
298        let sampler = unsafe {
299            llama_cpp_sys_2::llama_rs_sampler_init_grammar(
300                model.vocab_ptr(),
301                grammar_str.as_ptr(),
302                grammar_root.as_ptr(),
303            )
304        };
305
306        if sampler.is_null() {
307            Err(GrammarError::NullGrammar)
308        } else {
309            Ok(Self { sampler })
310        }
311    }
312
313    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
314    ///
315    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
316    #[must_use]
317    pub fn grammar_lazy(
318        model: &LlamaModel,
319        grammar_str: &str,
320        grammar_root: &str,
321        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
322        trigger_tokens: &[LlamaToken],
323    ) -> Result<Self, GrammarError> {
324        let (grammar_str, grammar_root) =
325            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
326        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
327
328        let mut trigger_word_ptrs: Vec<*const c_char> =
329            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
330
331        let sampler = unsafe {
332            llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy(
333                model.vocab_ptr(),
334                grammar_str.as_ptr(),
335                grammar_root.as_ptr(),
336                trigger_word_ptrs.as_mut_ptr(),
337                trigger_word_ptrs.len(),
338                trigger_tokens.as_ptr().cast(),
339                trigger_tokens.len(),
340            )
341        };
342
343        if sampler.is_null() {
344            Err(GrammarError::NullGrammar)
345        } else {
346            Ok(Self { sampler })
347        }
348    }
349
350    /// Lazy grammar sampler using regex trigger patterns.
351    ///
352    /// Trigger patterns are regular expressions matched from the start of the
353    /// generation output. The grammar sampler will be fed content starting from
354    /// the first match group.
355    #[must_use]
356    pub fn grammar_lazy_patterns(
357        model: &LlamaModel,
358        grammar_str: &str,
359        grammar_root: &str,
360        trigger_patterns: &[String],
361        trigger_tokens: &[LlamaToken],
362    ) -> Result<Self, GrammarError> {
363        let (grammar_str, grammar_root) =
364            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
365        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
366
367        let mut trigger_pattern_ptrs: Vec<*const c_char> =
368            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
369
370        let sampler = unsafe {
371            llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy_patterns(
372                model.vocab_ptr(),
373                grammar_str.as_ptr(),
374                grammar_root.as_ptr(),
375                trigger_pattern_ptrs.as_mut_ptr(),
376                trigger_pattern_ptrs.len(),
377                trigger_tokens.as_ptr().cast(),
378                trigger_tokens.len(),
379            )
380        };
381
382        if sampler.is_null() {
383            Err(GrammarError::NullGrammar)
384        } else {
385            Ok(Self { sampler })
386        }
387    }
388
389    /// `LLGuidance` sampler for constrained decoding.
390    ///
391    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
392    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
393    ///
394    /// # Errors
395    ///
396    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
397    #[cfg(feature = "llguidance")]
398    pub fn llguidance(
399        model: &LlamaModel,
400        grammar_kind: &str,
401        grammar_data: &str,
402    ) -> Result<Self, GrammarError> {
403        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
404    }
405
406    fn sanitize_grammar_strings(
407        grammar_str: &str,
408        grammar_root: &str,
409    ) -> Result<(CString, CString), GrammarError> {
410        if !grammar_str.contains(grammar_root) {
411            return Err(GrammarError::RootNotFound);
412        }
413
414        if grammar_str.contains('\0') || grammar_root.contains('\0') {
415            return Err(GrammarError::GrammarNullBytes);
416        }
417
418        Ok((
419            CString::new(grammar_str).unwrap(),
420            CString::new(grammar_root).unwrap(),
421        ))
422    }
423
424    fn sanitize_trigger_words(
425        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
426    ) -> Result<Vec<CString>, GrammarError> {
427        let trigger_words: Vec<_> = trigger_words.into_iter().collect();
428        if trigger_words
429            .iter()
430            .any(|word| word.as_ref().contains(&b'\0'))
431        {
432            return Err(GrammarError::TriggerWordNullBytes);
433        }
434        Ok(trigger_words
435            .into_iter()
436            .map(|word| CString::new(word.as_ref()).unwrap())
437            .collect())
438    }
439
440    fn sanitize_trigger_patterns(
441        trigger_patterns: &[String],
442    ) -> Result<Vec<CString>, GrammarError> {
443        let mut patterns = Vec::with_capacity(trigger_patterns.len());
444        for pattern in trigger_patterns {
445            if pattern.contains('\0') {
446                return Err(GrammarError::GrammarNullBytes);
447            }
448            patterns.push(CString::new(pattern.as_str()).unwrap());
449        }
450        Ok(patterns)
451    }
452
453    /// DRY sampler, designed by p-e-w, as described in:
454    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
455    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
456    ///
457    /// # Panics
458    /// If any string in ``seq_breakers`` contains null bytes.
459    #[allow(missing_docs)]
460    #[must_use]
461    pub fn dry(
462        model: &LlamaModel,
463        multiplier: f32,
464        base: f32,
465        allowed_length: i32,
466        penalty_last_n: i32,
467        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
468    ) -> Self {
469        let seq_breakers: Vec<CString> = seq_breakers
470            .into_iter()
471            .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
472            .collect();
473        let mut seq_breaker_pointers: Vec<*const c_char> =
474            seq_breakers.iter().map(|s| s.as_ptr()).collect();
475
476        let sampler = unsafe {
477            llama_cpp_sys_2::llama_sampler_init_dry(
478                model.vocab_ptr(),
479                model
480                    .n_ctx_train()
481                    .try_into()
482                    .expect("n_ctx_train exceeds i32::MAX"),
483                multiplier,
484                base,
485                allowed_length,
486                penalty_last_n,
487                seq_breaker_pointers.as_mut_ptr(),
488                seq_breaker_pointers.len(),
489            )
490        };
491        Self { sampler }
492    }
493
494    /// Penalizes tokens for being present in the context.
495    ///
496    /// Parameters:
497    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
498    /// - ``penalty_repeat``: 1.0 = disabled
499    /// - ``penalty_freq``: 0.0 = disabled
500    /// - ``penalty_present``: 0.0 = disabled
501    #[allow(clippy::too_many_arguments)]
502    #[must_use]
503    pub fn penalties(
504        penalty_last_n: i32,
505        penalty_repeat: f32,
506        penalty_freq: f32,
507        penalty_present: f32,
508    ) -> Self {
509        let sampler = unsafe {
510            llama_cpp_sys_2::llama_sampler_init_penalties(
511                penalty_last_n,
512                penalty_repeat,
513                penalty_freq,
514                penalty_present,
515            )
516        };
517        Self { sampler }
518    }
519
520    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
521    ///
522    /// # Parameters:
523    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
524    /// - ``seed``: Seed to initialize random generation with.
525    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
526    ///   generated text. A higher value corresponds to more surprising or less predictable text,
527    ///   while a lower value corresponds to less surprising or more predictable text.
528    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
529    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
530    ///   updated more quickly, while a smaller learning rate will result in slower updates.
531    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
532    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
533    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
534    ///   it affects the performance of the algorithm.
535    #[must_use]
536    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
537        let sampler =
538            unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
539        Self { sampler }
540    }
541
542    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
543    ///
544    /// # Parameters:
545    /// - ``seed``: Seed to initialize random generation with.
546    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
547    ///   generated text. A higher value corresponds to more surprising or less predictable text,
548    ///   while a lower value corresponds to less surprising or more predictable text.
549    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
550    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
551    ///   updated more quickly, while a smaller learning rate will result in slower updates.
552    #[must_use]
553    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
554        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
555        Self { sampler }
556    }
557
558    /// Selects a token at random based on each token's probabilities
559    #[must_use]
560    pub fn dist(seed: u32) -> Self {
561        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
562        Self { sampler }
563    }
564
565    /// Selects the most likely token
566    ///
567    /// # Example:
568    /// ```rust
569    /// use llama_cpp_2::token::{
570    ///    LlamaToken,
571    ///    data::LlamaTokenData,
572    ///    data_array::LlamaTokenDataArray
573    /// };
574    /// use llama_cpp_2::sampling::LlamaSampler;
575    ///
576    /// let mut data_array = LlamaTokenDataArray::new(vec![
577    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
578    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
579    /// ], false);
580    ///
581    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
582    ///
583    /// assert_eq!(data_array.data.len(), 2);
584    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
585    /// ```
586    #[must_use]
587    pub fn greedy() -> Self {
588        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
589        Self { sampler }
590    }
591
592    /// Creates a sampler that applies bias values to specific tokens during sampling.
593    ///
594    /// # Parameters
595    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
596    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
597    ///
598    /// # Example
599    /// ```rust
600    /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
601    /// use llama_cpp_2::sampling::LlamaSampler;
602    ///
603    /// let biases = vec![
604    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
605    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
606    /// ];
607    ///
608    /// // Assuming vocab_size of 32000
609    /// let sampler = LlamaSampler::logit_bias(32000, &biases);
610    /// ```
611    #[must_use]
612    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
613        let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
614
615        let sampler = unsafe {
616            llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
617        };
618
619        Self { sampler }
620    }
621}
622
623impl Drop for LlamaSampler {
624    fn drop(&mut self) {
625        unsafe {
626            llama_cpp_sys_2::llama_sampler_free(self.sampler);
627        }
628    }
629}