llama_cpp_2/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12use crate::{status_is_ok, status_to_i32, GrammarError, SamplerAcceptError};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16 pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
17}
18
19impl Debug for LlamaSampler {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("LlamaSamplerChain").finish()
22 }
23}
24
25impl LlamaSampler {
26 /// Sample and accept a token from the idx-th output of the last evaluation
27 #[must_use]
28 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
29 let token = unsafe {
30 llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
31 };
32
33 LlamaToken(token)
34 }
35
36 /// Applies this sampler to a [`LlamaTokenDataArray`].
37 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
38 data_array.apply_sampler(self);
39 }
40
41 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
42 /// (e.g. grammar, repetition, etc.)
43 pub fn accept(&mut self, token: LlamaToken) {
44 let _ = self.try_accept(token);
45 }
46
47 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
48 /// certain samplers (e.g. grammar, repetition, etc.)
49 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
50 for token in tokens {
51 let _ = self.try_accept(*token.borrow());
52 }
53 }
54
55 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
56 /// certain samplers (e.g. grammar, repetition, etc.)
57 #[must_use]
58 pub fn with_tokens(
59 mut self,
60 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
61 ) -> Self {
62 self.accept_many(tokens);
63 self
64 }
65
66 /// Try accepting a token from the sampler. Returns an error if the sampler throws.
67 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
68 let sampler_result =
69 unsafe { llama_cpp_sys_2::llama_rs_sampler_accept(self.sampler, token.0) };
70 if status_is_ok(sampler_result) {
71 Ok(())
72 } else {
73 Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
74 }
75 }
76
77 /// Resets the internal state of the sampler.
78 ///
79 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
80 pub fn reset(&mut self) {
81 unsafe {
82 llama_cpp_sys_2::llama_sampler_reset(self.sampler);
83 }
84 }
85
86 /// Gets the random seed used by this sampler.
87 ///
88 /// Returns:
89 /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
90 /// - For sampler chains: returns the first non-default seed found in reverse order
91 /// - For all other samplers: returns 0xFFFFFFFF
92 #[must_use]
93 pub fn get_seed(&self) -> u32 {
94 unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
95 }
96
97 /// Combines a list of samplers into a single sampler that applies each component sampler one
98 /// after another.
99 ///
100 /// If you are using a chain to select a token, the chain should always end with one of
101 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
102 /// [`LlamaSampler::mirostat_v2`].
103 #[must_use]
104 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
105 unsafe {
106 let chain = llama_cpp_sys_2::llama_sampler_chain_init(
107 llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
108 );
109
110 for sampler in samplers {
111 llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
112
113 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
114 // owned by the chain
115 std::mem::forget(sampler);
116 }
117
118 Self { sampler: chain }
119 }
120 }
121
122 /// Same as [`Self::chain`] with `no_perf = false`.
123 ///
124 /// # Example
125 /// ```rust
126 /// use llama_cpp_2::token::{
127 /// LlamaToken,
128 /// data::LlamaTokenData,
129 /// data_array::LlamaTokenDataArray
130 /// };
131 /// use llama_cpp_2::sampling::LlamaSampler;
132 /// use llama_cpp_2::llama_backend::LlamaBackend;
133 /// let backend = LlamaBackend::init().unwrap();
134 ///
135 /// let mut data_array = LlamaTokenDataArray::new(vec![
136 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
137 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
138 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
139 /// ], false);
140 ///
141 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
142 /// LlamaSampler::temp(0.5),
143 /// LlamaSampler::greedy(),
144 /// ]));
145 ///
146 /// assert_eq!(data_array.data[0].logit(), 0.);
147 /// assert_eq!(data_array.data[1].logit(), 2.);
148 /// assert_eq!(data_array.data[2].logit(), 4.);
149 ///
150 /// assert_eq!(data_array.data.len(), 3);
151 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
152 /// ```
153 #[must_use]
154 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
155 Self::chain(samplers, false)
156 }
157
158 #[allow(clippy::doc_markdown)]
159 /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
160 /// value, the rest are set to -inf
161 ///
162 /// # Example:
163 /// ```rust
164 /// use llama_cpp_2::token::{
165 /// LlamaToken,
166 /// data::LlamaTokenData,
167 /// data_array::LlamaTokenDataArray
168 /// };
169 /// use llama_cpp_2::sampling::LlamaSampler;
170 ///
171 /// let mut data_array = LlamaTokenDataArray::new(vec![
172 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
173 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
174 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
175 /// ], false);
176 ///
177 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
178 ///
179 /// assert_eq!(data_array.data[0].logit(), 0.);
180 /// assert_eq!(data_array.data[1].logit(), 2.);
181 /// assert_eq!(data_array.data[2].logit(), 4.);
182 /// ```
183 #[must_use]
184 pub fn temp(t: f32) -> Self {
185 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
186 Self { sampler }
187 }
188
189 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
190 /// <https://arxiv.org/abs/2309.02772>.
191 #[must_use]
192 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
193 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
194 Self { sampler }
195 }
196
197 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
198 /// <https://arxiv.org/abs/1904.09751>
199 ///
200 /// # Example:
201 /// ```rust
202 /// use llama_cpp_2::token::{
203 /// LlamaToken,
204 /// data::LlamaTokenData,
205 /// data_array::LlamaTokenDataArray
206 /// };
207 /// use llama_cpp_2::sampling::LlamaSampler;
208 ///
209 /// let mut data_array = LlamaTokenDataArray::new(vec![
210 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
211 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
212 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
213 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
214 /// ], false);
215 ///
216 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
217 ///
218 /// assert_eq!(data_array.data.len(), 2);
219 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
220 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
221 /// ```
222 #[must_use]
223 pub fn top_k(k: i32) -> Self {
224 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
225 Self { sampler }
226 }
227
228 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
229 /// <https://arxiv.org/pdf/2411.07641>
230 ///
231 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
232 ///
233 /// # Parameters
234 /// - `n`: Number of standard deviations from the mean to include in sampling
235 ///
236 /// # Example
237 /// ```rust
238 /// use llama_cpp_2::sampling::LlamaSampler;
239 /// use llama_cpp_2::token::{
240 /// LlamaToken,
241 /// data::LlamaTokenData,
242 /// data_array::LlamaTokenDataArray
243 /// };
244 ///
245 /// let mut data_array = LlamaTokenDataArray::new(vec![
246 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
247 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
248 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
249 /// ], false);
250 ///
251 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
252 /// ```
253 #[must_use]
254 pub fn top_n_sigma(n: f32) -> Self {
255 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
256 Self { sampler }
257 }
258
259 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
260 #[must_use]
261 pub fn typical(p: f32, min_keep: usize) -> Self {
262 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
263 Self { sampler }
264 }
265
266 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
267 /// <https://arxiv.org/abs/1904.09751>
268 #[must_use]
269 pub fn top_p(p: f32, min_keep: usize) -> Self {
270 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
271 Self { sampler }
272 }
273
274 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
275 #[must_use]
276 pub fn min_p(p: f32, min_keep: usize) -> Self {
277 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
278 Self { sampler }
279 }
280
281 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
282 #[must_use]
283 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
284 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
285 Self { sampler }
286 }
287
288 /// Grammar sampler
289 #[must_use]
290 pub fn grammar(
291 model: &LlamaModel,
292 grammar_str: &str,
293 grammar_root: &str,
294 ) -> Result<Self, GrammarError> {
295 let (grammar_str, grammar_root) =
296 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
297
298 let sampler = unsafe {
299 llama_cpp_sys_2::llama_rs_sampler_init_grammar(
300 model.vocab_ptr(),
301 grammar_str.as_ptr(),
302 grammar_root.as_ptr(),
303 )
304 };
305
306 if sampler.is_null() {
307 Err(GrammarError::NullGrammar)
308 } else {
309 Ok(Self { sampler })
310 }
311 }
312
313 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
314 ///
315 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
316 #[must_use]
317 pub fn grammar_lazy(
318 model: &LlamaModel,
319 grammar_str: &str,
320 grammar_root: &str,
321 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
322 trigger_tokens: &[LlamaToken],
323 ) -> Result<Self, GrammarError> {
324 let (grammar_str, grammar_root) =
325 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
326 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
327
328 let mut trigger_word_ptrs: Vec<*const c_char> =
329 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
330
331 let sampler = unsafe {
332 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy(
333 model.vocab_ptr(),
334 grammar_str.as_ptr(),
335 grammar_root.as_ptr(),
336 trigger_word_ptrs.as_mut_ptr(),
337 trigger_word_ptrs.len(),
338 trigger_tokens.as_ptr().cast(),
339 trigger_tokens.len(),
340 )
341 };
342
343 if sampler.is_null() {
344 Err(GrammarError::NullGrammar)
345 } else {
346 Ok(Self { sampler })
347 }
348 }
349
350 /// Lazy grammar sampler using regex trigger patterns.
351 ///
352 /// Trigger patterns are regular expressions matched from the start of the
353 /// generation output. The grammar sampler will be fed content starting from
354 /// the first match group.
355 #[must_use]
356 pub fn grammar_lazy_patterns(
357 model: &LlamaModel,
358 grammar_str: &str,
359 grammar_root: &str,
360 trigger_patterns: &[String],
361 trigger_tokens: &[LlamaToken],
362 ) -> Result<Self, GrammarError> {
363 let (grammar_str, grammar_root) =
364 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
365 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
366
367 let mut trigger_pattern_ptrs: Vec<*const c_char> =
368 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
369
370 let sampler = unsafe {
371 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy_patterns(
372 model.vocab_ptr(),
373 grammar_str.as_ptr(),
374 grammar_root.as_ptr(),
375 trigger_pattern_ptrs.as_mut_ptr(),
376 trigger_pattern_ptrs.len(),
377 trigger_tokens.as_ptr().cast(),
378 trigger_tokens.len(),
379 )
380 };
381
382 if sampler.is_null() {
383 Err(GrammarError::NullGrammar)
384 } else {
385 Ok(Self { sampler })
386 }
387 }
388
389 /// `LLGuidance` sampler for constrained decoding.
390 ///
391 /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
392 /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
393 ///
394 /// # Errors
395 ///
396 /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
397 #[cfg(feature = "llguidance")]
398 pub fn llguidance(
399 model: &LlamaModel,
400 grammar_kind: &str,
401 grammar_data: &str,
402 ) -> Result<Self, GrammarError> {
403 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
404 }
405
406 fn sanitize_grammar_strings(
407 grammar_str: &str,
408 grammar_root: &str,
409 ) -> Result<(CString, CString), GrammarError> {
410 if !grammar_str.contains(grammar_root) {
411 return Err(GrammarError::RootNotFound);
412 }
413
414 if grammar_str.contains('\0') || grammar_root.contains('\0') {
415 return Err(GrammarError::GrammarNullBytes);
416 }
417
418 Ok((
419 CString::new(grammar_str).unwrap(),
420 CString::new(grammar_root).unwrap(),
421 ))
422 }
423
424 fn sanitize_trigger_words(
425 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
426 ) -> Result<Vec<CString>, GrammarError> {
427 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
428 if trigger_words
429 .iter()
430 .any(|word| word.as_ref().contains(&b'\0'))
431 {
432 return Err(GrammarError::TriggerWordNullBytes);
433 }
434 Ok(trigger_words
435 .into_iter()
436 .map(|word| CString::new(word.as_ref()).unwrap())
437 .collect())
438 }
439
440 fn sanitize_trigger_patterns(
441 trigger_patterns: &[String],
442 ) -> Result<Vec<CString>, GrammarError> {
443 let mut patterns = Vec::with_capacity(trigger_patterns.len());
444 for pattern in trigger_patterns {
445 if pattern.contains('\0') {
446 return Err(GrammarError::GrammarNullBytes);
447 }
448 patterns.push(CString::new(pattern.as_str()).unwrap());
449 }
450 Ok(patterns)
451 }
452
453 /// DRY sampler, designed by p-e-w, as described in:
454 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
455 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
456 ///
457 /// # Panics
458 /// If any string in ``seq_breakers`` contains null bytes.
459 #[allow(missing_docs)]
460 #[must_use]
461 pub fn dry(
462 model: &LlamaModel,
463 multiplier: f32,
464 base: f32,
465 allowed_length: i32,
466 penalty_last_n: i32,
467 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
468 ) -> Self {
469 let seq_breakers: Vec<CString> = seq_breakers
470 .into_iter()
471 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
472 .collect();
473 let mut seq_breaker_pointers: Vec<*const c_char> =
474 seq_breakers.iter().map(|s| s.as_ptr()).collect();
475
476 let sampler = unsafe {
477 llama_cpp_sys_2::llama_sampler_init_dry(
478 model.vocab_ptr(),
479 model
480 .n_ctx_train()
481 .try_into()
482 .expect("n_ctx_train exceeds i32::MAX"),
483 multiplier,
484 base,
485 allowed_length,
486 penalty_last_n,
487 seq_breaker_pointers.as_mut_ptr(),
488 seq_breaker_pointers.len(),
489 )
490 };
491 Self { sampler }
492 }
493
494 /// Penalizes tokens for being present in the context.
495 ///
496 /// Parameters:
497 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
498 /// - ``penalty_repeat``: 1.0 = disabled
499 /// - ``penalty_freq``: 0.0 = disabled
500 /// - ``penalty_present``: 0.0 = disabled
501 #[allow(clippy::too_many_arguments)]
502 #[must_use]
503 pub fn penalties(
504 penalty_last_n: i32,
505 penalty_repeat: f32,
506 penalty_freq: f32,
507 penalty_present: f32,
508 ) -> Self {
509 let sampler = unsafe {
510 llama_cpp_sys_2::llama_sampler_init_penalties(
511 penalty_last_n,
512 penalty_repeat,
513 penalty_freq,
514 penalty_present,
515 )
516 };
517 Self { sampler }
518 }
519
520 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
521 ///
522 /// # Parameters:
523 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
524 /// - ``seed``: Seed to initialize random generation with.
525 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
526 /// generated text. A higher value corresponds to more surprising or less predictable text,
527 /// while a lower value corresponds to less surprising or more predictable text.
528 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
529 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
530 /// updated more quickly, while a smaller learning rate will result in slower updates.
531 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
532 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
533 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
534 /// it affects the performance of the algorithm.
535 #[must_use]
536 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
537 let sampler =
538 unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
539 Self { sampler }
540 }
541
542 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
543 ///
544 /// # Parameters:
545 /// - ``seed``: Seed to initialize random generation with.
546 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
547 /// generated text. A higher value corresponds to more surprising or less predictable text,
548 /// while a lower value corresponds to less surprising or more predictable text.
549 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
550 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
551 /// updated more quickly, while a smaller learning rate will result in slower updates.
552 #[must_use]
553 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
554 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
555 Self { sampler }
556 }
557
558 /// Selects a token at random based on each token's probabilities
559 #[must_use]
560 pub fn dist(seed: u32) -> Self {
561 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
562 Self { sampler }
563 }
564
565 /// Selects the most likely token
566 ///
567 /// # Example:
568 /// ```rust
569 /// use llama_cpp_2::token::{
570 /// LlamaToken,
571 /// data::LlamaTokenData,
572 /// data_array::LlamaTokenDataArray
573 /// };
574 /// use llama_cpp_2::sampling::LlamaSampler;
575 ///
576 /// let mut data_array = LlamaTokenDataArray::new(vec![
577 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
578 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
579 /// ], false);
580 ///
581 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
582 ///
583 /// assert_eq!(data_array.data.len(), 2);
584 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
585 /// ```
586 #[must_use]
587 pub fn greedy() -> Self {
588 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
589 Self { sampler }
590 }
591
592 /// Creates a sampler that applies bias values to specific tokens during sampling.
593 ///
594 /// # Parameters
595 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
596 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
597 ///
598 /// # Example
599 /// ```rust
600 /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
601 /// use llama_cpp_2::sampling::LlamaSampler;
602 ///
603 /// let biases = vec![
604 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
605 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
606 /// ];
607 ///
608 /// // Assuming vocab_size of 32000
609 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
610 /// ```
611 #[must_use]
612 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
613 let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
614
615 let sampler = unsafe {
616 llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
617 };
618
619 Self { sampler }
620 }
621}
622
623impl Drop for LlamaSampler {
624 fn drop(&mut self) {
625 unsafe {
626 llama_cpp_sys_2::llama_sampler_free(self.sampler);
627 }
628 }
629}