Skip to main content

llama_cpp_4/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9    common::common_sampler_params, llama_logit_bias, llama_sampler, llama_sampler_accept,
10    llama_sampler_chain_add, llama_sampler_chain_default_params, llama_sampler_chain_init,
11    llama_sampler_chain_n, llama_sampler_chain_remove, llama_sampler_clone, llama_sampler_free,
12    llama_sampler_get_seed, llama_sampler_init_adaptive_p, llama_sampler_init_dist,
13    llama_sampler_init_dry, llama_sampler_init_grammar, llama_sampler_init_grammar_lazy,
14    llama_sampler_init_grammar_lazy_patterns, llama_sampler_init_greedy,
15    llama_sampler_init_infill, llama_sampler_init_logit_bias, llama_sampler_init_min_p,
16    llama_sampler_init_mirostat, llama_sampler_init_mirostat_v2, llama_sampler_init_penalties,
17    llama_sampler_init_temp, llama_sampler_init_temp_ext, llama_sampler_init_top_k,
18    llama_sampler_init_top_n_sigma, llama_sampler_init_top_p, llama_sampler_init_typical,
19    llama_sampler_init_xtc, llama_sampler_name, llama_sampler_reset, llama_sampler_sample,
20};
21
22use crate::context::LlamaContext;
23use crate::model::LlamaModel;
24use crate::token::data_array::LlamaTokenDataArray;
25use crate::token::LlamaToken;
26
27/// A safe wrapper around `llama_sampler`.
28pub struct LlamaSampler {
29    pub(crate) sampler: NonNull<llama_sampler>,
30}
31
32impl Debug for LlamaSampler {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("LlamaSamplerChain").finish()
35    }
36}
37#[derive(Debug, Clone)]
38#[allow(
39    missing_docs,
40    clippy::struct_excessive_bools,
41    clippy::module_name_repetitions,
42    dead_code
43)]
44pub struct LlamaSamplerParams {
45    top_k: i32,
46    top_p: f32,
47    temp: f32,
48    seed: u32,
49}
50
51impl LlamaSamplerParams {
52    /// Set the seed of the context
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
58    /// let params = LlamaSamplerParams::default();
59    /// let params = params.with_seed(1234);
60    /// assert_eq!(params.seed(), 1234);
61    /// ```
62    #[must_use]
63    pub fn with_seed(mut self, seed: u32) -> Self {
64        self.seed = seed;
65        self
66    }
67
68    /// Get the seed of the context
69    ///
70    /// # Examples
71    ///
72    /// ```rust
73    /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
74    /// let params = LlamaSamplerParams::default();
75    ///     .with_seed(1234);
76    /// assert_eq!(params.seed(), 1234);
77    /// ```
78    #[must_use]
79    pub fn seed(&self) -> u32 {
80        self.seed
81    }
82}
83
84impl Default for LlamaSamplerParams {
85    fn default() -> Self {
86        Self {
87            top_k: 50,
88            top_p: 0.9,
89            temp: 0.8,
90            seed: 1234,
91        }
92    }
93}
94
95impl Default for LlamaSampler {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl LlamaSampler {
102    /// Create new sampler with default params.
103    ///
104    /// # Panics
105    ///
106    /// Panics if llama.cpp returns a null pointer.
107    #[must_use]
108    pub fn new() -> Self {
109        let sparams = unsafe { llama_sampler_chain_default_params() };
110
111        Self {
112            sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
113        }
114    }
115
116    /// Sample and accept a token from the idx-th output of the last evaluation
117    #[must_use]
118    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
119        let token =
120            unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
121
122        LlamaToken(token)
123    }
124
125    /// Applies this sampler to a [`LlamaTokenDataArray`].
126    pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
127        data_array.apply_sampler(self);
128    }
129
130    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
131    /// (e.g. grammar, repetition, etc.)
132    pub fn accept(&mut self, token: LlamaToken) {
133        unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
134    }
135
136    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
137    /// certain samplers (e.g. grammar, repetition, etc.)
138    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
139        for token in tokens {
140            unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
141        }
142    }
143
144    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
145    /// certain samplers (e.g. grammar, repetition, etc.)
146    #[must_use]
147    pub fn with_tokens(
148        mut self,
149        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
150    ) -> Self {
151        self.accept_many(tokens);
152        self
153    }
154
155    /// Combines a list of samplers into a single sampler that applies each component sampler one
156    /// after another.
157    ///
158    /// If you are using a chain to select a token, the chain should always end with one of
159    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
160    /// [`LlamaSampler::mirostat_v2`].
161    ///
162    /// # Panics
163    ///
164    /// Panics if llama.cpp returns a null pointer.
165    #[must_use]
166    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
167        unsafe {
168            let mut params = llama_sampler_chain_default_params();
169            params.no_perf = no_perf;
170            let chain = llama_sampler_chain_init(params);
171
172            for sampler in samplers {
173                llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
174
175                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
176                // owned by the chain
177                std::mem::forget(sampler);
178            }
179
180            Self {
181                sampler: NonNull::new(chain).unwrap(),
182            }
183        }
184    }
185
186    /// Same as [`Self::chain`] with `no_perf = false`.
187    ///
188    /// # Panics
189    ///
190    /// Panics if llama.cpp returns a null pointer.
191    ///
192    /// # Example
193    /// ```rust
194    /// use llama_cpp_4::token::{
195    ///    LlamaToken,
196    ///    data::LlamaTokenData,
197    ///    data_array::LlamaTokenDataArray
198    /// };
199    /// use llama_cpp_4::sampling::LlamaSampler;
200    ///
201    /// let mut data_array = LlamaTokenDataArray::new(vec![
202    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
203    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
204    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
205    /// ], false);
206    ///
207    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
208    ///     LlamaSampler::temp(0.5),
209    ///     LlamaSampler::greedy(),
210    /// ]));
211    ///
212    /// assert_eq!(data_array.data[0].logit(), 0.);
213    /// assert_eq!(data_array.data[1].logit(), 2.);
214    /// assert_eq!(data_array.data[2].logit(), 4.);
215    ///
216    /// assert_eq!(data_array.data.len(), 3);
217    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
218    /// ```
219    #[must_use]
220    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
221        Self::chain(samplers, false)
222    }
223
224    /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
225    /// value, the rest are set to -inf.
226    ///
227    /// # Panics
228    ///
229    /// Panics if llama.cpp returns a null pointer.
230    ///
231    /// # Example:
232    /// ```rust
233    /// use llama_cpp_4::token::{
234    ///    LlamaToken,
235    ///    data::LlamaTokenData,
236    ///    data_array::LlamaTokenDataArray
237    /// };
238    /// use llama_cpp_4::sampling::LlamaSampler;
239    ///
240    /// let mut data_array = LlamaTokenDataArray::new(vec![
241    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
242    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
243    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
244    /// ], false);
245    ///
246    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
247    ///
248    /// assert_eq!(data_array.data[0].logit(), 0.);
249    /// assert_eq!(data_array.data[1].logit(), 2.);
250    /// assert_eq!(data_array.data[2].logit(), 4.);
251    /// ```
252    #[must_use]
253    pub fn temp(t: f32) -> Self {
254        let sampler = unsafe { llama_sampler_init_temp(t) };
255        Self {
256            sampler: NonNull::new(sampler).unwrap(),
257        }
258    }
259
260    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
261    /// <https://arxiv.org/abs/2309.02772>.
262    ///
263    /// # Panics
264    ///
265    /// Panics if llama.cpp returns a null pointer.
266    #[must_use]
267    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
268        let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
269        Self {
270            sampler: NonNull::new(sampler).unwrap(),
271        }
272    }
273
274    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
275    /// <https://arxiv.org/abs/1904.09751>.
276    ///
277    /// # Panics
278    ///
279    /// Panics if llama.cpp returns a null pointer.
280    ///
281    /// # Example:
282    /// ```rust
283    /// use llama_cpp_4::token::{
284    ///    LlamaToken,
285    ///    data::LlamaTokenData,
286    ///    data_array::LlamaTokenDataArray
287    /// };
288    /// use llama_cpp_4::sampling::LlamaSampler;
289    ///
290    /// let mut data_array = LlamaTokenDataArray::new(vec![
291    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
292    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
293    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
294    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
295    /// ], false);
296    ///
297    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
298    ///
299    /// assert_eq!(data_array.data.len(), 2);
300    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
301    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
302    /// ```
303    #[must_use]
304    pub fn top_k(k: i32) -> Self {
305        let sampler = unsafe { llama_sampler_init_top_k(k) };
306        Self {
307            sampler: NonNull::new(sampler).unwrap(),
308        }
309    }
310
311    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
312    ///
313    /// # Panics
314    ///
315    /// Panics if llama.cpp returns a null pointer.
316    #[must_use]
317    pub fn typical(p: f32, min_keep: usize) -> Self {
318        let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
319        Self {
320            sampler: NonNull::new(sampler).unwrap(),
321        }
322    }
323
324    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
325    /// <https://arxiv.org/abs/1904.09751>.
326    ///
327    /// # Panics
328    ///
329    /// Panics if llama.cpp returns a null pointer.
330    #[must_use]
331    pub fn top_p(p: f32, min_keep: usize) -> Self {
332        let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
333        Self {
334            sampler: NonNull::new(sampler).unwrap(),
335        }
336    }
337
338    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
339    ///
340    /// # Panics
341    ///
342    /// Panics if llama.cpp returns a null pointer.
343    #[must_use]
344    pub fn min_p(p: f32, min_keep: usize) -> Self {
345        let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
346        Self {
347            sampler: NonNull::new(sampler).unwrap(),
348        }
349    }
350
351    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
352    ///
353    /// # Panics
354    ///
355    /// Panics if llama.cpp returns a null pointer.
356    #[must_use]
357    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
358        let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
359        Self {
360            sampler: NonNull::new(sampler).unwrap(),
361        }
362    }
363
364    /// Grammar sampler
365    ///
366    /// # Panics
367    /// - If either of `grammar_str` or `grammar_root` contain null bytes.
368    /// - If llama.cpp returns a null pointer.
369    #[must_use]
370    pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
371        let grammar_str = CString::new(grammar_str).unwrap();
372        let grammar_root = CString::new(grammar_root).unwrap();
373
374        let sampler = unsafe {
375            llama_sampler_init_grammar(
376                model.get_vocab().vocab.as_ref(),
377                grammar_str.as_ptr(),
378                grammar_root.as_ptr(),
379            )
380        };
381        Self {
382            sampler: NonNull::new(sampler).unwrap(),
383        }
384    }
385
386    /// DRY sampler, designed by p-e-w, as described in:
387    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
388    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
389    ///
390    /// # Panics
391    /// - If any string in `seq_breakers` contains null bytes.
392    /// - If llama.cpp returns a null pointer.
393    #[allow(clippy::too_many_arguments)]
394    #[must_use]
395    pub fn dry(
396        &self,
397        model: &LlamaModel,
398        n_ctx_train: i32,
399        multiplier: f32,
400        base: f32,
401        allowed_length: i32,
402        penalty_last_n: i32,
403        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
404    ) -> Self {
405        let seq_breakers: Vec<CString> = seq_breakers
406            .into_iter()
407            .map(|s| CString::new(s.as_ref()).unwrap())
408            .collect();
409        // CString::as_ptr() returns *const c_char, which matches what the binding
410        // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
411        let mut seq_breaker_pointers: Vec<*const c_char> =
412            seq_breakers.iter().map(|s| s.as_ptr()).collect();
413
414        let sampler = unsafe {
415            llama_sampler_init_dry(
416                model.get_vocab().vocab.as_ref(),
417                n_ctx_train,
418                multiplier,
419                base,
420                allowed_length,
421                penalty_last_n,
422                seq_breaker_pointers.as_mut_ptr(),
423                seq_breaker_pointers.len(),
424            )
425        };
426
427        Self {
428            sampler: NonNull::new(sampler).unwrap(),
429        }
430    }
431
432    /// Penalizes tokens for being present in the context.
433    ///
434    /// Parameters:
435    /// - `n_vocab`: [`LlamaModel::n_vocab`]
436    /// - `special_eos_id`: [`LlamaModel::token_eos`]
437    /// - `linefeed_id`: [`LlamaModel::token_nl`]
438    /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
439    ///
440    /// # Panics
441    ///
442    /// Panics if llama.cpp returns a null pointer.
443    #[allow(clippy::too_many_arguments)]
444    #[must_use]
445    pub fn penalties(
446        n_vocab: i32,
447        special_eos_id: f32,
448        linefeed_id: f32,
449        penalty_last_n: f32,
450        // penalty_repeat: f32,
451        // penalty_freq: f32,
452        // penalty_present: f32,
453        // penalize_nl: bool,
454        // ignore_eos: bool,
455    ) -> Self {
456        let sampler = unsafe {
457            llama_sampler_init_penalties(
458                n_vocab,
459                special_eos_id,
460                linefeed_id,
461                penalty_last_n,
462                // penalty_repeat,
463                // penalty_freq,
464                // penalty_present,
465                // penalize_nl,
466                // ignore_eos,
467            )
468        };
469        Self {
470            sampler: NonNull::new(sampler).unwrap(),
471        }
472    }
473
474    /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
475    /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
476    ///
477    /// Parameters:
478    /// - `model`: The model's tokenizer to use to initialize the sampler.
479    /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
480    ///
481    /// # Panics
482    ///
483    /// Panics if llama.cpp returns a null pointer.
484    #[must_use]
485    pub fn penalties_simple(
486        model: &LlamaModel,
487        penalty_last_n: i32,
488        // penalty_repeat: f32,
489        // penalty_freq: f32,
490        // penalty_present: f32,
491    ) -> Self {
492        Self::penalties(
493            model.n_vocab(),
494            #[allow(clippy::cast_precision_loss)]
495            {
496                model.token_eos().0 as f32
497            },
498            #[allow(clippy::cast_precision_loss)]
499            {
500                model.token_nl().0 as f32
501            },
502            #[allow(clippy::cast_precision_loss)]
503            {
504                penalty_last_n as f32
505            },
506            // penalty_repeat,
507            // penalty_freq,
508            // penalty_present,
509            // false,
510            // true,
511        )
512    }
513
514    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
515    ///
516    /// # Panics
517    ///
518    /// Panics if llama.cpp returns a null pointer.
519    ///
520    /// # Parameters:
521    /// - `n_vocab`: [`LlamaModel::n_vocab`]
522    /// - `seed`: Seed to initialize random generation with.
523    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
524    ///   generated text. A higher value corresponds to more surprising or less predictable text,
525    ///   while a lower value corresponds to less surprising or more predictable text.
526    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
527    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
528    ///   updated more quickly, while a smaller learning rate will result in slower updates.
529    /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
530    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
531    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
532    ///   it affects the performance of the algorithm.
533    #[must_use]
534    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
535        let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
536        Self {
537            sampler: NonNull::new(sampler).unwrap(),
538        }
539    }
540
541    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
542    ///
543    /// # Panics
544    ///
545    /// Panics if llama.cpp returns a null pointer.
546    ///
547    /// # Parameters:
548    /// - `seed`: Seed to initialize random generation with.
549    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
550    ///   generated text. A higher value corresponds to more surprising or less predictable text,
551    ///   while a lower value corresponds to less surprising or more predictable text.
552    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
553    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
554    ///   updated more quickly, while a smaller learning rate will result in slower updates.
555    #[must_use]
556    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
557        let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
558        Self {
559            sampler: NonNull::new(sampler).unwrap(),
560        }
561    }
562
563    /// Selects a token at random based on each token's probabilities.
564    ///
565    /// # Panics
566    ///
567    /// Panics if llama.cpp returns a null pointer.
568    #[must_use]
569    pub fn dist(seed: u32) -> Self {
570        let sampler = unsafe { llama_sampler_init_dist(seed) };
571        Self {
572            sampler: NonNull::new(sampler).unwrap(),
573        }
574    }
575
576    /// Selects the most likely token.
577    ///
578    /// # Panics
579    ///
580    /// Panics if llama.cpp returns a null pointer.
581    ///
582    /// # Example:
583    /// ```rust
584    /// use llama_cpp_4::token::{
585    ///    LlamaToken,
586    ///    data::LlamaTokenData,
587    ///    data_array::LlamaTokenDataArray
588    /// };
589    /// use llama_cpp_4::sampling::LlamaSampler;
590    ///
591    /// let mut data_array = LlamaTokenDataArray::new(vec![
592    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
593    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
594    /// ], false);
595    ///
596    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
597    ///
598    /// assert_eq!(data_array.data.len(), 2);
599    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
600    /// ```
601    #[must_use]
602    pub fn greedy() -> Self {
603        let sampler = unsafe { llama_sampler_init_greedy() };
604        Self {
605            sampler: NonNull::new(sampler).unwrap(),
606        }
607    }
608
609    /// Top-N sigma sampling.
610    ///
611    /// Keeps tokens within N standard deviations of the maximum logit.
612    ///
613    /// # Panics
614    ///
615    /// Panics if llama.cpp returns a null pointer.
616    #[must_use]
617    pub fn top_n_sigma(n: f32) -> Self {
618        let sampler = unsafe { llama_sampler_init_top_n_sigma(n) };
619        Self {
620            sampler: NonNull::new(sampler).unwrap(),
621        }
622    }
623
624    /// Adaptive P sampling.
625    ///
626    /// # Panics
627    ///
628    /// Panics if llama.cpp returns a null pointer.
629    ///
630    /// # Parameters
631    /// - `target`: Target probability.
632    /// - `decay`: Decay rate.
633    /// - `seed`: Random seed.
634    #[must_use]
635    pub fn adaptive_p(target: f32, decay: f32, seed: u32) -> Self {
636        let sampler = unsafe { llama_sampler_init_adaptive_p(target, decay, seed) };
637        Self {
638            sampler: NonNull::new(sampler).unwrap(),
639        }
640    }
641
642    /// Logit bias sampler.
643    ///
644    /// Applies additive bias to specific token logits before sampling.
645    ///
646    /// # Panics
647    ///
648    /// Panics if llama.cpp returns a null pointer.
649    ///
650    /// # Parameters
651    /// - `n_vocab`: Number of tokens in the vocabulary ([`LlamaModel::n_vocab`]).
652    /// - `biases`: Slice of `(token_id, bias)` pairs.
653    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
654    #[must_use]
655    pub fn logit_bias(n_vocab: i32, biases: &[(LlamaToken, f32)]) -> Self {
656        let logit_biases: Vec<llama_logit_bias> = biases
657            .iter()
658            .map(|(token, bias)| llama_logit_bias {
659                token: token.0,
660                bias: *bias,
661            })
662            .collect();
663
664        let sampler = unsafe {
665            llama_sampler_init_logit_bias(
666                n_vocab,
667                logit_biases.len() as i32,
668                logit_biases.as_ptr(),
669            )
670        };
671        Self {
672            sampler: NonNull::new(sampler).unwrap(),
673        }
674    }
675
676    /// Infill sampler.
677    ///
678    /// Reorders token probabilities for fill-in-the-middle tasks.
679    ///
680    /// # Panics
681    ///
682    /// Panics if llama.cpp returns a null pointer.
683    #[must_use]
684    pub fn infill(model: &LlamaModel) -> Self {
685        let sampler =
686            unsafe { llama_sampler_init_infill(model.get_vocab().vocab.as_ref()) };
687        Self {
688            sampler: NonNull::new(sampler).unwrap(),
689        }
690    }
691
692    /// Get the seed of the sampler.
693    ///
694    /// Returns `LLAMA_DEFAULT_SEED` if the sampler is not seeded.
695    #[must_use]
696    pub fn get_seed(&self) -> u32 {
697        unsafe { llama_sampler_get_seed(self.sampler.as_ptr()) }
698    }
699
700    /// Get the name of the sampler.
701    ///
702    /// # Panics
703    ///
704    /// Panics if the name is not valid UTF-8.
705    #[must_use]
706    pub fn name(&self) -> String {
707        let c_str = unsafe { llama_sampler_name(self.sampler.as_ptr()) };
708        let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
709        c_str.to_str().expect("sampler name is not valid UTF-8").to_owned()
710    }
711
712    /// Reset the sampler state (e.g. grammar, repetition penalties).
713    pub fn reset(&mut self) {
714        unsafe { llama_sampler_reset(self.sampler.as_ptr()) }
715    }
716
717    /// Get the number of samplers in a chain.
718    ///
719    /// Returns 0 if this sampler is not a chain.
720    #[must_use]
721    pub fn chain_n(&self) -> i32 {
722        unsafe { llama_sampler_chain_n(self.sampler.as_ptr()) }
723    }
724
725    /// Remove and return the sampler at position `i` from a chain.
726    ///
727    /// The returned sampler is owned by the caller and will be freed on drop.
728    ///
729    /// # Panics
730    ///
731    /// Panics if `i` is out of range or if llama.cpp returns a null pointer.
732    #[must_use]
733    pub fn chain_remove(&mut self, i: i32) -> Self {
734        let sampler = unsafe { llama_sampler_chain_remove(self.sampler.as_ptr(), i) };
735        Self {
736            sampler: NonNull::new(sampler).expect("chain_remove returned null"),
737        }
738    }
739
740    /// Grammar sampler with lazy activation.
741    ///
742    /// The grammar is only activated when one of the trigger words or trigger tokens is encountered.
743    ///
744    /// # Panics
745    /// - If `grammar_str` or `grammar_root` contain null bytes.
746    /// - If any trigger word contains null bytes.
747    /// - If llama.cpp returns a null pointer.
748    #[must_use]
749    pub fn grammar_lazy(
750        model: &LlamaModel,
751        grammar_str: &str,
752        grammar_root: &str,
753        trigger_words: &[&str],
754        trigger_tokens: &[LlamaToken],
755    ) -> Self {
756        let grammar_str = CString::new(grammar_str).unwrap();
757        let grammar_root = CString::new(grammar_root).unwrap();
758        let trigger_cstrings: Vec<CString> = trigger_words
759            .iter()
760            .map(|w| CString::new(*w).unwrap())
761            .collect();
762        let mut trigger_ptrs: Vec<*const c_char> =
763            trigger_cstrings.iter().map(|s| s.as_ptr()).collect();
764
765        let sampler = unsafe {
766            llama_sampler_init_grammar_lazy(
767                model.get_vocab().vocab.as_ref(),
768                grammar_str.as_ptr(),
769                grammar_root.as_ptr(),
770                trigger_ptrs.as_mut_ptr(),
771                trigger_ptrs.len(),
772                trigger_tokens.as_ptr().cast(),
773                trigger_tokens.len(),
774            )
775        };
776        Self {
777            sampler: NonNull::new(sampler).unwrap(),
778        }
779    }
780
781    /// Grammar sampler with lazy activation via regex patterns.
782    ///
783    /// The grammar is only activated when one of the trigger patterns or trigger tokens matches.
784    ///
785    /// # Panics
786    /// - If `grammar_str` or `grammar_root` contain null bytes.
787    /// - If any trigger pattern contains null bytes.
788    /// - If llama.cpp returns a null pointer.
789    #[must_use]
790    pub fn grammar_lazy_patterns(
791        model: &LlamaModel,
792        grammar_str: &str,
793        grammar_root: &str,
794        trigger_patterns: &[&str],
795        trigger_tokens: &[LlamaToken],
796    ) -> Self {
797        let grammar_str = CString::new(grammar_str).unwrap();
798        let grammar_root = CString::new(grammar_root).unwrap();
799        let pattern_cstrings: Vec<CString> = trigger_patterns
800            .iter()
801            .map(|w| CString::new(*w).unwrap())
802            .collect();
803        let mut pattern_ptrs: Vec<*const c_char> =
804            pattern_cstrings.iter().map(|s| s.as_ptr()).collect();
805
806        let sampler = unsafe {
807            llama_sampler_init_grammar_lazy_patterns(
808                model.get_vocab().vocab.as_ref(),
809                grammar_str.as_ptr(),
810                grammar_root.as_ptr(),
811                pattern_ptrs.as_mut_ptr(),
812                pattern_ptrs.len(),
813                trigger_tokens.as_ptr().cast(),
814                trigger_tokens.len(),
815            )
816        };
817        Self {
818            sampler: NonNull::new(sampler).unwrap(),
819        }
820    }
821
822    /// Clone this sampler.
823    ///
824    /// Creates an independent copy of this sampler with the same state.
825    ///
826    /// # Panics
827    ///
828    /// Panics if llama.cpp returns a null pointer.
829    #[must_use]
830    pub fn clone_sampler(&self) -> Self {
831        let sampler = unsafe { llama_sampler_clone(self.sampler.as_ptr()) };
832        Self {
833            sampler: NonNull::new(sampler).expect("sampler_clone returned null"),
834        }
835    }
836
837    /// Print sampler performance data.
838    pub fn perf_print(&self) {
839        unsafe { llama_cpp_sys_4::llama_perf_sampler_print(self.sampler.as_ptr()) }
840    }
841
842    /// Reset sampler performance counters.
843    pub fn perf_reset(&mut self) {
844        unsafe { llama_cpp_sys_4::llama_perf_sampler_reset(self.sampler.as_ptr()) }
845    }
846
847    /// Get sampler performance data.
848    #[must_use]
849    pub fn perf_data(&self) -> llama_cpp_sys_4::llama_perf_sampler_data {
850        unsafe { llama_cpp_sys_4::llama_perf_sampler(self.sampler.as_ptr()) }
851    }
852
853    /// Get a non-owning reference to the `i`th sampler in a chain.
854    ///
855    /// # Safety
856    ///
857    /// The returned pointer is owned by the chain. Do not free it or use it
858    /// after the chain is dropped or modified.
859    #[must_use]
860    pub unsafe fn chain_get_ptr(&self, i: i32) -> *mut llama_sampler {
861        llama_cpp_sys_4::llama_sampler_chain_get(self.sampler.as_ptr(), i)
862    }
863
864    /// Create a sampler from a raw interface and context.
865    ///
866    /// # Safety
867    ///
868    /// The caller must ensure that `iface` and `ctx` are valid and that the
869    /// interface functions properly manage the context lifetime.
870    ///
871    /// # Panics
872    ///
873    /// Panics if llama.cpp returns a null pointer.
874    #[must_use]
875    pub unsafe fn from_raw(
876        iface: *mut llama_cpp_sys_4::llama_sampler_i,
877        ctx: llama_cpp_sys_4::llama_sampler_context_t,
878    ) -> Self {
879        let sampler = llama_cpp_sys_4::llama_sampler_init(iface, ctx);
880        Self {
881            sampler: NonNull::new(sampler).expect("sampler_init returned null"),
882        }
883    }
884
885    /// Creates a new instance of `LlamaSampler` with common sampling parameters.
886    ///
887    /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
888    /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
889    ///
890    /// # Panics
891    ///
892    /// Panics if llama.cpp returns a null pointer.
893    ///
894    /// # Returns
895    /// A `LlamaSampler` instance configured with the common sampling parameters.
896    #[must_use]
897    pub fn common() -> Self {
898        let params = common_sampler_params::default();
899
900        let sampler = unsafe {
901            let mut sparams = llama_sampler_chain_default_params();
902            sparams.no_perf = false;
903
904            let smpl = llama_sampler_chain_init(sparams);
905
906            llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
907            llama_sampler_chain_add(
908                smpl,
909                #[allow(clippy::cast_sign_loss)]
910                llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
911            );
912            llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
913            #[allow(clippy::cast_sign_loss)]
914            llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
915
916            smpl
917        };
918
919        Self {
920            sampler: NonNull::new(sampler).unwrap(),
921        }
922    }
923}
924
925impl Drop for LlamaSampler {
926    fn drop(&mut self) {
927        unsafe {
928            llama_sampler_free(self.sampler.as_ptr());
929        }
930    }
931}