llama_cpp_2/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9#[cfg(feature = "common")]
10use crate::status_is_ok;
11use crate::token::data_array::LlamaTokenDataArray;
12use crate::token::logit_bias::LlamaLogitBias;
13use crate::token::LlamaToken;
14use crate::{GrammarError, SamplerAcceptError};
15
16/// A safe wrapper around `llama_sampler`.
17pub struct LlamaSampler {
18 pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
19}
20
21impl Debug for LlamaSampler {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("LlamaSamplerChain").finish()
24 }
25}
26
27impl LlamaSampler {
28 /// Sample and accept a token from the idx-th output of the last evaluation
29 #[must_use]
30 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
31 let token = unsafe {
32 llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
33 };
34
35 LlamaToken(token)
36 }
37
38 /// Applies this sampler to a [`LlamaTokenDataArray`].
39 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
40 data_array.apply_sampler(self);
41 }
42
43 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
44 /// (e.g. grammar, repetition, etc.)
45 pub fn accept(&mut self, token: LlamaToken) {
46 #[cfg(feature = "common")]
47 {
48 let _ = self.try_accept(token);
49 }
50 #[cfg(not(feature = "common"))]
51 unsafe {
52 llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0);
53 }
54 }
55
56 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
57 /// certain samplers (e.g. grammar, repetition, etc.)
58 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
59 for token in tokens {
60 self.accept(*token.borrow());
61 }
62 }
63
64 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
65 /// certain samplers (e.g. grammar, repetition, etc.)
66 #[must_use]
67 pub fn with_tokens(
68 mut self,
69 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
70 ) -> Self {
71 self.accept_many(tokens);
72 self
73 }
74
75 /// Try accepting a token from the sampler. Returns an error if the sampler throws.
76 #[cfg(feature = "common")]
77 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
78 let sampler_result =
79 unsafe { llama_cpp_sys_2::llama_rs_sampler_accept(self.sampler, token.0) };
80 if status_is_ok(sampler_result) {
81 Ok(())
82 } else {
83 Err(SamplerAcceptError::FfiError(sampler_result))
84 }
85 }
86
87 /// Resets the internal state of the sampler.
88 ///
89 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
90 pub fn reset(&mut self) {
91 unsafe {
92 llama_cpp_sys_2::llama_sampler_reset(self.sampler);
93 }
94 }
95
96 /// Gets the random seed used by this sampler.
97 ///
98 /// Returns:
99 /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
100 /// - For sampler chains: returns the first non-default seed found in reverse order
101 /// - For all other samplers: returns 0xFFFFFFFF
102 #[must_use]
103 pub fn get_seed(&self) -> u32 {
104 unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
105 }
106
107 /// Combines a list of samplers into a single sampler that applies each component sampler one
108 /// after another.
109 ///
110 /// If you are using a chain to select a token, the chain should always end with one of
111 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
112 /// [`LlamaSampler::mirostat_v2`].
113 #[must_use]
114 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
115 unsafe {
116 let chain = llama_cpp_sys_2::llama_sampler_chain_init(
117 llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
118 );
119
120 for sampler in samplers {
121 llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
122
123 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
124 // owned by the chain
125 std::mem::forget(sampler);
126 }
127
128 Self { sampler: chain }
129 }
130 }
131
132 /// Same as [`Self::chain`] with `no_perf = false`.
133 ///
134 /// # Example
135 /// ```rust
136 /// use llama_cpp_2::token::{
137 /// LlamaToken,
138 /// data::LlamaTokenData,
139 /// data_array::LlamaTokenDataArray
140 /// };
141 /// use llama_cpp_2::sampling::LlamaSampler;
142 /// use llama_cpp_2::llama_backend::LlamaBackend;
143 /// let backend = LlamaBackend::init().unwrap();
144 ///
145 /// let mut data_array = LlamaTokenDataArray::new(vec![
146 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
147 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
148 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
149 /// ], false);
150 ///
151 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
152 /// LlamaSampler::temp(0.5),
153 /// LlamaSampler::greedy(),
154 /// ]));
155 ///
156 /// assert_eq!(data_array.data[0].logit(), 0.);
157 /// assert_eq!(data_array.data[1].logit(), 2.);
158 /// assert_eq!(data_array.data[2].logit(), 4.);
159 ///
160 /// assert_eq!(data_array.data.len(), 3);
161 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
162 /// ```
163 #[must_use]
164 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
165 Self::chain(samplers, false)
166 }
167
168 #[allow(clippy::doc_markdown)]
169 /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
170 /// value, the rest are set to -inf
171 ///
172 /// # Example:
173 /// ```rust
174 /// use llama_cpp_2::token::{
175 /// LlamaToken,
176 /// data::LlamaTokenData,
177 /// data_array::LlamaTokenDataArray
178 /// };
179 /// use llama_cpp_2::sampling::LlamaSampler;
180 ///
181 /// let mut data_array = LlamaTokenDataArray::new(vec![
182 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
183 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
184 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
185 /// ], false);
186 ///
187 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
188 ///
189 /// assert_eq!(data_array.data[0].logit(), 0.);
190 /// assert_eq!(data_array.data[1].logit(), 2.);
191 /// assert_eq!(data_array.data[2].logit(), 4.);
192 /// ```
193 #[must_use]
194 pub fn temp(t: f32) -> Self {
195 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
196 Self { sampler }
197 }
198
199 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
200 /// <https://arxiv.org/abs/2309.02772>.
201 #[must_use]
202 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
203 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
204 Self { sampler }
205 }
206
207 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
208 /// <https://arxiv.org/abs/1904.09751>
209 ///
210 /// # Example:
211 /// ```rust
212 /// use llama_cpp_2::token::{
213 /// LlamaToken,
214 /// data::LlamaTokenData,
215 /// data_array::LlamaTokenDataArray
216 /// };
217 /// use llama_cpp_2::sampling::LlamaSampler;
218 ///
219 /// let mut data_array = LlamaTokenDataArray::new(vec![
220 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
221 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
222 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
223 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
224 /// ], false);
225 ///
226 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
227 ///
228 /// assert_eq!(data_array.data.len(), 2);
229 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
230 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
231 /// ```
232 #[must_use]
233 pub fn top_k(k: i32) -> Self {
234 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
235 Self { sampler }
236 }
237
238 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
239 /// <https://arxiv.org/pdf/2411.07641>
240 ///
241 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
242 ///
243 /// # Parameters
244 /// - `n`: Number of standard deviations from the mean to include in sampling
245 ///
246 /// # Example
247 /// ```rust
248 /// use llama_cpp_2::sampling::LlamaSampler;
249 /// use llama_cpp_2::token::{
250 /// LlamaToken,
251 /// data::LlamaTokenData,
252 /// data_array::LlamaTokenDataArray
253 /// };
254 ///
255 /// let mut data_array = LlamaTokenDataArray::new(vec![
256 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
257 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
258 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
259 /// ], false);
260 ///
261 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
262 /// ```
263 #[must_use]
264 pub fn top_n_sigma(n: f32) -> Self {
265 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
266 Self { sampler }
267 }
268
269 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
270 #[must_use]
271 pub fn typical(p: f32, min_keep: usize) -> Self {
272 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
273 Self { sampler }
274 }
275
276 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
277 /// <https://arxiv.org/abs/1904.09751>
278 #[must_use]
279 pub fn top_p(p: f32, min_keep: usize) -> Self {
280 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
281 Self { sampler }
282 }
283
284 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
285 #[must_use]
286 pub fn min_p(p: f32, min_keep: usize) -> Self {
287 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
288 Self { sampler }
289 }
290
291 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
292 #[must_use]
293 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
294 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
295 Self { sampler }
296 }
297
298 /// Grammar sampler
299 #[cfg(feature = "common")]
300 #[must_use]
301 pub fn grammar(
302 model: &LlamaModel,
303 grammar_str: &str,
304 grammar_root: &str,
305 ) -> Result<Self, GrammarError> {
306 let (grammar_str, grammar_root) =
307 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
308
309 let sampler = unsafe {
310 llama_cpp_sys_2::llama_rs_sampler_init_grammar(
311 model.vocab_ptr(),
312 grammar_str.as_ptr(),
313 grammar_root.as_ptr(),
314 )
315 };
316
317 if sampler.is_null() {
318 Err(GrammarError::NullGrammar)
319 } else {
320 Ok(Self { sampler })
321 }
322 }
323
324 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
325 ///
326 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
327 #[cfg(feature = "common")]
328 #[must_use]
329 pub fn grammar_lazy(
330 model: &LlamaModel,
331 grammar_str: &str,
332 grammar_root: &str,
333 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
334 trigger_tokens: &[LlamaToken],
335 ) -> Result<Self, GrammarError> {
336 let (grammar_str, grammar_root) =
337 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
338 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
339
340 let mut trigger_word_ptrs: Vec<*const c_char> =
341 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
342
343 let sampler = unsafe {
344 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy(
345 model.vocab_ptr(),
346 grammar_str.as_ptr(),
347 grammar_root.as_ptr(),
348 trigger_word_ptrs.as_mut_ptr(),
349 trigger_word_ptrs.len(),
350 trigger_tokens.as_ptr().cast(),
351 trigger_tokens.len(),
352 )
353 };
354
355 if sampler.is_null() {
356 Err(GrammarError::NullGrammar)
357 } else {
358 Ok(Self { sampler })
359 }
360 }
361
362 /// Lazy grammar sampler using regex trigger patterns.
363 ///
364 /// Trigger patterns are regular expressions matched from the start of the
365 /// generation output. The grammar sampler will be fed content starting from
366 /// the first match group.
367 #[cfg(feature = "common")]
368 #[must_use]
369 pub fn grammar_lazy_patterns(
370 model: &LlamaModel,
371 grammar_str: &str,
372 grammar_root: &str,
373 trigger_patterns: &[String],
374 trigger_tokens: &[LlamaToken],
375 ) -> Result<Self, GrammarError> {
376 let (grammar_str, grammar_root) =
377 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
378 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
379
380 let mut trigger_pattern_ptrs: Vec<*const c_char> =
381 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
382
383 let sampler = unsafe {
384 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy_patterns(
385 model.vocab_ptr(),
386 grammar_str.as_ptr(),
387 grammar_root.as_ptr(),
388 trigger_pattern_ptrs.as_mut_ptr(),
389 trigger_pattern_ptrs.len(),
390 trigger_tokens.as_ptr().cast(),
391 trigger_tokens.len(),
392 )
393 };
394
395 if sampler.is_null() {
396 Err(GrammarError::NullGrammar)
397 } else {
398 Ok(Self { sampler })
399 }
400 }
401
402 /// `LLGuidance` sampler for constrained decoding.
403 ///
404 /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
405 /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
406 ///
407 /// # Errors
408 ///
409 /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
410 #[cfg(feature = "llguidance")]
411 pub fn llguidance(
412 model: &LlamaModel,
413 grammar_kind: &str,
414 grammar_data: &str,
415 ) -> Result<Self, GrammarError> {
416 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
417 }
418
419 fn sanitize_grammar_strings(
420 grammar_str: &str,
421 grammar_root: &str,
422 ) -> Result<(CString, CString), GrammarError> {
423 if !grammar_str.contains(grammar_root) {
424 return Err(GrammarError::RootNotFound);
425 }
426
427 if grammar_str.contains('\0') || grammar_root.contains('\0') {
428 return Err(GrammarError::GrammarNullBytes);
429 }
430
431 Ok((
432 CString::new(grammar_str).unwrap(),
433 CString::new(grammar_root).unwrap(),
434 ))
435 }
436
437 fn sanitize_trigger_words(
438 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
439 ) -> Result<Vec<CString>, GrammarError> {
440 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
441 if trigger_words
442 .iter()
443 .any(|word| word.as_ref().contains(&b'\0'))
444 {
445 return Err(GrammarError::TriggerWordNullBytes);
446 }
447 Ok(trigger_words
448 .into_iter()
449 .map(|word| CString::new(word.as_ref()).unwrap())
450 .collect())
451 }
452
453 fn sanitize_trigger_patterns(
454 trigger_patterns: &[String],
455 ) -> Result<Vec<CString>, GrammarError> {
456 let mut patterns = Vec::with_capacity(trigger_patterns.len());
457 for pattern in trigger_patterns {
458 if pattern.contains('\0') {
459 return Err(GrammarError::GrammarNullBytes);
460 }
461 patterns.push(CString::new(pattern.as_str()).unwrap());
462 }
463 Ok(patterns)
464 }
465
466 /// DRY sampler, designed by p-e-w, as described in:
467 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
468 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
469 ///
470 /// # Panics
471 /// If any string in ``seq_breakers`` contains null bytes.
472 #[allow(missing_docs)]
473 #[must_use]
474 pub fn dry(
475 model: &LlamaModel,
476 multiplier: f32,
477 base: f32,
478 allowed_length: i32,
479 penalty_last_n: i32,
480 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
481 ) -> Self {
482 let seq_breakers: Vec<CString> = seq_breakers
483 .into_iter()
484 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
485 .collect();
486 let mut seq_breaker_pointers: Vec<*const c_char> =
487 seq_breakers.iter().map(|s| s.as_ptr()).collect();
488
489 let sampler = unsafe {
490 llama_cpp_sys_2::llama_sampler_init_dry(
491 model.vocab_ptr(),
492 model
493 .n_ctx_train()
494 .try_into()
495 .expect("n_ctx_train exceeds i32::MAX"),
496 multiplier,
497 base,
498 allowed_length,
499 penalty_last_n,
500 seq_breaker_pointers.as_mut_ptr(),
501 seq_breaker_pointers.len(),
502 )
503 };
504 Self { sampler }
505 }
506
507 /// Penalizes tokens for being present in the context.
508 ///
509 /// Parameters:
510 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
511 /// - ``penalty_repeat``: 1.0 = disabled
512 /// - ``penalty_freq``: 0.0 = disabled
513 /// - ``penalty_present``: 0.0 = disabled
514 #[allow(clippy::too_many_arguments)]
515 #[must_use]
516 pub fn penalties(
517 penalty_last_n: i32,
518 penalty_repeat: f32,
519 penalty_freq: f32,
520 penalty_present: f32,
521 ) -> Self {
522 let sampler = unsafe {
523 llama_cpp_sys_2::llama_sampler_init_penalties(
524 penalty_last_n,
525 penalty_repeat,
526 penalty_freq,
527 penalty_present,
528 )
529 };
530 Self { sampler }
531 }
532
533 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
534 ///
535 /// # Parameters:
536 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
537 /// - ``seed``: Seed to initialize random generation with.
538 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
539 /// generated text. A higher value corresponds to more surprising or less predictable text,
540 /// while a lower value corresponds to less surprising or more predictable text.
541 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
542 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
543 /// updated more quickly, while a smaller learning rate will result in slower updates.
544 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
545 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
546 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
547 /// it affects the performance of the algorithm.
548 #[must_use]
549 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
550 let sampler =
551 unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
552 Self { sampler }
553 }
554
555 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
556 ///
557 /// # Parameters:
558 /// - ``seed``: Seed to initialize random generation with.
559 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
560 /// generated text. A higher value corresponds to more surprising or less predictable text,
561 /// while a lower value corresponds to less surprising or more predictable text.
562 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
563 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
564 /// updated more quickly, while a smaller learning rate will result in slower updates.
565 #[must_use]
566 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
567 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
568 Self { sampler }
569 }
570
571 /// Selects a token at random based on each token's probabilities
572 #[must_use]
573 pub fn dist(seed: u32) -> Self {
574 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
575 Self { sampler }
576 }
577
578 /// Selects the most likely token
579 ///
580 /// # Example:
581 /// ```rust
582 /// use llama_cpp_2::token::{
583 /// LlamaToken,
584 /// data::LlamaTokenData,
585 /// data_array::LlamaTokenDataArray
586 /// };
587 /// use llama_cpp_2::sampling::LlamaSampler;
588 ///
589 /// let mut data_array = LlamaTokenDataArray::new(vec![
590 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
591 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
592 /// ], false);
593 ///
594 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
595 ///
596 /// assert_eq!(data_array.data.len(), 2);
597 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
598 /// ```
599 #[must_use]
600 pub fn greedy() -> Self {
601 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
602 Self { sampler }
603 }
604
605 /// Creates a sampler that applies bias values to specific tokens during sampling.
606 ///
607 /// # Parameters
608 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
609 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
610 ///
611 /// # Example
612 /// ```rust
613 /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
614 /// use llama_cpp_2::sampling::LlamaSampler;
615 ///
616 /// let biases = vec![
617 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
618 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
619 /// ];
620 ///
621 /// // Assuming vocab_size of 32000
622 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
623 /// ```
624 #[must_use]
625 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
626 let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
627
628 let sampler = unsafe {
629 llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
630 };
631
632 Self { sampler }
633 }
634}
635
636impl Drop for LlamaSampler {
637 fn drop(&mut self) {
638 unsafe {
639 llama_cpp_sys_2::llama_sampler_free(self.sampler);
640 }
641 }
642}