llama_cpp_4/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9 common::common_sampler_params, llama_logit_bias, llama_sampler, llama_sampler_accept,
10 llama_sampler_chain_add, llama_sampler_chain_default_params, llama_sampler_chain_init,
11 llama_sampler_chain_n, llama_sampler_chain_remove, llama_sampler_clone, llama_sampler_free,
12 llama_sampler_get_seed, llama_sampler_init_adaptive_p, llama_sampler_init_dist,
13 llama_sampler_init_dry, llama_sampler_init_grammar, llama_sampler_init_grammar_lazy,
14 llama_sampler_init_grammar_lazy_patterns, llama_sampler_init_greedy, llama_sampler_init_infill,
15 llama_sampler_init_logit_bias, llama_sampler_init_min_p, llama_sampler_init_mirostat,
16 llama_sampler_init_mirostat_v2, llama_sampler_init_penalties, llama_sampler_init_temp,
17 llama_sampler_init_temp_ext, llama_sampler_init_top_k, llama_sampler_init_top_n_sigma,
18 llama_sampler_init_top_p, llama_sampler_init_typical, llama_sampler_init_xtc,
19 llama_sampler_name, llama_sampler_reset, llama_sampler_sample,
20};
21
22use crate::context::LlamaContext;
23use crate::model::LlamaModel;
24use crate::token::data_array::LlamaTokenDataArray;
25use crate::token::LlamaToken;
26
27/// A safe wrapper around `llama_sampler`.
28pub struct LlamaSampler {
29 pub(crate) sampler: NonNull<llama_sampler>,
30}
31
32impl Debug for LlamaSampler {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("LlamaSamplerChain").finish()
35 }
36}
37#[derive(Debug, Clone)]
38#[allow(
39 missing_docs,
40 clippy::struct_excessive_bools,
41 clippy::module_name_repetitions,
42 dead_code
43)]
44pub struct LlamaSamplerParams {
45 top_k: i32,
46 top_p: f32,
47 temp: f32,
48 seed: u32,
49}
50
51impl LlamaSamplerParams {
52 /// Set the seed of the context
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// use llama_cpp_4::sampling::LlamaSamplerParams;
58 /// let params = LlamaSamplerParams::default();
59 /// let params = params.with_seed(1234);
60 /// assert_eq!(params.seed(), 1234);
61 /// ```
62 #[must_use]
63 pub fn with_seed(mut self, seed: u32) -> Self {
64 self.seed = seed;
65 self
66 }
67
68 /// Get the seed of the context
69 ///
70 /// # Examples
71 ///
72 /// ```rust
73 /// use llama_cpp_4::sampling::LlamaSamplerParams;
74 /// let params = LlamaSamplerParams::default()
75 /// .with_seed(1234);
76 /// assert_eq!(params.seed(), 1234);
77 /// ```
78 #[must_use]
79 pub fn seed(&self) -> u32 {
80 self.seed
81 }
82}
83
84impl Default for LlamaSamplerParams {
85 fn default() -> Self {
86 Self {
87 top_k: 50,
88 top_p: 0.9,
89 temp: 0.8,
90 seed: 1234,
91 }
92 }
93}
94
95impl Default for LlamaSampler {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl LlamaSampler {
102 /// Create new sampler with default params.
103 ///
104 /// # Panics
105 ///
106 /// Panics if llama.cpp returns a null pointer.
107 #[must_use]
108 pub fn new() -> Self {
109 let sparams = unsafe { llama_sampler_chain_default_params() };
110
111 Self {
112 sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
113 }
114 }
115
116 /// Sample and accept a token from the idx-th output of the last evaluation
117 #[must_use]
118 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
119 let token =
120 unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
121
122 LlamaToken(token)
123 }
124
125 /// Applies this sampler to a [`LlamaTokenDataArray`].
126 pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
127 data_array.apply_sampler(self);
128 }
129
130 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
131 /// (e.g. grammar, repetition, etc.)
132 pub fn accept(&mut self, token: LlamaToken) {
133 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
134 }
135
136 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
137 /// certain samplers (e.g. grammar, repetition, etc.)
138 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
139 for token in tokens {
140 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
141 }
142 }
143
144 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
145 /// certain samplers (e.g. grammar, repetition, etc.)
146 #[must_use]
147 pub fn with_tokens(
148 mut self,
149 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
150 ) -> Self {
151 self.accept_many(tokens);
152 self
153 }
154
155 /// Combines a list of samplers into a single sampler that applies each component sampler one
156 /// after another.
157 ///
158 /// If you are using a chain to select a token, the chain should always end with one of
159 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
160 /// [`LlamaSampler::mirostat_v2`].
161 ///
162 /// # Panics
163 ///
164 /// Panics if llama.cpp returns a null pointer.
165 #[must_use]
166 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
167 unsafe {
168 let mut params = llama_sampler_chain_default_params();
169 params.no_perf = no_perf;
170 let chain = llama_sampler_chain_init(params);
171
172 for sampler in samplers {
173 llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
174
175 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
176 // owned by the chain
177 std::mem::forget(sampler);
178 }
179
180 Self {
181 sampler: NonNull::new(chain).unwrap(),
182 }
183 }
184 }
185
186 /// Same as [`Self::chain`] with `no_perf = false`.
187 ///
188 /// # Panics
189 ///
190 /// Panics if llama.cpp returns a null pointer.
191 ///
192 /// # Example
193 /// ```rust
194 /// use llama_cpp_4::token::{
195 /// LlamaToken,
196 /// data::LlamaTokenData,
197 /// data_array::LlamaTokenDataArray
198 /// };
199 /// use llama_cpp_4::sampling::LlamaSampler;
200 ///
201 /// let mut data_array = LlamaTokenDataArray::new(vec![
202 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
203 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
204 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
205 /// ], false);
206 ///
207 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
208 /// LlamaSampler::temp(0.5),
209 /// LlamaSampler::greedy(),
210 /// ]));
211 ///
212 /// assert_eq!(data_array.data[0].logit(), 0.);
213 /// assert_eq!(data_array.data[1].logit(), 2.);
214 /// assert_eq!(data_array.data[2].logit(), 4.);
215 ///
216 /// assert_eq!(data_array.data.len(), 3);
217 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
218 /// ```
219 #[must_use]
220 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
221 Self::chain(samplers, false)
222 }
223
224 /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
225 /// value, the rest are set to -inf.
226 ///
227 /// # Panics
228 ///
229 /// Panics if llama.cpp returns a null pointer.
230 ///
231 /// # Example:
232 /// ```rust
233 /// use llama_cpp_4::token::{
234 /// LlamaToken,
235 /// data::LlamaTokenData,
236 /// data_array::LlamaTokenDataArray
237 /// };
238 /// use llama_cpp_4::sampling::LlamaSampler;
239 ///
240 /// let mut data_array = LlamaTokenDataArray::new(vec![
241 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
242 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
243 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
244 /// ], false);
245 ///
246 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
247 ///
248 /// assert_eq!(data_array.data[0].logit(), 0.);
249 /// assert_eq!(data_array.data[1].logit(), 2.);
250 /// assert_eq!(data_array.data[2].logit(), 4.);
251 /// ```
252 #[must_use]
253 pub fn temp(t: f32) -> Self {
254 let sampler = unsafe { llama_sampler_init_temp(t) };
255 Self {
256 sampler: NonNull::new(sampler).unwrap(),
257 }
258 }
259
260 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
261 /// <https://arxiv.org/abs/2309.02772>.
262 ///
263 /// # Panics
264 ///
265 /// Panics if llama.cpp returns a null pointer.
266 #[must_use]
267 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
268 let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
269 Self {
270 sampler: NonNull::new(sampler).unwrap(),
271 }
272 }
273
274 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
275 /// <https://arxiv.org/abs/1904.09751>.
276 ///
277 /// # Panics
278 ///
279 /// Panics if llama.cpp returns a null pointer.
280 ///
281 /// # Example:
282 /// ```rust
283 /// use llama_cpp_4::token::{
284 /// LlamaToken,
285 /// data::LlamaTokenData,
286 /// data_array::LlamaTokenDataArray
287 /// };
288 /// use llama_cpp_4::sampling::LlamaSampler;
289 ///
290 /// let mut data_array = LlamaTokenDataArray::new(vec![
291 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
292 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
293 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
294 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
295 /// ], false);
296 ///
297 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
298 ///
299 /// assert_eq!(data_array.data.len(), 2);
300 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
301 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
302 /// ```
303 #[must_use]
304 pub fn top_k(k: i32) -> Self {
305 let sampler = unsafe { llama_sampler_init_top_k(k) };
306 Self {
307 sampler: NonNull::new(sampler).unwrap(),
308 }
309 }
310
311 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
312 ///
313 /// # Panics
314 ///
315 /// Panics if llama.cpp returns a null pointer.
316 #[must_use]
317 pub fn typical(p: f32, min_keep: usize) -> Self {
318 let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
319 Self {
320 sampler: NonNull::new(sampler).unwrap(),
321 }
322 }
323
324 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
325 /// <https://arxiv.org/abs/1904.09751>.
326 ///
327 /// # Panics
328 ///
329 /// Panics if llama.cpp returns a null pointer.
330 #[must_use]
331 pub fn top_p(p: f32, min_keep: usize) -> Self {
332 let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
333 Self {
334 sampler: NonNull::new(sampler).unwrap(),
335 }
336 }
337
338 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
339 ///
340 /// # Panics
341 ///
342 /// Panics if llama.cpp returns a null pointer.
343 #[must_use]
344 pub fn min_p(p: f32, min_keep: usize) -> Self {
345 let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
346 Self {
347 sampler: NonNull::new(sampler).unwrap(),
348 }
349 }
350
351 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
352 ///
353 /// # Panics
354 ///
355 /// Panics if llama.cpp returns a null pointer.
356 #[must_use]
357 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
358 let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
359 Self {
360 sampler: NonNull::new(sampler).unwrap(),
361 }
362 }
363
364 /// Grammar sampler
365 ///
366 /// # Panics
367 /// - If either of `grammar_str` or `grammar_root` contain null bytes.
368 /// - If llama.cpp returns a null pointer.
369 #[must_use]
370 pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
371 let grammar_str = CString::new(grammar_str).unwrap();
372 let grammar_root = CString::new(grammar_root).unwrap();
373
374 let sampler = unsafe {
375 llama_sampler_init_grammar(
376 model.get_vocab().vocab.as_ref(),
377 grammar_str.as_ptr(),
378 grammar_root.as_ptr(),
379 )
380 };
381 Self {
382 sampler: NonNull::new(sampler).unwrap(),
383 }
384 }
385
386 /// DRY sampler, designed by p-e-w, as described in:
387 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
388 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
389 ///
390 /// # Panics
391 /// - If any string in `seq_breakers` contains null bytes.
392 /// - If llama.cpp returns a null pointer.
393 #[allow(clippy::too_many_arguments)]
394 #[must_use]
395 pub fn dry(
396 &self,
397 model: &LlamaModel,
398 n_ctx_train: i32,
399 multiplier: f32,
400 base: f32,
401 allowed_length: i32,
402 penalty_last_n: i32,
403 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
404 ) -> Self {
405 let seq_breakers: Vec<CString> = seq_breakers
406 .into_iter()
407 .map(|s| CString::new(s.as_ref()).unwrap())
408 .collect();
409 // CString::as_ptr() returns *const c_char, which matches what the binding
410 // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
411 let mut seq_breaker_pointers: Vec<*const c_char> =
412 seq_breakers.iter().map(|s| s.as_ptr()).collect();
413
414 let sampler = unsafe {
415 llama_sampler_init_dry(
416 model.get_vocab().vocab.as_ref(),
417 n_ctx_train,
418 multiplier,
419 base,
420 allowed_length,
421 penalty_last_n,
422 seq_breaker_pointers.as_mut_ptr(),
423 seq_breaker_pointers.len(),
424 )
425 };
426
427 Self {
428 sampler: NonNull::new(sampler).unwrap(),
429 }
430 }
431
432 /// Penalizes tokens for being present in the context.
433 ///
434 /// Parameters:
435 /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
436 /// - `penalty_repeat`: repetition penalty (1.0 = disabled, >1.0 = penalize repeats)
437 /// - `penalty_freq`: frequency penalty (0.0 = disabled)
438 /// - `penalty_present`: presence penalty (0.0 = disabled)
439 ///
440 /// # Panics
441 ///
442 /// Panics if llama.cpp returns a null pointer.
443 #[allow(clippy::too_many_arguments)]
444 #[must_use]
445 pub fn penalties(
446 penalty_last_n: i32,
447 penalty_repeat: f32,
448 penalty_freq: f32,
449 penalty_present: f32,
450 ) -> Self {
451 let sampler = unsafe {
452 llama_sampler_init_penalties(
453 penalty_last_n,
454 penalty_repeat,
455 penalty_freq,
456 penalty_present,
457 )
458 };
459 Self {
460 sampler: NonNull::new(sampler).unwrap(),
461 }
462 }
463
464 /// Same as [`Self::penalties`] with sensible defaults:
465 /// `penalty_freq = 0.0` and `penalty_present = 0.0`.
466 ///
467 /// Parameters:
468 /// - `penalty_last_n`: last n tokens to penalize (0 = disable, -1 = context size)
469 /// - `penalty_repeat`: repetition penalty (1.0 = disabled)
470 ///
471 /// # Panics
472 ///
473 /// Panics if llama.cpp returns a null pointer.
474 #[must_use]
475 pub fn penalties_simple(penalty_last_n: i32, penalty_repeat: f32) -> Self {
476 Self::penalties(
477 #[allow(clippy::cast_precision_loss)]
478 {
479 penalty_last_n as i32
480 },
481 #[allow(clippy::cast_precision_loss)]
482 {
483 penalty_repeat as f32
484 },
485 #[allow(clippy::cast_precision_loss)]
486 {
487 0.0 as f32
488 },
489 #[allow(clippy::cast_precision_loss)]
490 {
491 0.0 as f32
492 },
493 )
494 }
495
496 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
497 ///
498 /// # Panics
499 ///
500 /// Panics if llama.cpp returns a null pointer.
501 ///
502 /// # Parameters:
503 /// - `n_vocab`: [`LlamaModel::n_vocab`]
504 /// - `seed`: Seed to initialize random generation with.
505 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
506 /// generated text. A higher value corresponds to more surprising or less predictable text,
507 /// while a lower value corresponds to less surprising or more predictable text.
508 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
509 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
510 /// updated more quickly, while a smaller learning rate will result in slower updates.
511 /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
512 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
513 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
514 /// it affects the performance of the algorithm.
515 #[must_use]
516 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
517 let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
518 Self {
519 sampler: NonNull::new(sampler).unwrap(),
520 }
521 }
522
523 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
524 ///
525 /// # Panics
526 ///
527 /// Panics if llama.cpp returns a null pointer.
528 ///
529 /// # Parameters:
530 /// - `seed`: Seed to initialize random generation with.
531 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
532 /// generated text. A higher value corresponds to more surprising or less predictable text,
533 /// while a lower value corresponds to less surprising or more predictable text.
534 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
535 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
536 /// updated more quickly, while a smaller learning rate will result in slower updates.
537 #[must_use]
538 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
539 let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
540 Self {
541 sampler: NonNull::new(sampler).unwrap(),
542 }
543 }
544
545 /// Selects a token at random based on each token's probabilities.
546 ///
547 /// # Panics
548 ///
549 /// Panics if llama.cpp returns a null pointer.
550 #[must_use]
551 pub fn dist(seed: u32) -> Self {
552 let sampler = unsafe { llama_sampler_init_dist(seed) };
553 Self {
554 sampler: NonNull::new(sampler).unwrap(),
555 }
556 }
557
558 /// Selects the most likely token.
559 ///
560 /// # Panics
561 ///
562 /// Panics if llama.cpp returns a null pointer.
563 ///
564 /// # Example:
565 /// ```rust
566 /// use llama_cpp_4::token::{
567 /// LlamaToken,
568 /// data::LlamaTokenData,
569 /// data_array::LlamaTokenDataArray
570 /// };
571 /// use llama_cpp_4::sampling::LlamaSampler;
572 ///
573 /// let mut data_array = LlamaTokenDataArray::new(vec![
574 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
575 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
576 /// ], false);
577 ///
578 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
579 ///
580 /// assert_eq!(data_array.data.len(), 2);
581 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
582 /// ```
583 #[must_use]
584 pub fn greedy() -> Self {
585 let sampler = unsafe { llama_sampler_init_greedy() };
586 Self {
587 sampler: NonNull::new(sampler).unwrap(),
588 }
589 }
590
591 /// Top-N sigma sampling.
592 ///
593 /// Keeps tokens within N standard deviations of the maximum logit.
594 ///
595 /// # Panics
596 ///
597 /// Panics if llama.cpp returns a null pointer.
598 #[must_use]
599 pub fn top_n_sigma(n: f32) -> Self {
600 let sampler = unsafe { llama_sampler_init_top_n_sigma(n) };
601 Self {
602 sampler: NonNull::new(sampler).unwrap(),
603 }
604 }
605
606 /// Adaptive P sampling.
607 ///
608 /// # Panics
609 ///
610 /// Panics if llama.cpp returns a null pointer.
611 ///
612 /// # Parameters
613 /// - `target`: Target probability.
614 /// - `decay`: Decay rate.
615 /// - `seed`: Random seed.
616 #[must_use]
617 pub fn adaptive_p(target: f32, decay: f32, seed: u32) -> Self {
618 let sampler = unsafe { llama_sampler_init_adaptive_p(target, decay, seed) };
619 Self {
620 sampler: NonNull::new(sampler).unwrap(),
621 }
622 }
623
624 /// Logit bias sampler.
625 ///
626 /// Applies additive bias to specific token logits before sampling.
627 ///
628 /// # Panics
629 ///
630 /// Panics if llama.cpp returns a null pointer.
631 ///
632 /// # Parameters
633 /// - `n_vocab`: Number of tokens in the vocabulary ([`LlamaModel::n_vocab`]).
634 /// - `biases`: Slice of `(token_id, bias)` pairs.
635 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
636 #[must_use]
637 pub fn logit_bias(n_vocab: i32, biases: &[(LlamaToken, f32)]) -> Self {
638 let logit_biases: Vec<llama_logit_bias> = biases
639 .iter()
640 .map(|(token, bias)| llama_logit_bias {
641 token: token.0,
642 bias: *bias,
643 })
644 .collect();
645
646 let sampler = unsafe {
647 llama_sampler_init_logit_bias(n_vocab, logit_biases.len() as i32, logit_biases.as_ptr())
648 };
649 Self {
650 sampler: NonNull::new(sampler).unwrap(),
651 }
652 }
653
654 /// Infill sampler.
655 ///
656 /// Reorders token probabilities for fill-in-the-middle tasks.
657 ///
658 /// # Panics
659 ///
660 /// Panics if llama.cpp returns a null pointer.
661 #[must_use]
662 pub fn infill(model: &LlamaModel) -> Self {
663 let sampler = unsafe { llama_sampler_init_infill(model.get_vocab().vocab.as_ref()) };
664 Self {
665 sampler: NonNull::new(sampler).unwrap(),
666 }
667 }
668
669 /// Get the seed of the sampler.
670 ///
671 /// Returns `LLAMA_DEFAULT_SEED` if the sampler is not seeded.
672 #[must_use]
673 pub fn get_seed(&self) -> u32 {
674 unsafe { llama_sampler_get_seed(self.sampler.as_ptr()) }
675 }
676
677 /// Get the name of the sampler.
678 ///
679 /// # Panics
680 ///
681 /// Panics if the name is not valid UTF-8.
682 #[must_use]
683 pub fn name(&self) -> String {
684 let c_str = unsafe { llama_sampler_name(self.sampler.as_ptr()) };
685 let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
686 c_str
687 .to_str()
688 .expect("sampler name is not valid UTF-8")
689 .to_owned()
690 }
691
692 /// Reset the sampler state (e.g. grammar, repetition penalties).
693 pub fn reset(&mut self) {
694 unsafe { llama_sampler_reset(self.sampler.as_ptr()) }
695 }
696
697 /// Get the number of samplers in a chain.
698 ///
699 /// Returns 0 if this sampler is not a chain.
700 #[must_use]
701 pub fn chain_n(&self) -> i32 {
702 unsafe { llama_sampler_chain_n(self.sampler.as_ptr()) }
703 }
704
705 /// Remove and return the sampler at position `i` from a chain.
706 ///
707 /// The returned sampler is owned by the caller and will be freed on drop.
708 ///
709 /// # Panics
710 ///
711 /// Panics if `i` is out of range or if llama.cpp returns a null pointer.
712 #[must_use]
713 pub fn chain_remove(&mut self, i: i32) -> Self {
714 let sampler = unsafe { llama_sampler_chain_remove(self.sampler.as_ptr(), i) };
715 Self {
716 sampler: NonNull::new(sampler).expect("chain_remove returned null"),
717 }
718 }
719
720 /// Grammar sampler with lazy activation.
721 ///
722 /// The grammar is only activated when one of the trigger words or trigger tokens is encountered.
723 ///
724 /// # Panics
725 /// - If `grammar_str` or `grammar_root` contain null bytes.
726 /// - If any trigger word contains null bytes.
727 /// - If llama.cpp returns a null pointer.
728 #[must_use]
729 #[deprecated(note = "use grammar_lazy_patterns instead")]
730 pub fn grammar_lazy(
731 model: &LlamaModel,
732 grammar_str: &str,
733 grammar_root: &str,
734 trigger_words: &[&str],
735 trigger_tokens: &[LlamaToken],
736 ) -> Self {
737 let grammar_str = CString::new(grammar_str).unwrap();
738 let grammar_root = CString::new(grammar_root).unwrap();
739 let trigger_cstrings: Vec<CString> = trigger_words
740 .iter()
741 .map(|w| CString::new(*w).unwrap())
742 .collect();
743 let mut trigger_ptrs: Vec<*const c_char> =
744 trigger_cstrings.iter().map(|s| s.as_ptr()).collect();
745
746 let sampler = unsafe {
747 llama_sampler_init_grammar_lazy(
748 model.get_vocab().vocab.as_ref(),
749 grammar_str.as_ptr(),
750 grammar_root.as_ptr(),
751 trigger_ptrs.as_mut_ptr(),
752 trigger_ptrs.len(),
753 trigger_tokens.as_ptr().cast(),
754 trigger_tokens.len(),
755 )
756 };
757 Self {
758 sampler: NonNull::new(sampler).unwrap(),
759 }
760 }
761
762 /// Grammar sampler with lazy activation via regex patterns.
763 ///
764 /// The grammar is only activated when one of the trigger patterns or trigger tokens matches.
765 ///
766 /// # Panics
767 /// - If `grammar_str` or `grammar_root` contain null bytes.
768 /// - If any trigger pattern contains null bytes.
769 /// - If llama.cpp returns a null pointer.
770 #[must_use]
771 pub fn grammar_lazy_patterns(
772 model: &LlamaModel,
773 grammar_str: &str,
774 grammar_root: &str,
775 trigger_patterns: &[&str],
776 trigger_tokens: &[LlamaToken],
777 ) -> Self {
778 let grammar_str = CString::new(grammar_str).unwrap();
779 let grammar_root = CString::new(grammar_root).unwrap();
780 let pattern_cstrings: Vec<CString> = trigger_patterns
781 .iter()
782 .map(|w| CString::new(*w).unwrap())
783 .collect();
784 let mut pattern_ptrs: Vec<*const c_char> =
785 pattern_cstrings.iter().map(|s| s.as_ptr()).collect();
786
787 let sampler = unsafe {
788 llama_sampler_init_grammar_lazy_patterns(
789 model.get_vocab().vocab.as_ref(),
790 grammar_str.as_ptr(),
791 grammar_root.as_ptr(),
792 pattern_ptrs.as_mut_ptr(),
793 pattern_ptrs.len(),
794 trigger_tokens.as_ptr().cast(),
795 trigger_tokens.len(),
796 )
797 };
798 Self {
799 sampler: NonNull::new(sampler).unwrap(),
800 }
801 }
802
803 /// Clone this sampler.
804 ///
805 /// Creates an independent copy of this sampler with the same state.
806 ///
807 /// # Panics
808 ///
809 /// Panics if llama.cpp returns a null pointer.
810 #[must_use]
811 pub fn clone_sampler(&self) -> Self {
812 let sampler = unsafe { llama_sampler_clone(self.sampler.as_ptr()) };
813 Self {
814 sampler: NonNull::new(sampler).expect("sampler_clone returned null"),
815 }
816 }
817
818 /// Print sampler performance data.
819 pub fn perf_print(&self) {
820 unsafe { llama_cpp_sys_4::llama_perf_sampler_print(self.sampler.as_ptr()) }
821 }
822
823 /// Reset sampler performance counters.
824 pub fn perf_reset(&mut self) {
825 unsafe { llama_cpp_sys_4::llama_perf_sampler_reset(self.sampler.as_ptr()) }
826 }
827
828 /// Get sampler performance data.
829 #[must_use]
830 pub fn perf_data(&self) -> llama_cpp_sys_4::llama_perf_sampler_data {
831 unsafe { llama_cpp_sys_4::llama_perf_sampler(self.sampler.as_ptr()) }
832 }
833
834 /// Get a non-owning reference to the `i`th sampler in a chain.
835 ///
836 /// # Safety
837 ///
838 /// The returned pointer is owned by the chain. Do not free it or use it
839 /// after the chain is dropped or modified.
840 #[must_use]
841 pub unsafe fn chain_get_ptr(&self, i: i32) -> *mut llama_sampler {
842 llama_cpp_sys_4::llama_sampler_chain_get(self.sampler.as_ptr(), i)
843 }
844
845 /// Create a sampler from a raw interface and context.
846 ///
847 /// # Safety
848 ///
849 /// The caller must ensure that `iface` and `ctx` are valid and that the
850 /// interface functions properly manage the context lifetime.
851 ///
852 /// # Panics
853 ///
854 /// Panics if llama.cpp returns a null pointer.
855 #[must_use]
856 pub unsafe fn from_raw(
857 iface: *mut llama_cpp_sys_4::llama_sampler_i,
858 ctx: llama_cpp_sys_4::llama_sampler_context_t,
859 ) -> Self {
860 let sampler = llama_cpp_sys_4::llama_sampler_init(iface, ctx);
861 Self {
862 sampler: NonNull::new(sampler).expect("sampler_init returned null"),
863 }
864 }
865
866 /// Creates a new instance of `LlamaSampler` with common sampling parameters.
867 ///
868 /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
869 /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
870 ///
871 /// # Panics
872 ///
873 /// Panics if llama.cpp returns a null pointer.
874 ///
875 /// # Returns
876 /// A `LlamaSampler` instance configured with the common sampling parameters.
877 #[must_use]
878 pub fn common() -> Self {
879 let params = common_sampler_params::default();
880
881 let sampler = unsafe {
882 let mut sparams = llama_sampler_chain_default_params();
883 sparams.no_perf = false;
884
885 let smpl = llama_sampler_chain_init(sparams);
886
887 llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
888 llama_sampler_chain_add(
889 smpl,
890 #[allow(clippy::cast_sign_loss)]
891 llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
892 );
893 llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
894 #[allow(clippy::cast_sign_loss)]
895 llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
896
897 smpl
898 };
899
900 Self {
901 sampler: NonNull::new(sampler).unwrap(),
902 }
903 }
904}
905
906impl Drop for LlamaSampler {
907 fn drop(&mut self) {
908 unsafe {
909 llama_sampler_free(self.sampler.as_ptr());
910 }
911 }
912}