Skip to main content

llama_cpp_2/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9#[cfg(feature = "common")]
10use crate::status_is_ok;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::token::LlamaToken;
14use crate::{GrammarError, SamplerAcceptError};
15
16/// A safe wrapper around `llama_sampler`.
17pub struct LlamaSampler {
18    pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
19}
20
21impl Debug for LlamaSampler {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("LlamaSamplerChain").finish()
24    }
25}
26
27impl LlamaSampler {
28    /// Sample and accept a token from the idx-th output of the last evaluation
29    #[must_use]
30    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
31        let token = unsafe {
32            llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
33        };
34
35        LlamaToken(token)
36    }
37
38    /// Applies this sampler to a [`LlamaTokenDataArray`].
39    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
40        data_array.apply_sampler(self);
41    }
42
43    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
44    /// (e.g. grammar, repetition, etc.)
45    pub fn accept(&mut self, token: LlamaToken) {
46        #[cfg(feature = "common")]
47        {
48            let _ = self.try_accept(token);
49        }
50        #[cfg(not(feature = "common"))]
51        unsafe {
52            llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0);
53        }
54    }
55
56    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
57    /// certain samplers (e.g. grammar, repetition, etc.)
58    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
59        for token in tokens {
60            self.accept(*token.borrow());
61        }
62    }
63
64    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
65    /// certain samplers (e.g. grammar, repetition, etc.)
66    #[must_use]
67    pub fn with_tokens(
68        mut self,
69        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
70    ) -> Self {
71        self.accept_many(tokens);
72        self
73    }
74
75    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
76    #[cfg(feature = "common")]
77    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
78        let sampler_result =
79            unsafe { llama_cpp_sys_2::llama_rs_sampler_accept(self.sampler, token.0) };
80        if status_is_ok(sampler_result) {
81            Ok(())
82        } else {
83            Err(SamplerAcceptError::FfiError(sampler_result))
84        }
85    }
86
87    /// Resets the internal state of the sampler.
88    ///
89    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
90    pub fn reset(&mut self) {
91        unsafe {
92            llama_cpp_sys_2::llama_sampler_reset(self.sampler);
93        }
94    }
95
96    /// Gets the random seed used by this sampler.
97    ///
98    /// Returns:
99    /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
100    /// - For sampler chains: returns the first non-default seed found in reverse order
101    /// - For all other samplers: returns 0xFFFFFFFF
102    #[must_use]
103    pub fn get_seed(&self) -> u32 {
104        unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
105    }
106
107    /// Combines a list of samplers into a single sampler that applies each component sampler one
108    /// after another.
109    ///
110    /// If you are using a chain to select a token, the chain should always end with one of
111    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
112    /// [`LlamaSampler::mirostat_v2`].
113    #[must_use]
114    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
115        unsafe {
116            let chain = llama_cpp_sys_2::llama_sampler_chain_init(
117                llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
118            );
119
120            for sampler in samplers {
121                llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
122
123                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
124                // owned by the chain
125                std::mem::forget(sampler);
126            }
127
128            Self { sampler: chain }
129        }
130    }
131
132    /// Same as [`Self::chain`] with `no_perf = false`.
133    ///
134    /// # Example
135    /// ```rust
136    /// use llama_cpp_2::token::{
137    ///    LlamaToken,
138    ///    data::LlamaTokenData,
139    ///    data_array::LlamaTokenDataArray
140    /// };
141    /// use llama_cpp_2::sampling::LlamaSampler;
142    /// use llama_cpp_2::llama_backend::LlamaBackend;
143    /// let backend = LlamaBackend::init().unwrap();
144    ///
145    /// let mut data_array = LlamaTokenDataArray::new(vec![
146    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
147    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
148    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
149    /// ], false);
150    ///
151    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
152    ///     LlamaSampler::temp(0.5),
153    ///     LlamaSampler::greedy(),
154    /// ]));
155    ///
156    /// assert_eq!(data_array.data[0].logit(), 0.);
157    /// assert_eq!(data_array.data[1].logit(), 2.);
158    /// assert_eq!(data_array.data[2].logit(), 4.);
159    ///
160    /// assert_eq!(data_array.data.len(), 3);
161    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
162    /// ```
163    #[must_use]
164    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
165        Self::chain(samplers, false)
166    }
167
168    #[allow(clippy::doc_markdown)]
169    /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
170    /// value, the rest are set to -inf
171    ///
172    /// # Example:
173    /// ```rust
174    /// use llama_cpp_2::token::{
175    ///    LlamaToken,
176    ///    data::LlamaTokenData,
177    ///    data_array::LlamaTokenDataArray
178    /// };
179    /// use llama_cpp_2::sampling::LlamaSampler;
180    ///
181    /// let mut data_array = LlamaTokenDataArray::new(vec![
182    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
183    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
184    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
185    /// ], false);
186    ///
187    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
188    ///
189    /// assert_eq!(data_array.data[0].logit(), 0.);
190    /// assert_eq!(data_array.data[1].logit(), 2.);
191    /// assert_eq!(data_array.data[2].logit(), 4.);
192    /// ```
193    #[must_use]
194    pub fn temp(t: f32) -> Self {
195        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
196        Self { sampler }
197    }
198
199    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
200    /// <https://arxiv.org/abs/2309.02772>.
201    #[must_use]
202    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
203        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
204        Self { sampler }
205    }
206
207    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
208    /// <https://arxiv.org/abs/1904.09751>
209    ///
210    /// # Example:
211    /// ```rust
212    /// use llama_cpp_2::token::{
213    ///    LlamaToken,
214    ///    data::LlamaTokenData,
215    ///    data_array::LlamaTokenDataArray
216    /// };
217    /// use llama_cpp_2::sampling::LlamaSampler;
218    ///
219    /// let mut data_array = LlamaTokenDataArray::new(vec![
220    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
221    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
222    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
223    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
224    /// ], false);
225    ///
226    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
227    ///
228    /// assert_eq!(data_array.data.len(), 2);
229    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
230    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
231    /// ```
232    #[must_use]
233    pub fn top_k(k: i32) -> Self {
234        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
235        Self { sampler }
236    }
237
238    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
239    /// <https://arxiv.org/pdf/2411.07641>
240    ///
241    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
242    ///
243    /// # Parameters
244    /// - `n`: Number of standard deviations from the mean to include in sampling
245    ///
246    /// # Example
247    /// ```rust
248    /// use llama_cpp_2::sampling::LlamaSampler;
249    /// use llama_cpp_2::token::{
250    ///     LlamaToken,
251    ///     data::LlamaTokenData,
252    ///     data_array::LlamaTokenDataArray
253    /// };
254    ///
255    /// let mut data_array = LlamaTokenDataArray::new(vec![
256    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
257    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
258    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
259    /// ], false);
260    ///
261    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
262    /// ```
263    #[must_use]
264    pub fn top_n_sigma(n: f32) -> Self {
265        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
266        Self { sampler }
267    }
268
269    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
270    #[must_use]
271    pub fn typical(p: f32, min_keep: usize) -> Self {
272        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
273        Self { sampler }
274    }
275
276    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
277    /// <https://arxiv.org/abs/1904.09751>
278    #[must_use]
279    pub fn top_p(p: f32, min_keep: usize) -> Self {
280        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
281        Self { sampler }
282    }
283
284    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
285    #[must_use]
286    pub fn min_p(p: f32, min_keep: usize) -> Self {
287        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
288        Self { sampler }
289    }
290
291    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
292    #[must_use]
293    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
294        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
295        Self { sampler }
296    }
297
298    /// Grammar sampler
299    #[cfg(feature = "common")]
300    #[must_use]
301    pub fn grammar(
302        model: &LlamaModel,
303        grammar_str: &str,
304        grammar_root: &str,
305    ) -> Result<Self, GrammarError> {
306        let (grammar_str, grammar_root) =
307            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
308
309        let sampler = unsafe {
310            llama_cpp_sys_2::llama_rs_sampler_init_grammar(
311                model.vocab_ptr(),
312                grammar_str.as_ptr(),
313                grammar_root.as_ptr(),
314            )
315        };
316
317        if sampler.is_null() {
318            Err(GrammarError::NullGrammar)
319        } else {
320            Ok(Self { sampler })
321        }
322    }
323
324    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
325    ///
326    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
327    #[cfg(feature = "common")]
328    #[must_use]
329    pub fn grammar_lazy(
330        model: &LlamaModel,
331        grammar_str: &str,
332        grammar_root: &str,
333        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
334        trigger_tokens: &[LlamaToken],
335    ) -> Result<Self, GrammarError> {
336        let (grammar_str, grammar_root) =
337            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
338        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
339
340        let mut trigger_word_ptrs: Vec<*const c_char> =
341            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
342
343        let sampler = unsafe {
344            llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy(
345                model.vocab_ptr(),
346                grammar_str.as_ptr(),
347                grammar_root.as_ptr(),
348                trigger_word_ptrs.as_mut_ptr(),
349                trigger_word_ptrs.len(),
350                trigger_tokens.as_ptr().cast(),
351                trigger_tokens.len(),
352            )
353        };
354
355        if sampler.is_null() {
356            Err(GrammarError::NullGrammar)
357        } else {
358            Ok(Self { sampler })
359        }
360    }
361
362    /// Lazy grammar sampler using regex trigger patterns.
363    ///
364    /// Trigger patterns are regular expressions matched from the start of the
365    /// generation output. The grammar sampler will be fed content starting from
366    /// the first match group.
367    #[cfg(feature = "common")]
368    #[must_use]
369    pub fn grammar_lazy_patterns(
370        model: &LlamaModel,
371        grammar_str: &str,
372        grammar_root: &str,
373        trigger_patterns: &[String],
374        trigger_tokens: &[LlamaToken],
375    ) -> Result<Self, GrammarError> {
376        let (grammar_str, grammar_root) =
377            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
378        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
379
380        let mut trigger_pattern_ptrs: Vec<*const c_char> =
381            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
382
383        let sampler = unsafe {
384            llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy_patterns(
385                model.vocab_ptr(),
386                grammar_str.as_ptr(),
387                grammar_root.as_ptr(),
388                trigger_pattern_ptrs.as_mut_ptr(),
389                trigger_pattern_ptrs.len(),
390                trigger_tokens.as_ptr().cast(),
391                trigger_tokens.len(),
392            )
393        };
394
395        if sampler.is_null() {
396            Err(GrammarError::NullGrammar)
397        } else {
398            Ok(Self { sampler })
399        }
400    }
401
402    /// `LLGuidance` sampler for constrained decoding.
403    ///
404    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
405    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
406    ///
407    /// # Errors
408    ///
409    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
410    #[cfg(feature = "llguidance")]
411    pub fn llguidance(
412        model: &LlamaModel,
413        grammar_kind: &str,
414        grammar_data: &str,
415    ) -> Result<Self, GrammarError> {
416        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
417    }
418
419    fn sanitize_grammar_strings(
420        grammar_str: &str,
421        grammar_root: &str,
422    ) -> Result<(CString, CString), GrammarError> {
423        if !grammar_str.contains(grammar_root) {
424            return Err(GrammarError::RootNotFound);
425        }
426
427        if grammar_str.contains('\0') || grammar_root.contains('\0') {
428            return Err(GrammarError::GrammarNullBytes);
429        }
430
431        Ok((
432            CString::new(grammar_str).unwrap(),
433            CString::new(grammar_root).unwrap(),
434        ))
435    }
436
437    fn sanitize_trigger_words(
438        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
439    ) -> Result<Vec<CString>, GrammarError> {
440        let trigger_words: Vec<_> = trigger_words.into_iter().collect();
441        if trigger_words
442            .iter()
443            .any(|word| word.as_ref().contains(&b'\0'))
444        {
445            return Err(GrammarError::TriggerWordNullBytes);
446        }
447        Ok(trigger_words
448            .into_iter()
449            .map(|word| CString::new(word.as_ref()).unwrap())
450            .collect())
451    }
452
453    fn sanitize_trigger_patterns(
454        trigger_patterns: &[String],
455    ) -> Result<Vec<CString>, GrammarError> {
456        let mut patterns = Vec::with_capacity(trigger_patterns.len());
457        for pattern in trigger_patterns {
458            if pattern.contains('\0') {
459                return Err(GrammarError::GrammarNullBytes);
460            }
461            patterns.push(CString::new(pattern.as_str()).unwrap());
462        }
463        Ok(patterns)
464    }
465
466    /// DRY sampler, designed by p-e-w, as described in:
467    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
468    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
469    ///
470    /// # Panics
471    /// If any string in ``seq_breakers`` contains null bytes.
472    #[allow(missing_docs)]
473    #[must_use]
474    pub fn dry(
475        model: &LlamaModel,
476        multiplier: f32,
477        base: f32,
478        allowed_length: i32,
479        penalty_last_n: i32,
480        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
481    ) -> Self {
482        let seq_breakers: Vec<CString> = seq_breakers
483            .into_iter()
484            .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
485            .collect();
486        let mut seq_breaker_pointers: Vec<*const c_char> =
487            seq_breakers.iter().map(|s| s.as_ptr()).collect();
488
489        let sampler = unsafe {
490            llama_cpp_sys_2::llama_sampler_init_dry(
491                model.vocab_ptr(),
492                model
493                    .n_ctx_train()
494                    .try_into()
495                    .expect("n_ctx_train exceeds i32::MAX"),
496                multiplier,
497                base,
498                allowed_length,
499                penalty_last_n,
500                seq_breaker_pointers.as_mut_ptr(),
501                seq_breaker_pointers.len(),
502            )
503        };
504        Self { sampler }
505    }
506
507    /// Penalizes tokens for being present in the context.
508    ///
509    /// Parameters:
510    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
511    /// - ``penalty_repeat``: 1.0 = disabled
512    /// - ``penalty_freq``: 0.0 = disabled
513    /// - ``penalty_present``: 0.0 = disabled
514    #[allow(clippy::too_many_arguments)]
515    #[must_use]
516    pub fn penalties(
517        penalty_last_n: i32,
518        penalty_repeat: f32,
519        penalty_freq: f32,
520        penalty_present: f32,
521    ) -> Self {
522        let sampler = unsafe {
523            llama_cpp_sys_2::llama_sampler_init_penalties(
524                penalty_last_n,
525                penalty_repeat,
526                penalty_freq,
527                penalty_present,
528            )
529        };
530        Self { sampler }
531    }
532
533    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
534    ///
535    /// # Parameters:
536    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
537    /// - ``seed``: Seed to initialize random generation with.
538    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
539    ///   generated text. A higher value corresponds to more surprising or less predictable text,
540    ///   while a lower value corresponds to less surprising or more predictable text.
541    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
542    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
543    ///   updated more quickly, while a smaller learning rate will result in slower updates.
544    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
545    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
546    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
547    ///   it affects the performance of the algorithm.
548    #[must_use]
549    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
550        let sampler =
551            unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
552        Self { sampler }
553    }
554
555    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
556    ///
557    /// # Parameters:
558    /// - ``seed``: Seed to initialize random generation with.
559    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
560    ///   generated text. A higher value corresponds to more surprising or less predictable text,
561    ///   while a lower value corresponds to less surprising or more predictable text.
562    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
563    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
564    ///   updated more quickly, while a smaller learning rate will result in slower updates.
565    #[must_use]
566    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
567        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
568        Self { sampler }
569    }
570
571    /// Selects a token at random based on each token's probabilities
572    #[must_use]
573    pub fn dist(seed: u32) -> Self {
574        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
575        Self { sampler }
576    }
577
578    /// Selects the most likely token
579    ///
580    /// # Example:
581    /// ```rust
582    /// use llama_cpp_2::token::{
583    ///    LlamaToken,
584    ///    data::LlamaTokenData,
585    ///    data_array::LlamaTokenDataArray
586    /// };
587    /// use llama_cpp_2::sampling::LlamaSampler;
588    ///
589    /// let mut data_array = LlamaTokenDataArray::new(vec![
590    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
591    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
592    /// ], false);
593    ///
594    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
595    ///
596    /// assert_eq!(data_array.data.len(), 2);
597    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
598    /// ```
599    #[must_use]
600    pub fn greedy() -> Self {
601        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
602        Self { sampler }
603    }
604
605    /// Creates a sampler that applies bias values to specific tokens during sampling.
606    ///
607    /// # Parameters
608    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
609    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
610    ///
611    /// # Example
612    /// ```rust
613    /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
614    /// use llama_cpp_2::sampling::LlamaSampler;
615    ///
616    /// let biases = vec![
617    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
618    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
619    /// ];
620    ///
621    /// // Assuming vocab_size of 32000
622    /// let sampler = LlamaSampler::logit_bias(32000, &biases);
623    /// ```
624    #[must_use]
625    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
626        let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
627
628        let sampler = unsafe {
629            llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
630        };
631
632        Self { sampler }
633    }
634}
635
636impl Drop for LlamaSampler {
637    fn drop(&mut self) {
638        unsafe {
639            llama_cpp_sys_2::llama_sampler_free(self.sampler);
640        }
641    }
642}