llama_cpp_2/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12use crate::GrammarError;
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16 pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
17}
18
19impl Debug for LlamaSampler {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("LlamaSamplerChain").finish()
22 }
23}
24
25impl LlamaSampler {
26 /// Sample and accept a token from the idx-th output of the last evaluation
27 #[must_use]
28 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
29 let token = unsafe {
30 llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
31 };
32
33 LlamaToken(token)
34 }
35
36 /// Applies this sampler to a [`LlamaTokenDataArray`].
37 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
38 data_array.apply_sampler(self);
39 }
40
41 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
42 /// (e.g. grammar, repetition, etc.)
43 pub fn accept(&mut self, token: LlamaToken) {
44 unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) }
45 }
46
47 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
48 /// certain samplers (e.g. grammar, repetition, etc.)
49 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
50 for token in tokens {
51 unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
52 }
53 }
54
55 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
56 /// certain samplers (e.g. grammar, repetition, etc.)
57 #[must_use]
58 pub fn with_tokens(
59 mut self,
60 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
61 ) -> Self {
62 self.accept_many(tokens);
63 self
64 }
65
66 /// Resets the internal state of the sampler.
67 ///
68 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
69 pub fn reset(&mut self) {
70 unsafe {
71 llama_cpp_sys_2::llama_sampler_reset(self.sampler);
72 }
73 }
74
75 /// Gets the random seed used by this sampler.
76 ///
77 /// Returns:
78 /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
79 /// - For sampler chains: returns the first non-default seed found in reverse order
80 /// - For all other samplers: returns 0xFFFFFFFF
81 #[must_use]
82 pub fn get_seed(&self) -> u32 {
83 unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
84 }
85
86 /// Combines a list of samplers into a single sampler that applies each component sampler one
87 /// after another.
88 ///
89 /// If you are using a chain to select a token, the chain should always end with one of
90 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
91 /// [`LlamaSampler::mirostat_v2`].
92 #[must_use]
93 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
94 unsafe {
95 let chain = llama_cpp_sys_2::llama_sampler_chain_init(
96 llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
97 );
98
99 for sampler in samplers {
100 llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
101
102 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
103 // owned by the chain
104 std::mem::forget(sampler);
105 }
106
107 Self { sampler: chain }
108 }
109 }
110
111 /// Same as [`Self::chain`] with `no_perf = false`.
112 ///
113 /// # Example
114 /// ```rust
115 /// use llama_cpp_2::token::{
116 /// LlamaToken,
117 /// data::LlamaTokenData,
118 /// data_array::LlamaTokenDataArray
119 /// };
120 /// use llama_cpp_2::sampling::LlamaSampler;
121 /// use llama_cpp_2::llama_backend::LlamaBackend;
122 /// let backend = LlamaBackend::init().unwrap();
123 ///
124 /// let mut data_array = LlamaTokenDataArray::new(vec![
125 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
126 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
127 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
128 /// ], false);
129 ///
130 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
131 /// LlamaSampler::temp(0.5),
132 /// LlamaSampler::greedy(),
133 /// ]));
134 ///
135 /// assert_eq!(data_array.data[0].logit(), 0.);
136 /// assert_eq!(data_array.data[1].logit(), 2.);
137 /// assert_eq!(data_array.data[2].logit(), 4.);
138 ///
139 /// assert_eq!(data_array.data.len(), 3);
140 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
141 /// ```
142 #[must_use]
143 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
144 Self::chain(samplers, false)
145 }
146
147 #[allow(clippy::doc_markdown)]
148 /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
149 /// value, the rest are set to -inf
150 ///
151 /// # Example:
152 /// ```rust
153 /// use llama_cpp_2::token::{
154 /// LlamaToken,
155 /// data::LlamaTokenData,
156 /// data_array::LlamaTokenDataArray
157 /// };
158 /// use llama_cpp_2::sampling::LlamaSampler;
159 ///
160 /// let mut data_array = LlamaTokenDataArray::new(vec![
161 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
162 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
163 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
164 /// ], false);
165 ///
166 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
167 ///
168 /// assert_eq!(data_array.data[0].logit(), 0.);
169 /// assert_eq!(data_array.data[1].logit(), 2.);
170 /// assert_eq!(data_array.data[2].logit(), 4.);
171 /// ```
172 #[must_use]
173 pub fn temp(t: f32) -> Self {
174 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
175 Self { sampler }
176 }
177
178 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
179 /// <https://arxiv.org/abs/2309.02772>.
180 #[must_use]
181 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
182 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
183 Self { sampler }
184 }
185
186 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
187 /// <https://arxiv.org/abs/1904.09751>
188 ///
189 /// # Example:
190 /// ```rust
191 /// use llama_cpp_2::token::{
192 /// LlamaToken,
193 /// data::LlamaTokenData,
194 /// data_array::LlamaTokenDataArray
195 /// };
196 /// use llama_cpp_2::sampling::LlamaSampler;
197 ///
198 /// let mut data_array = LlamaTokenDataArray::new(vec![
199 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
200 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
201 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
202 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
203 /// ], false);
204 ///
205 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
206 ///
207 /// assert_eq!(data_array.data.len(), 2);
208 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
209 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
210 /// ```
211 #[must_use]
212 pub fn top_k(k: i32) -> Self {
213 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
214 Self { sampler }
215 }
216
217 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
218 /// <https://arxiv.org/pdf/2411.07641>
219 ///
220 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
221 ///
222 /// # Parameters
223 /// - `n`: Number of standard deviations from the mean to include in sampling
224 ///
225 /// # Example
226 /// ```rust
227 /// use llama_cpp_2::sampling::LlamaSampler;
228 /// use llama_cpp_2::token::{
229 /// LlamaToken,
230 /// data::LlamaTokenData,
231 /// data_array::LlamaTokenDataArray
232 /// };
233 ///
234 /// let mut data_array = LlamaTokenDataArray::new(vec![
235 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
236 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
237 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
238 /// ], false);
239 ///
240 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
241 /// ```
242 #[must_use]
243 pub fn top_n_sigma(n: f32) -> Self {
244 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
245 Self { sampler }
246 }
247
248 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
249 #[must_use]
250 pub fn typical(p: f32, min_keep: usize) -> Self {
251 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
252 Self { sampler }
253 }
254
255 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
256 /// <https://arxiv.org/abs/1904.09751>
257 #[must_use]
258 pub fn top_p(p: f32, min_keep: usize) -> Self {
259 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
260 Self { sampler }
261 }
262
263 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
264 #[must_use]
265 pub fn min_p(p: f32, min_keep: usize) -> Self {
266 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
267 Self { sampler }
268 }
269
270 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
271 #[must_use]
272 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
273 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
274 Self { sampler }
275 }
276
277 /// Grammar sampler
278 #[must_use]
279 pub fn grammar(
280 model: &LlamaModel,
281 grammar_str: &str,
282 grammar_root: &str,
283 ) -> Result<Self, GrammarError> {
284 let (grammar_str, grammar_root) =
285 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
286
287 let sampler = unsafe {
288 llama_cpp_sys_2::llama_sampler_init_grammar(
289 model.vocab_ptr(),
290 grammar_str.as_ptr(),
291 grammar_root.as_ptr(),
292 )
293 };
294
295 if sampler.is_null() {
296 Err(GrammarError::NullGrammar)
297 } else {
298 Ok(Self { sampler })
299 }
300 }
301
302 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
303 ///
304 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
305 #[must_use]
306 pub fn grammar_lazy(
307 model: &LlamaModel,
308 grammar_str: &str,
309 grammar_root: &str,
310 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
311 trigger_tokens: &[LlamaToken],
312 ) -> Result<Self, GrammarError> {
313 let (grammar_str, grammar_root) =
314 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
315 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
316
317 let mut trigger_word_ptrs: Vec<*const c_char> =
318 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
319
320 let sampler = unsafe {
321 llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
322 model.vocab_ptr(),
323 grammar_str.as_ptr(),
324 grammar_root.as_ptr(),
325 trigger_word_ptrs.as_mut_ptr(),
326 trigger_word_ptrs.len(),
327 trigger_tokens.as_ptr().cast(),
328 trigger_tokens.len(),
329 )
330 };
331
332 if sampler.is_null() {
333 Err(GrammarError::NullGrammar)
334 } else {
335 Ok(Self { sampler })
336 }
337 }
338
339 fn sanitize_grammar_strings(
340 grammar_str: &str,
341 grammar_root: &str,
342 ) -> Result<(CString, CString), GrammarError> {
343 if !grammar_str.contains(grammar_root) {
344 return Err(GrammarError::RootNotFound);
345 }
346
347 if grammar_str.contains('\0') || grammar_root.contains('\0') {
348 return Err(GrammarError::GrammarNullBytes);
349 }
350
351 Ok((
352 CString::new(grammar_str).unwrap(),
353 CString::new(grammar_root).unwrap(),
354 ))
355 }
356
357 fn sanitize_trigger_words(
358 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
359 ) -> Result<Vec<CString>, GrammarError> {
360 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
361 if trigger_words
362 .iter()
363 .any(|word| word.as_ref().contains(&b'\0'))
364 {
365 return Err(GrammarError::TriggerWordNullBytes);
366 }
367 Ok(trigger_words
368 .into_iter()
369 .map(|word| CString::new(word.as_ref()).unwrap())
370 .collect())
371 }
372
373 /// DRY sampler, designed by p-e-w, as described in:
374 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
375 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
376 ///
377 /// # Panics
378 /// If any string in ``seq_breakers`` contains null bytes.
379 #[allow(missing_docs)]
380 #[must_use]
381 pub fn dry(
382 model: &LlamaModel,
383 multiplier: f32,
384 base: f32,
385 allowed_length: i32,
386 penalty_last_n: i32,
387 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
388 ) -> Self {
389 let seq_breakers: Vec<CString> = seq_breakers
390 .into_iter()
391 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
392 .collect();
393 let mut seq_breaker_pointers: Vec<*const c_char> =
394 seq_breakers.iter().map(|s| s.as_ptr()).collect();
395
396 let sampler = unsafe {
397 llama_cpp_sys_2::llama_sampler_init_dry(
398 model.vocab_ptr(),
399 model
400 .n_ctx_train()
401 .try_into()
402 .expect("n_ctx_train exceeds i32::MAX"),
403 multiplier,
404 base,
405 allowed_length,
406 penalty_last_n,
407 seq_breaker_pointers.as_mut_ptr(),
408 seq_breaker_pointers.len(),
409 )
410 };
411 Self { sampler }
412 }
413
414 /// Penalizes tokens for being present in the context.
415 ///
416 /// Parameters:
417 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
418 /// - ``penalty_repeat``: 1.0 = disabled
419 /// - ``penalty_freq``: 0.0 = disabled
420 /// - ``penalty_present``: 0.0 = disabled
421 #[allow(clippy::too_many_arguments)]
422 #[must_use]
423 pub fn penalties(
424 penalty_last_n: i32,
425 penalty_repeat: f32,
426 penalty_freq: f32,
427 penalty_present: f32,
428 ) -> Self {
429 let sampler = unsafe {
430 llama_cpp_sys_2::llama_sampler_init_penalties(
431 penalty_last_n,
432 penalty_repeat,
433 penalty_freq,
434 penalty_present,
435 )
436 };
437 Self { sampler }
438 }
439
440 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
441 ///
442 /// # Parameters:
443 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
444 /// - ``seed``: Seed to initialize random generation with.
445 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
446 /// generated text. A higher value corresponds to more surprising or less predictable text,
447 /// while a lower value corresponds to less surprising or more predictable text.
448 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
449 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
450 /// updated more quickly, while a smaller learning rate will result in slower updates.
451 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
452 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
453 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
454 /// it affects the performance of the algorithm.
455 #[must_use]
456 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
457 let sampler =
458 unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
459 Self { sampler }
460 }
461
462 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
463 ///
464 /// # Parameters:
465 /// - ``seed``: Seed to initialize random generation with.
466 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
467 /// generated text. A higher value corresponds to more surprising or less predictable text,
468 /// while a lower value corresponds to less surprising or more predictable text.
469 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
470 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
471 /// updated more quickly, while a smaller learning rate will result in slower updates.
472 #[must_use]
473 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
474 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
475 Self { sampler }
476 }
477
478 /// Selects a token at random based on each token's probabilities
479 #[must_use]
480 pub fn dist(seed: u32) -> Self {
481 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
482 Self { sampler }
483 }
484
485 /// Selects the most likely token
486 ///
487 /// # Example:
488 /// ```rust
489 /// use llama_cpp_2::token::{
490 /// LlamaToken,
491 /// data::LlamaTokenData,
492 /// data_array::LlamaTokenDataArray
493 /// };
494 /// use llama_cpp_2::sampling::LlamaSampler;
495 ///
496 /// let mut data_array = LlamaTokenDataArray::new(vec![
497 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
498 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
499 /// ], false);
500 ///
501 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
502 ///
503 /// assert_eq!(data_array.data.len(), 2);
504 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
505 /// ```
506 #[must_use]
507 pub fn greedy() -> Self {
508 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
509 Self { sampler }
510 }
511
512 /// Creates a sampler that applies bias values to specific tokens during sampling.
513 ///
514 /// # Parameters
515 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
516 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
517 ///
518 /// # Example
519 /// ```rust
520 /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
521 /// use llama_cpp_2::sampling::LlamaSampler;
522 ///
523 /// let biases = vec![
524 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
525 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
526 /// ];
527 ///
528 /// // Assuming vocab_size of 32000
529 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
530 /// ```
531 #[must_use]
532 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
533 let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
534
535 let sampler = unsafe {
536 llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
537 };
538
539 Self { sampler }
540 }
541}
542
543impl Drop for LlamaSampler {
544 fn drop(&mut self) {
545 unsafe {
546 llama_cpp_sys_2::llama_sampler_free(self.sampler);
547 }
548 }
549}