llama_cpp_bindings/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{CString, c_char};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::LlamaToken;
10use crate::token::data_array::LlamaTokenDataArray;
11use crate::token::logit_bias::LlamaLogitBias;
12use crate::{GrammarError, SamplerAcceptError, status_is_ok, status_to_i32};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16 /// Raw pointer to the underlying `llama_sampler`.
17 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
18}
19
20impl Debug for LlamaSampler {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("LlamaSamplerChain").finish()
23 }
24}
25
26impl LlamaSampler {
27 /// Sample and accept a token from the idx-th output of the last evaluation
28 #[must_use]
29 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
30 let token = unsafe {
31 llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
32 };
33
34 LlamaToken(token)
35 }
36
37 /// Applies this sampler to a [`LlamaTokenDataArray`].
38 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
39 data_array.apply_sampler(self);
40 }
41
42 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
43 /// (e.g. grammar, repetition, etc.)
44 pub fn accept(&mut self, token: LlamaToken) {
45 let _ = self.try_accept(token);
46 }
47
48 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
49 /// certain samplers (e.g. grammar, repetition, etc.)
50 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
51 for token in tokens {
52 let _ = self.try_accept(*token.borrow());
53 }
54 }
55
56 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
57 /// certain samplers (e.g. grammar, repetition, etc.)
58 #[must_use]
59 pub fn with_tokens(
60 mut self,
61 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
62 ) -> Self {
63 self.accept_many(tokens);
64 self
65 }
66
67 /// Try accepting a token from the sampler. Returns an error if the sampler throws.
68 ///
69 /// # Errors
70 /// Returns an error if the underlying sampler rejects the token.
71 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
72 let sampler_result =
73 unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) };
74 if status_is_ok(sampler_result) {
75 Ok(())
76 } else {
77 Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
78 }
79 }
80
81 /// Resets the internal state of the sampler.
82 ///
83 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
84 pub fn reset(&mut self) {
85 unsafe {
86 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
87 }
88 }
89
90 /// Gets the random seed used by this sampler.
91 ///
92 /// Returns:
93 /// - For random samplers (dist, mirostat, `mirostat_v2)`: returns their current seed
94 /// - For sampler chains: returns the first non-default seed found in reverse order
95 /// - For all other samplers: returns 0xFFFFFFFF
96 #[must_use]
97 pub fn get_seed(&self) -> u32 {
98 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
99 }
100
101 /// Combines a list of samplers into a single sampler that applies each component sampler one
102 /// after another.
103 ///
104 /// If you are using a chain to select a token, the chain should always end with one of
105 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
106 /// [`LlamaSampler::mirostat_v2`].
107 #[must_use]
108 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
109 unsafe {
110 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
111 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
112 );
113
114 for sampler in samplers {
115 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
116
117 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
118 // owned by the chain
119 std::mem::forget(sampler);
120 }
121
122 Self { sampler: chain }
123 }
124 }
125
126 /// Same as [`Self::chain`] with `no_perf = false`.
127 ///
128 /// # Example
129 /// ```rust
130 /// use llama_cpp_bindings::token::{
131 /// LlamaToken,
132 /// data::LlamaTokenData,
133 /// data_array::LlamaTokenDataArray
134 /// };
135 /// use llama_cpp_bindings::sampling::LlamaSampler;
136 /// use llama_cpp_bindings::llama_backend::LlamaBackend;
137 /// let backend = LlamaBackend::init().unwrap();
138 ///
139 /// let mut data_array = LlamaTokenDataArray::new(vec![
140 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
141 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
142 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
143 /// ], false);
144 ///
145 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
146 /// LlamaSampler::temp(0.5),
147 /// LlamaSampler::greedy(),
148 /// ]));
149 ///
150 /// assert_eq!(data_array.data[0].logit(), 0.);
151 /// assert_eq!(data_array.data[1].logit(), 2.);
152 /// assert_eq!(data_array.data[2].logit(), 4.);
153 ///
154 /// assert_eq!(data_array.data.len(), 3);
155 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
156 /// ```
157 #[must_use]
158 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
159 Self::chain(samplers, false)
160 }
161
162 /// Updates the logits `l_i' = l_i/t`. When `t <= 0.0`, the maximum logit is kept at its original
163 /// value, the rest are set to -inf
164 ///
165 /// # Example:
166 /// ```rust
167 /// use llama_cpp_bindings::token::{
168 /// LlamaToken,
169 /// data::LlamaTokenData,
170 /// data_array::LlamaTokenDataArray
171 /// };
172 /// use llama_cpp_bindings::sampling::LlamaSampler;
173 ///
174 /// let mut data_array = LlamaTokenDataArray::new(vec![
175 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
176 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
177 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
178 /// ], false);
179 ///
180 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
181 ///
182 /// assert_eq!(data_array.data[0].logit(), 0.);
183 /// assert_eq!(data_array.data[1].logit(), 2.);
184 /// assert_eq!(data_array.data[2].logit(), 4.);
185 /// ```
186 #[must_use]
187 pub fn temp(t: f32) -> Self {
188 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
189 Self { sampler }
190 }
191
192 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
193 /// <https://arxiv.org/abs/2309.02772>.
194 #[must_use]
195 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
196 let sampler =
197 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
198 Self { sampler }
199 }
200
201 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
202 /// <https://arxiv.org/abs/1904.09751>
203 ///
204 /// # Example:
205 /// ```rust
206 /// use llama_cpp_bindings::token::{
207 /// LlamaToken,
208 /// data::LlamaTokenData,
209 /// data_array::LlamaTokenDataArray
210 /// };
211 /// use llama_cpp_bindings::sampling::LlamaSampler;
212 ///
213 /// let mut data_array = LlamaTokenDataArray::new(vec![
214 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
215 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
216 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
217 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
218 /// ], false);
219 ///
220 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
221 ///
222 /// assert_eq!(data_array.data.len(), 2);
223 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
224 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
225 /// ```
226 #[must_use]
227 pub fn top_k(k: i32) -> Self {
228 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
229 Self { sampler }
230 }
231
232 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
233 /// <https://arxiv.org/pdf/2411.07641>
234 ///
235 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
236 ///
237 /// # Parameters
238 /// - `n`: Number of standard deviations from the mean to include in sampling
239 ///
240 /// # Example
241 /// ```rust
242 /// use llama_cpp_bindings::sampling::LlamaSampler;
243 /// use llama_cpp_bindings::token::{
244 /// LlamaToken,
245 /// data::LlamaTokenData,
246 /// data_array::LlamaTokenDataArray
247 /// };
248 ///
249 /// let mut data_array = LlamaTokenDataArray::new(vec![
250 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
251 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
252 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
253 /// ], false);
254 ///
255 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
256 /// ```
257 #[must_use]
258 pub fn top_n_sigma(n: f32) -> Self {
259 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
260 Self { sampler }
261 }
262
263 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
264 #[must_use]
265 pub fn typical(p: f32, min_keep: usize) -> Self {
266 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
267 Self { sampler }
268 }
269
270 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
271 /// <https://arxiv.org/abs/1904.09751>
272 #[must_use]
273 pub fn top_p(p: f32, min_keep: usize) -> Self {
274 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
275 Self { sampler }
276 }
277
278 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
279 #[must_use]
280 pub fn min_p(p: f32, min_keep: usize) -> Self {
281 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
282 Self { sampler }
283 }
284
285 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
286 #[must_use]
287 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
288 let sampler =
289 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
290 Self { sampler }
291 }
292
293 /// Grammar sampler
294 ///
295 /// # Errors
296 /// Returns an error if the grammar is invalid or the sampler cannot be initialized.
297 pub fn grammar(
298 model: &LlamaModel,
299 grammar_str: &str,
300 grammar_root: &str,
301 ) -> Result<Self, GrammarError> {
302 let (grammar_str, grammar_root) =
303 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
304
305 let sampler = unsafe {
306 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
307 model.vocab_ptr(),
308 grammar_str.as_ptr(),
309 grammar_root.as_ptr(),
310 )
311 };
312
313 if sampler.is_null() {
314 Err(GrammarError::NullGrammar)
315 } else {
316 Ok(Self { sampler })
317 }
318 }
319
320 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
321 ///
322 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
323 ///
324 /// # Errors
325 /// Returns an error if the grammar or trigger words are invalid.
326 pub fn grammar_lazy(
327 model: &LlamaModel,
328 grammar_str: &str,
329 grammar_root: &str,
330 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
331 trigger_tokens: &[LlamaToken],
332 ) -> Result<Self, GrammarError> {
333 let (grammar_str, grammar_root) =
334 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
335 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
336
337 let mut trigger_word_ptrs: Vec<*const c_char> =
338 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
339
340 let sampler = unsafe {
341 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
342 model.vocab_ptr(),
343 grammar_str.as_ptr(),
344 grammar_root.as_ptr(),
345 trigger_word_ptrs.as_mut_ptr(),
346 trigger_word_ptrs.len(),
347 trigger_tokens.as_ptr().cast(),
348 trigger_tokens.len(),
349 )
350 };
351
352 if sampler.is_null() {
353 Err(GrammarError::NullGrammar)
354 } else {
355 Ok(Self { sampler })
356 }
357 }
358
359 /// Lazy grammar sampler using regex trigger patterns.
360 ///
361 /// Trigger patterns are regular expressions matched from the start of the
362 /// generation output. The grammar sampler will be fed content starting from
363 /// the first match group.
364 ///
365 /// # Errors
366 /// Returns an error if the grammar or trigger patterns are invalid.
367 pub fn grammar_lazy_patterns(
368 model: &LlamaModel,
369 grammar_str: &str,
370 grammar_root: &str,
371 trigger_patterns: &[String],
372 trigger_tokens: &[LlamaToken],
373 ) -> Result<Self, GrammarError> {
374 let (grammar_str, grammar_root) =
375 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
376 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
377
378 let mut trigger_pattern_ptrs: Vec<*const c_char> =
379 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
380
381 let sampler = unsafe {
382 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
383 model.vocab_ptr(),
384 grammar_str.as_ptr(),
385 grammar_root.as_ptr(),
386 trigger_pattern_ptrs.as_mut_ptr(),
387 trigger_pattern_ptrs.len(),
388 trigger_tokens.as_ptr().cast(),
389 trigger_tokens.len(),
390 )
391 };
392
393 if sampler.is_null() {
394 Err(GrammarError::NullGrammar)
395 } else {
396 Ok(Self { sampler })
397 }
398 }
399
400 /// `LLGuidance` sampler for constrained decoding.
401 ///
402 /// Uses the `llguidance` and `toktrie` Rust crates to enforce grammar constraints
403 /// during token sampling. Supports JSON schema, regex, Lark, and other grammar types.
404 ///
405 /// # Errors
406 ///
407 /// Returns [`GrammarError`] if the grammar is invalid or the sampler cannot be initialized.
408 #[cfg(feature = "llguidance")]
409 pub fn llguidance(
410 model: &LlamaModel,
411 grammar_kind: &str,
412 grammar_data: &str,
413 ) -> Result<Self, GrammarError> {
414 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
415 }
416
417 fn sanitize_grammar_strings(
418 grammar_str: &str,
419 grammar_root: &str,
420 ) -> Result<(CString, CString), GrammarError> {
421 if !grammar_str.contains(grammar_root) {
422 return Err(GrammarError::RootNotFound);
423 }
424
425 if grammar_str.contains('\0') || grammar_root.contains('\0') {
426 return Err(GrammarError::GrammarNullBytes);
427 }
428
429 Ok((
430 CString::new(grammar_str).unwrap(),
431 CString::new(grammar_root).unwrap(),
432 ))
433 }
434
435 fn sanitize_trigger_words(
436 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
437 ) -> Result<Vec<CString>, GrammarError> {
438 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
439 if trigger_words
440 .iter()
441 .any(|word| word.as_ref().contains(&b'\0'))
442 {
443 return Err(GrammarError::TriggerWordNullBytes);
444 }
445 Ok(trigger_words
446 .into_iter()
447 .map(|word| CString::new(word.as_ref()).unwrap())
448 .collect())
449 }
450
451 fn sanitize_trigger_patterns(
452 trigger_patterns: &[String],
453 ) -> Result<Vec<CString>, GrammarError> {
454 let mut patterns = Vec::with_capacity(trigger_patterns.len());
455 for pattern in trigger_patterns {
456 if pattern.contains('\0') {
457 return Err(GrammarError::GrammarNullBytes);
458 }
459 patterns.push(CString::new(pattern.as_str()).unwrap());
460 }
461 Ok(patterns)
462 }
463
464 /// DRY sampler, designed by p-e-w, as described in:
465 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
466 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
467 ///
468 /// # Panics
469 /// If any string in ``seq_breakers`` contains null bytes.
470 #[allow(missing_docs)]
471 #[must_use]
472 pub fn dry(
473 model: &LlamaModel,
474 multiplier: f32,
475 base: f32,
476 allowed_length: i32,
477 penalty_last_n: i32,
478 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
479 ) -> Self {
480 let seq_breakers: Vec<CString> = seq_breakers
481 .into_iter()
482 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
483 .collect();
484 let mut seq_breaker_pointers: Vec<*const c_char> =
485 seq_breakers.iter().map(|s| s.as_ptr()).collect();
486
487 let sampler = unsafe {
488 llama_cpp_bindings_sys::llama_sampler_init_dry(
489 model.vocab_ptr(),
490 model
491 .n_ctx_train()
492 .try_into()
493 .expect("n_ctx_train exceeds i32::MAX"),
494 multiplier,
495 base,
496 allowed_length,
497 penalty_last_n,
498 seq_breaker_pointers.as_mut_ptr(),
499 seq_breaker_pointers.len(),
500 )
501 };
502 Self { sampler }
503 }
504
505 /// Penalizes tokens for being present in the context.
506 ///
507 /// Parameters:
508 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
509 /// - ``penalty_repeat``: 1.0 = disabled
510 /// - ``penalty_freq``: 0.0 = disabled
511 /// - ``penalty_present``: 0.0 = disabled
512 #[must_use]
513 pub fn penalties(
514 penalty_last_n: i32,
515 penalty_repeat: f32,
516 penalty_freq: f32,
517 penalty_present: f32,
518 ) -> Self {
519 let sampler = unsafe {
520 llama_cpp_bindings_sys::llama_sampler_init_penalties(
521 penalty_last_n,
522 penalty_repeat,
523 penalty_freq,
524 penalty_present,
525 )
526 };
527 Self { sampler }
528 }
529
530 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
531 ///
532 /// # Parameters:
533 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
534 /// - ``seed``: Seed to initialize random generation with.
535 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
536 /// generated text. A higher value corresponds to more surprising or less predictable text,
537 /// while a lower value corresponds to less surprising or more predictable text.
538 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
539 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
540 /// updated more quickly, while a smaller learning rate will result in slower updates.
541 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
542 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
543 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
544 /// it affects the performance of the algorithm.
545 #[must_use]
546 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
547 let sampler = unsafe {
548 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
549 };
550 Self { sampler }
551 }
552
553 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
554 ///
555 /// # Parameters:
556 /// - ``seed``: Seed to initialize random generation with.
557 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
558 /// generated text. A higher value corresponds to more surprising or less predictable text,
559 /// while a lower value corresponds to less surprising or more predictable text.
560 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
561 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
562 /// updated more quickly, while a smaller learning rate will result in slower updates.
563 #[must_use]
564 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
565 let sampler =
566 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
567 Self { sampler }
568 }
569
570 /// Selects a token at random based on each token's probabilities
571 #[must_use]
572 pub fn dist(seed: u32) -> Self {
573 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
574 Self { sampler }
575 }
576
577 /// Selects the most likely token
578 ///
579 /// # Example:
580 /// ```rust
581 /// use llama_cpp_bindings::token::{
582 /// LlamaToken,
583 /// data::LlamaTokenData,
584 /// data_array::LlamaTokenDataArray
585 /// };
586 /// use llama_cpp_bindings::sampling::LlamaSampler;
587 ///
588 /// let mut data_array = LlamaTokenDataArray::new(vec![
589 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
590 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
591 /// ], false);
592 ///
593 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
594 ///
595 /// assert_eq!(data_array.data.len(), 2);
596 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
597 /// ```
598 #[must_use]
599 pub fn greedy() -> Self {
600 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
601 Self { sampler }
602 }
603
604 /// Creates a sampler that applies bias values to specific tokens during sampling.
605 ///
606 /// # Parameters
607 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
608 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
609 ///
610 /// # Panics
611 ///
612 /// Panics if `biases.len()` exceeds `i32::MAX`.
613 ///
614 /// # Example
615 /// ```rust
616 /// use llama_cpp_bindings::token::{LlamaToken, logit_bias::LlamaLogitBias};
617 /// use llama_cpp_bindings::sampling::LlamaSampler;
618 ///
619 /// let biases = vec![
620 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
621 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
622 /// ];
623 ///
624 /// // Assuming vocab_size of 32000
625 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
626 /// ```
627 #[must_use]
628 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
629 let data = biases
630 .as_ptr()
631 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
632
633 let sampler = unsafe {
634 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(
635 n_vocab,
636 i32::try_from(biases.len()).expect("bias count exceeds i32::MAX"),
637 data,
638 )
639 };
640
641 Self { sampler }
642 }
643}
644
645impl Drop for LlamaSampler {
646 fn drop(&mut self) {
647 unsafe {
648 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
649 }
650 }
651}