llama_cpp_bindings/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16 /// Raw pointer to the underlying `llama_sampler`.
17 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
18}
19
20impl Debug for LlamaSampler {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("LlamaSamplerChain").finish()
23 }
24}
25
26impl LlamaSampler {
27 /// Sample and accept a token from the idx-th output of the last evaluation
28 #[must_use]
29 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
30 let token = unsafe {
31 llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
32 };
33
34 LlamaToken(token)
35 }
36
37 /// Applies this sampler to a [`LlamaTokenDataArray`].
38 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
39 data_array.apply_sampler(self);
40 }
41
42 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
43 /// (e.g. grammar, repetition, etc.)
44 ///
45 /// # Errors
46 /// Returns [`SamplerAcceptError`] if the underlying sampler rejects the token.
47 pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
48 self.try_accept(token)
49 }
50
51 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
52 /// certain samplers (e.g. grammar, repetition, etc.)
53 ///
54 /// # Errors
55 /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
56 pub fn accept_many(
57 &mut self,
58 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
59 ) -> Result<(), SamplerAcceptError> {
60 for token in tokens {
61 self.try_accept(*token.borrow())?;
62 }
63
64 Ok(())
65 }
66
67 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
68 /// certain samplers (e.g. grammar, repetition, etc.)
69 ///
70 /// # Errors
71 /// Returns [`SamplerAcceptError`] if the underlying sampler rejects any token.
72 pub fn with_tokens(
73 mut self,
74 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
75 ) -> Result<Self, SamplerAcceptError> {
76 self.accept_many(tokens)?;
77
78 Ok(self)
79 }
80
81 /// Try accepting a token from the sampler. Returns an error if the sampler throws.
82 ///
83 /// # Errors
84 /// Returns an error if the underlying sampler rejects the token.
85 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
86 let sampler_result =
87 unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
88 if status_is_ok(sampler_result) {
89 Ok(())
90 } else {
91 Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
92 }
93 }
94
95 /// Resets the internal state of the sampler.
96 ///
97 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
98 pub fn reset(&mut self) {
99 unsafe {
100 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
101 }
102 }
103
104 /// Gets the random seed used by this sampler.
105 ///
106 /// Returns:
107 /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
108 /// - For sampler chains: returns the first non-default seed found in reverse order
109 /// - For all other samplers: returns 0xFFFFFFFF
110 #[must_use]
111 pub fn get_seed(&self) -> u32 {
112 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
113 }
114
115 /// Combines a list of samplers into a single sampler that applies each component sampler one
116 /// after another.
117 ///
118 /// If you are using a chain to select a token, the chain should always end with one of
119 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
120 /// [`LlamaSampler::mirostat_v2`].
121 #[must_use]
122 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
123 unsafe {
124 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
125 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
126 );
127
128 for sampler in samplers {
129 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
130
131 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
132 // owned by the chain
133 std::mem::forget(sampler);
134 }
135
136 Self { sampler: chain }
137 }
138 }
139
140 /// Same as [`Self::chain`] with `no_perf = false`.
141 ///
142 /// # Example
143 /// ```rust
144 /// use llama_cpp_bindings::token::{
145 /// LlamaToken,
146 /// data::LlamaTokenData,
147 /// data_array::LlamaTokenDataArray
148 /// };
149 /// use llama_cpp_bindings::sampling::LlamaSampler;
150 /// use llama_cpp_bindings::llama_backend::LlamaBackend;
151 /// let backend = LlamaBackend::init().unwrap();
152 ///
153 /// let mut data_array = LlamaTokenDataArray::new(vec![
154 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
155 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
156 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
157 /// ], false);
158 ///
159 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
160 /// LlamaSampler::temp(0.5),
161 /// LlamaSampler::greedy(),
162 /// ]));
163 ///
164 /// assert_eq!(data_array.data[0].logit(), 0.);
165 /// assert_eq!(data_array.data[1].logit(), 2.);
166 /// assert_eq!(data_array.data[2].logit(), 4.);
167 ///
168 /// assert_eq!(data_array.data.len(), 3);
169 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
170 /// ```
171 #[must_use]
172 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
173 Self::chain(samplers, false)
174 }
175
176 /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
177 /// value, the rest are set to -inf
178 ///
179 /// # Example:
180 /// ```rust
181 /// use llama_cpp_bindings::token::{
182 /// LlamaToken,
183 /// data::LlamaTokenData,
184 /// data_array::LlamaTokenDataArray
185 /// };
186 /// use llama_cpp_bindings::sampling::LlamaSampler;
187 ///
188 /// let mut data_array = LlamaTokenDataArray::new(vec![
189 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
190 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
191 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
192 /// ], false);
193 ///
194 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
195 ///
196 /// assert_eq!(data_array.data[0].logit(), 0.);
197 /// assert_eq!(data_array.data[1].logit(), 2.);
198 /// assert_eq!(data_array.data[2].logit(), 4.);
199 /// ```
200 #[must_use]
201 pub fn temp(t: f32) -> Self {
202 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
203 Self { sampler }
204 }
205
206 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
207 /// <https://arxiv.org/abs/2309.02772>.
208 #[must_use]
209 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
210 let sampler =
211 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
212 Self { sampler }
213 }
214
215 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
216 /// <https://arxiv.org/abs/1904.09751>
217 ///
218 /// # Example:
219 /// ```rust
220 /// use llama_cpp_bindings::token::{
221 /// LlamaToken,
222 /// data::LlamaTokenData,
223 /// data_array::LlamaTokenDataArray
224 /// };
225 /// use llama_cpp_bindings::sampling::LlamaSampler;
226 ///
227 /// let mut data_array = LlamaTokenDataArray::new(vec![
228 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
229 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
230 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
231 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
232 /// ], false);
233 ///
234 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
235 ///
236 /// assert_eq!(data_array.data.len(), 2);
237 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
238 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
239 /// ```
240 #[must_use]
241 pub fn top_k(k: i32) -> Self {
242 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
243 Self { sampler }
244 }
245
246 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
247 /// <https://arxiv.org/pdf/2411.07641>
248 ///
249 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
250 ///
251 /// # Parameters
252 /// - `n`: Number of standard deviations from the mean to include in sampling
253 ///
254 /// # Example
255 /// ```rust
256 /// use llama_cpp_bindings::sampling::LlamaSampler;
257 /// use llama_cpp_bindings::token::{
258 /// LlamaToken,
259 /// data::LlamaTokenData,
260 /// data_array::LlamaTokenDataArray
261 /// };
262 ///
263 /// let mut data_array = LlamaTokenDataArray::new(vec![
264 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
265 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
266 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
267 /// ], false);
268 ///
269 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
270 /// ```
271 #[must_use]
272 pub fn top_n_sigma(n: f32) -> Self {
273 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
274 Self { sampler }
275 }
276
277 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
278 #[must_use]
279 pub fn typical(p: f32, min_keep: usize) -> Self {
280 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
281 Self { sampler }
282 }
283
284 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
285 /// <https://arxiv.org/abs/1904.09751>
286 #[must_use]
287 pub fn top_p(p: f32, min_keep: usize) -> Self {
288 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
289 Self { sampler }
290 }
291
292 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
293 #[must_use]
294 pub fn min_p(p: f32, min_keep: usize) -> Self {
295 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
296 Self { sampler }
297 }
298
299 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
300 #[must_use]
301 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
302 let sampler =
303 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
304 Self { sampler }
305 }
306
307 /// Grammar sampler
308 ///
309 /// # Errors
310 /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
311 pub fn grammar(
312 model: &LlamaModel,
313 grammar_str: &str,
314 grammar_root: &str,
315 ) -> Result<Self, GrammarError> {
316 let (grammar_str, grammar_root) =
317 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
318
319 let sampler = unsafe {
320 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
321 model.vocab_ptr(),
322 grammar_str.as_ptr(),
323 grammar_root.as_ptr(),
324 )
325 };
326
327 if sampler.is_null() {
328 Err(GrammarError::NullGrammar)
329 } else {
330 Ok(Self { sampler })
331 }
332 }
333
334 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
335 ///
336 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
337 ///
338 /// # Errors
339 /// Returns an error if the grammar or trigger words are invalid.
340 pub fn grammar_lazy(
341 model: &LlamaModel,
342 grammar_str: &str,
343 grammar_root: &str,
344 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
345 trigger_tokens: &[LlamaToken],
346 ) -> Result<Self, GrammarError> {
347 let (grammar_str, grammar_root) =
348 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
349 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
350
351 let mut trigger_word_ptrs: Vec<*const c_char> =
352 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
353
354 let sampler = unsafe {
355 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
356 model.vocab_ptr(),
357 grammar_str.as_ptr(),
358 grammar_root.as_ptr(),
359 trigger_word_ptrs.as_mut_ptr(),
360 trigger_word_ptrs.len(),
361 trigger_tokens.as_ptr().cast(),
362 trigger_tokens.len(),
363 )
364 };
365
366 if sampler.is_null() {
367 Err(GrammarError::NullGrammar)
368 } else {
369 Ok(Self { sampler })
370 }
371 }
372
373 /// Lazy grammar sampler using regex trigger patterns.
374 ///
375 /// Trigger patterns are regular expressions matched from the start of the
376 /// generation output. The grammar sampler will be fed content starting from
377 /// the first match group.
378 ///
379 /// # Errors
380 /// Returns an error if the grammar or trigger patterns are invalid.
381 pub fn grammar_lazy_patterns(
382 model: &LlamaModel,
383 grammar_str: &str,
384 grammar_root: &str,
385 trigger_patterns: &[String],
386 trigger_tokens: &[LlamaToken],
387 ) -> Result<Self, GrammarError> {
388 let (grammar_str, grammar_root) =
389 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
390 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
391
392 let mut trigger_pattern_ptrs: Vec<*const c_char> =
393 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
394
395 let sampler = unsafe {
396 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
397 model.vocab_ptr(),
398 grammar_str.as_ptr(),
399 grammar_root.as_ptr(),
400 trigger_pattern_ptrs.as_mut_ptr(),
401 trigger_pattern_ptrs.len(),
402 trigger_tokens.as_ptr().cast(),
403 trigger_tokens.len(),
404 )
405 };
406
407 if sampler.is_null() {
408 Err(GrammarError::NullGrammar)
409 } else {
410 Ok(Self { sampler })
411 }
412 }
413
414 /// `LLGuidance` sampler for constrained decoding.
415 ///
416 /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
417 /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
418 ///
419 /// # Errors
420 ///
421 /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
422 #[cfg(feature = "llguidance")]
423 pub fn llguidance(
424 model: &LlamaModel,
425 grammar_kind: &str,
426 grammar_data: &str,
427 ) -> Result<Self, GrammarError> {
428 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
429 }
430
431 fn sanitize_grammar_strings(
432 grammar_str: &str,
433 grammar_root: &str,
434 ) -> Result<(CString, CString), GrammarError> {
435 if !grammar_str.contains(grammar_root) {
436 return Err(GrammarError::RootNotFound);
437 }
438
439 if grammar_str.contains('\0') || grammar_root.contains('\0') {
440 return Err(GrammarError::GrammarNullBytes);
441 }
442
443 Ok((CString::new(grammar_str)?, CString::new(grammar_root)?))
444 }
445
446 fn sanitize_trigger_words(
447 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
448 ) -> Result<Vec<CString>, GrammarError> {
449 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
450 if trigger_words
451 .iter()
452 .any(|word| word.as_ref().contains(&b'\0'))
453 {
454 return Err(GrammarError::TriggerWordNullBytes);
455 }
456 trigger_words
457 .into_iter()
458 .map(|word| Ok(CString::new(word.as_ref())?))
459 .collect()
460 }
461
462 fn sanitize_trigger_patterns(
463 trigger_patterns: &[String],
464 ) -> Result<Vec<CString>, GrammarError> {
465 let mut patterns = Vec::with_capacity(trigger_patterns.len());
466 for pattern in trigger_patterns {
467 if pattern.contains('\0') {
468 return Err(GrammarError::GrammarNullBytes);
469 }
470 patterns.push(CString::new(pattern.as_str())?);
471 }
472 Ok(patterns)
473 }
474
475 /// DRY sampler, designed by p-e-w, as described in:
476 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
477 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
478 ///
479 /// # Errors
480 /// Returns an error if any string in `seq_breakers` contains null bytes.
481 #[allow(missing_docs)]
482 pub fn dry(
483 model: &LlamaModel,
484 multiplier: f32,
485 base: f32,
486 allowed_length: i32,
487 penalty_last_n: i32,
488 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
489 ) -> Result<Self, GrammarError> {
490 let seq_breakers: Vec<CString> = seq_breakers
491 .into_iter()
492 .map(|s| CString::new(s.as_ref()))
493 .collect::<Result<Vec<_>, _>>()?;
494 let mut seq_breaker_pointers: Vec<*const c_char> =
495 seq_breakers.iter().map(|s| s.as_ptr()).collect();
496
497 let n_ctx_train = model.n_ctx_train().try_into().map_err(|convert_error| {
498 GrammarError::IntegerOverflow(format!("n_ctx_train exceeds i32::MAX: {convert_error}"))
499 })?;
500 let sampler = unsafe {
501 llama_cpp_bindings_sys::llama_sampler_init_dry(
502 model.vocab_ptr(),
503 n_ctx_train,
504 multiplier,
505 base,
506 allowed_length,
507 penalty_last_n,
508 seq_breaker_pointers.as_mut_ptr(),
509 seq_breaker_pointers.len(),
510 )
511 };
512
513 Ok(Self { sampler })
514 }
515
516 /// Penalizes tokens for being present in the context.
517 ///
518 /// Parameters:
519 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
520 /// - ``penalty_repeat``: 1.0 = disabled
521 /// - ``penalty_freq``: 0.0 = disabled
522 /// - ``penalty_present``: 0.0 = disabled
523 #[must_use]
524 pub fn penalties(
525 penalty_last_n: i32,
526 penalty_repeat: f32,
527 penalty_freq: f32,
528 penalty_present: f32,
529 ) -> Self {
530 let sampler = unsafe {
531 llama_cpp_bindings_sys::llama_sampler_init_penalties(
532 penalty_last_n,
533 penalty_repeat,
534 penalty_freq,
535 penalty_present,
536 )
537 };
538 Self { sampler }
539 }
540
541 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
542 ///
543 /// # Parameters:
544 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
545 /// - ``seed``: Seed to initialize random generation with.
546 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
547 /// generated text. A higher value corresponds to more surprising or less predictable text,
548 /// while a lower value corresponds to less surprising or more predictable text.
549 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
550 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
551 /// updated more quickly, while a smaller learning rate will result in slower updates.
552 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
553 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
554 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
555 /// it affects the performance of the algorithm.
556 #[must_use]
557 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
558 let sampler = unsafe {
559 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
560 };
561 Self { sampler }
562 }
563
564 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
565 ///
566 /// # Parameters:
567 /// - ``seed``: Seed to initialize random generation with.
568 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
569 /// generated text. A higher value corresponds to more surprising or less predictable text,
570 /// while a lower value corresponds to less surprising or more predictable text.
571 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
572 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
573 /// updated more quickly, while a smaller learning rate will result in slower updates.
574 #[must_use]
575 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
576 let sampler =
577 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
578 Self { sampler }
579 }
580
581 /// Selects a token at random based on each token's probabilities
582 #[must_use]
583 pub fn dist(seed: u32) -> Self {
584 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
585 Self { sampler }
586 }
587
588 /// Selects the most likely token
589 ///
590 /// # Example:
591 /// ```rust
592 /// use llama_cpp_bindings::token::{
593 /// LlamaToken,
594 /// data::LlamaTokenData,
595 /// data_array::LlamaTokenDataArray
596 /// };
597 /// use llama_cpp_bindings::sampling::LlamaSampler;
598 ///
599 /// let mut data_array = LlamaTokenDataArray::new(vec![
600 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
601 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
602 /// ], false);
603 ///
604 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
605 ///
606 /// assert_eq!(data_array.data.len(), 2);
607 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
608 /// ```
609 #[must_use]
610 pub fn greedy() -> Self {
611 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
612 Self { sampler }
613 }
614
615 /// Creates a sampler that applies bias values to specific tokens during sampling.
616 ///
617 /// # Parameters
618 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
619 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
620 ///
621 /// # Errors
622 /// Returns [`SamplingError::IntegerOverflow`] if `biases.len()` exceeds `i32::MAX`.
623 ///
624 /// # Example
625 /// ```rust
626 /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
627 /// use llama_cpp_bindings::sampling::LlamaSampler;
628 ///
629 /// let biases = vec![
630 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
631 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
632 /// ];
633 ///
634 /// // Assuming vocab_size of 32000
635 /// let sampler = LlamaSampler::logit_bias(32000, &biases).unwrap();
636 /// ```
637 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
638 let bias_count = i32::try_from(biases.len()).map_err(|convert_error| {
639 SamplingError::IntegerOverflow(format!("bias count exceeds i32::MAX: {convert_error}"))
640 })?;
641 let data = biases
642 .as_ptr()
643 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
644
645 let sampler = unsafe {
646 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
647 };
648
649 Ok(Self { sampler })
650 }
651}
652
653impl Drop for LlamaSampler {
654 fn drop(&mut self) {
655 unsafe {
656 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
657 }
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::LlamaSampler;
664 use crate::GrammarError;
665
666 #[test]
667 fn sanitize_grammar_strings_valid() {
668 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
669
670 assert!(result.is_ok());
671 }
672
673 #[test]
674 fn sanitize_grammar_strings_root_not_found() {
675 let result = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root");
676
677 assert_eq!(result.err(), Some(GrammarError::RootNotFound));
678 }
679
680 #[test]
681 fn sanitize_grammar_strings_null_byte_in_grammar() {
682 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root");
683
684 assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
685 }
686
687 #[test]
688 fn sanitize_grammar_strings_null_byte_in_root() {
689 let result = LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot");
690
691 assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
692 }
693
694 #[test]
695 fn sanitize_trigger_words_valid() {
696 let words: Vec<&[u8]> = vec![b"hello", b"world"];
697 let result = LlamaSampler::sanitize_trigger_words(words);
698
699 assert!(result.is_ok());
700 assert_eq!(result.expect("valid trigger words").len(), 2);
701 }
702
703 #[test]
704 fn sanitize_trigger_words_empty_list() {
705 let words: Vec<&[u8]> = vec![];
706 let result = LlamaSampler::sanitize_trigger_words(words);
707
708 assert!(result.is_ok());
709 assert!(result.expect("valid trigger words").is_empty());
710 }
711
712 #[test]
713 fn sanitize_trigger_words_null_byte() {
714 let words: Vec<&[u8]> = vec![b"hel\0lo"];
715 let result = LlamaSampler::sanitize_trigger_words(words);
716
717 assert_eq!(result.err(), Some(GrammarError::TriggerWordNullBytes));
718 }
719
720 #[test]
721 fn sanitize_trigger_patterns_valid() {
722 let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
723 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
724
725 assert!(result.is_ok());
726 assert_eq!(result.expect("valid trigger patterns").len(), 2);
727 }
728
729 #[test]
730 fn sanitize_trigger_patterns_empty_list() {
731 let patterns: Vec<String> = vec![];
732 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
733
734 assert!(result.is_ok());
735 assert!(result.expect("valid trigger patterns").is_empty());
736 }
737
738 #[test]
739 fn sanitize_trigger_patterns_null_byte() {
740 let patterns = vec!["hel\0lo".to_string()];
741 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
742
743 assert_eq!(result.err(), Some(GrammarError::GrammarNullBytes));
744 }
745
746 #[test]
747 fn apply_modifies_data_array() {
748 use crate::token::LlamaToken;
749 use crate::token::data::LlamaTokenData;
750 use crate::token::data_array::LlamaTokenDataArray;
751
752 let sampler = LlamaSampler::greedy();
753 let mut data_array = LlamaTokenDataArray::new(
754 vec![
755 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
756 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
757 ],
758 false,
759 );
760
761 sampler.apply(&mut data_array);
762
763 assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
764 }
765
766 #[test]
767 fn accept_succeeds() {
768 let mut sampler = LlamaSampler::chain_simple([
769 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
770 LlamaSampler::greedy(),
771 ]);
772
773 sampler
774 .accept(crate::token::LlamaToken::new(1))
775 .expect("test: accept should succeed");
776 }
777
778 #[test]
779 fn try_accept_succeeds_on_penalties_sampler() {
780 let mut sampler = LlamaSampler::chain_simple([
781 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
782 LlamaSampler::greedy(),
783 ]);
784
785 let result = sampler.try_accept(crate::token::LlamaToken::new(42));
786
787 assert!(result.is_ok());
788 }
789
790 #[test]
791 fn accept_many_multiple_tokens() {
792 use crate::token::LlamaToken;
793
794 let mut sampler = LlamaSampler::chain_simple([
795 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
796 LlamaSampler::greedy(),
797 ]);
798
799 sampler
800 .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
801 .expect("test: accept_many should succeed");
802 }
803
804 #[test]
805 fn with_tokens_builder_pattern() {
806 use crate::token::LlamaToken;
807
808 let _sampler = LlamaSampler::chain_simple([
809 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
810 LlamaSampler::greedy(),
811 ])
812 .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
813 .expect("test: with_tokens should succeed");
814 }
815}