llama_cpp_4/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6use std::ptr::NonNull;
7
8use llama_cpp_sys_4::{
9 common::common_sampler_params, llama_sampler, llama_sampler_accept, llama_sampler_chain_add,
10 llama_sampler_chain_default_params, llama_sampler_chain_init, llama_sampler_free,
11 llama_sampler_init_dist, llama_sampler_init_dry, llama_sampler_init_grammar,
12 llama_sampler_init_greedy, llama_sampler_init_min_p, llama_sampler_init_mirostat,
13 llama_sampler_init_mirostat_v2, llama_sampler_init_penalties, llama_sampler_init_temp,
14 llama_sampler_init_temp_ext, llama_sampler_init_top_k, llama_sampler_init_top_p,
15 llama_sampler_init_typical, llama_sampler_init_xtc, llama_sampler_sample,
16};
17
18use crate::context::LlamaContext;
19use crate::model::LlamaModel;
20use crate::token::data_array::LlamaTokenDataArray;
21use crate::token::LlamaToken;
22
23/// A safe wrapper around `llama_sampler`.
24pub struct LlamaSampler {
25 pub(crate) sampler: NonNull<llama_sampler>,
26}
27
28impl Debug for LlamaSampler {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("LlamaSamplerChain").finish()
31 }
32}
33#[derive(Debug, Clone)]
34#[allow(
35 missing_docs,
36 clippy::struct_excessive_bools,
37 clippy::module_name_repetitions,
38 dead_code
39)]
40pub struct LlamaSamplerParams {
41 top_k: i32,
42 top_p: f32,
43 temp: f32,
44 seed: u32,
45}
46
47impl LlamaSamplerParams {
48 /// Set the seed of the context
49 ///
50 /// # Examples
51 ///
52 /// ```rust
53 /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
54 /// let params = LlamaSamplerParams::default();
55 /// let params = params.with_seed(1234);
56 /// assert_eq!(params.seed(), 1234);
57 /// ```
58 #[must_use]
59 pub fn with_seed(mut self, seed: u32) -> Self {
60 self.seed = seed;
61 self
62 }
63
64 /// Get the seed of the context
65 ///
66 /// # Examples
67 ///
68 /// ```rust
69 /// use llama_cpp_4::context::sampler::LlamaSamplerParams;
70 /// let params = LlamaSamplerParams::default();
71 /// .with_seed(1234);
72 /// assert_eq!(params.seed(), 1234);
73 /// ```
74 #[must_use]
75 pub fn seed(&self) -> u32 {
76 self.seed
77 }
78}
79
80impl Default for LlamaSamplerParams {
81 fn default() -> Self {
82 Self {
83 top_k: 50,
84 top_p: 0.9,
85 temp: 0.8,
86 seed: 1234,
87 }
88 }
89}
90
91impl Default for LlamaSampler {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl LlamaSampler {
98 /// Create new sampler with default params.
99 ///
100 /// # Panics
101 ///
102 /// Panics if llama.cpp returns a null pointer.
103 #[must_use]
104 pub fn new() -> Self {
105 let sparams = unsafe { llama_sampler_chain_default_params() };
106
107 Self {
108 sampler: NonNull::new(unsafe { llama_sampler_chain_init(sparams) }).unwrap(),
109 }
110 }
111
112 /// Sample and accept a token from the idx-th output of the last evaluation
113 #[must_use]
114 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
115 let token =
116 unsafe { llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) };
117
118 LlamaToken(token)
119 }
120
121 /// Applies this sampler to a [`LlamaTokenDataArray`].
122 pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) {
123 data_array.apply_sampler(self);
124 }
125
126 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
127 /// (e.g. grammar, repetition, etc.)
128 pub fn accept(&mut self, token: LlamaToken) {
129 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.0) }
130 }
131
132 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
133 /// certain samplers (e.g. grammar, repetition, etc.)
134 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
135 for token in tokens {
136 unsafe { llama_sampler_accept(self.sampler.as_ptr(), token.borrow().0) }
137 }
138 }
139
140 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
141 /// certain samplers (e.g. grammar, repetition, etc.)
142 #[must_use]
143 pub fn with_tokens(
144 mut self,
145 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
146 ) -> Self {
147 self.accept_many(tokens);
148 self
149 }
150
151 /// Combines a list of samplers into a single sampler that applies each component sampler one
152 /// after another.
153 ///
154 /// If you are using a chain to select a token, the chain should always end with one of
155 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
156 /// [`LlamaSampler::mirostat_v2`].
157 ///
158 /// # Panics
159 ///
160 /// Panics if llama.cpp returns a null pointer.
161 #[must_use]
162 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
163 unsafe {
164 let mut params = llama_sampler_chain_default_params();
165 params.no_perf = no_perf;
166 let chain = llama_sampler_chain_init(params);
167
168 for sampler in samplers {
169 llama_sampler_chain_add(chain, sampler.sampler.as_ptr());
170
171 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
172 // owned by the chain
173 std::mem::forget(sampler);
174 }
175
176 Self {
177 sampler: NonNull::new(chain).unwrap(),
178 }
179 }
180 }
181
182 /// Same as [`Self::chain`] with `no_perf = false`.
183 ///
184 /// # Panics
185 ///
186 /// Panics if llama.cpp returns a null pointer.
187 ///
188 /// # Example
189 /// ```rust
190 /// use llama_cpp_4::token::{
191 /// LlamaToken,
192 /// data::LlamaTokenData,
193 /// data_array::LlamaTokenDataArray
194 /// };
195 /// use llama_cpp_4::sampling::LlamaSampler;
196 ///
197 /// let mut data_array = LlamaTokenDataArray::new(vec![
198 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
199 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
200 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
201 /// ], false);
202 ///
203 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
204 /// LlamaSampler::temp(0.5),
205 /// LlamaSampler::greedy(),
206 /// ]));
207 ///
208 /// assert_eq!(data_array.data[0].logit(), 0.);
209 /// assert_eq!(data_array.data[1].logit(), 2.);
210 /// assert_eq!(data_array.data[2].logit(), 4.);
211 ///
212 /// assert_eq!(data_array.data.len(), 3);
213 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
214 /// ```
215 #[must_use]
216 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
217 Self::chain(samplers, false)
218 }
219
220 /// Updates the logits `l_i`' = `l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
221 /// value, the rest are set to -inf.
222 ///
223 /// # Panics
224 ///
225 /// Panics if llama.cpp returns a null pointer.
226 ///
227 /// # Example:
228 /// ```rust
229 /// use llama_cpp_4::token::{
230 /// LlamaToken,
231 /// data::LlamaTokenData,
232 /// data_array::LlamaTokenDataArray
233 /// };
234 /// use llama_cpp_4::sampling::LlamaSampler;
235 ///
236 /// let mut data_array = LlamaTokenDataArray::new(vec![
237 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
238 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
239 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
240 /// ], false);
241 ///
242 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
243 ///
244 /// assert_eq!(data_array.data[0].logit(), 0.);
245 /// assert_eq!(data_array.data[1].logit(), 2.);
246 /// assert_eq!(data_array.data[2].logit(), 4.);
247 /// ```
248 #[must_use]
249 pub fn temp(t: f32) -> Self {
250 let sampler = unsafe { llama_sampler_init_temp(t) };
251 Self {
252 sampler: NonNull::new(sampler).unwrap(),
253 }
254 }
255
256 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
257 /// <https://arxiv.org/abs/2309.02772>.
258 ///
259 /// # Panics
260 ///
261 /// Panics if llama.cpp returns a null pointer.
262 #[must_use]
263 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
264 let sampler = unsafe { llama_sampler_init_temp_ext(t, delta, exponent) };
265 Self {
266 sampler: NonNull::new(sampler).unwrap(),
267 }
268 }
269
270 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
271 /// <https://arxiv.org/abs/1904.09751>.
272 ///
273 /// # Panics
274 ///
275 /// Panics if llama.cpp returns a null pointer.
276 ///
277 /// # Example:
278 /// ```rust
279 /// use llama_cpp_4::token::{
280 /// LlamaToken,
281 /// data::LlamaTokenData,
282 /// data_array::LlamaTokenDataArray
283 /// };
284 /// use llama_cpp_4::sampling::LlamaSampler;
285 ///
286 /// let mut data_array = LlamaTokenDataArray::new(vec![
287 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
288 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
289 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
290 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
291 /// ], false);
292 ///
293 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
294 ///
295 /// assert_eq!(data_array.data.len(), 2);
296 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
297 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
298 /// ```
299 #[must_use]
300 pub fn top_k(k: i32) -> Self {
301 let sampler = unsafe { llama_sampler_init_top_k(k) };
302 Self {
303 sampler: NonNull::new(sampler).unwrap(),
304 }
305 }
306
307 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
308 ///
309 /// # Panics
310 ///
311 /// Panics if llama.cpp returns a null pointer.
312 #[must_use]
313 pub fn typical(p: f32, min_keep: usize) -> Self {
314 let sampler = unsafe { llama_sampler_init_typical(p, min_keep) };
315 Self {
316 sampler: NonNull::new(sampler).unwrap(),
317 }
318 }
319
320 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
321 /// <https://arxiv.org/abs/1904.09751>.
322 ///
323 /// # Panics
324 ///
325 /// Panics if llama.cpp returns a null pointer.
326 #[must_use]
327 pub fn top_p(p: f32, min_keep: usize) -> Self {
328 let sampler = unsafe { llama_sampler_init_top_p(p, min_keep) };
329 Self {
330 sampler: NonNull::new(sampler).unwrap(),
331 }
332 }
333
334 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>.
335 ///
336 /// # Panics
337 ///
338 /// Panics if llama.cpp returns a null pointer.
339 #[must_use]
340 pub fn min_p(p: f32, min_keep: usize) -> Self {
341 let sampler = unsafe { llama_sampler_init_min_p(p, min_keep) };
342 Self {
343 sampler: NonNull::new(sampler).unwrap(),
344 }
345 }
346
347 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>.
348 ///
349 /// # Panics
350 ///
351 /// Panics if llama.cpp returns a null pointer.
352 #[must_use]
353 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
354 let sampler = unsafe { llama_sampler_init_xtc(p, t, min_keep, seed) };
355 Self {
356 sampler: NonNull::new(sampler).unwrap(),
357 }
358 }
359
360 /// Grammar sampler
361 ///
362 /// # Panics
363 /// - If either of `grammar_str` or `grammar_root` contain null bytes.
364 /// - If llama.cpp returns a null pointer.
365 #[must_use]
366 pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self {
367 let grammar_str = CString::new(grammar_str).unwrap();
368 let grammar_root = CString::new(grammar_root).unwrap();
369
370 let sampler = unsafe {
371 llama_sampler_init_grammar(
372 model.get_vocab().vocab.as_ref(),
373 grammar_str.as_ptr(),
374 grammar_root.as_ptr(),
375 )
376 };
377 Self {
378 sampler: NonNull::new(sampler).unwrap(),
379 }
380 }
381
382 /// DRY sampler, designed by p-e-w, as described in:
383 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
384 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
385 ///
386 /// # Panics
387 /// - If any string in `seq_breakers` contains null bytes.
388 /// - If llama.cpp returns a null pointer.
389 #[allow(clippy::too_many_arguments)]
390 #[must_use]
391 pub fn dry(
392 &self,
393 model: &LlamaModel,
394 n_ctx_train: i32,
395 multiplier: f32,
396 base: f32,
397 allowed_length: i32,
398 penalty_last_n: i32,
399 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
400 ) -> Self {
401 let seq_breakers: Vec<CString> = seq_breakers
402 .into_iter()
403 .map(|s| CString::new(s.as_ref()).unwrap())
404 .collect();
405 // CString::as_ptr() returns *const c_char, which matches what the binding
406 // expects on every platform (signed on macOS/x86 Linux, unsigned on musl ARM).
407 let mut seq_breaker_pointers: Vec<*const c_char> =
408 seq_breakers.iter().map(|s| s.as_ptr()).collect();
409
410 let sampler = unsafe {
411 llama_sampler_init_dry(
412 model.get_vocab().vocab.as_ref(),
413 n_ctx_train,
414 multiplier,
415 base,
416 allowed_length,
417 penalty_last_n,
418 seq_breaker_pointers.as_mut_ptr(),
419 seq_breaker_pointers.len(),
420 )
421 };
422
423 Self {
424 sampler: NonNull::new(sampler).unwrap(),
425 }
426 }
427
428 /// Penalizes tokens for being present in the context.
429 ///
430 /// Parameters:
431 /// - `n_vocab`: [`LlamaModel::n_vocab`]
432 /// - `special_eos_id`: [`LlamaModel::token_eos`]
433 /// - `linefeed_id`: [`LlamaModel::token_nl`]
434 /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
435 ///
436 /// # Panics
437 ///
438 /// Panics if llama.cpp returns a null pointer.
439 #[allow(clippy::too_many_arguments)]
440 #[must_use]
441 pub fn penalties(
442 n_vocab: i32,
443 special_eos_id: f32,
444 linefeed_id: f32,
445 penalty_last_n: f32,
446 // penalty_repeat: f32,
447 // penalty_freq: f32,
448 // penalty_present: f32,
449 // penalize_nl: bool,
450 // ignore_eos: bool,
451 ) -> Self {
452 let sampler = unsafe {
453 llama_sampler_init_penalties(
454 n_vocab,
455 special_eos_id,
456 linefeed_id,
457 penalty_last_n,
458 // penalty_repeat,
459 // penalty_freq,
460 // penalty_present,
461 // penalize_nl,
462 // ignore_eos,
463 )
464 };
465 Self {
466 sampler: NonNull::new(sampler).unwrap(),
467 }
468 }
469
470 /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
471 /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
472 ///
473 /// Parameters:
474 /// - `model`: The model's tokenizer to use to initialize the sampler.
475 /// - `penalty_last_n`: last n tokens to penalize (0 = disable penalty, -1 = context size)
476 ///
477 /// # Panics
478 ///
479 /// Panics if llama.cpp returns a null pointer.
480 #[must_use]
481 pub fn penalties_simple(
482 model: &LlamaModel,
483 penalty_last_n: i32,
484 // penalty_repeat: f32,
485 // penalty_freq: f32,
486 // penalty_present: f32,
487 ) -> Self {
488 Self::penalties(
489 model.n_vocab(),
490 #[allow(clippy::cast_precision_loss)]
491 {
492 model.token_eos().0 as f32
493 },
494 #[allow(clippy::cast_precision_loss)]
495 {
496 model.token_nl().0 as f32
497 },
498 #[allow(clippy::cast_precision_loss)]
499 {
500 penalty_last_n as f32
501 },
502 // penalty_repeat,
503 // penalty_freq,
504 // penalty_present,
505 // false,
506 // true,
507 )
508 }
509
510 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
511 ///
512 /// # Panics
513 ///
514 /// Panics if llama.cpp returns a null pointer.
515 ///
516 /// # Parameters:
517 /// - `n_vocab`: [`LlamaModel::n_vocab`]
518 /// - `seed`: Seed to initialize random generation with.
519 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
520 /// generated text. A higher value corresponds to more surprising or less predictable text,
521 /// while a lower value corresponds to less surprising or more predictable text.
522 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
523 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
524 /// updated more quickly, while a smaller learning rate will result in slower updates.
525 /// - `m`: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
526 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
527 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
528 /// it affects the performance of the algorithm.
529 #[must_use]
530 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
531 let sampler = unsafe { llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
532 Self {
533 sampler: NonNull::new(sampler).unwrap(),
534 }
535 }
536
537 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
538 ///
539 /// # Panics
540 ///
541 /// Panics if llama.cpp returns a null pointer.
542 ///
543 /// # Parameters:
544 /// - `seed`: Seed to initialize random generation with.
545 /// - `tau`: The target cross-entropy (or surprise) value you want to achieve for the
546 /// generated text. A higher value corresponds to more surprising or less predictable text,
547 /// while a lower value corresponds to less surprising or more predictable text.
548 /// - `eta`: The learning rate used to update `mu` based on the error between the target and
549 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
550 /// updated more quickly, while a smaller learning rate will result in slower updates.
551 #[must_use]
552 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
553 let sampler = unsafe { llama_sampler_init_mirostat_v2(seed, tau, eta) };
554 Self {
555 sampler: NonNull::new(sampler).unwrap(),
556 }
557 }
558
559 /// Selects a token at random based on each token's probabilities.
560 ///
561 /// # Panics
562 ///
563 /// Panics if llama.cpp returns a null pointer.
564 #[must_use]
565 pub fn dist(seed: u32) -> Self {
566 let sampler = unsafe { llama_sampler_init_dist(seed) };
567 Self {
568 sampler: NonNull::new(sampler).unwrap(),
569 }
570 }
571
572 /// Selects the most likely token.
573 ///
574 /// # Panics
575 ///
576 /// Panics if llama.cpp returns a null pointer.
577 ///
578 /// # Example:
579 /// ```rust
580 /// use llama_cpp_4::token::{
581 /// LlamaToken,
582 /// data::LlamaTokenData,
583 /// data_array::LlamaTokenDataArray
584 /// };
585 /// use llama_cpp_4::sampling::LlamaSampler;
586 ///
587 /// let mut data_array = LlamaTokenDataArray::new(vec![
588 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
589 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
590 /// ], false);
591 ///
592 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
593 ///
594 /// assert_eq!(data_array.data.len(), 2);
595 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
596 /// ```
597 #[must_use]
598 pub fn greedy() -> Self {
599 let sampler = unsafe { llama_sampler_init_greedy() };
600 Self {
601 sampler: NonNull::new(sampler).unwrap(),
602 }
603 }
604
605 /// Creates a new instance of `LlamaSampler` with common sampling parameters.
606 ///
607 /// This function initializes a `LlamaSampler` using default values from `common_sampler_params`
608 /// and configures it with common settings such as `top_k`, `top_p`, `temperature`, and `seed` values.
609 ///
610 /// # Panics
611 ///
612 /// Panics if llama.cpp returns a null pointer.
613 ///
614 /// # Returns
615 /// A `LlamaSampler` instance configured with the common sampling parameters.
616 #[must_use]
617 pub fn common() -> Self {
618 let params = common_sampler_params::default();
619
620 let sampler = unsafe {
621 let mut sparams = llama_sampler_chain_default_params();
622 sparams.no_perf = false;
623
624 let smpl = llama_sampler_chain_init(sparams);
625
626 llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
627 llama_sampler_chain_add(
628 smpl,
629 #[allow(clippy::cast_sign_loss)]
630 llama_sampler_init_top_p(params.top_p, params.min_keep as usize),
631 );
632 llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temp));
633 #[allow(clippy::cast_sign_loss)]
634 llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
635
636 smpl
637 };
638
639 Self {
640 sampler: NonNull::new(sampler).unwrap(),
641 }
642 }
643}
644
645impl Drop for LlamaSampler {
646 fn drop(&mut self) {
647 unsafe {
648 llama_sampler_free(self.sampler.as_ptr());
649 }
650 }
651}