Skip to main content

llama_cpp_4/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9    common::common_sampler_params, llama_sampler, llama_sampler_accept, llama_sampler_chain_add,
10    llama_sampler_chain_default_params, llama_sampler_chain_init, llama_sampler_free,
11    llama_sampler_init_dist, llama_sampler_init_dry, llama_sampler_init_grammar,
12    llama_sampler_init_greedy, llama_sampler_init_min_p, llama_sampler_init_mirostat,
13    llama_sampler_init_mirostat_v2, llama_sampler_init_penalties, llama_sampler_init_temp,
14    llama_sampler_init_temp_ext, llama_sampler_init_top_k, llama_sampler_init_top_p,
15    llama_sampler_init_typical, llama_sampler_init_xtc, llama_sampler_sample,
16};
17
18use crate::context::LlamaContext;
19use crate::model::LlamaModel;
20use crate::token::data_array::LlamaTokenDataArray;
21use crate::token::LlamaToken;
22
23/// A safe wrapper around `llama_sampler`.
24pub struct LlamaSampler {
25    pub(crate) sampler: NonNull<llama_sampler>,
26}
27
28impl Debug for LlamaSampler {
29    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("LlamaSamplerChain").finish()
31    }
32}
33#[derive(Debug, Clone)]
34#[allow(
35    missing_docs,
36    clippy::struct_excessive_bools,
37    clippy::module_name_repetitions,
38    dead_code
39)]
40pub struct LlamaSamplerParams {
41    top_k: i32,
42    top_p: f32,
43    temp: f32,
44    seed: u32,
45}
46
47impl LlamaSamplerParams {
48    /// Set the seed of the context
49    ///
50    /// # Examples
51    ///
52    /// ```rust
53    /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
54    /// let params = LlamaSamplerParams::default();
55    /// let params = params.with_seed(1234);
56    /// assert_eq!(params.seed(), 1234);
57    /// ```
58    #[must_use]
59    pub fn with_seed(mut self, seed: u32) -> Self {
60        self.seed = seed;
61        self
62    }
63
64    /// Get the seed of the context
65    ///
66    /// # Examples
67    ///
68    /// ```rust
69    /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
70    /// let params = LlamaSamplerParams::default();
71    ///     .with_seed(1234);
72    /// assert_eq!(params.seed(), 1234);
73    /// ```
74    #[must_use]
75    pub fn seed(&self) -> u32 {
76        self.seed
77    }
78}
79
80impl Default for LlamaSamplerParams {
81    fn default() -> Self {
82        Self {
83            top_k: 50,
84            top_p: 0.9,
85            temp: 0.8,
86            seed: 1234,
87        }
88    }
89}
90
91impl Default for LlamaSampler {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl LlamaSampler {
98    /// Create new sampler with default params.
99    ///
100    /// # Panics
101    ///
102    /// Panics if llama.cpp returns a null pointer.
103    #[must_use]
104    pub fn new() -> Self {
105        let sparams = unsafe { llama_sampler_chain_default_params() };
106
107        Self {
108            sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
109        }
110    }
111
112    /// Sample and accept a token from the idx-th output of the last evaluation
113    #[must_use]
114    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
115        let token =
116            unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
117
118        LlamaToken(token)
119    }
120
121    /// Applies this sampler to a [`LlamaTokenDataArray`].
122    pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
123        data_array.apply_sampler(self);
124    }
125
126    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
127    /// (e.g. grammar, repetition, etc.)
128    pub fn accept(&mut self, token: LlamaToken) {
129        unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
130    }
131
132    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
133    /// certain samplers (e.g. grammar, repetition, etc.)
134    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
135        for token in tokens {
136            unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
137        }
138    }
139
140    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
141    /// certain samplers (e.g. grammar, repetition, etc.)
142    #[must_use]
143    pub fn with_tokens(
144        mut self,
145        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
146    ) -> Self {
147        self.accept_many(tokens);
148        self
149    }
150
151    /// Combines a list of samplers into a single sampler that applies each component sampler one
152    /// after another.
153    ///
154    /// If you are using a chain to select a token, the chain should always end with one of
155    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
156    /// [`LlamaSampler::mirostat_v2`].
157    ///
158    /// # Panics
159    ///
160    /// Panics if llama.cpp returns a null pointer.
161    #[must_use]
162    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
163        unsafe {
164            let mut params = llama_sampler_chain_default_params();
165            params.no_perf = no_perf;
166            let chain = llama_sampler_chain_init(params);
167
168            for sampler in samplers {
169                llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
170
171                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
172                // owned by the chain
173                std::mem::forget(sampler);
174            }
175
176            Self {
177                sampler: NonNull::new(chain).unwrap(),
178            }
179        }
180    }
181
182    /// Same as [`Self::chain`] with `no_perf = false`.
183    ///
184    /// # Panics
185    ///
186    /// Panics if llama.cpp returns a null pointer.
187    ///
188    /// # Example
189    /// ```rust
190    /// use llama_cpp_4::token::{
191    ///    LlamaToken,
192    ///    data::LlamaTokenData,
193    ///    data_array::LlamaTokenDataArray
194    /// };
195    /// use llama_cpp_4::sampling::LlamaSampler;
196    ///
197    /// let mut data_array = LlamaTokenDataArray::new(vec![
198    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
199    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
200    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
201    /// ], false);
202    ///
203    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
204    ///     LlamaSampler::temp(0.5),
205    ///     LlamaSampler::greedy(),
206    /// ]));
207    ///
208    /// assert_eq!(data_array.data[0].logit(), 0.);
209    /// assert_eq!(data_array.data[1].logit(), 2.);
210    /// assert_eq!(data_array.data[2].logit(), 4.);
211    ///
212    /// assert_eq!(data_array.data.len(), 3);
213    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
214    /// ```
215    #[must_use]
216    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
217        Self::chain(samplers, false)
218    }
219
220    /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
221    /// value, the rest are set to -inf.
222    ///
223    /// # Panics
224    ///
225    /// Panics if llama.cpp returns a null pointer.
226    ///
227    /// # Example:
228    /// ```rust
229    /// use llama_cpp_4::token::{
230    ///    LlamaToken,
231    ///    data::LlamaTokenData,
232    ///    data_array::LlamaTokenDataArray
233    /// };
234    /// use llama_cpp_4::sampling::LlamaSampler;
235    ///
236    /// let mut data_array = LlamaTokenDataArray::new(vec![
237    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
238    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
239    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
240    /// ], false);
241    ///
242    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
243    ///
244    /// assert_eq!(data_array.data[0].logit(), 0.);
245    /// assert_eq!(data_array.data[1].logit(), 2.);
246    /// assert_eq!(data_array.data[2].logit(), 4.);
247    /// ```
248    #[must_use]
249    pub fn temp(t: f32) -> Self {
250        let sampler = unsafe { llama_sampler_init_temp(t) };
251        Self {
252            sampler: NonNull::new(sampler).unwrap(),
253        }
254    }
255
256    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
257    /// <https://arxiv.org/abs/2309.02772>.
258    ///
259    /// # Panics
260    ///
261    /// Panics if llama.cpp returns a null pointer.
262    #[must_use]
263    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
264        let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
265        Self {
266            sampler: NonNull::new(sampler).unwrap(),
267        }
268    }
269
270    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
271    /// <https://arxiv.org/abs/1904.09751>.
272    ///
273    /// # Panics
274    ///
275    /// Panics if llama.cpp returns a null pointer.
276    ///
277    /// # Example:
278    /// ```rust
279    /// use llama_cpp_4::token::{
280    ///    LlamaToken,
281    ///    data::LlamaTokenData,
282    ///    data_array::LlamaTokenDataArray
283    /// };
284    /// use llama_cpp_4::sampling::LlamaSampler;
285    ///
286    /// let mut data_array = LlamaTokenDataArray::new(vec![
287    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
288    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
289    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
290    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
291    /// ], false);
292    ///
293    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
294    ///
295    /// assert_eq!(data_array.data.len(), 2);
296    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
297    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
298    /// ```
299    #[must_use]
300    pub fn top_k(k: i32) -> Self {
301        let sampler = unsafe { llama_sampler_init_top_k(k) };
302        Self {
303            sampler: NonNull::new(sampler).unwrap(),
304        }
305    }
306
307    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
308    ///
309    /// # Panics
310    ///
311    /// Panics if llama.cpp returns a null pointer.
312    #[must_use]
313    pub fn typical(p: f32, min_keep: usize) -> Self {
314        let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
315        Self {
316            sampler: NonNull::new(sampler).unwrap(),
317        }
318    }
319
320    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
321    /// <https://arxiv.org/abs/1904.09751>.
322    ///
323    /// # Panics
324    ///
325    /// Panics if llama.cpp returns a null pointer.
326    #[must_use]
327    pub fn top_p(p: f32, min_keep: usize) -> Self {
328        let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
329        Self {
330            sampler: NonNull::new(sampler).unwrap(),
331        }
332    }
333
334    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
335    ///
336    /// # Panics
337    ///
338    /// Panics if llama.cpp returns a null pointer.
339    #[must_use]
340    pub fn min_p(p: f32, min_keep: usize) -> Self {
341        let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
342        Self {
343            sampler: NonNull::new(sampler).unwrap(),
344        }
345    }
346
347    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
348    ///
349    /// # Panics
350    ///
351    /// Panics if llama.cpp returns a null pointer.
352    #[must_use]
353    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
354        let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
355        Self {
356            sampler: NonNull::new(sampler).unwrap(),
357        }
358    }
359
360    /// Grammar sampler
361    ///
362    /// # Panics
363    /// - If either of `grammar_str` or `grammar_root` contain null bytes.
364    /// - If llama.cpp returns a null pointer.
365    #[must_use]
366    pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
367        let grammar_str = CString::new(grammar_str).unwrap();
368        let grammar_root = CString::new(grammar_root).unwrap();
369
370        let sampler = unsafe {
371            llama_sampler_init_grammar(
372                model.get_vocab().vocab.as_ref(),
373                grammar_str.as_ptr(),
374                grammar_root.as_ptr(),
375            )
376        };
377        Self {
378            sampler: NonNull::new(sampler).unwrap(),
379        }
380    }
381
382    /// DRY sampler, designed by p-e-w, as described in:
383    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
384    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
385    ///
386    /// # Panics
387    /// - If any string in `seq_breakers` contains null bytes.
388    /// - If llama.cpp returns a null pointer.
389    #[allow(clippy::too_many_arguments)]
390    #[must_use]
391    pub fn dry(
392        &self,
393        model: &LlamaModel,
394        n_ctx_train: i32,
395        multiplier: f32,
396        base: f32,
397        allowed_length: i32,
398        penalty_last_n: i32,
399        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
400    ) -> Self {
401        let seq_breakers: Vec<CString> = seq_breakers
402            .into_iter()
403            .map(|s| CString::new(s.as_ref()).unwrap())
404            .collect();
405        // CString::as_ptr() returns *const c_char, which matches what the binding
406        // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
407        let mut seq_breaker_pointers: Vec<*const c_char> =
408            seq_breakers.iter().map(|s| s.as_ptr()).collect();
409
410        let sampler = unsafe {
411            llama_sampler_init_dry(
412                model.get_vocab().vocab.as_ref(),
413                n_ctx_train,
414                multiplier,
415                base,
416                allowed_length,
417                penalty_last_n,
418                seq_breaker_pointers.as_mut_ptr(),
419                seq_breaker_pointers.len(),
420            )
421        };
422
423        Self {
424            sampler: NonNull::new(sampler).unwrap(),
425        }
426    }
427
428    /// Penalizes tokens for being present in the context.
429    ///
430    /// Parameters:
431    /// - `n_vocab`: [`LlamaModel::n_vocab`]
432    /// - `special_eos_id`: [`LlamaModel::token_eos`]
433    /// - `linefeed_id`: [`LlamaModel::token_nl`]
434    /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
435    ///
436    /// # Panics
437    ///
438    /// Panics if llama.cpp returns a null pointer.
439    #[allow(clippy::too_many_arguments)]
440    #[must_use]
441    pub fn penalties(
442        n_vocab: i32,
443        special_eos_id: f32,
444        linefeed_id: f32,
445        penalty_last_n: f32,
446        // penalty_repeat: f32,
447        // penalty_freq: f32,
448        // penalty_present: f32,
449        // penalize_nl: bool,
450        // ignore_eos: bool,
451    ) -> Self {
452        let sampler = unsafe {
453            llama_sampler_init_penalties(
454                n_vocab,
455                special_eos_id,
456                linefeed_id,
457                penalty_last_n,
458                // penalty_repeat,
459                // penalty_freq,
460                // penalty_present,
461                // penalize_nl,
462                // ignore_eos,
463            )
464        };
465        Self {
466            sampler: NonNull::new(sampler).unwrap(),
467        }
468    }
469
470    /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
471    /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
472    ///
473    /// Parameters:
474    /// - `model`: The model's tokenizer to use to initialize the sampler.
475    /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
476    ///
477    /// # Panics
478    ///
479    /// Panics if llama.cpp returns a null pointer.
480    #[must_use]
481    pub fn penalties_simple(
482        model: &LlamaModel,
483        penalty_last_n: i32,
484        // penalty_repeat: f32,
485        // penalty_freq: f32,
486        // penalty_present: f32,
487    ) -> Self {
488        Self::penalties(
489            model.n_vocab(),
490            #[allow(clippy::cast_precision_loss)]
491            {
492                model.token_eos().0 as f32
493            },
494            #[allow(clippy::cast_precision_loss)]
495            {
496                model.token_nl().0 as f32
497            },
498            #[allow(clippy::cast_precision_loss)]
499            {
500                penalty_last_n as f32
501            },
502            // penalty_repeat,
503            // penalty_freq,
504            // penalty_present,
505            // false,
506            // true,
507        )
508    }
509
510    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
511    ///
512    /// # Panics
513    ///
514    /// Panics if llama.cpp returns a null pointer.
515    ///
516    /// # Parameters:
517    /// - `n_vocab`: [`LlamaModel::n_vocab`]
518    /// - `seed`: Seed to initialize random generation with.
519    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
520    ///   generated text. A higher value corresponds to more surprising or less predictable text,
521    ///   while a lower value corresponds to less surprising or more predictable text.
522    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
523    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
524    ///   updated more quickly, while a smaller learning rate will result in slower updates.
525    /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
526    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
527    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
528    ///   it affects the performance of the algorithm.
529    #[must_use]
530    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
531        let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
532        Self {
533            sampler: NonNull::new(sampler).unwrap(),
534        }
535    }
536
537    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
538    ///
539    /// # Panics
540    ///
541    /// Panics if llama.cpp returns a null pointer.
542    ///
543    /// # Parameters:
544    /// - `seed`: Seed to initialize random generation with.
545    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
546    ///   generated text. A higher value corresponds to more surprising or less predictable text,
547    ///   while a lower value corresponds to less surprising or more predictable text.
548    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
549    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
550    ///   updated more quickly, while a smaller learning rate will result in slower updates.
551    #[must_use]
552    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
553        let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
554        Self {
555            sampler: NonNull::new(sampler).unwrap(),
556        }
557    }
558
559    /// Selects a token at random based on each token's probabilities.
560    ///
561    /// # Panics
562    ///
563    /// Panics if llama.cpp returns a null pointer.
564    #[must_use]
565    pub fn dist(seed: u32) -> Self {
566        let sampler = unsafe { llama_sampler_init_dist(seed) };
567        Self {
568            sampler: NonNull::new(sampler).unwrap(),
569        }
570    }
571
572    /// Selects the most likely token.
573    ///
574    /// # Panics
575    ///
576    /// Panics if llama.cpp returns a null pointer.
577    ///
578    /// # Example:
579    /// ```rust
580    /// use llama_cpp_4::token::{
581    ///    LlamaToken,
582    ///    data::LlamaTokenData,
583    ///    data_array::LlamaTokenDataArray
584    /// };
585    /// use llama_cpp_4::sampling::LlamaSampler;
586    ///
587    /// let mut data_array = LlamaTokenDataArray::new(vec![
588    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
589    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
590    /// ], false);
591    ///
592    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
593    ///
594    /// assert_eq!(data_array.data.len(), 2);
595    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
596    /// ```
597    #[must_use]
598    pub fn greedy() -> Self {
599        let sampler = unsafe { llama_sampler_init_greedy() };
600        Self {
601            sampler: NonNull::new(sampler).unwrap(),
602        }
603    }
604
605    /// Creates a new instance of `LlamaSampler` with common sampling parameters.
606    ///
607    /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
608    /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
609    ///
610    /// # Panics
611    ///
612    /// Panics if llama.cpp returns a null pointer.
613    ///
614    /// # Returns
615    /// A `LlamaSampler` instance configured with the common sampling parameters.
616    #[must_use]
617    pub fn common() -> Self {
618        let params = common_sampler_params::default();
619
620        let sampler = unsafe {
621            let mut sparams = llama_sampler_chain_default_params();
622            sparams.no_perf = false;
623
624            let smpl = llama_sampler_chain_init(sparams);
625
626            llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
627            llama_sampler_chain_add(
628                smpl,
629                #[allow(clippy::cast_sign_loss)]
630                llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
631            );
632            llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
633            #[allow(clippy::cast_sign_loss)]
634            llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
635
636            smpl
637        };
638
639        Self {
640            sampler: NonNull::new(sampler).unwrap(),
641        }
642    }
643}
644
645impl Drop for LlamaSampler {
646    fn drop(&mut self) {
647        unsafe {
648            llama_sampler_free(self.sampler.as_ptr());
649        }
650    }
651}