llama_cpp_4/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9 common::common_sampler_params, llama_logit_bias, llama_sampler, llama_sampler_accept,
10 llama_sampler_chain_add, llama_sampler_chain_default_params, llama_sampler_chain_init,
11 llama_sampler_chain_n, llama_sampler_chain_remove, llama_sampler_clone, llama_sampler_free,
12 llama_sampler_get_seed, llama_sampler_init_adaptive_p, llama_sampler_init_dist,
13 llama_sampler_init_dry, llama_sampler_init_grammar, llama_sampler_init_grammar_lazy,
14 llama_sampler_init_grammar_lazy_patterns, llama_sampler_init_greedy,
15 llama_sampler_init_infill, llama_sampler_init_logit_bias, llama_sampler_init_min_p,
16 llama_sampler_init_mirostat, llama_sampler_init_mirostat_v2, llama_sampler_init_penalties,
17 llama_sampler_init_temp, llama_sampler_init_temp_ext, llama_sampler_init_top_k,
18 llama_sampler_init_top_n_sigma, llama_sampler_init_top_p, llama_sampler_init_typical,
19 llama_sampler_init_xtc, llama_sampler_name, llama_sampler_reset, llama_sampler_sample,
20};
21
22use crate::context::LlamaContext;
23use crate::model::LlamaModel;
24use crate::token::data_array::LlamaTokenDataArray;
25use crate::token::LlamaToken;
26
27/// A safe wrapper around `llama_sampler`.
28pub struct LlamaSampler {
29 pub(crate) sampler: NonNull<llama_sampler>,
30}
31
32impl Debug for LlamaSampler {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("LlamaSamplerChain").finish()
35 }
36}
37#[derive(Debug, Clone)]
38#[allow(
39 missing_docs,
40 clippy::struct_excessive_bools,
41 clippy::module_name_repetitions,
42 dead_code
43)]
44pub struct LlamaSamplerParams {
45 top_k: i32,
46 top_p: f32,
47 temp: f32,
48 seed: u32,
49}
50
51impl LlamaSamplerParams {
52 /// Set the seed of the context
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
58 /// let params = LlamaSamplerParams::default();
59 /// let params = params.with_seed(1234);
60 /// assert_eq!(params.seed(), 1234);
61 /// ```
62 #[must_use]
63 pub fn with_seed(mut self, seed: u32) -> Self {
64 self.seed = seed;
65 self
66 }
67
68 /// Get the seed of the context
69 ///
70 /// # Examples
71 ///
72 /// ```rust
73 /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
74 /// let params = LlamaSamplerParams::default();
75 /// .with_seed(1234);
76 /// assert_eq!(params.seed(), 1234);
77 /// ```
78 #[must_use]
79 pub fn seed(&self) -> u32 {
80 self.seed
81 }
82}
83
84impl Default for LlamaSamplerParams {
85 fn default() -> Self {
86 Self {
87 top_k: 50,
88 top_p: 0.9,
89 temp: 0.8,
90 seed: 1234,
91 }
92 }
93}
94
95impl Default for LlamaSampler {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl LlamaSampler {
102 /// Create new sampler with default params.
103 ///
104 /// # Panics
105 ///
106 /// Panics if llama.cpp returns a null pointer.
107 #[must_use]
108 pub fn new() -> Self {
109 let sparams = unsafe { llama_sampler_chain_default_params() };
110
111 Self {
112 sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
113 }
114 }
115
116 /// Sample and accept a token from the idx-th output of the last evaluation
117 #[must_use]
118 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
119 let token =
120 unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
121
122 LlamaToken(token)
123 }
124
125 /// Applies this sampler to a [`LlamaTokenDataArray`].
126 pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
127 data_array.apply_sampler(self);
128 }
129
130 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
131 /// (e.g. grammar, repetition, etc.)
132 pub fn accept(&mut self, token: LlamaToken) {
133 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
134 }
135
136 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
137 /// certain samplers (e.g. grammar, repetition, etc.)
138 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
139 for token in tokens {
140 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
141 }
142 }
143
144 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
145 /// certain samplers (e.g. grammar, repetition, etc.)
146 #[must_use]
147 pub fn with_tokens(
148 mut self,
149 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
150 ) -> Self {
151 self.accept_many(tokens);
152 self
153 }
154
155 /// Combines a list of samplers into a single sampler that applies each component sampler one
156 /// after another.
157 ///
158 /// If you are using a chain to select a token, the chain should always end with one of
159 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
160 /// [`LlamaSampler::mirostat_v2`].
161 ///
162 /// # Panics
163 ///
164 /// Panics if llama.cpp returns a null pointer.
165 #[must_use]
166 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
167 unsafe {
168 let mut params = llama_sampler_chain_default_params();
169 params.no_perf = no_perf;
170 let chain = llama_sampler_chain_init(params);
171
172 for sampler in samplers {
173 llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
174
175 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
176 // owned by the chain
177 std::mem::forget(sampler);
178 }
179
180 Self {
181 sampler: NonNull::new(chain).unwrap(),
182 }
183 }
184 }
185
186 /// Same as [`Self::chain`] with `no_perf = false`.
187 ///
188 /// # Panics
189 ///
190 /// Panics if llama.cpp returns a null pointer.
191 ///
192 /// # Example
193 /// ```rust
194 /// use llama_cpp_4::token::{
195 /// LlamaToken,
196 /// data::LlamaTokenData,
197 /// data_array::LlamaTokenDataArray
198 /// };
199 /// use llama_cpp_4::sampling::LlamaSampler;
200 ///
201 /// let mut data_array = LlamaTokenDataArray::new(vec![
202 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
203 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
204 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
205 /// ], false);
206 ///
207 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
208 /// LlamaSampler::temp(0.5),
209 /// LlamaSampler::greedy(),
210 /// ]));
211 ///
212 /// assert_eq!(data_array.data[0].logit(), 0.);
213 /// assert_eq!(data_array.data[1].logit(), 2.);
214 /// assert_eq!(data_array.data[2].logit(), 4.);
215 ///
216 /// assert_eq!(data_array.data.len(), 3);
217 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
218 /// ```
219 #[must_use]
220 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
221 Self::chain(samplers, false)
222 }
223
224 /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
225 /// value, the rest are set to -inf.
226 ///
227 /// # Panics
228 ///
229 /// Panics if llama.cpp returns a null pointer.
230 ///
231 /// # Example:
232 /// ```rust
233 /// use llama_cpp_4::token::{
234 /// LlamaToken,
235 /// data::LlamaTokenData,
236 /// data_array::LlamaTokenDataArray
237 /// };
238 /// use llama_cpp_4::sampling::LlamaSampler;
239 ///
240 /// let mut data_array = LlamaTokenDataArray::new(vec![
241 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
242 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
243 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
244 /// ], false);
245 ///
246 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
247 ///
248 /// assert_eq!(data_array.data[0].logit(), 0.);
249 /// assert_eq!(data_array.data[1].logit(), 2.);
250 /// assert_eq!(data_array.data[2].logit(), 4.);
251 /// ```
252 #[must_use]
253 pub fn temp(t: f32) -> Self {
254 let sampler = unsafe { llama_sampler_init_temp(t) };
255 Self {
256 sampler: NonNull::new(sampler).unwrap(),
257 }
258 }
259
260 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
261 /// <https://arxiv.org/abs/2309.02772>.
262 ///
263 /// # Panics
264 ///
265 /// Panics if llama.cpp returns a null pointer.
266 #[must_use]
267 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
268 let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
269 Self {
270 sampler: NonNull::new(sampler).unwrap(),
271 }
272 }
273
274 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
275 /// <https://arxiv.org/abs/1904.09751>.
276 ///
277 /// # Panics
278 ///
279 /// Panics if llama.cpp returns a null pointer.
280 ///
281 /// # Example:
282 /// ```rust
283 /// use llama_cpp_4::token::{
284 /// LlamaToken,
285 /// data::LlamaTokenData,
286 /// data_array::LlamaTokenDataArray
287 /// };
288 /// use llama_cpp_4::sampling::LlamaSampler;
289 ///
290 /// let mut data_array = LlamaTokenDataArray::new(vec![
291 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
292 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
293 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
294 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
295 /// ], false);
296 ///
297 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
298 ///
299 /// assert_eq!(data_array.data.len(), 2);
300 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
301 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
302 /// ```
303 #[must_use]
304 pub fn top_k(k: i32) -> Self {
305 let sampler = unsafe { llama_sampler_init_top_k(k) };
306 Self {
307 sampler: NonNull::new(sampler).unwrap(),
308 }
309 }
310
311 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
312 ///
313 /// # Panics
314 ///
315 /// Panics if llama.cpp returns a null pointer.
316 #[must_use]
317 pub fn typical(p: f32, min_keep: usize) -> Self {
318 let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
319 Self {
320 sampler: NonNull::new(sampler).unwrap(),
321 }
322 }
323
324 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
325 /// <https://arxiv.org/abs/1904.09751>.
326 ///
327 /// # Panics
328 ///
329 /// Panics if llama.cpp returns a null pointer.
330 #[must_use]
331 pub fn top_p(p: f32, min_keep: usize) -> Self {
332 let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
333 Self {
334 sampler: NonNull::new(sampler).unwrap(),
335 }
336 }
337
338 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
339 ///
340 /// # Panics
341 ///
342 /// Panics if llama.cpp returns a null pointer.
343 #[must_use]
344 pub fn min_p(p: f32, min_keep: usize) -> Self {
345 let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
346 Self {
347 sampler: NonNull::new(sampler).unwrap(),
348 }
349 }
350
351 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
352 ///
353 /// # Panics
354 ///
355 /// Panics if llama.cpp returns a null pointer.
356 #[must_use]
357 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
358 let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
359 Self {
360 sampler: NonNull::new(sampler).unwrap(),
361 }
362 }
363
364 /// Grammar sampler
365 ///
366 /// # Panics
367 /// - If either of `grammar_str` or `grammar_root` contain null bytes.
368 /// - If llama.cpp returns a null pointer.
369 #[must_use]
370 pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
371 let grammar_str = CString::new(grammar_str).unwrap();
372 let grammar_root = CString::new(grammar_root).unwrap();
373
374 let sampler = unsafe {
375 llama_sampler_init_grammar(
376 model.get_vocab().vocab.as_ref(),
377 grammar_str.as_ptr(),
378 grammar_root.as_ptr(),
379 )
380 };
381 Self {
382 sampler: NonNull::new(sampler).unwrap(),
383 }
384 }
385
386 /// DRY sampler, designed by p-e-w, as described in:
387 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
388 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
389 ///
390 /// # Panics
391 /// - If any string in `seq_breakers` contains null bytes.
392 /// - If llama.cpp returns a null pointer.
393 #[allow(clippy::too_many_arguments)]
394 #[must_use]
395 pub fn dry(
396 &self,
397 model: &LlamaModel,
398 n_ctx_train: i32,
399 multiplier: f32,
400 base: f32,
401 allowed_length: i32,
402 penalty_last_n: i32,
403 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
404 ) -> Self {
405 let seq_breakers: Vec<CString> = seq_breakers
406 .into_iter()
407 .map(|s| CString::new(s.as_ref()).unwrap())
408 .collect();
409 // CString::as_ptr() returns *const c_char, which matches what the binding
410 // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
411 let mut seq_breaker_pointers: Vec<*const c_char> =
412 seq_breakers.iter().map(|s| s.as_ptr()).collect();
413
414 let sampler = unsafe {
415 llama_sampler_init_dry(
416 model.get_vocab().vocab.as_ref(),
417 n_ctx_train,
418 multiplier,
419 base,
420 allowed_length,
421 penalty_last_n,
422 seq_breaker_pointers.as_mut_ptr(),
423 seq_breaker_pointers.len(),
424 )
425 };
426
427 Self {
428 sampler: NonNull::new(sampler).unwrap(),
429 }
430 }
431
432 /// Penalizes tokens for being present in the context.
433 ///
434 /// Parameters:
435 /// - `n_vocab`: [`LlamaModel::n_vocab`]
436 /// - `special_eos_id`: [`LlamaModel::token_eos`]
437 /// - `linefeed_id`: [`LlamaModel::token_nl`]
438 /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
439 ///
440 /// # Panics
441 ///
442 /// Panics if llama.cpp returns a null pointer.
443 #[allow(clippy::too_many_arguments)]
444 #[must_use]
445 pub fn penalties(
446 n_vocab: i32,
447 special_eos_id: f32,
448 linefeed_id: f32,
449 penalty_last_n: f32,
450 // penalty_repeat: f32,
451 // penalty_freq: f32,
452 // penalty_present: f32,
453 // penalize_nl: bool,
454 // ignore_eos: bool,
455 ) -> Self {
456 let sampler = unsafe {
457 llama_sampler_init_penalties(
458 n_vocab,
459 special_eos_id,
460 linefeed_id,
461 penalty_last_n,
462 // penalty_repeat,
463 // penalty_freq,
464 // penalty_present,
465 // penalize_nl,
466 // ignore_eos,
467 )
468 };
469 Self {
470 sampler: NonNull::new(sampler).unwrap(),
471 }
472 }
473
474 /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
475 /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
476 ///
477 /// Parameters:
478 /// - `model`: The model's tokenizer to use to initialize the sampler.
479 /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
480 ///
481 /// # Panics
482 ///
483 /// Panics if llama.cpp returns a null pointer.
484 #[must_use]
485 pub fn penalties_simple(
486 model: &LlamaModel,
487 penalty_last_n: i32,
488 // penalty_repeat: f32,
489 // penalty_freq: f32,
490 // penalty_present: f32,
491 ) -> Self {
492 Self::penalties(
493 model.n_vocab(),
494 #[allow(clippy::cast_precision_loss)]
495 {
496 model.token_eos().0 as f32
497 },
498 #[allow(clippy::cast_precision_loss)]
499 {
500 model.token_nl().0 as f32
501 },
502 #[allow(clippy::cast_precision_loss)]
503 {
504 penalty_last_n as f32
505 },
506 // penalty_repeat,
507 // penalty_freq,
508 // penalty_present,
509 // false,
510 // true,
511 )
512 }
513
514 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
515 ///
516 /// # Panics
517 ///
518 /// Panics if llama.cpp returns a null pointer.
519 ///
520 /// # Parameters:
521 /// - `n_vocab`: [`LlamaModel::n_vocab`]
522 /// - `seed`: Seed to initialize random generation with.
523 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
524 /// generated text. A higher value corresponds to more surprising or less predictable text,
525 /// while a lower value corresponds to less surprising or more predictable text.
526 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
527 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
528 /// updated more quickly, while a smaller learning rate will result in slower updates.
529 /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
530 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
531 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
532 /// it affects the performance of the algorithm.
533 #[must_use]
534 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
535 let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
536 Self {
537 sampler: NonNull::new(sampler).unwrap(),
538 }
539 }
540
541 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
542 ///
543 /// # Panics
544 ///
545 /// Panics if llama.cpp returns a null pointer.
546 ///
547 /// # Parameters:
548 /// - `seed`: Seed to initialize random generation with.
549 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
550 /// generated text. A higher value corresponds to more surprising or less predictable text,
551 /// while a lower value corresponds to less surprising or more predictable text.
552 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
553 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
554 /// updated more quickly, while a smaller learning rate will result in slower updates.
555 #[must_use]
556 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
557 let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
558 Self {
559 sampler: NonNull::new(sampler).unwrap(),
560 }
561 }
562
563 /// Selects a token at random based on each token's probabilities.
564 ///
565 /// # Panics
566 ///
567 /// Panics if llama.cpp returns a null pointer.
568 #[must_use]
569 pub fn dist(seed: u32) -> Self {
570 let sampler = unsafe { llama_sampler_init_dist(seed) };
571 Self {
572 sampler: NonNull::new(sampler).unwrap(),
573 }
574 }
575
576 /// Selects the most likely token.
577 ///
578 /// # Panics
579 ///
580 /// Panics if llama.cpp returns a null pointer.
581 ///
582 /// # Example:
583 /// ```rust
584 /// use llama_cpp_4::token::{
585 /// LlamaToken,
586 /// data::LlamaTokenData,
587 /// data_array::LlamaTokenDataArray
588 /// };
589 /// use llama_cpp_4::sampling::LlamaSampler;
590 ///
591 /// let mut data_array = LlamaTokenDataArray::new(vec![
592 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
593 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
594 /// ], false);
595 ///
596 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
597 ///
598 /// assert_eq!(data_array.data.len(), 2);
599 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
600 /// ```
601 #[must_use]
602 pub fn greedy() -> Self {
603 let sampler = unsafe { llama_sampler_init_greedy() };
604 Self {
605 sampler: NonNull::new(sampler).unwrap(),
606 }
607 }
608
609 /// Top-N sigma sampling.
610 ///
611 /// Keeps tokens within N standard deviations of the maximum logit.
612 ///
613 /// # Panics
614 ///
615 /// Panics if llama.cpp returns a null pointer.
616 #[must_use]
617 pub fn top_n_sigma(n: f32) -> Self {
618 let sampler = unsafe { llama_sampler_init_top_n_sigma(n) };
619 Self {
620 sampler: NonNull::new(sampler).unwrap(),
621 }
622 }
623
624 /// Adaptive P sampling.
625 ///
626 /// # Panics
627 ///
628 /// Panics if llama.cpp returns a null pointer.
629 ///
630 /// # Parameters
631 /// - `target`: Target probability.
632 /// - `decay`: Decay rate.
633 /// - `seed`: Random seed.
634 #[must_use]
635 pub fn adaptive_p(target: f32, decay: f32, seed: u32) -> Self {
636 let sampler = unsafe { llama_sampler_init_adaptive_p(target, decay, seed) };
637 Self {
638 sampler: NonNull::new(sampler).unwrap(),
639 }
640 }
641
642 /// Logit bias sampler.
643 ///
644 /// Applies additive bias to specific token logits before sampling.
645 ///
646 /// # Panics
647 ///
648 /// Panics if llama.cpp returns a null pointer.
649 ///
650 /// # Parameters
651 /// - `n_vocab`: Number of tokens in the vocabulary ([`LlamaModel::n_vocab`]).
652 /// - `biases`: Slice of `(token_id, bias)` pairs.
653 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
654 #[must_use]
655 pub fn logit_bias(n_vocab: i32, biases: &[(LlamaToken, f32)]) -> Self {
656 let logit_biases: Vec<llama_logit_bias> = biases
657 .iter()
658 .map(|(token, bias)| llama_logit_bias {
659 token: token.0,
660 bias: *bias,
661 })
662 .collect();
663
664 let sampler = unsafe {
665 llama_sampler_init_logit_bias(
666 n_vocab,
667 logit_biases.len() as i32,
668 logit_biases.as_ptr(),
669 )
670 };
671 Self {
672 sampler: NonNull::new(sampler).unwrap(),
673 }
674 }
675
676 /// Infill sampler.
677 ///
678 /// Reorders token probabilities for fill-in-the-middle tasks.
679 ///
680 /// # Panics
681 ///
682 /// Panics if llama.cpp returns a null pointer.
683 #[must_use]
684 pub fn infill(model: &LlamaModel) -> Self {
685 let sampler =
686 unsafe { llama_sampler_init_infill(model.get_vocab().vocab.as_ref()) };
687 Self {
688 sampler: NonNull::new(sampler).unwrap(),
689 }
690 }
691
692 /// Get the seed of the sampler.
693 ///
694 /// Returns `LLAMA_DEFAULT_SEED` if the sampler is not seeded.
695 #[must_use]
696 pub fn get_seed(&self) -> u32 {
697 unsafe { llama_sampler_get_seed(self.sampler.as_ptr()) }
698 }
699
700 /// Get the name of the sampler.
701 ///
702 /// # Panics
703 ///
704 /// Panics if the name is not valid UTF-8.
705 #[must_use]
706 pub fn name(&self) -> String {
707 let c_str = unsafe { llama_sampler_name(self.sampler.as_ptr()) };
708 let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
709 c_str.to_str().expect("sampler name is not valid UTF-8").to_owned()
710 }
711
712 /// Reset the sampler state (e.g. grammar, repetition penalties).
713 pub fn reset(&mut self) {
714 unsafe { llama_sampler_reset(self.sampler.as_ptr()) }
715 }
716
717 /// Get the number of samplers in a chain.
718 ///
719 /// Returns 0 if this sampler is not a chain.
720 #[must_use]
721 pub fn chain_n(&self) -> i32 {
722 unsafe { llama_sampler_chain_n(self.sampler.as_ptr()) }
723 }
724
725 /// Remove and return the sampler at position `i` from a chain.
726 ///
727 /// The returned sampler is owned by the caller and will be freed on drop.
728 ///
729 /// # Panics
730 ///
731 /// Panics if `i` is out of range or if llama.cpp returns a null pointer.
732 #[must_use]
733 pub fn chain_remove(&mut self, i: i32) -> Self {
734 let sampler = unsafe { llama_sampler_chain_remove(self.sampler.as_ptr(), i) };
735 Self {
736 sampler: NonNull::new(sampler).expect("chain_remove returned null"),
737 }
738 }
739
740 /// Grammar sampler with lazy activation.
741 ///
742 /// The grammar is only activated when one of the trigger words or trigger tokens is encountered.
743 ///
744 /// # Panics
745 /// - If `grammar_str` or `grammar_root` contain null bytes.
746 /// - If any trigger word contains null bytes.
747 /// - If llama.cpp returns a null pointer.
748 #[must_use]
749 pub fn grammar_lazy(
750 model: &LlamaModel,
751 grammar_str: &str,
752 grammar_root: &str,
753 trigger_words: &[&str],
754 trigger_tokens: &[LlamaToken],
755 ) -> Self {
756 let grammar_str = CString::new(grammar_str).unwrap();
757 let grammar_root = CString::new(grammar_root).unwrap();
758 let trigger_cstrings: Vec<CString> = trigger_words
759 .iter()
760 .map(|w| CString::new(*w).unwrap())
761 .collect();
762 let mut trigger_ptrs: Vec<*const c_char> =
763 trigger_cstrings.iter().map(|s| s.as_ptr()).collect();
764
765 let sampler = unsafe {
766 llama_sampler_init_grammar_lazy(
767 model.get_vocab().vocab.as_ref(),
768 grammar_str.as_ptr(),
769 grammar_root.as_ptr(),
770 trigger_ptrs.as_mut_ptr(),
771 trigger_ptrs.len(),
772 trigger_tokens.as_ptr().cast(),
773 trigger_tokens.len(),
774 )
775 };
776 Self {
777 sampler: NonNull::new(sampler).unwrap(),
778 }
779 }
780
781 /// Grammar sampler with lazy activation via regex patterns.
782 ///
783 /// The grammar is only activated when one of the trigger patterns or trigger tokens matches.
784 ///
785 /// # Panics
786 /// - If `grammar_str` or `grammar_root` contain null bytes.
787 /// - If any trigger pattern contains null bytes.
788 /// - If llama.cpp returns a null pointer.
789 #[must_use]
790 pub fn grammar_lazy_patterns(
791 model: &LlamaModel,
792 grammar_str: &str,
793 grammar_root: &str,
794 trigger_patterns: &[&str],
795 trigger_tokens: &[LlamaToken],
796 ) -> Self {
797 let grammar_str = CString::new(grammar_str).unwrap();
798 let grammar_root = CString::new(grammar_root).unwrap();
799 let pattern_cstrings: Vec<CString> = trigger_patterns
800 .iter()
801 .map(|w| CString::new(*w).unwrap())
802 .collect();
803 let mut pattern_ptrs: Vec<*const c_char> =
804 pattern_cstrings.iter().map(|s| s.as_ptr()).collect();
805
806 let sampler = unsafe {
807 llama_sampler_init_grammar_lazy_patterns(
808 model.get_vocab().vocab.as_ref(),
809 grammar_str.as_ptr(),
810 grammar_root.as_ptr(),
811 pattern_ptrs.as_mut_ptr(),
812 pattern_ptrs.len(),
813 trigger_tokens.as_ptr().cast(),
814 trigger_tokens.len(),
815 )
816 };
817 Self {
818 sampler: NonNull::new(sampler).unwrap(),
819 }
820 }
821
822 /// Clone this sampler.
823 ///
824 /// Creates an independent copy of this sampler with the same state.
825 ///
826 /// # Panics
827 ///
828 /// Panics if llama.cpp returns a null pointer.
829 #[must_use]
830 pub fn clone_sampler(&self) -> Self {
831 let sampler = unsafe { llama_sampler_clone(self.sampler.as_ptr()) };
832 Self {
833 sampler: NonNull::new(sampler).expect("sampler_clone returned null"),
834 }
835 }
836
837 /// Print sampler performance data.
838 pub fn perf_print(&self) {
839 unsafe { llama_cpp_sys_4::llama_perf_sampler_print(self.sampler.as_ptr()) }
840 }
841
842 /// Reset sampler performance counters.
843 pub fn perf_reset(&mut self) {
844 unsafe { llama_cpp_sys_4::llama_perf_sampler_reset(self.sampler.as_ptr()) }
845 }
846
847 /// Get sampler performance data.
848 #[must_use]
849 pub fn perf_data(&self) -> llama_cpp_sys_4::llama_perf_sampler_data {
850 unsafe { llama_cpp_sys_4::llama_perf_sampler(self.sampler.as_ptr()) }
851 }
852
853 /// Get a non-owning reference to the `i`th sampler in a chain.
854 ///
855 /// # Safety
856 ///
857 /// The returned pointer is owned by the chain. Do not free it or use it
858 /// after the chain is dropped or modified.
859 #[must_use]
860 pub unsafe fn chain_get_ptr(&self, i: i32) -> *mut llama_sampler {
861 llama_cpp_sys_4::llama_sampler_chain_get(self.sampler.as_ptr(), i)
862 }
863
864 /// Create a sampler from a raw interface and context.
865 ///
866 /// # Safety
867 ///
868 /// The caller must ensure that `iface` and `ctx` are valid and that the
869 /// interface functions properly manage the context lifetime.
870 ///
871 /// # Panics
872 ///
873 /// Panics if llama.cpp returns a null pointer.
874 #[must_use]
875 pub unsafe fn from_raw(
876 iface: *mut llama_cpp_sys_4::llama_sampler_i,
877 ctx: llama_cpp_sys_4::llama_sampler_context_t,
878 ) -> Self {
879 let sampler = llama_cpp_sys_4::llama_sampler_init(iface, ctx);
880 Self {
881 sampler: NonNull::new(sampler).expect("sampler_init returned null"),
882 }
883 }
884
885 /// Creates a new instance of `LlamaSampler` with common sampling parameters.
886 ///
887 /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
888 /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
889 ///
890 /// # Panics
891 ///
892 /// Panics if llama.cpp returns a null pointer.
893 ///
894 /// # Returns
895 /// A `LlamaSampler` instance configured with the common sampling parameters.
896 #[must_use]
897 pub fn common() -> Self {
898 let params = common_sampler_params::default();
899
900 let sampler = unsafe {
901 let mut sparams = llama_sampler_chain_default_params();
902 sparams.no_perf = false;
903
904 let smpl = llama_sampler_chain_init(sparams);
905
906 llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
907 llama_sampler_chain_add(
908 smpl,
909 #[allow(clippy::cast_sign_loss)]
910 llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
911 );
912 llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
913 #[allow(clippy::cast_sign_loss)]
914 llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
915
916 smpl
917 };
918
919 Self {
920 sampler: NonNull::new(sampler).unwrap(),
921 }
922 }
923}
924
925impl Drop for LlamaSampler {
926 fn drop(&mut self) {
927 unsafe {
928 llama_sampler_free(self.sampler.as_ptr());
929 }
930 }
931}