Skip to main content

llama_cpp_bindings/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::ffi_error_reader::read_and_free_cpp_error;
9use crate::model::LlamaModel;
10use crate::token::LlamaToken;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
14
15fn check_sampler_accept_status(
16    status: llama_cpp_bindings_sys::llama_rs_sampler_accept_status,
17    error_ptr: *mut c_char,
18) -> Result<(), SamplerAcceptError> {
19    match status {
20        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK => Ok(()),
21        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_ERROR_STRING_ALLOCATION_FAILED => {
22            Err(SamplerAcceptError::NotEnoughMemory)
23        }
24        llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION => {
25            let message = unsafe { read_and_free_cpp_error(error_ptr) };
26            Err(SamplerAcceptError::GrammarStateCorrupted { message })
27        }
28        other => unreachable!("llama_rs_sampler_accept returned unrecognized status {other}"),
29    }
30}
31
32fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
33    i32::try_from(value).map_err(|convert_error| {
34        GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
35    })
36}
37
38fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
39    i32::try_from(value).map_err(|convert_error| {
40        SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
41    })
42}
43
44/// A safe wrapper around `llama_sampler`.
45pub struct LlamaSampler {
46    /// Raw pointer to the underlying `llama_sampler`.
47    pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
48}
49
50impl Debug for LlamaSampler {
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("LlamaSamplerChain").finish()
53    }
54}
55
56impl LlamaSampler {
57    /// Sample and accept a token from the idx-th output of the last evaluation.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid.
62    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
63        let mut token: i32 = -1;
64        let mut error_ptr: *mut c_char = std::ptr::null_mut();
65
66        let status = unsafe {
67            llama_cpp_bindings_sys::llama_rs_sampler_sample(
68                self.sampler,
69                ctx.context.as_ptr(),
70                idx,
71                &raw mut token,
72                &raw mut error_ptr,
73            )
74        };
75
76        match status {
77            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_OK => Ok(LlamaToken(token)),
78            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_ERROR_STRING_ALLOCATION_FAILED => {
79                Err(SampleError::NotEnoughMemory)
80            }
81            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_VENDORED_THREW_CXX_EXCEPTION => {
82                let message = unsafe { read_and_free_cpp_error(error_ptr) };
83                Err(SampleError::Reported { message })
84            }
85            other => unreachable!("llama_rs_sampler_sample returned unrecognized status {other}"),
86        }
87    }
88
89    /// Applies this sampler to a [`LlamaTokenDataArray`].
90    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
91        data_array.apply_sampler(self);
92    }
93
94    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
95    /// (e.g. grammar, repetition, etc.)
96    ///
97    /// # Errors
98    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
99    pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
100        self.try_accept(token)
101    }
102
103    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
104    /// certain samplers (e.g. grammar, repetition, etc.)
105    ///
106    /// # Errors
107    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
108    pub fn accept_many(
109        &mut self,
110        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
111    ) -> Result<(), SamplerAcceptError> {
112        for token in tokens {
113            self.try_accept(*token.borrow())?;
114        }
115
116        Ok(())
117    }
118
119    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
120    /// certain samplers (e.g. grammar, repetition, etc.)
121    ///
122    /// # Errors
123    /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
124    pub fn with_tokens(
125        mut self,
126        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
127    ) -> Result<Self, SamplerAcceptError> {
128        self.accept_many(tokens)?;
129
130        Ok(self)
131    }
132
133    /// Try accepting a token from the sampler. Returns an error if the sampler throws.
134    ///
135    /// # Errors
136    /// Returns an error if the underlying sampler rejects the token.
137    pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
138        let mut error_ptr: *mut c_char = std::ptr::null_mut();
139
140        let status = unsafe {
141            llama_cpp_bindings_sys::llama_rs_sampler_accept(
142                self.sampler,
143                token.0,
144                &raw mut error_ptr,
145            )
146        };
147
148        check_sampler_accept_status(status, error_ptr)
149    }
150
151    /// Resets the internal state of the sampler.
152    ///
153    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
154    pub fn reset(&mut self) {
155        unsafe {
156            llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
157        }
158    }
159
160    /// Gets the random seed used by this sampler.
161    ///
162    /// Returns:
163    /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
164    /// - For sampler chains: returns the first non-default seed found in reverse order
165    /// - For all other samplers: returns 0xFFFFFFFF
166    #[must_use]
167    pub fn get_seed(&self) -> u32 {
168        unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
169    }
170
171    /// Combines a list of samplers into a single sampler that applies each component sampler one
172    /// after another.
173    ///
174    /// If you are using a chain to select a token, the chain should always end with one of
175    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
176    /// [`LlamaSampler::mirostat_v2`].
177    #[must_use]
178    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
179        unsafe {
180            let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
181                llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
182            );
183
184            for sampler in samplers {
185                llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
186                std::mem::forget(sampler);
187            }
188
189            Self { sampler: chain }
190        }
191    }
192
193    /// Same as [`Self::chain`] with `no_perf = false`.
194    ///
195    /// # Example
196    /// ```rust
197    /// use llama_cpp_bindings::token::{
198    ///    LlamaToken,
199    ///    data::LlamaTokenData,
200    ///    data_array::LlamaTokenDataArray
201    /// };
202    /// use llama_cpp_bindings::sampling::LlamaSampler;
203    /// use llama_cpp_bindings::llama_backend::LlamaBackend;
204    /// let backend = LlamaBackend::init().unwrap();
205    ///
206    /// let mut data_array = LlamaTokenDataArray::new(vec![
207    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
208    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
209    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
210    /// ], false);
211    ///
212    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
213    ///     LlamaSampler::temp(0.5),
214    ///     LlamaSampler::greedy(),
215    /// ]));
216    ///
217    /// assert_eq!(data_array.data[0].logit(), 0.);
218    /// assert_eq!(data_array.data[1].logit(), 2.);
219    /// assert_eq!(data_array.data[2].logit(), 4.);
220    ///
221    /// assert_eq!(data_array.data.len(), 3);
222    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
223    /// ```
224    #[must_use]
225    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
226        Self::chain(samplers, false)
227    }
228
229    /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
230    /// value, the rest are set to -inf
231    ///
232    /// # Example:
233    /// ```rust
234    /// use llama_cpp_bindings::token::{
235    ///    LlamaToken,
236    ///    data::LlamaTokenData,
237    ///    data_array::LlamaTokenDataArray
238    /// };
239    /// use llama_cpp_bindings::sampling::LlamaSampler;
240    ///
241    /// let mut data_array = LlamaTokenDataArray::new(vec![
242    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
243    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
244    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
245    /// ], false);
246    ///
247    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
248    ///
249    /// assert_eq!(data_array.data[0].logit(), 0.);
250    /// assert_eq!(data_array.data[1].logit(), 2.);
251    /// assert_eq!(data_array.data[2].logit(), 4.);
252    /// ```
253    #[must_use]
254    pub fn temp(t: f32) -> Self {
255        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
256        Self { sampler }
257    }
258
259    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
260    /// <https://arxiv.org/abs/2309.02772>.
261    #[must_use]
262    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
263        let sampler =
264            unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
265        Self { sampler }
266    }
267
268    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
269    /// <https://arxiv.org/abs/1904.09751>
270    ///
271    /// # Example:
272    /// ```rust
273    /// use llama_cpp_bindings::token::{
274    ///    LlamaToken,
275    ///    data::LlamaTokenData,
276    ///    data_array::LlamaTokenDataArray
277    /// };
278    /// use llama_cpp_bindings::sampling::LlamaSampler;
279    ///
280    /// let mut data_array = LlamaTokenDataArray::new(vec![
281    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
282    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
283    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
284    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
285    /// ], false);
286    ///
287    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
288    ///
289    /// assert_eq!(data_array.data.len(), 2);
290    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
291    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
292    /// ```
293    #[must_use]
294    pub fn top_k(k: i32) -> Self {
295        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
296        Self { sampler }
297    }
298
299    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
300    /// <https://arxiv.org/pdf/2411.07641>
301    ///
302    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
303    ///
304    /// # Parameters
305    /// - `n`: Number of standard deviations from the mean to include in sampling
306    ///
307    /// # Example
308    /// ```rust
309    /// use llama_cpp_bindings::sampling::LlamaSampler;
310    /// use llama_cpp_bindings::token::{
311    ///     LlamaToken,
312    ///     data::LlamaTokenData,
313    ///     data_array::LlamaTokenDataArray
314    /// };
315    ///
316    /// let mut data_array = LlamaTokenDataArray::new(vec![
317    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
318    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
319    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
320    /// ], false);
321    ///
322    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
323    /// ```
324    #[must_use]
325    pub fn top_n_sigma(n: f32) -> Self {
326        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
327        Self { sampler }
328    }
329
330    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
331    #[must_use]
332    pub fn typical(p: f32, min_keep: usize) -> Self {
333        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
334        Self { sampler }
335    }
336
337    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
338    /// <https://arxiv.org/abs/1904.09751>
339    #[must_use]
340    pub fn top_p(p: f32, min_keep: usize) -> Self {
341        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
342        Self { sampler }
343    }
344
345    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
346    #[must_use]
347    pub fn min_p(p: f32, min_keep: usize) -> Self {
348        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
349        Self { sampler }
350    }
351
352    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
353    #[must_use]
354    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
355        let sampler =
356            unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
357        Self { sampler }
358    }
359
360    /// Grammar sampler
361    ///
362    /// # Errors
363    /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
364    pub fn grammar(
365        model: &LlamaModel,
366        grammar_str: &str,
367        grammar_root: &str,
368    ) -> Result<Self, GrammarError> {
369        let (grammar_str, grammar_root) =
370            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
371        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
372        let mut error_ptr: *mut c_char = std::ptr::null_mut();
373
374        let status = unsafe {
375            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
376                model.vocab_ptr(),
377                grammar_str.as_ptr(),
378                grammar_root.as_ptr(),
379                &raw mut sampler,
380                &raw mut error_ptr,
381            )
382        };
383
384        match status {
385            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_OK => {
386                Ok(Self { sampler })
387            }
388            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_RETURNED_NULL => {
389                Err(GrammarError::GrammarMalformed)
390            }
391            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_ERROR_STRING_ALLOCATION_FAILED => {
392                Err(GrammarError::NotEnoughMemory)
393            }
394            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_THREW_CXX_EXCEPTION => {
395                let message = unsafe { read_and_free_cpp_error(error_ptr) };
396                Err(GrammarError::Reported { message })
397            }
398            other => unreachable!(
399                "llama_rs_sampler_init_grammar returned unrecognized status {other}"
400            ),
401        }
402    }
403
404    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
405    ///
406    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
407    ///
408    /// # Errors
409    /// Returns an error if the grammar or trigger words are invalid.
410    pub fn grammar_lazy(
411        model: &LlamaModel,
412        grammar_str: &str,
413        grammar_root: &str,
414        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
415        trigger_tokens: &[LlamaToken],
416    ) -> Result<Self, GrammarError> {
417        let (grammar_str, grammar_root) =
418            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
419        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
420        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
421        let mut error_ptr: *mut c_char = std::ptr::null_mut();
422
423        let mut trigger_word_ptrs: Vec<*const c_char> =
424            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
425
426        let status = unsafe {
427            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
428                model.vocab_ptr(),
429                grammar_str.as_ptr(),
430                grammar_root.as_ptr(),
431                trigger_word_ptrs.as_mut_ptr(),
432                trigger_word_ptrs.len(),
433                trigger_tokens.as_ptr().cast(),
434                trigger_tokens.len(),
435                &raw mut sampler,
436                &raw mut error_ptr,
437            )
438        };
439
440        match status {
441            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_OK => {
442                Ok(Self { sampler })
443            }
444            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_RETURNED_NULL => {
445                Err(GrammarError::LazyGrammarMalformed)
446            }
447            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_ERROR_STRING_ALLOCATION_FAILED => {
448                Err(GrammarError::NotEnoughMemory)
449            }
450            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_THREW_CXX_EXCEPTION => {
451                let message = unsafe { read_and_free_cpp_error(error_ptr) };
452                Err(GrammarError::Reported { message })
453            }
454            other => unreachable!(
455                "llama_rs_sampler_init_grammar_lazy returned unrecognized status {other}"
456            ),
457        }
458    }
459
460    /// Lazy grammar sampler using regex trigger patterns.
461    ///
462    /// Trigger patterns are regular expressions matched from the start of the
463    /// generation output. The grammar sampler will be fed content starting from
464    /// the first match group.
465    ///
466    /// # Errors
467    /// Returns an error if the grammar or trigger patterns are invalid.
468    pub fn grammar_lazy_patterns(
469        model: &LlamaModel,
470        grammar_str: &str,
471        grammar_root: &str,
472        trigger_patterns: &[String],
473        trigger_tokens: &[LlamaToken],
474    ) -> Result<Self, GrammarError> {
475        let (grammar_str, grammar_root) =
476            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
477        let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
478        let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
479        let mut error_ptr: *mut c_char = std::ptr::null_mut();
480
481        let mut trigger_pattern_ptrs: Vec<*const c_char> =
482            trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
483
484        let status = unsafe {
485            llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
486                model.vocab_ptr(),
487                grammar_str.as_ptr(),
488                grammar_root.as_ptr(),
489                trigger_pattern_ptrs.as_mut_ptr(),
490                trigger_pattern_ptrs.len(),
491                trigger_tokens.as_ptr().cast(),
492                trigger_tokens.len(),
493                &raw mut sampler,
494                &raw mut error_ptr,
495            )
496        };
497
498        match status {
499            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_OK => {
500                Ok(Self { sampler })
501            }
502            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_RETURNED_NULL => {
503                Err(GrammarError::LazyPatternsGrammarMalformed)
504            }
505            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_ERROR_STRING_ALLOCATION_FAILED => {
506                Err(GrammarError::NotEnoughMemory)
507            }
508            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_INVALID_TRIGGER_PATTERN => {
509                let message = unsafe { read_and_free_cpp_error(error_ptr) };
510                Err(GrammarError::InvalidTriggerPattern { message })
511            }
512            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_THREW_CXX_EXCEPTION => {
513                let message = unsafe { read_and_free_cpp_error(error_ptr) };
514                Err(GrammarError::Reported { message })
515            }
516            other => unreachable!(
517                "llama_rs_sampler_init_grammar_lazy_patterns returned unrecognized status {other}"
518            ),
519        }
520    }
521
522    /// `LLGuidance` sampler for constrained decoding.
523    ///
524    /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
525    /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
526    ///
527    /// # Errors
528    ///
529    /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
530    pub fn llguidance(
531        model: &LlamaModel,
532        grammar_kind: &str,
533        grammar_data: &str,
534    ) -> Result<Self, GrammarError> {
535        crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
536    }
537
538    fn sanitize_grammar_strings(
539        grammar_str: &str,
540        grammar_root: &str,
541    ) -> Result<(CString, CString), GrammarError> {
542        if !grammar_str.contains(grammar_root) {
543            return Err(GrammarError::RootNotFound);
544        }
545
546        let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
547        let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
548
549        Ok((grammar, root))
550    }
551
552    fn sanitize_trigger_words(
553        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
554    ) -> Result<Vec<CString>, GrammarError> {
555        trigger_words
556            .into_iter()
557            .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
558            .collect()
559    }
560
561    fn sanitize_trigger_patterns(
562        trigger_patterns: &[String],
563    ) -> Result<Vec<CString>, GrammarError> {
564        trigger_patterns
565            .iter()
566            .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
567            .collect()
568    }
569
570    /// DRY sampler, designed by p-e-w, as described in:
571    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
572    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
573    ///
574    /// # Errors
575    /// Returns an error if any string in `seq_breakers` contains null bytes.
576    pub fn dry(
577        model: &LlamaModel,
578        multiplier: f32,
579        base: f32,
580        allowed_length: i32,
581        penalty_last_n: i32,
582        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
583    ) -> Result<Self, GrammarError> {
584        let seq_breakers: Vec<CString> = seq_breakers
585            .into_iter()
586            .map(|seq_breaker| CString::new(seq_breaker.as_ref()))
587            .collect::<Result<Vec<_>, _>>()?;
588        let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers
589            .iter()
590            .map(|seq_breaker| seq_breaker.as_ptr())
591            .collect();
592
593        let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
594            GrammarError::IntegerOverflow(format!(
595                "n_ctx_train does not fit into u32: {convert_error}"
596            ))
597        })?;
598        let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
599        let sampler = unsafe {
600            llama_cpp_bindings_sys::llama_sampler_init_dry(
601                model.vocab_ptr(),
602                n_ctx_train,
603                multiplier,
604                base,
605                allowed_length,
606                penalty_last_n,
607                seq_breaker_pointers.as_mut_ptr(),
608                seq_breaker_pointers.len(),
609            )
610        };
611
612        Ok(Self { sampler })
613    }
614
615    /// Penalizes tokens for being present in the context.
616    ///
617    /// Parameters:
618    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
619    /// - ``penalty_repeat``: 1.0 = disabled
620    /// - ``penalty_freq``: 0.0 = disabled
621    /// - ``penalty_present``: 0.0 = disabled
622    #[must_use]
623    pub fn penalties(
624        penalty_last_n: i32,
625        penalty_repeat: f32,
626        penalty_freq: f32,
627        penalty_present: f32,
628    ) -> Self {
629        let sampler = unsafe {
630            llama_cpp_bindings_sys::llama_sampler_init_penalties(
631                penalty_last_n,
632                penalty_repeat,
633                penalty_freq,
634                penalty_present,
635            )
636        };
637        Self { sampler }
638    }
639
640    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
641    ///
642    /// # Parameters:
643    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
644    /// - ``seed``: Seed to initialize random generation with.
645    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
646    ///   generated text. A higher value corresponds to more surprising or less predictable text,
647    ///   while a lower value corresponds to less surprising or more predictable text.
648    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
649    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
650    ///   updated more quickly, while a smaller learning rate will result in slower updates.
651    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
652    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
653    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
654    ///   it affects the performance of the algorithm.
655    #[must_use]
656    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
657        let sampler = unsafe {
658            llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
659        };
660        Self { sampler }
661    }
662
663    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
664    ///
665    /// # Parameters:
666    /// - ``seed``: Seed to initialize random generation with.
667    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
668    ///   generated text. A higher value corresponds to more surprising or less predictable text,
669    ///   while a lower value corresponds to less surprising or more predictable text.
670    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
671    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
672    ///   updated more quickly, while a smaller learning rate will result in slower updates.
673    #[must_use]
674    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
675        let sampler =
676            unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
677        Self { sampler }
678    }
679
680    /// Selects a token at random based on each token's probabilities
681    #[must_use]
682    pub fn dist(seed: u32) -> Self {
683        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
684        Self { sampler }
685    }
686
687    /// Selects the most likely token
688    ///
689    /// # Example:
690    /// ```rust
691    /// use llama_cpp_bindings::token::{
692    ///    LlamaToken,
693    ///    data::LlamaTokenData,
694    ///    data_array::LlamaTokenDataArray
695    /// };
696    /// use llama_cpp_bindings::sampling::LlamaSampler;
697    ///
698    /// let mut data_array = LlamaTokenDataArray::new(vec![
699    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
700    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
701    /// ], false);
702    ///
703    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
704    ///
705    /// assert_eq!(data_array.data.len(), 2);
706    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
707    /// ```
708    #[must_use]
709    pub fn greedy() -> Self {
710        let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
711        Self { sampler }
712    }
713
714    /// Creates a sampler that applies bias values to specific tokens during sampling.
715    ///
716    /// # Parameters
717    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
718    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
719    ///
720    /// # Errors
721    /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
722    ///
723    /// # Example
724    /// ```rust
725    /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
726    /// use llama_cpp_bindings::sampling::LlamaSampler;
727    ///
728    /// let biases = vec![
729    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
730    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
731    /// ];
732    ///
733    /// // Assuming vocab_size of 32000
734    /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
735    /// ```
736    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
737        let bias_count = checked_usize_as_i32_sampling(biases.len())?;
738        let data = biases
739            .as_ptr()
740            .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
741
742        let sampler = unsafe {
743            llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
744        };
745
746        Ok(Self { sampler })
747    }
748}
749
750impl Drop for LlamaSampler {
751    fn drop(&mut self) {
752        unsafe {
753            llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
754        }
755    }
756}
757
758#[cfg(test)]
759mod tests {
760    use super::LlamaSampler;
761    use crate::GrammarError;
762
763    #[test]
764    fn sanitize_grammar_strings_valid() {
765        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
766
767        assert!(result.is_ok());
768    }
769
770    #[test]
771    fn sanitize_grammar_strings_root_not_found() {
772        let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
773
774        assert!(matches!(result.err(), Some(GrammarError::RootNotFound)));
775    }
776
777    #[test]
778    fn sanitize_grammar_strings_null_byte_in_grammar() {
779        let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
780
781        assert!(matches!(
782            result.err(),
783            Some(GrammarError::GrammarNullBytes(_))
784        ));
785    }
786
787    #[test]
788    fn sanitize_grammar_strings_null_byte_in_root() {
789        let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
790
791        assert!(matches!(
792            result.err(),
793            Some(GrammarError::GrammarNullBytes(_))
794        ));
795    }
796
797    #[test]
798    fn sanitize_trigger_words_valid() {
799        let words: Vec<&[u8]> = vec![b"hello", b"world"];
800        let result = LlamaSampler::sanitize_trigger_words(words);
801
802        assert!(result.is_ok());
803        assert_eq!(result.expect("valid trigger words").len(), 2);
804    }
805
806    #[test]
807    fn sanitize_trigger_words_empty_list() {
808        let words: Vec<&[u8]> = vec![];
809        let result = LlamaSampler::sanitize_trigger_words(words);
810
811        assert!(result.is_ok());
812        assert!(result.expect("valid trigger words").is_empty());
813    }
814
815    #[test]
816    fn sanitize_trigger_words_null_byte() {
817        let words: Vec<&[u8]> = vec![b"hel\0lo"];
818        let result = LlamaSampler::sanitize_trigger_words(words);
819
820        assert!(matches!(
821            result.err(),
822            Some(GrammarError::TriggerWordNullBytes(_))
823        ));
824    }
825
826    #[test]
827    fn sanitize_trigger_patterns_valid() {
828        let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
829        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
830
831        assert!(result.is_ok());
832        assert_eq!(result.expect("valid trigger patterns").len(), 2);
833    }
834
835    #[test]
836    fn sanitize_trigger_patterns_empty_list() {
837        let patterns: Vec<String> = vec![];
838        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
839
840        assert!(result.is_ok());
841        assert!(result.expect("valid trigger patterns").is_empty());
842    }
843
844    #[test]
845    fn sanitize_trigger_patterns_null_byte() {
846        let patterns = vec!["hel\0lo".to_string()];
847        let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
848
849        assert!(matches!(
850            result.err(),
851            Some(GrammarError::GrammarNullBytes(_))
852        ));
853    }
854
855    #[test]
856    fn apply_modifies_data_array() {
857        use crate::token::LlamaToken;
858        use crate::token::data::LlamaTokenData;
859        use crate::token::data_array::LlamaTokenDataArray;
860
861        let sampler = LlamaSampler::greedy();
862        let mut data_array = LlamaTokenDataArray::new(
863            vec![
864                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
865                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
866            ],
867            false,
868        );
869
870        sampler.apply(&mut data_array);
871
872        assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
873    }
874
875    #[test]
876    fn accept_succeeds() {
877        let mut sampler = LlamaSampler::chain_simple([
878            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
879            LlamaSampler::greedy(),
880        ]);
881
882        sampler
883            .accept(crate::token::LlamaToken::new(1))
884            .expect("test: accept should succeed");
885    }
886
887    #[test]
888    fn try_accept_succeeds_on_penalties_sampler() {
889        let mut sampler = LlamaSampler::chain_simple([
890            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
891            LlamaSampler::greedy(),
892        ]);
893
894        let result = sampler.try_accept(crate::token::LlamaToken::new(42));
895
896        assert!(result.is_ok());
897    }
898
899    #[test]
900    fn accept_many_multiple_tokens() {
901        use crate::token::LlamaToken;
902
903        let mut sampler = LlamaSampler::chain_simple([
904            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
905            LlamaSampler::greedy(),
906        ]);
907
908        sampler
909            .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
910            .expect("test: accept_many should succeed");
911    }
912
913    #[test]
914    fn with_tokens_builder_pattern() {
915        use crate::token::LlamaToken;
916
917        let _sampler = LlamaSampler::chain_simple([
918            LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
919            LlamaSampler::greedy(),
920        ])
921        .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
922        .expect("test: with_tokens should succeed");
923    }
924
925    #[test]
926    fn all_sampler_constructors() {
927        use crate::token::LlamaToken;
928        use crate::token::logit_bias::LlamaLogitBias;
929
930        let _temp = LlamaSampler::temp(0.8);
931        let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
932        let _top_k = LlamaSampler::top_k(40);
933        let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
934        let _top_p = LlamaSampler::top_p(0.9, 1);
935        let _min_p = LlamaSampler::min_p(0.05, 1);
936        let _typical = LlamaSampler::typical(0.9, 1);
937        let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
938        let _dist = LlamaSampler::dist(42);
939        let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
940        let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
941        let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
942        let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
943        let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
944    }
945
946    #[test]
947    fn reset_and_get_seed() {
948        let mut sampler = LlamaSampler::dist(42);
949        sampler.reset();
950        let _seed = sampler.get_seed();
951    }
952
953    #[test]
954    fn debug_formatting() {
955        let sampler = LlamaSampler::greedy();
956        let debug_output = format!("{sampler:?}");
957        assert!(debug_output.contains("LlamaSampler"));
958    }
959
960    #[test]
961    fn checked_u32_as_i32_overflow() {
962        let result = super::checked_u32_as_i32(u32::MAX);
963        assert!(result.is_err());
964    }
965
966    #[test]
967    fn checked_usize_as_i32_sampling_overflow() {
968        let result = super::checked_usize_as_i32_sampling(usize::MAX);
969        assert!(result.is_err());
970    }
971
972    #[test]
973    fn check_sampler_accept_status_ok() {
974        let result = super::check_sampler_accept_status(
975            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK,
976            std::ptr::null_mut(),
977        );
978
979        assert!(result.is_ok());
980    }
981
982    #[test]
983    fn check_sampler_accept_status_exception_maps_to_typed_variant() {
984        let result = super::check_sampler_accept_status(
985            llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION,
986            std::ptr::null_mut(),
987        );
988
989        assert!(matches!(
990            result,
991            Err(crate::SamplerAcceptError::GrammarStateCorrupted { .. })
992        ));
993    }
994}