Skip to main content

llama_cpp_4/
sampling.rs

1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9    common::common_sampler_params, llama_logit_bias, llama_sampler, llama_sampler_accept,
10    llama_sampler_chain_add, llama_sampler_chain_default_params, llama_sampler_chain_init,
11    llama_sampler_chain_n, llama_sampler_chain_remove, llama_sampler_clone, llama_sampler_free,
12    llama_sampler_get_seed, llama_sampler_init_adaptive_p, llama_sampler_init_dist,
13    llama_sampler_init_dry, llama_sampler_init_grammar, llama_sampler_init_grammar_lazy,
14    llama_sampler_init_grammar_lazy_patterns, llama_sampler_init_greedy, llama_sampler_init_infill,
15    llama_sampler_init_logit_bias, llama_sampler_init_min_p, llama_sampler_init_mirostat,
16    llama_sampler_init_mirostat_v2, llama_sampler_init_penalties, llama_sampler_init_temp,
17    llama_sampler_init_temp_ext, llama_sampler_init_top_k, llama_sampler_init_top_n_sigma,
18    llama_sampler_init_top_p, llama_sampler_init_typical, llama_sampler_init_xtc,
19    llama_sampler_name, llama_sampler_reset, llama_sampler_sample,
20};
21
22use crate::context::LlamaContext;
23use crate::model::LlamaModel;
24use crate::token::data_array::LlamaTokenDataArray;
25use crate::token::LlamaToken;
26
27/// A safe wrapper around `llama_sampler`.
28pub struct LlamaSampler {
29    pub(crate) sampler: NonNull<llama_sampler>,
30}
31
32impl Debug for LlamaSampler {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("LlamaSamplerChain").finish()
35    }
36}
37#[derive(Debug, Clone)]
38#[allow(
39    missing_docs,
40    clippy::struct_excessive_bools,
41    clippy::module_name_repetitions,
42    dead_code
43)]
44pub struct LlamaSamplerParams {
45    top_k: i32,
46    top_p: f32,
47    temp: f32,
48    seed: u32,
49}
50
51impl LlamaSamplerParams {
52    /// Set the seed of the context
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// use llama_cpp_4::sampling::LlamaSamplerParams;
58    /// let params = LlamaSamplerParams::default();
59    /// let params = params.with_seed(1234);
60    /// assert_eq!(params.seed(), 1234);
61    /// ```
62    #[must_use]
63    pub fn with_seed(mut self, seed: u32) -> Self {
64        self.seed = seed;
65        self
66    }
67
68    /// Get the seed of the context
69    ///
70    /// # Examples
71    ///
72    /// ```rust
73    /// use llama_cpp_4::sampling::LlamaSamplerParams;
74    /// let params = LlamaSamplerParams::default()
75    ///     .with_seed(1234);
76    /// assert_eq!(params.seed(), 1234);
77    /// ```
78    #[must_use]
79    pub fn seed(&self) -> u32 {
80        self.seed
81    }
82}
83
84impl Default for LlamaSamplerParams {
85    fn default() -> Self {
86        Self {
87            top_k: 50,
88            top_p: 0.9,
89            temp: 0.8,
90            seed: 1234,
91        }
92    }
93}
94
95impl Default for LlamaSampler {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl LlamaSampler {
102    /// Create new sampler with default params.
103    ///
104    /// # Panics
105    ///
106    /// Panics if llama.cpp returns a null pointer.
107    #[must_use]
108    pub fn new() -> Self {
109        let sparams = unsafe { llama_sampler_chain_default_params() };
110
111        Self {
112            sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
113        }
114    }
115
116    /// Sample and accept a token from the idx-th output of the last evaluation
117    #[must_use]
118    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
119        let token =
120            unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
121
122        LlamaToken(token)
123    }
124
125    /// Applies this sampler to a [`LlamaTokenDataArray`].
126    pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
127        data_array.apply_sampler(self);
128    }
129
130    /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
131    /// (e.g. grammar, repetition, etc.)
132    pub fn accept(&mut self, token: LlamaToken) {
133        unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
134    }
135
136    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
137    /// certain samplers (e.g. grammar, repetition, etc.)
138    pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
139        for token in tokens {
140            unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
141        }
142    }
143
144    /// Accepts several tokens from the sampler or context, possibly updating the internal state of
145    /// certain samplers (e.g. grammar, repetition, etc.)
146    #[must_use]
147    pub fn with_tokens(
148        mut self,
149        tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
150    ) -> Self {
151        self.accept_many(tokens);
152        self
153    }
154
155    /// Combines a list of samplers into a single sampler that applies each component sampler one
156    /// after another.
157    ///
158    /// If you are using a chain to select a token, the chain should always end with one of
159    /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
160    /// [`LlamaSampler::mirostat_v2`].
161    ///
162    /// # Panics
163    ///
164    /// Panics if llama.cpp returns a null pointer.
165    #[must_use]
166    pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
167        unsafe {
168            let mut params = llama_sampler_chain_default_params();
169            params.no_perf = no_perf;
170            let chain = llama_sampler_chain_init(params);
171
172            for sampler in samplers {
173                llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
174
175                // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
176                // owned by the chain
177                std::mem::forget(sampler);
178            }
179
180            Self {
181                sampler: NonNull::new(chain).unwrap(),
182            }
183        }
184    }
185
186    /// Same as [`Self::chain`] with `no_perf = false`.
187    ///
188    /// # Panics
189    ///
190    /// Panics if llama.cpp returns a null pointer.
191    ///
192    /// # Example
193    /// ```rust
194    /// use llama_cpp_4::token::{
195    ///    LlamaToken,
196    ///    data::LlamaTokenData,
197    ///    data_array::LlamaTokenDataArray
198    /// };
199    /// use llama_cpp_4::sampling::LlamaSampler;
200    ///
201    /// let mut data_array = LlamaTokenDataArray::new(vec![
202    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
203    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
204    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
205    /// ], false);
206    ///
207    /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
208    ///     LlamaSampler::temp(0.5),
209    ///     LlamaSampler::greedy(),
210    /// ]));
211    ///
212    /// assert_eq!(data_array.data[0].logit(), 0.);
213    /// assert_eq!(data_array.data[1].logit(), 2.);
214    /// assert_eq!(data_array.data[2].logit(), 4.);
215    ///
216    /// assert_eq!(data_array.data.len(), 3);
217    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
218    /// ```
219    #[must_use]
220    pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
221        Self::chain(samplers, false)
222    }
223
224    /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
225    /// value, the rest are set to -inf.
226    ///
227    /// # Panics
228    ///
229    /// Panics if llama.cpp returns a null pointer.
230    ///
231    /// # Example:
232    /// ```rust
233    /// use llama_cpp_4::token::{
234    ///    LlamaToken,
235    ///    data::LlamaTokenData,
236    ///    data_array::LlamaTokenDataArray
237    /// };
238    /// use llama_cpp_4::sampling::LlamaSampler;
239    ///
240    /// let mut data_array = LlamaTokenDataArray::new(vec![
241    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
242    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
243    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
244    /// ], false);
245    ///
246    /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
247    ///
248    /// assert_eq!(data_array.data[0].logit(), 0.);
249    /// assert_eq!(data_array.data[1].logit(), 2.);
250    /// assert_eq!(data_array.data[2].logit(), 4.);
251    /// ```
252    #[must_use]
253    pub fn temp(t: f32) -> Self {
254        let sampler = unsafe { llama_sampler_init_temp(t) };
255        Self {
256            sampler: NonNull::new(sampler).unwrap(),
257        }
258    }
259
260    /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
261    /// <https://arxiv.org/abs/2309.02772>.
262    ///
263    /// # Panics
264    ///
265    /// Panics if llama.cpp returns a null pointer.
266    #[must_use]
267    pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
268        let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
269        Self {
270            sampler: NonNull::new(sampler).unwrap(),
271        }
272    }
273
274    /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
275    /// <https://arxiv.org/abs/1904.09751>.
276    ///
277    /// # Panics
278    ///
279    /// Panics if llama.cpp returns a null pointer.
280    ///
281    /// # Example:
282    /// ```rust
283    /// use llama_cpp_4::token::{
284    ///    LlamaToken,
285    ///    data::LlamaTokenData,
286    ///    data_array::LlamaTokenDataArray
287    /// };
288    /// use llama_cpp_4::sampling::LlamaSampler;
289    ///
290    /// let mut data_array = LlamaTokenDataArray::new(vec![
291    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
292    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
293    ///     LlamaTokenData::new(LlamaToken(2), 2., 0.),
294    ///     LlamaTokenData::new(LlamaToken(3), 3., 0.),
295    /// ], false);
296    ///
297    /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
298    ///
299    /// assert_eq!(data_array.data.len(), 2);
300    /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
301    /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
302    /// ```
303    #[must_use]
304    pub fn top_k(k: i32) -> Self {
305        let sampler = unsafe { llama_sampler_init_top_k(k) };
306        Self {
307            sampler: NonNull::new(sampler).unwrap(),
308        }
309    }
310
311    /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
312    ///
313    /// # Panics
314    ///
315    /// Panics if llama.cpp returns a null pointer.
316    #[must_use]
317    pub fn typical(p: f32, min_keep: usize) -> Self {
318        let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
319        Self {
320            sampler: NonNull::new(sampler).unwrap(),
321        }
322    }
323
324    /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
325    /// <https://arxiv.org/abs/1904.09751>.
326    ///
327    /// # Panics
328    ///
329    /// Panics if llama.cpp returns a null pointer.
330    #[must_use]
331    pub fn top_p(p: f32, min_keep: usize) -> Self {
332        let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
333        Self {
334            sampler: NonNull::new(sampler).unwrap(),
335        }
336    }
337
338    /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
339    ///
340    /// # Panics
341    ///
342    /// Panics if llama.cpp returns a null pointer.
343    #[must_use]
344    pub fn min_p(p: f32, min_keep: usize) -> Self {
345        let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
346        Self {
347            sampler: NonNull::new(sampler).unwrap(),
348        }
349    }
350
351    /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
352    ///
353    /// # Panics
354    ///
355    /// Panics if llama.cpp returns a null pointer.
356    #[must_use]
357    pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
358        let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
359        Self {
360            sampler: NonNull::new(sampler).unwrap(),
361        }
362    }
363
364    /// Grammar sampler
365    ///
366    /// # Panics
367    /// - If either of `grammar_str` or `grammar_root` contain null bytes.
368    /// - If llama.cpp returns a null pointer.
369    #[must_use]
370    pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
371        let grammar_str = CString::new(grammar_str).unwrap();
372        let grammar_root = CString::new(grammar_root).unwrap();
373
374        let sampler = unsafe {
375            llama_sampler_init_grammar(
376                model.get_vocab().vocab.as_ref(),
377                grammar_str.as_ptr(),
378                grammar_root.as_ptr(),
379            )
380        };
381        Self {
382            sampler: NonNull::new(sampler).unwrap(),
383        }
384    }
385
386    /// DRY sampler, designed by p-e-w, as described in:
387    /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
388    /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
389    ///
390    /// # Panics
391    /// - If any string in `seq_breakers` contains null bytes.
392    /// - If llama.cpp returns a null pointer.
393    #[allow(clippy::too_many_arguments)]
394    #[must_use]
395    pub fn dry(
396        &self,
397        model: &LlamaModel,
398        n_ctx_train: i32,
399        multiplier: f32,
400        base: f32,
401        allowed_length: i32,
402        penalty_last_n: i32,
403        seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
404    ) -> Self {
405        let seq_breakers: Vec<CString> = seq_breakers
406            .into_iter()
407            .map(|s| CString::new(s.as_ref()).unwrap())
408            .collect();
409        // CString::as_ptr() returns *const c_char, which matches what the binding
410        // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
411        let mut seq_breaker_pointers: Vec<*const c_char> =
412            seq_breakers.iter().map(|s| s.as_ptr()).collect();
413
414        let sampler = unsafe {
415            llama_sampler_init_dry(
416                model.get_vocab().vocab.as_ref(),
417                n_ctx_train,
418                multiplier,
419                base,
420                allowed_length,
421                penalty_last_n,
422                seq_breaker_pointers.as_mut_ptr(),
423                seq_breaker_pointers.len(),
424            )
425        };
426
427        Self {
428            sampler: NonNull::new(sampler).unwrap(),
429        }
430    }
431
432    /// Penalizes tokens for being present in the context.
433    ///
434    /// Parameters:
435    /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
436    /// - `penalty_repeat`: repetition penalty (1.0 = disabled, >1.0 = penalize repeats)
437    /// - `penalty_freq`: frequency penalty (0.0 = disabled)
438    /// - `penalty_present`: presence penalty (0.0 = disabled)
439    ///
440    /// # Panics
441    ///
442    /// Panics if llama.cpp returns a null pointer.
443    #[allow(clippy::too_many_arguments)]
444    #[must_use]
445    pub fn penalties(
446        penalty_last_n: i32,
447        penalty_repeat: f32,
448        penalty_freq: f32,
449        penalty_present: f32,
450    ) -> Self {
451        let sampler = unsafe {
452            llama_sampler_init_penalties(
453                penalty_last_n,
454                penalty_repeat,
455                penalty_freq,
456                penalty_present,
457            )
458        };
459        Self {
460            sampler: NonNull::new(sampler).unwrap(),
461        }
462    }
463
464    /// Same as [`Self::penalties`] with sensible defaults:
465    /// `penalty_freq = 0.0` and `penalty_present = 0.0`.
466    ///
467    /// Parameters:
468    /// - `penalty_last_n`: last n tokens to penalize (0 = disable, -1 = context size)
469    /// - `penalty_repeat`: repetition penalty (1.0 = disabled)
470    ///
471    /// # Panics
472    ///
473    /// Panics if llama.cpp returns a null pointer.
474    #[must_use]
475    pub fn penalties_simple(penalty_last_n: i32, penalty_repeat: f32) -> Self {
476        Self::penalties(
477            #[allow(clippy::cast_precision_loss)]
478            {
479                penalty_last_n as i32
480            },
481            #[allow(clippy::cast_precision_loss)]
482            {
483                penalty_repeat as f32
484            },
485            #[allow(clippy::cast_precision_loss)]
486            {
487                0.0 as f32
488            },
489            #[allow(clippy::cast_precision_loss)]
490            {
491                0.0 as f32
492            },
493        )
494    }
495
496    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
497    ///
498    /// # Panics
499    ///
500    /// Panics if llama.cpp returns a null pointer.
501    ///
502    /// # Parameters:
503    /// - `n_vocab`: [`LlamaModel::n_vocab`]
504    /// - `seed`: Seed to initialize random generation with.
505    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
506    ///   generated text. A higher value corresponds to more surprising or less predictable text,
507    ///   while a lower value corresponds to less surprising or more predictable text.
508    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
509    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
510    ///   updated more quickly, while a smaller learning rate will result in slower updates.
511    /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
512    ///   value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
513    ///   In the paper, they use `m = 100`, but you can experiment with different values to see how
514    ///   it affects the performance of the algorithm.
515    #[must_use]
516    pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
517        let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
518        Self {
519            sampler: NonNull::new(sampler).unwrap(),
520        }
521    }
522
523    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
524    ///
525    /// # Panics
526    ///
527    /// Panics if llama.cpp returns a null pointer.
528    ///
529    /// # Parameters:
530    /// - `seed`: Seed to initialize random generation with.
531    /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
532    ///   generated text. A higher value corresponds to more surprising or less predictable text,
533    ///   while a lower value corresponds to less surprising or more predictable text.
534    /// - `eta`: The learning rate used to update `mu` based on the error between the target and
535    ///   observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
536    ///   updated more quickly, while a smaller learning rate will result in slower updates.
537    #[must_use]
538    pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
539        let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
540        Self {
541            sampler: NonNull::new(sampler).unwrap(),
542        }
543    }
544
545    /// Selects a token at random based on each token's probabilities.
546    ///
547    /// # Panics
548    ///
549    /// Panics if llama.cpp returns a null pointer.
550    #[must_use]
551    pub fn dist(seed: u32) -> Self {
552        let sampler = unsafe { llama_sampler_init_dist(seed) };
553        Self {
554            sampler: NonNull::new(sampler).unwrap(),
555        }
556    }
557
558    /// Selects the most likely token.
559    ///
560    /// # Panics
561    ///
562    /// Panics if llama.cpp returns a null pointer.
563    ///
564    /// # Example:
565    /// ```rust
566    /// use llama_cpp_4::token::{
567    ///    LlamaToken,
568    ///    data::LlamaTokenData,
569    ///    data_array::LlamaTokenDataArray
570    /// };
571    /// use llama_cpp_4::sampling::LlamaSampler;
572    ///
573    /// let mut data_array = LlamaTokenDataArray::new(vec![
574    ///     LlamaTokenData::new(LlamaToken(0), 0., 0.),
575    ///     LlamaTokenData::new(LlamaToken(1), 1., 0.),
576    /// ], false);
577    ///
578    /// data_array.apply_sampler(&mut LlamaSampler::greedy());
579    ///
580    /// assert_eq!(data_array.data.len(), 2);
581    /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
582    /// ```
583    #[must_use]
584    pub fn greedy() -> Self {
585        let sampler = unsafe { llama_sampler_init_greedy() };
586        Self {
587            sampler: NonNull::new(sampler).unwrap(),
588        }
589    }
590
591    /// Top-N sigma sampling.
592    ///
593    /// Keeps tokens within N standard deviations of the maximum logit.
594    ///
595    /// # Panics
596    ///
597    /// Panics if llama.cpp returns a null pointer.
598    #[must_use]
599    pub fn top_n_sigma(n: f32) -> Self {
600        let sampler = unsafe { llama_sampler_init_top_n_sigma(n) };
601        Self {
602            sampler: NonNull::new(sampler).unwrap(),
603        }
604    }
605
606    /// Adaptive P sampling.
607    ///
608    /// # Panics
609    ///
610    /// Panics if llama.cpp returns a null pointer.
611    ///
612    /// # Parameters
613    /// - `target`: Target probability.
614    /// - `decay`: Decay rate.
615    /// - `seed`: Random seed.
616    #[must_use]
617    pub fn adaptive_p(target: f32, decay: f32, seed: u32) -> Self {
618        let sampler = unsafe { llama_sampler_init_adaptive_p(target, decay, seed) };
619        Self {
620            sampler: NonNull::new(sampler).unwrap(),
621        }
622    }
623
624    /// Logit bias sampler.
625    ///
626    /// Applies additive bias to specific token logits before sampling.
627    ///
628    /// # Panics
629    ///
630    /// Panics if llama.cpp returns a null pointer.
631    ///
632    /// # Parameters
633    /// - `n_vocab`: Number of tokens in the vocabulary ([`LlamaModel::n_vocab`]).
634    /// - `biases`: Slice of `(token_id, bias)` pairs.
635    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
636    #[must_use]
637    pub fn logit_bias(n_vocab: i32, biases: &[(LlamaToken, f32)]) -> Self {
638        let logit_biases: Vec<llama_logit_bias> = biases
639            .iter()
640            .map(|(token, bias)| llama_logit_bias {
641                token: token.0,
642                bias: *bias,
643            })
644            .collect();
645
646        let sampler = unsafe {
647            llama_sampler_init_logit_bias(n_vocab, logit_biases.len() as i32, logit_biases.as_ptr())
648        };
649        Self {
650            sampler: NonNull::new(sampler).unwrap(),
651        }
652    }
653
654    /// Infill sampler.
655    ///
656    /// Reorders token probabilities for fill-in-the-middle tasks.
657    ///
658    /// # Panics
659    ///
660    /// Panics if llama.cpp returns a null pointer.
661    #[must_use]
662    pub fn infill(model: &LlamaModel) -> Self {
663        let sampler = unsafe { llama_sampler_init_infill(model.get_vocab().vocab.as_ref()) };
664        Self {
665            sampler: NonNull::new(sampler).unwrap(),
666        }
667    }
668
669    /// Get the seed of the sampler.
670    ///
671    /// Returns `LLAMA_DEFAULT_SEED` if the sampler is not seeded.
672    #[must_use]
673    pub fn get_seed(&self) -> u32 {
674        unsafe { llama_sampler_get_seed(self.sampler.as_ptr()) }
675    }
676
677    /// Get the name of the sampler.
678    ///
679    /// # Panics
680    ///
681    /// Panics if the name is not valid UTF-8.
682    #[must_use]
683    pub fn name(&self) -> String {
684        let c_str = unsafe { llama_sampler_name(self.sampler.as_ptr()) };
685        let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
686        c_str
687            .to_str()
688            .expect("sampler name is not valid UTF-8")
689            .to_owned()
690    }
691
692    /// Reset the sampler state (e.g. grammar, repetition penalties).
693    pub fn reset(&mut self) {
694        unsafe { llama_sampler_reset(self.sampler.as_ptr()) }
695    }
696
697    /// Get the number of samplers in a chain.
698    ///
699    /// Returns 0 if this sampler is not a chain.
700    #[must_use]
701    pub fn chain_n(&self) -> i32 {
702        unsafe { llama_sampler_chain_n(self.sampler.as_ptr()) }
703    }
704
705    /// Remove and return the sampler at position `i` from a chain.
706    ///
707    /// The returned sampler is owned by the caller and will be freed on drop.
708    ///
709    /// # Panics
710    ///
711    /// Panics if `i` is out of range or if llama.cpp returns a null pointer.
712    #[must_use]
713    pub fn chain_remove(&mut self, i: i32) -> Self {
714        let sampler = unsafe { llama_sampler_chain_remove(self.sampler.as_ptr(), i) };
715        Self {
716            sampler: NonNull::new(sampler).expect("chain_remove returned null"),
717        }
718    }
719
720    /// Grammar sampler with lazy activation.
721    ///
722    /// The grammar is only activated when one of the trigger words or trigger tokens is encountered.
723    ///
724    /// # Panics
725    /// - If `grammar_str` or `grammar_root` contain null bytes.
726    /// - If any trigger word contains null bytes.
727    /// - If llama.cpp returns a null pointer.
728    #[must_use]
729    #[deprecated(note = "use grammar_lazy_patterns instead")]
730    pub fn grammar_lazy(
731        model: &LlamaModel,
732        grammar_str: &str,
733        grammar_root: &str,
734        trigger_words: &[&str],
735        trigger_tokens: &[LlamaToken],
736    ) -> Self {
737        let grammar_str = CString::new(grammar_str).unwrap();
738        let grammar_root = CString::new(grammar_root).unwrap();
739        let trigger_cstrings: Vec<CString> = trigger_words
740            .iter()
741            .map(|w| CString::new(*w).unwrap())
742            .collect();
743        let mut trigger_ptrs: Vec<*const c_char> =
744            trigger_cstrings.iter().map(|s| s.as_ptr()).collect();
745
746        let sampler = unsafe {
747            llama_sampler_init_grammar_lazy(
748                model.get_vocab().vocab.as_ref(),
749                grammar_str.as_ptr(),
750                grammar_root.as_ptr(),
751                trigger_ptrs.as_mut_ptr(),
752                trigger_ptrs.len(),
753                trigger_tokens.as_ptr().cast(),
754                trigger_tokens.len(),
755            )
756        };
757        Self {
758            sampler: NonNull::new(sampler).unwrap(),
759        }
760    }
761
762    /// Grammar sampler with lazy activation via regex patterns.
763    ///
764    /// The grammar is only activated when one of the trigger patterns or trigger tokens matches.
765    ///
766    /// # Panics
767    /// - If `grammar_str` or `grammar_root` contain null bytes.
768    /// - If any trigger pattern contains null bytes.
769    /// - If llama.cpp returns a null pointer.
770    #[must_use]
771    pub fn grammar_lazy_patterns(
772        model: &LlamaModel,
773        grammar_str: &str,
774        grammar_root: &str,
775        trigger_patterns: &[&str],
776        trigger_tokens: &[LlamaToken],
777    ) -> Self {
778        let grammar_str = CString::new(grammar_str).unwrap();
779        let grammar_root = CString::new(grammar_root).unwrap();
780        let pattern_cstrings: Vec<CString> = trigger_patterns
781            .iter()
782            .map(|w| CString::new(*w).unwrap())
783            .collect();
784        let mut pattern_ptrs: Vec<*const c_char> =
785            pattern_cstrings.iter().map(|s| s.as_ptr()).collect();
786
787        let sampler = unsafe {
788            llama_sampler_init_grammar_lazy_patterns(
789                model.get_vocab().vocab.as_ref(),
790                grammar_str.as_ptr(),
791                grammar_root.as_ptr(),
792                pattern_ptrs.as_mut_ptr(),
793                pattern_ptrs.len(),
794                trigger_tokens.as_ptr().cast(),
795                trigger_tokens.len(),
796            )
797        };
798        Self {
799            sampler: NonNull::new(sampler).unwrap(),
800        }
801    }
802
803    /// Clone this sampler.
804    ///
805    /// Creates an independent copy of this sampler with the same state.
806    ///
807    /// # Panics
808    ///
809    /// Panics if llama.cpp returns a null pointer.
810    #[must_use]
811    pub fn clone_sampler(&self) -> Self {
812        let sampler = unsafe { llama_sampler_clone(self.sampler.as_ptr()) };
813        Self {
814            sampler: NonNull::new(sampler).expect("sampler_clone returned null"),
815        }
816    }
817
818    /// Print sampler performance data.
819    pub fn perf_print(&self) {
820        unsafe { llama_cpp_sys_4::llama_perf_sampler_print(self.sampler.as_ptr()) }
821    }
822
823    /// Reset sampler performance counters.
824    pub fn perf_reset(&mut self) {
825        unsafe { llama_cpp_sys_4::llama_perf_sampler_reset(self.sampler.as_ptr()) }
826    }
827
828    /// Get sampler performance data.
829    #[must_use]
830    pub fn perf_data(&self) -> llama_cpp_sys_4::llama_perf_sampler_data {
831        unsafe { llama_cpp_sys_4::llama_perf_sampler(self.sampler.as_ptr()) }
832    }
833
834    /// Get a non-owning reference to the `i`th sampler in a chain.
835    ///
836    /// # Safety
837    ///
838    /// The returned pointer is owned by the chain. Do not free it or use it
839    /// after the chain is dropped or modified.
840    #[must_use]
841    pub unsafe fn chain_get_ptr(&self, i: i32) -> *mut llama_sampler {
842        llama_cpp_sys_4::llama_sampler_chain_get(self.sampler.as_ptr(), i)
843    }
844
845    /// Create a sampler from a raw interface and context.
846    ///
847    /// # Safety
848    ///
849    /// The caller must ensure that `iface` and `ctx` are valid and that the
850    /// interface functions properly manage the context lifetime.
851    ///
852    /// # Panics
853    ///
854    /// Panics if llama.cpp returns a null pointer.
855    #[must_use]
856    pub unsafe fn from_raw(
857        iface: *mut llama_cpp_sys_4::llama_sampler_i,
858        ctx: llama_cpp_sys_4::llama_sampler_context_t,
859    ) -> Self {
860        let sampler = llama_cpp_sys_4::llama_sampler_init(iface, ctx);
861        Self {
862            sampler: NonNull::new(sampler).expect("sampler_init returned null"),
863        }
864    }
865
866    /// Creates a new instance of `LlamaSampler` with common sampling parameters.
867    ///
868    /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
869    /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
870    ///
871    /// # Panics
872    ///
873    /// Panics if llama.cpp returns a null pointer.
874    ///
875    /// # Returns
876    /// A `LlamaSampler` instance configured with the common sampling parameters.
877    #[must_use]
878    pub fn common() -> Self {
879        let params = common_sampler_params::default();
880
881        let sampler = unsafe {
882            let mut sparams = llama_sampler_chain_default_params();
883            sparams.no_perf = false;
884
885            let smpl = llama_sampler_chain_init(sparams);
886
887            llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
888            llama_sampler_chain_add(
889                smpl,
890                #[allow(clippy::cast_sign_loss)]
891                llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
892            );
893            llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
894            #[allow(clippy::cast_sign_loss)]
895            llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
896
897            smpl
898        };
899
900        Self {
901            sampler: NonNull::new(sampler).unwrap(),
902        }
903    }
904}
905
906impl Drop for LlamaSampler {
907    fn drop(&mut self) {
908        unsafe {
909            llama_sampler_free(self.sampler.as_ptr());
910        }
911    }
912}