llama_cpp_2/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12
13/// A safe wrapper around `llama_sampler`.
14pub struct LlamaSampler {
15 pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
16}
17
18impl Debug for LlamaSampler {
19 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("LlamaSamplerChain").finish()
21 }
22}
23
24impl LlamaSampler {
25 /// Sample and accept a token from the idx-th output of the last evaluation
26 #[must_use]
27 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
28 let token = unsafe {
29 llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
30 };
31
32 LlamaToken(token)
33 }
34
35 /// Applies this sampler to a [`LlamaTokenDataArray`].
36 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
37 data_array.apply_sampler(self);
38 }
39
40 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
41 /// (e.g. grammar, repetition, etc.)
42 pub fn accept(&mut self, token: LlamaToken) {
43 unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) }
44 }
45
46 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
47 /// certain samplers (e.g. grammar, repetition, etc.)
48 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
49 for token in tokens {
50 unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
51 }
52 }
53
54 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
55 /// certain samplers (e.g. grammar, repetition, etc.)
56 #[must_use]
57 pub fn with_tokens(
58 mut self,
59 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
60 ) -> Self {
61 self.accept_many(tokens);
62 self
63 }
64
65 /// Resets the internal state of the sampler.
66 ///
67 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
68 pub fn reset(&mut self) {
69 unsafe {
70 llama_cpp_sys_2::llama_sampler_reset(self.sampler);
71 }
72 }
73
74 /// Gets the random seed used by this sampler.
75 ///
76 /// Returns:
77 /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
78 /// - For sampler chains: returns the first non-default seed found in reverse order
79 /// - For all other samplers: returns 0xFFFFFFFF
80 #[must_use]
81 pub fn get_seed(&self) -> u32 {
82 unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
83 }
84
85 /// Combines a list of samplers into a single sampler that applies each component sampler one
86 /// after another.
87 ///
88 /// If you are using a chain to select a token, the chain should always end with one of
89 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
90 /// [`LlamaSampler::mirostat_v2`].
91 #[must_use]
92 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
93 unsafe {
94 let chain = llama_cpp_sys_2::llama_sampler_chain_init(
95 llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
96 );
97
98 for sampler in samplers {
99 llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
100
101 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
102 // owned by the chain
103 std::mem::forget(sampler);
104 }
105
106 Self { sampler: chain }
107 }
108 }
109
110 /// Same as [`Self::chain`] with `no_perf = false`.
111 ///
112 /// # Example
113 /// ```rust
114 /// use llama_cpp_2::token::{
115 /// LlamaToken,
116 /// data::LlamaTokenData,
117 /// data_array::LlamaTokenDataArray
118 /// };
119 /// use llama_cpp_2::sampling::LlamaSampler;
120 /// use llama_cpp_2::llama_backend::LlamaBackend;
121 /// let backend = LlamaBackend::init().unwrap();
122 ///
123 /// let mut data_array = LlamaTokenDataArray::new(vec![
124 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
125 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
126 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
127 /// ], false);
128 ///
129 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
130 /// LlamaSampler::temp(0.5),
131 /// LlamaSampler::greedy(),
132 /// ]));
133 ///
134 /// assert_eq!(data_array.data[0].logit(), 0.);
135 /// assert_eq!(data_array.data[1].logit(), 2.);
136 /// assert_eq!(data_array.data[2].logit(), 4.);
137 ///
138 /// assert_eq!(data_array.data.len(), 3);
139 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
140 /// ```
141 #[must_use]
142 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
143 Self::chain(samplers, false)
144 }
145
146 #[allow(clippy::doc_markdown)]
147 /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
148 /// value, the rest are set to -inf
149 ///
150 /// # Example:
151 /// ```rust
152 /// use llama_cpp_2::token::{
153 /// LlamaToken,
154 /// data::LlamaTokenData,
155 /// data_array::LlamaTokenDataArray
156 /// };
157 /// use llama_cpp_2::sampling::LlamaSampler;
158 ///
159 /// let mut data_array = LlamaTokenDataArray::new(vec![
160 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
161 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
162 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
163 /// ], false);
164 ///
165 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
166 ///
167 /// assert_eq!(data_array.data[0].logit(), 0.);
168 /// assert_eq!(data_array.data[1].logit(), 2.);
169 /// assert_eq!(data_array.data[2].logit(), 4.);
170 /// ```
171 #[must_use]
172 pub fn temp(t: f32) -> Self {
173 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
174 Self { sampler }
175 }
176
177 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
178 /// <https://arxiv.org/abs/2309.02772>.
179 #[must_use]
180 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
181 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
182 Self { sampler }
183 }
184
185 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
186 /// <https://arxiv.org/abs/1904.09751>
187 ///
188 /// # Example:
189 /// ```rust
190 /// use llama_cpp_2::token::{
191 /// LlamaToken,
192 /// data::LlamaTokenData,
193 /// data_array::LlamaTokenDataArray
194 /// };
195 /// use llama_cpp_2::sampling::LlamaSampler;
196 ///
197 /// let mut data_array = LlamaTokenDataArray::new(vec![
198 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
199 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
200 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
201 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
202 /// ], false);
203 ///
204 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
205 ///
206 /// assert_eq!(data_array.data.len(), 2);
207 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
208 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
209 /// ```
210 #[must_use]
211 pub fn top_k(k: i32) -> Self {
212 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
213 Self { sampler }
214 }
215
216 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
217 /// <https://arxiv.org/pdf/2411.07641>
218 ///
219 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
220 ///
221 /// # Parameters
222 /// - `n`: Number of standard deviations from the mean to include in sampling
223 ///
224 /// # Example
225 /// ```rust
226 /// use llama_cpp_2::sampling::LlamaSampler;
227 /// use llama_cpp_2::token::{
228 /// LlamaToken,
229 /// data::LlamaTokenData,
230 /// data_array::LlamaTokenDataArray
231 /// };
232 ///
233 /// let mut data_array = LlamaTokenDataArray::new(vec![
234 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
235 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
236 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
237 /// ], false);
238 ///
239 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
240 /// ```
241 #[must_use]
242 pub fn top_n_sigma(n: f32) -> Self {
243 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
244 Self { sampler }
245 }
246
247 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
248 #[must_use]
249 pub fn typical(p: f32, min_keep: usize) -> Self {
250 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
251 Self { sampler }
252 }
253
254 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
255 /// <https://arxiv.org/abs/1904.09751>
256 #[must_use]
257 pub fn top_p(p: f32, min_keep: usize) -> Self {
258 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
259 Self { sampler }
260 }
261
262 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
263 #[must_use]
264 pub fn min_p(p: f32, min_keep: usize) -> Self {
265 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
266 Self { sampler }
267 }
268
269 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
270 #[must_use]
271 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
272 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
273 Self { sampler }
274 }
275
276 /// Grammar sampler
277 ///
278 /// # Panics
279 /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes.
280 #[must_use]
281 pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Option<Self> {
282 let grammar_str = CString::new(grammar_str).unwrap();
283 let grammar_root = CString::new(grammar_root).unwrap();
284
285 let sampler = unsafe {
286 llama_cpp_sys_2::llama_sampler_init_grammar(
287 model.vocab_ptr(),
288 grammar_str.as_ptr(),
289 grammar_root.as_ptr(),
290 )
291 };
292
293 if sampler.is_null() {
294 None
295 } else {
296 Some(Self { sampler })
297 }
298 }
299
300 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
301 ///
302 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
303 ///
304 /// # Panics
305 /// - If `grammar_str` or `grammar_root` contain null bytes
306 /// - If any trigger word contains null bytes
307 #[must_use]
308 pub fn grammar_lazy(
309 model: &LlamaModel,
310 grammar_str: &str,
311 grammar_root: &str,
312 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
313 trigger_tokens: &[LlamaToken],
314 ) -> Option<Self> {
315 let grammar_str = CString::new(grammar_str).unwrap();
316 let grammar_root = CString::new(grammar_root).unwrap();
317
318 let trigger_word_cstrings: Vec<CString> = trigger_words
319 .into_iter()
320 .map(|word| CString::new(word.as_ref()).unwrap())
321 .collect();
322
323 let mut trigger_word_ptrs: Vec<*const c_char> =
324 trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect();
325
326 let sampler = unsafe {
327 llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
328 model.vocab_ptr(),
329 grammar_str.as_ptr(),
330 grammar_root.as_ptr(),
331 trigger_word_ptrs.as_mut_ptr(),
332 trigger_word_ptrs.len(),
333 trigger_tokens.as_ptr().cast(),
334 trigger_tokens.len(),
335 )
336 };
337
338 if sampler.is_null() {
339 None
340 } else {
341 Some(Self { sampler })
342 }
343 }
344
345 /// DRY sampler, designed by p-e-w, as described in:
346 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
347 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
348 ///
349 /// # Panics
350 /// If any string in ``seq_breakers`` contains null bytes.
351 #[allow(missing_docs)]
352 #[must_use]
353 pub fn dry(
354 model: &LlamaModel,
355 multiplier: f32,
356 base: f32,
357 allowed_length: i32,
358 penalty_last_n: i32,
359 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
360 ) -> Self {
361 let seq_breakers: Vec<CString> = seq_breakers
362 .into_iter()
363 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
364 .collect();
365 let mut seq_breaker_pointers: Vec<*const c_char> =
366 seq_breakers.iter().map(|s| s.as_ptr()).collect();
367
368 let sampler = unsafe {
369 llama_cpp_sys_2::llama_sampler_init_dry(
370 model.vocab_ptr(),
371 model
372 .n_ctx_train()
373 .try_into()
374 .expect("n_ctx_train exceeds i32::MAX"),
375 multiplier,
376 base,
377 allowed_length,
378 penalty_last_n,
379 seq_breaker_pointers.as_mut_ptr(),
380 seq_breaker_pointers.len(),
381 )
382 };
383 Self { sampler }
384 }
385
386 /// Penalizes tokens for being present in the context.
387 ///
388 /// Parameters:
389 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
390 /// - ``penalty_repeat``: 1.0 = disabled
391 /// - ``penalty_freq``: 0.0 = disabled
392 /// - ``penalty_present``: 0.0 = disabled
393 #[allow(clippy::too_many_arguments)]
394 #[must_use]
395 pub fn penalties(
396 penalty_last_n: i32,
397 penalty_repeat: f32,
398 penalty_freq: f32,
399 penalty_present: f32,
400 ) -> Self {
401 let sampler = unsafe {
402 llama_cpp_sys_2::llama_sampler_init_penalties(
403 penalty_last_n,
404 penalty_repeat,
405 penalty_freq,
406 penalty_present,
407 )
408 };
409 Self { sampler }
410 }
411
412 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
413 ///
414 /// # Parameters:
415 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
416 /// - ``seed``: Seed to initialize random generation with.
417 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
418 /// generated text. A higher value corresponds to more surprising or less predictable text,
419 /// while a lower value corresponds to less surprising or more predictable text.
420 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
421 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
422 /// updated more quickly, while a smaller learning rate will result in slower updates.
423 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
424 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
425 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
426 /// it affects the performance of the algorithm.
427 #[must_use]
428 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
429 let sampler =
430 unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
431 Self { sampler }
432 }
433
434 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
435 ///
436 /// # Parameters:
437 /// - ``seed``: Seed to initialize random generation with.
438 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
439 /// generated text. A higher value corresponds to more surprising or less predictable text,
440 /// while a lower value corresponds to less surprising or more predictable text.
441 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
442 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
443 /// updated more quickly, while a smaller learning rate will result in slower updates.
444 #[must_use]
445 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
446 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
447 Self { sampler }
448 }
449
450 /// Selects a token at random based on each token's probabilities
451 #[must_use]
452 pub fn dist(seed: u32) -> Self {
453 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
454 Self { sampler }
455 }
456
457 /// Selects the most likely token
458 ///
459 /// # Example:
460 /// ```rust
461 /// use llama_cpp_2::token::{
462 /// LlamaToken,
463 /// data::LlamaTokenData,
464 /// data_array::LlamaTokenDataArray
465 /// };
466 /// use llama_cpp_2::sampling::LlamaSampler;
467 ///
468 /// let mut data_array = LlamaTokenDataArray::new(vec![
469 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
470 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
471 /// ], false);
472 ///
473 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
474 ///
475 /// assert_eq!(data_array.data.len(), 2);
476 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
477 /// ```
478 #[must_use]
479 pub fn greedy() -> Self {
480 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
481 Self { sampler }
482 }
483
484 /// Creates a sampler that applies bias values to specific tokens during sampling.
485 ///
486 /// # Parameters
487 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
488 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
489 ///
490 /// # Example
491 /// ```rust
492 /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
493 /// use llama_cpp_2::sampling::LlamaSampler;
494 ///
495 /// let biases = vec![
496 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
497 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
498 /// ];
499 ///
500 /// // Assuming vocab_size of 32000
501 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
502 /// ```
503 #[must_use]
504 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
505 let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
506
507 let sampler = unsafe {
508 llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
509 };
510
511 Self { sampler }
512 }
513}
514
515impl Drop for LlamaSampler {
516 fn drop(&mut self) {
517 unsafe {
518 llama_cpp_sys_2::llama_sampler_free(self.sampler);
519 }
520 }
521}