llama_cpp_2/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12use crate::GrammarError;
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16    pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
17}
18
19impl Debug for LlamaSampler {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("LlamaSamplerChain").finish()
22    }
23}
24
25impl LlamaSampler {
26    /// Sample and accept a token from the idx-th output of the last evaluation
27    #[must_use]
28    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
29        let token = unsafe {
30            llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
31        };
32
33        LlamaToken(token)
34    }
35
36    /// Applies this sampler to a [`LlamaTokenDataArray`].
37    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
38        data_array.apply_sampler(self);
39    }
40
41    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
42    /// (e.g. grammar, repetition, etc.)
43    pub fn accept(&mut self, token: LlamaToken) {
44        unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) }
45    }
46
47    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
48    /// certain samplers (e.g. grammar, repetition, etc.)
49    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
50        for token in tokens {
51            unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
52        }
53    }
54
55    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
56    /// certain samplers (e.g. grammar, repetition, etc.)
57    #[must_use]
58    pub fn with_tokens(
59        mut self,
60        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
61    ) -> Self {
62        self.accept_many(tokens);
63        self
64    }
65
66    /// Resets the internal state of the sampler.
67    ///
68    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
69    pub fn reset(&mut self) {
70        unsafe {
71            llama_cpp_sys_2::llama_sampler_reset(self.sampler);
72        }
73    }
74
75    /// Gets the random seed used by this sampler.
76    ///
77    /// Returns:
78    /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
79    /// - For sampler chains: returns the first non-default seed found in reverse order
80    /// - For all other samplers: returns 0xFFFFFFFF
81    #[must_use]
82    pub fn get_seed(&self) -> u32 {
83        unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
84    }
85
86    /// Combines a list of samplers into a single sampler that applies each component sampler one
87    /// after another.
88    ///
89    /// If you are using a chain to select a token, the chain should always end with one of
90    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
91    /// [`LlamaSampler::mirostat_v2`].
92    #[must_use]
93    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
94        unsafe {
95            let chain = llama_cpp_sys_2::llama_sampler_chain_init(
96                llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
97            );
98
99            for sampler in samplers {
100                llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
101
102                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
103                // owned by the chain
104                std::mem::forget(sampler);
105            }
106
107            Self { sampler: chain }
108        }
109    }
110
111    /// Same as [`Self::chain`] with `no_perf = false`.
112    ///
113    /// # Example
114    /// ```rust
115    /// use llama_cpp_2::token::{
116    ///    LlamaToken,
117    ///    data::LlamaTokenData,
118    ///    data_array::LlamaTokenDataArray
119    /// };
120    /// use llama_cpp_2::sampling::LlamaSampler;
121    /// use llama_cpp_2::llama_backend::LlamaBackend;
122    /// let backend = LlamaBackend::init().unwrap();
123    ///
124    /// let mut data_array = LlamaTokenDataArray::new(vec![
125    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
126    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
127    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
128    /// ], false);
129    ///
130    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
131    ///     LlamaSampler::temp(0.5),
132    ///     LlamaSampler::greedy(),
133    /// ]));
134    ///
135    /// assert_eq!(data_array.data[0].logit(), 0.);
136    /// assert_eq!(data_array.data[1].logit(), 2.);
137    /// assert_eq!(data_array.data[2].logit(), 4.);
138    ///
139    /// assert_eq!(data_array.data.len(), 3);
140    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
141    /// ```
142    #[must_use]
143    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
144        Self::chain(samplers, false)
145    }
146
147    #[allow(clippy::doc_markdown)]
148    /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
149    /// value, the rest are set to -inf
150    ///
151    /// # Example:
152    /// ```rust
153    /// use llama_cpp_2::token::{
154    ///    LlamaToken,
155    ///    data::LlamaTokenData,
156    ///    data_array::LlamaTokenDataArray
157    /// };
158    /// use llama_cpp_2::sampling::LlamaSampler;
159    ///
160    /// let mut data_array = LlamaTokenDataArray::new(vec![
161    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
162    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
163    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
164    /// ], false);
165    ///
166    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
167    ///
168    /// assert_eq!(data_array.data[0].logit(), 0.);
169    /// assert_eq!(data_array.data[1].logit(), 2.);
170    /// assert_eq!(data_array.data[2].logit(), 4.);
171    /// ```
172    #[must_use]
173    pub fn temp(t: f32) -> Self {
174        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
175        Self { sampler }
176    }
177
178    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
179    /// <https://arxiv.org/abs/2309.02772>.
180    #[must_use]
181    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
182        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
183        Self { sampler }
184    }
185
186    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
187    /// <https://arxiv.org/abs/1904.09751>
188    ///
189    /// # Example:
190    /// ```rust
191    /// use llama_cpp_2::token::{
192    ///    LlamaToken,
193    ///    data::LlamaTokenData,
194    ///    data_array::LlamaTokenDataArray
195    /// };
196    /// use llama_cpp_2::sampling::LlamaSampler;
197    ///
198    /// let mut data_array = LlamaTokenDataArray::new(vec![
199    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
200    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
201    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
202    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
203    /// ], false);
204    ///
205    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
206    ///
207    /// assert_eq!(data_array.data.len(), 2);
208    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
209    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
210    /// ```
211    #[must_use]
212    pub fn top_k(k: i32) -> Self {
213        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
214        Self { sampler }
215    }
216
217    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
218    /// <https://arxiv.org/pdf/2411.07641>
219    ///
220    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
221    ///
222    /// # Parameters
223    /// - `n`: Number of standard deviations from the mean to include in sampling
224    ///
225    /// # Example
226    /// ```rust
227    /// use llama_cpp_2::sampling::LlamaSampler;
228    /// use llama_cpp_2::token::{
229    ///     LlamaToken,
230    ///     data::LlamaTokenData,
231    ///     data_array::LlamaTokenDataArray
232    /// };
233    ///
234    /// let mut data_array = LlamaTokenDataArray::new(vec![
235    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
236    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
237    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
238    /// ], false);
239    ///
240    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
241    /// ```
242    #[must_use]
243    pub fn top_n_sigma(n: f32) -> Self {
244        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
245        Self { sampler }
246    }
247
248    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
249    #[must_use]
250    pub fn typical(p: f32, min_keep: usize) -> Self {
251        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
252        Self { sampler }
253    }
254
255    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
256    /// <https://arxiv.org/abs/1904.09751>
257    #[must_use]
258    pub fn top_p(p: f32, min_keep: usize) -> Self {
259        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
260        Self { sampler }
261    }
262
263    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
264    #[must_use]
265    pub fn min_p(p: f32, min_keep: usize) -> Self {
266        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
267        Self { sampler }
268    }
269
270    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
271    #[must_use]
272    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
273        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
274        Self { sampler }
275    }
276
277    /// Grammar sampler
278    #[must_use]
279    pub fn grammar(
280        model: &LlamaModel,
281        grammar_str: &str,
282        grammar_root: &str,
283    ) -> Result<Self, GrammarError> {
284        let (grammar_str, grammar_root) =
285            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
286
287        let sampler = unsafe {
288            llama_cpp_sys_2::llama_sampler_init_grammar(
289                model.vocab_ptr(),
290                grammar_str.as_ptr(),
291                grammar_root.as_ptr(),
292            )
293        };
294
295        if sampler.is_null() {
296            Err(GrammarError::NullGrammar)
297        } else {
298            Ok(Self { sampler })
299        }
300    }
301
302    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
303    ///
304    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
305    #[must_use]
306    pub fn grammar_lazy(
307        model: &LlamaModel,
308        grammar_str: &str,
309        grammar_root: &str,
310        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
311        trigger_tokens: &[LlamaToken],
312    ) -> Result<Self, GrammarError> {
313        let (grammar_str, grammar_root) =
314            Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
315        let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
316
317        let mut trigger_word_ptrs: Vec<*const c_char> =
318            trigger_words.iter().map(|cs| cs.as_ptr()).collect();
319
320        let sampler = unsafe {
321            llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
322                model.vocab_ptr(),
323                grammar_str.as_ptr(),
324                grammar_root.as_ptr(),
325                trigger_word_ptrs.as_mut_ptr(),
326                trigger_word_ptrs.len(),
327                trigger_tokens.as_ptr().cast(),
328                trigger_tokens.len(),
329            )
330        };
331
332        if sampler.is_null() {
333            Err(GrammarError::NullGrammar)
334        } else {
335            Ok(Self { sampler })
336        }
337    }
338
339    fn sanitize_grammar_strings(
340        grammar_str: &str,
341        grammar_root: &str,
342    ) -> Result<(CString, CString), GrammarError> {
343        if !grammar_str.contains(grammar_root) {
344            return Err(GrammarError::RootNotFound);
345        }
346
347        if grammar_str.contains('\0') || grammar_root.contains('\0') {
348            return Err(GrammarError::GrammarNullBytes);
349        }
350
351        Ok((
352            CString::new(grammar_str).unwrap(),
353            CString::new(grammar_root).unwrap(),
354        ))
355    }
356
357    fn sanitize_trigger_words(
358        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
359    ) -> Result<Vec<CString>, GrammarError> {
360        let trigger_words: Vec<_> = trigger_words.into_iter().collect();
361        if trigger_words
362            .iter()
363            .any(|word| word.as_ref().contains(&b'\0'))
364        {
365            return Err(GrammarError::TriggerWordNullBytes);
366        }
367        Ok(trigger_words
368            .into_iter()
369            .map(|word| CString::new(word.as_ref()).unwrap())
370            .collect())
371    }
372
373    /// DRY sampler, designed by p-e-w, as described in:
374    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
375    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
376    ///
377    /// # Panics
378    /// If any string in ``seq_breakers`` contains null bytes.
379    #[allow(missing_docs)]
380    #[must_use]
381    pub fn dry(
382        model: &LlamaModel,
383        multiplier: f32,
384        base: f32,
385        allowed_length: i32,
386        penalty_last_n: i32,
387        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
388    ) -> Self {
389        let seq_breakers: Vec<CString> = seq_breakers
390            .into_iter()
391            .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
392            .collect();
393        let mut seq_breaker_pointers: Vec<*const c_char> =
394            seq_breakers.iter().map(|s| s.as_ptr()).collect();
395
396        let sampler = unsafe {
397            llama_cpp_sys_2::llama_sampler_init_dry(
398                model.vocab_ptr(),
399                model
400                    .n_ctx_train()
401                    .try_into()
402                    .expect("n_ctx_train exceeds i32::MAX"),
403                multiplier,
404                base,
405                allowed_length,
406                penalty_last_n,
407                seq_breaker_pointers.as_mut_ptr(),
408                seq_breaker_pointers.len(),
409            )
410        };
411        Self { sampler }
412    }
413
414    /// Penalizes tokens for being present in the context.
415    ///
416    /// Parameters:
417    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
418    /// - ``penalty_repeat``: 1.0 = disabled
419    /// - ``penalty_freq``: 0.0 = disabled
420    /// - ``penalty_present``: 0.0 = disabled
421    #[allow(clippy::too_many_arguments)]
422    #[must_use]
423    pub fn penalties(
424        penalty_last_n: i32,
425        penalty_repeat: f32,
426        penalty_freq: f32,
427        penalty_present: f32,
428    ) -> Self {
429        let sampler = unsafe {
430            llama_cpp_sys_2::llama_sampler_init_penalties(
431                penalty_last_n,
432                penalty_repeat,
433                penalty_freq,
434                penalty_present,
435            )
436        };
437        Self { sampler }
438    }
439
440    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
441    ///
442    /// # Parameters:
443    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
444    /// - ``seed``: Seed to initialize random generation with.
445    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
446    ///   generated text. A higher value corresponds to more surprising or less predictable text,
447    ///   while a lower value corresponds to less surprising or more predictable text.
448    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
449    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
450    ///   updated more quickly, while a smaller learning rate will result in slower updates.
451    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
452    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
453    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
454    ///   it affects the performance of the algorithm.
455    #[must_use]
456    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
457        let sampler =
458            unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
459        Self { sampler }
460    }
461
462    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
463    ///
464    /// # Parameters:
465    /// - ``seed``: Seed to initialize random generation with.
466    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
467    ///   generated text. A higher value corresponds to more surprising or less predictable text,
468    ///   while a lower value corresponds to less surprising or more predictable text.
469    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
470    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
471    ///   updated more quickly, while a smaller learning rate will result in slower updates.
472    #[must_use]
473    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
474        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
475        Self { sampler }
476    }
477
478    /// Selects a token at random based on each token's probabilities
479    #[must_use]
480    pub fn dist(seed: u32) -> Self {
481        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
482        Self { sampler }
483    }
484
485    /// Selects the most likely token
486    ///
487    /// # Example:
488    /// ```rust
489    /// use llama_cpp_2::token::{
490    ///    LlamaToken,
491    ///    data::LlamaTokenData,
492    ///    data_array::LlamaTokenDataArray
493    /// };
494    /// use llama_cpp_2::sampling::LlamaSampler;
495    ///
496    /// let mut data_array = LlamaTokenDataArray::new(vec![
497    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
498    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
499    /// ], false);
500    ///
501    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
502    ///
503    /// assert_eq!(data_array.data.len(), 2);
504    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
505    /// ```
506    #[must_use]
507    pub fn greedy() -> Self {
508        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
509        Self { sampler }
510    }
511
512    /// Creates a sampler that applies bias values to specific tokens during sampling.
513    ///
514    /// # Parameters
515    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
516    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
517    ///
518    /// # Example
519    /// ```rust
520    /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
521    /// use llama_cpp_2::sampling::LlamaSampler;
522    ///
523    /// let biases = vec![
524    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
525    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
526    /// ];
527    ///
528    /// // Assuming vocab_size of 32000
529    /// let sampler = LlamaSampler::logit_bias(32000, &biases);
530    /// ```
531    #[must_use]
532    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
533        let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
534
535        let sampler = unsafe {
536            llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
537        };
538
539        Self { sampler }
540    }
541}
542
543impl Drop for LlamaSampler {
544    fn drop(&mut self) {
545        unsafe {
546            llama_cpp_sys_2::llama_sampler_free(self.sampler);
547        }
548    }
549}