llama_cpp_2/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12
13/// A safe wrapper around `llama_sampler`.
14pub struct LlamaSampler {
15    pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
16}
17
18impl Debug for LlamaSampler {
19    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("LlamaSamplerChain").finish()
21    }
22}
23
24impl LlamaSampler {
25    /// Sample and accept a token from the idx-th output of the last evaluation
26    #[must_use]
27    pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
28        let token = unsafe {
29            llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
30        };
31
32        LlamaToken(token)
33    }
34
35    /// Applies this sampler to a [`LlamaTokenDataArray`].
36    pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
37        data_array.apply_sampler(self);
38    }
39
40    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
41    /// (e.g. grammar, repetition, etc.)
42    pub fn accept(&mut self, token: LlamaToken) {
43        unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) }
44    }
45
46    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
47    /// certain samplers (e.g. grammar, repetition, etc.)
48    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
49        for token in tokens {
50            unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
51        }
52    }
53
54    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
55    /// certain samplers (e.g. grammar, repetition, etc.)
56    #[must_use]
57    pub fn with_tokens(
58        mut self,
59        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
60    ) -> Self {
61        self.accept_many(tokens);
62        self
63    }
64
65    /// Resets the internal state of the sampler.
66    ///
67    /// This can be useful when you want to start fresh with a sampler without creating a new instance.
68    pub fn reset(&mut self) {
69        unsafe {
70            llama_cpp_sys_2::llama_sampler_reset(self.sampler);
71        }
72    }
73
74    /// Gets the random seed used by this sampler.
75    ///
76    /// Returns:
77    /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
78    /// - For sampler chains: returns the first non-default seed found in reverse order
79    /// - For all other samplers: returns 0xFFFFFFFF
80    #[must_use]
81    pub fn get_seed(&self) -> u32 {
82        unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
83    }
84
85    /// Combines a list of samplers into a single sampler that applies each component sampler one
86    /// after another.
87    ///
88    /// If you are using a chain to select a token, the chain should always end with one of
89    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
90    /// [`LlamaSampler::mirostat_v2`].
91    #[must_use]
92    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
93        unsafe {
94            let chain = llama_cpp_sys_2::llama_sampler_chain_init(
95                llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
96            );
97
98            for sampler in samplers {
99                llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
100
101                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
102                // owned by the chain
103                std::mem::forget(sampler);
104            }
105
106            Self { sampler: chain }
107        }
108    }
109
110    /// Same as [`Self::chain`] with `no_perf = false`.
111    ///
112    /// # Example
113    /// ```rust
114    /// use llama_cpp_2::token::{
115    ///    LlamaToken,
116    ///    data::LlamaTokenData,
117    ///    data_array::LlamaTokenDataArray
118    /// };
119    /// use llama_cpp_2::sampling::LlamaSampler;
120    /// use llama_cpp_2::llama_backend::LlamaBackend;
121    /// let backend = LlamaBackend::init().unwrap();
122    ///
123    /// let mut data_array = LlamaTokenDataArray::new(vec![
124    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
125    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
126    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
127    /// ], false);
128    ///
129    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
130    ///     LlamaSampler::temp(0.5),
131    ///     LlamaSampler::greedy(),
132    /// ]));
133    ///
134    /// assert_eq!(data_array.data[0].logit(), 0.);
135    /// assert_eq!(data_array.data[1].logit(), 2.);
136    /// assert_eq!(data_array.data[2].logit(), 4.);
137    ///
138    /// assert_eq!(data_array.data.len(), 3);
139    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
140    /// ```
141    #[must_use]
142    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
143        Self::chain(samplers, false)
144    }
145
146    #[allow(clippy::doc_markdown)]
147    /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
148    /// value, the rest are set to -inf
149    ///
150    /// # Example:
151    /// ```rust
152    /// use llama_cpp_2::token::{
153    ///    LlamaToken,
154    ///    data::LlamaTokenData,
155    ///    data_array::LlamaTokenDataArray
156    /// };
157    /// use llama_cpp_2::sampling::LlamaSampler;
158    ///
159    /// let mut data_array = LlamaTokenDataArray::new(vec![
160    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
161    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
162    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
163    /// ], false);
164    ///
165    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
166    ///
167    /// assert_eq!(data_array.data[0].logit(), 0.);
168    /// assert_eq!(data_array.data[1].logit(), 2.);
169    /// assert_eq!(data_array.data[2].logit(), 4.);
170    /// ```
171    #[must_use]
172    pub fn temp(t: f32) -> Self {
173        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
174        Self { sampler }
175    }
176
177    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
178    /// <https://arxiv.org/abs/2309.02772>.
179    #[must_use]
180    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
181        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
182        Self { sampler }
183    }
184
185    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
186    /// <https://arxiv.org/abs/1904.09751>
187    ///
188    /// # Example:
189    /// ```rust
190    /// use llama_cpp_2::token::{
191    ///    LlamaToken,
192    ///    data::LlamaTokenData,
193    ///    data_array::LlamaTokenDataArray
194    /// };
195    /// use llama_cpp_2::sampling::LlamaSampler;
196    ///
197    /// let mut data_array = LlamaTokenDataArray::new(vec![
198    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
199    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
200    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
201    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
202    /// ], false);
203    ///
204    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
205    ///
206    /// assert_eq!(data_array.data.len(), 2);
207    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
208    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
209    /// ```
210    #[must_use]
211    pub fn top_k(k: i32) -> Self {
212        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
213        Self { sampler }
214    }
215
216    /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
217    /// <https://arxiv.org/pdf/2411.07641>
218    ///
219    /// This method filters logits by selecting only those within *n* standard deviations of the mean.
220    ///
221    /// # Parameters
222    /// - `n`: Number of standard deviations from the mean to include in sampling
223    ///
224    /// # Example
225    /// ```rust
226    /// use llama_cpp_2::sampling::LlamaSampler;
227    /// use llama_cpp_2::token::{
228    ///     LlamaToken,
229    ///     data::LlamaTokenData,
230    ///     data_array::LlamaTokenDataArray
231    /// };
232    ///
233    /// let mut data_array = LlamaTokenDataArray::new(vec![
234    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
235    ///     LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
236    ///     LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
237    /// ], false);
238    ///
239    /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
240    /// ```
241    #[must_use]
242    pub fn top_n_sigma(n: f32) -> Self {
243        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
244        Self { sampler }
245    }
246
247    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
248    #[must_use]
249    pub fn typical(p: f32, min_keep: usize) -> Self {
250        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
251        Self { sampler }
252    }
253
254    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
255    /// <https://arxiv.org/abs/1904.09751>
256    #[must_use]
257    pub fn top_p(p: f32, min_keep: usize) -> Self {
258        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
259        Self { sampler }
260    }
261
262    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
263    #[must_use]
264    pub fn min_p(p: f32, min_keep: usize) -> Self {
265        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
266        Self { sampler }
267    }
268
269    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
270    #[must_use]
271    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
272        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
273        Self { sampler }
274    }
275
276    /// Grammar sampler
277    ///
278    /// # Panics
279    /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes.
280    #[must_use]
281    pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Option<Self> {
282        let grammar_str = CString::new(grammar_str).unwrap();
283        let grammar_root = CString::new(grammar_root).unwrap();
284
285        let sampler = unsafe {
286            llama_cpp_sys_2::llama_sampler_init_grammar(
287                model.vocab_ptr(),
288                grammar_str.as_ptr(),
289                grammar_root.as_ptr(),
290            )
291        };
292
293        if sampler.is_null() {
294            None
295        } else {
296            Some(Self { sampler })
297        }
298    }
299
300    /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
301    ///
302    /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
303    ///
304    /// # Panics
305    /// - If `grammar_str` or `grammar_root` contain null bytes
306    /// - If any trigger word contains null bytes
307    #[must_use]
308    pub fn grammar_lazy(
309        model: &LlamaModel,
310        grammar_str: &str,
311        grammar_root: &str,
312        trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
313        trigger_tokens: &[LlamaToken],
314    ) -> Option<Self> {
315        let grammar_str = CString::new(grammar_str).unwrap();
316        let grammar_root = CString::new(grammar_root).unwrap();
317
318        let trigger_word_cstrings: Vec<CString> = trigger_words
319            .into_iter()
320            .map(|word| CString::new(word.as_ref()).unwrap())
321            .collect();
322
323        let mut trigger_word_ptrs: Vec<*const c_char> =
324            trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect();
325
326        let sampler = unsafe {
327            llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
328                model.vocab_ptr(),
329                grammar_str.as_ptr(),
330                grammar_root.as_ptr(),
331                trigger_word_ptrs.as_mut_ptr(),
332                trigger_word_ptrs.len(),
333                trigger_tokens.as_ptr().cast(),
334                trigger_tokens.len(),
335            )
336        };
337
338        if sampler.is_null() {
339            None
340        } else {
341            Some(Self { sampler })
342        }
343    }
344
345    /// DRY sampler, designed by p-e-w, as described in:
346    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
347    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
348    ///
349    /// # Panics
350    /// If any string in ``seq_breakers`` contains null bytes.
351    #[allow(missing_docs)]
352    #[must_use]
353    pub fn dry(
354        model: &LlamaModel,
355        multiplier: f32,
356        base: f32,
357        allowed_length: i32,
358        penalty_last_n: i32,
359        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
360    ) -> Self {
361        let seq_breakers: Vec<CString> = seq_breakers
362            .into_iter()
363            .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
364            .collect();
365        let mut seq_breaker_pointers: Vec<*const c_char> =
366            seq_breakers.iter().map(|s| s.as_ptr()).collect();
367
368        let sampler = unsafe {
369            llama_cpp_sys_2::llama_sampler_init_dry(
370                model.vocab_ptr(),
371                model
372                    .n_ctx_train()
373                    .try_into()
374                    .expect("n_ctx_train exceeds i32::MAX"),
375                multiplier,
376                base,
377                allowed_length,
378                penalty_last_n,
379                seq_breaker_pointers.as_mut_ptr(),
380                seq_breaker_pointers.len(),
381            )
382        };
383        Self { sampler }
384    }
385
386    /// Penalizes tokens for being present in the context.
387    ///
388    /// Parameters:  
389    /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
390    /// - ``penalty_repeat``: 1.0 = disabled
391    /// - ``penalty_freq``: 0.0 = disabled
392    /// - ``penalty_present``: 0.0 = disabled
393    #[allow(clippy::too_many_arguments)]
394    #[must_use]
395    pub fn penalties(
396        penalty_last_n: i32,
397        penalty_repeat: f32,
398        penalty_freq: f32,
399        penalty_present: f32,
400    ) -> Self {
401        let sampler = unsafe {
402            llama_cpp_sys_2::llama_sampler_init_penalties(
403                penalty_last_n,
404                penalty_repeat,
405                penalty_freq,
406                penalty_present,
407            )
408        };
409        Self { sampler }
410    }
411
412    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
413    ///
414    /// # Parameters:
415    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
416    /// - ``seed``: Seed to initialize random generation with.
417    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
418    ///     generated text. A higher value corresponds to more surprising or less predictable text,
419    ///     while a lower value corresponds to less surprising or more predictable text.
420    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
421    ///     observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
422    ///     updated more quickly, while a smaller learning rate will result in slower updates.
423    /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
424    ///     value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
425    ///     In the paper, they use `m = 100`, but you can experiment with different values to see how
426    ///     it affects the performance of the algorithm.
427    #[must_use]
428    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
429        let sampler =
430            unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
431        Self { sampler }
432    }
433
434    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
435    ///
436    /// # Parameters:
437    /// - ``seed``: Seed to initialize random generation with.
438    /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
439    ///     generated text. A higher value corresponds to more surprising or less predictable text,
440    ///     while a lower value corresponds to less surprising or more predictable text.
441    /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
442    ///     observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
443    ///     updated more quickly, while a smaller learning rate will result in slower updates.
444    #[must_use]
445    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
446        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
447        Self { sampler }
448    }
449
450    /// Selects a token at random based on each token's probabilities
451    #[must_use]
452    pub fn dist(seed: u32) -> Self {
453        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
454        Self { sampler }
455    }
456
457    /// Selects the most likely token
458    ///
459    /// # Example:
460    /// ```rust
461    /// use llama_cpp_2::token::{
462    ///    LlamaToken,
463    ///    data::LlamaTokenData,
464    ///    data_array::LlamaTokenDataArray
465    /// };
466    /// use llama_cpp_2::sampling::LlamaSampler;
467    ///
468    /// let mut data_array = LlamaTokenDataArray::new(vec![
469    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
470    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
471    /// ], false);
472    ///
473    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
474    ///
475    /// assert_eq!(data_array.data.len(), 2);
476    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
477    /// ```
478    #[must_use]
479    pub fn greedy() -> Self {
480        let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
481        Self { sampler }
482    }
483
484    /// Creates a sampler that applies bias values to specific tokens during sampling.
485    ///
486    /// # Parameters
487    /// - ``n_vocab``: [`LlamaModel::n_vocab`]
488    /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
489    ///
490    /// # Example
491    /// ```rust
492    /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
493    /// use llama_cpp_2::sampling::LlamaSampler;
494    ///
495    /// let biases = vec![
496    ///     LlamaLogitBias::new(LlamaToken(1), 1.5),  // Increase probability of token 1
497    ///     LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
498    /// ];
499    ///
500    /// // Assuming vocab_size of 32000
501    /// let sampler = LlamaSampler::logit_bias(32000, &biases);
502    /// ```
503    #[must_use]
504    pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
505        let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
506
507        let sampler = unsafe {
508            llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
509        };
510
511        Self { sampler }
512    }
513}
514
515impl Drop for LlamaSampler {
516    fn drop(&mut self) {
517        unsafe {
518            llama_cpp_sys_2::llama_sampler_free(self.sampler);
519        }
520    }
521}