llama_cpp_2/sampling.rs
1//! Safe wrapper around `llama_sampler`.
2
3use std::borrow::Borrow;
4use std::ffi::{c_char, CString};
5use std::fmt::{Debug, Formatter};
6
7use crate::context::LlamaContext;
8use crate::model::LlamaModel;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::token::LlamaToken;
12use crate::{status_is_ok, status_to_i32, GrammarError, SamplerAcceptError};
13
14/// A safe wrapper around `llama_sampler`.
15pub struct LlamaSampler {
16 pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler,
17}
18
19impl Debug for LlamaSampler {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("LlamaSamplerChain").finish()
22 }
23}
24
25impl LlamaSampler {
26 /// Sample and accept a token from the idx-th output of the last evaluation
27 #[must_use]
28 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
29 let token = unsafe {
30 llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx)
31 };
32
33 LlamaToken(token)
34 }
35
36 /// Applies this sampler to a [`LlamaTokenDataArray`].
37 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
38 data_array.apply_sampler(self);
39 }
40
41 /// Accepts a token from the sampler, possibly updating the internal state of certain samplers
42 /// (e.g. grammar, repetition, etc.)
43 pub fn accept(&mut self, token: LlamaToken) {
44 let _ = self.try_accept(token);
45 }
46
47 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
48 /// certain samplers (e.g. grammar, repetition, etc.)
49 pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
50 for token in tokens {
51 let _ = self.try_accept(*token.borrow());
52 }
53 }
54
55 /// Accepts several tokens from the sampler or context, possibly updating the internal state of
56 /// certain samplers (e.g. grammar, repetition, etc.)
57 #[must_use]
58 pub fn with_tokens(
59 mut self,
60 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
61 ) -> Self {
62 self.accept_many(tokens);
63 self
64 }
65
66 /// Try accepting a token from the sampler. Returns an error if the sampler throws.
67 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
68 let sampler_result =
69 unsafe { llama_cpp_sys_2::llama_rs_sampler_accept(self.sampler, token.0) };
70 if status_is_ok(sampler_result) {
71 Ok(())
72 } else {
73 Err(SamplerAcceptError::FfiError(status_to_i32(sampler_result)))
74 }
75 }
76
77 /// Resets the internal state of the sampler.
78 ///
79 /// This can be useful when you want to start fresh with a sampler without creating a new instance.
80 pub fn reset(&mut self) {
81 unsafe {
82 llama_cpp_sys_2::llama_sampler_reset(self.sampler);
83 }
84 }
85
86 /// Gets the random seed used by this sampler.
87 ///
88 /// Returns:
89 /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
90 /// - For sampler chains: returns the first non-default seed found in reverse order
91 /// - For all other samplers: returns 0xFFFFFFFF
92 #[must_use]
93 pub fn get_seed(&self) -> u32 {
94 unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
95 }
96
97 /// Combines a list of samplers into a single sampler that applies each component sampler one
98 /// after another.
99 ///
100 /// If you are using a chain to select a token, the chain should always end with one of
101 /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and
102 /// [`LlamaSampler::mirostat_v2`].
103 #[must_use]
104 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
105 unsafe {
106 let chain = llama_cpp_sys_2::llama_sampler_chain_init(
107 llama_cpp_sys_2::llama_sampler_chain_params { no_perf },
108 );
109
110 for sampler in samplers {
111 llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler);
112
113 // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now
114 // owned by the chain
115 std::mem::forget(sampler);
116 }
117
118 Self { sampler: chain }
119 }
120 }
121
122 /// Same as [`Self::chain`] with `no_perf = false`.
123 ///
124 /// # Example
125 /// ```rust
126 /// use llama_cpp_2::token::{
127 /// LlamaToken,
128 /// data::LlamaTokenData,
129 /// data_array::LlamaTokenDataArray
130 /// };
131 /// use llama_cpp_2::sampling::LlamaSampler;
132 /// use llama_cpp_2::llama_backend::LlamaBackend;
133 /// let backend = LlamaBackend::init().unwrap();
134 ///
135 /// let mut data_array = LlamaTokenDataArray::new(vec![
136 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
137 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
138 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
139 /// ], false);
140 ///
141 /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([
142 /// LlamaSampler::temp(0.5),
143 /// LlamaSampler::greedy(),
144 /// ]));
145 ///
146 /// assert_eq!(data_array.data[0].logit(), 0.);
147 /// assert_eq!(data_array.data[1].logit(), 2.);
148 /// assert_eq!(data_array.data[2].logit(), 4.);
149 ///
150 /// assert_eq!(data_array.data.len(), 3);
151 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2)));
152 /// ```
153 #[must_use]
154 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
155 Self::chain(samplers, false)
156 }
157
158 #[allow(clippy::doc_markdown)]
159 /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
160 /// value, the rest are set to -inf
161 ///
162 /// # Example:
163 /// ```rust
164 /// use llama_cpp_2::token::{
165 /// LlamaToken,
166 /// data::LlamaTokenData,
167 /// data_array::LlamaTokenDataArray
168 /// };
169 /// use llama_cpp_2::sampling::LlamaSampler;
170 ///
171 /// let mut data_array = LlamaTokenDataArray::new(vec![
172 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
173 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
174 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
175 /// ], false);
176 ///
177 /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5));
178 ///
179 /// assert_eq!(data_array.data[0].logit(), 0.);
180 /// assert_eq!(data_array.data[1].logit(), 2.);
181 /// assert_eq!(data_array.data[2].logit(), 4.);
182 /// ```
183 #[must_use]
184 pub fn temp(t: f32) -> Self {
185 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) };
186 Self { sampler }
187 }
188
189 /// Dynamic temperature implementation (a.k.a. entropy) described in the paper
190 /// <https://arxiv.org/abs/2309.02772>.
191 #[must_use]
192 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
193 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) };
194 Self { sampler }
195 }
196
197 /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration"
198 /// <https://arxiv.org/abs/1904.09751>
199 ///
200 /// # Example:
201 /// ```rust
202 /// use llama_cpp_2::token::{
203 /// LlamaToken,
204 /// data::LlamaTokenData,
205 /// data_array::LlamaTokenDataArray
206 /// };
207 /// use llama_cpp_2::sampling::LlamaSampler;
208 ///
209 /// let mut data_array = LlamaTokenDataArray::new(vec![
210 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
211 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
212 /// LlamaTokenData::new(LlamaToken(2), 2., 0.),
213 /// LlamaTokenData::new(LlamaToken(3), 3., 0.),
214 /// ], false);
215 ///
216 /// data_array.apply_sampler(&mut LlamaSampler::top_k(2));
217 ///
218 /// assert_eq!(data_array.data.len(), 2);
219 /// assert_eq!(data_array.data[0].id(), LlamaToken(3));
220 /// assert_eq!(data_array.data[1].id(), LlamaToken(2));
221 /// ```
222 #[must_use]
223 pub fn top_k(k: i32) -> Self {
224 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) };
225 Self { sampler }
226 }
227
228 /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
229 /// <https://arxiv.org/pdf/2411.07641>
230 ///
231 /// This method filters logits by selecting only those within *n* standard deviations of the mean.
232 ///
233 /// # Parameters
234 /// - `n`: Number of standard deviations from the mean to include in sampling
235 ///
236 /// # Example
237 /// ```rust
238 /// use llama_cpp_2::sampling::LlamaSampler;
239 /// use llama_cpp_2::token::{
240 /// LlamaToken,
241 /// data::LlamaTokenData,
242 /// data_array::LlamaTokenDataArray
243 /// };
244 ///
245 /// let mut data_array = LlamaTokenDataArray::new(vec![
246 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
247 /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
248 /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
249 /// ], false);
250 ///
251 /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
252 /// ```
253 #[must_use]
254 pub fn top_n_sigma(n: f32) -> Self {
255 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
256 Self { sampler }
257 }
258
259 /// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
260 #[must_use]
261 pub fn typical(p: f32, min_keep: usize) -> Self {
262 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) };
263 Self { sampler }
264 }
265
266 /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration"
267 /// <https://arxiv.org/abs/1904.09751>
268 #[must_use]
269 pub fn top_p(p: f32, min_keep: usize) -> Self {
270 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) };
271 Self { sampler }
272 }
273
274 /// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
275 #[must_use]
276 pub fn min_p(p: f32, min_keep: usize) -> Self {
277 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) };
278 Self { sampler }
279 }
280
281 /// XTC sampler as described in <https://github.com/oobabooga/text-generation-webui/pull/6335>
282 #[must_use]
283 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
284 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) };
285 Self { sampler }
286 }
287
288 /// Grammar sampler
289 #[must_use]
290 pub fn grammar(
291 model: &LlamaModel,
292 grammar_str: &str,
293 grammar_root: &str,
294 ) -> Result<Self, GrammarError> {
295 let (grammar_str, grammar_root) =
296 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
297
298 let sampler = unsafe {
299 llama_cpp_sys_2::llama_rs_sampler_init_grammar(
300 model.vocab_ptr(),
301 grammar_str.as_ptr(),
302 grammar_root.as_ptr(),
303 )
304 };
305
306 if sampler.is_null() {
307 Err(GrammarError::NullGrammar)
308 } else {
309 Ok(Self { sampler })
310 }
311 }
312
313 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
314 ///
315 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
316 #[must_use]
317 pub fn grammar_lazy(
318 model: &LlamaModel,
319 grammar_str: &str,
320 grammar_root: &str,
321 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
322 trigger_tokens: &[LlamaToken],
323 ) -> Result<Self, GrammarError> {
324 let (grammar_str, grammar_root) =
325 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
326 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
327
328 let mut trigger_word_ptrs: Vec<*const c_char> =
329 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
330
331 let sampler = unsafe {
332 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy(
333 model.vocab_ptr(),
334 grammar_str.as_ptr(),
335 grammar_root.as_ptr(),
336 trigger_word_ptrs.as_mut_ptr(),
337 trigger_word_ptrs.len(),
338 trigger_tokens.as_ptr().cast(),
339 trigger_tokens.len(),
340 )
341 };
342
343 if sampler.is_null() {
344 Err(GrammarError::NullGrammar)
345 } else {
346 Ok(Self { sampler })
347 }
348 }
349
350 /// Lazy grammar sampler using regex trigger patterns.
351 ///
352 /// Trigger patterns are regular expressions matched from the start of the
353 /// generation output. The grammar sampler will be fed content starting from
354 /// the first match group.
355 #[must_use]
356 pub fn grammar_lazy_patterns(
357 model: &LlamaModel,
358 grammar_str: &str,
359 grammar_root: &str,
360 trigger_patterns: &[String],
361 trigger_tokens: &[LlamaToken],
362 ) -> Result<Self, GrammarError> {
363 let (grammar_str, grammar_root) =
364 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
365 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
366
367 let mut trigger_pattern_ptrs: Vec<*const c_char> =
368 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
369
370 let sampler = unsafe {
371 llama_cpp_sys_2::llama_rs_sampler_init_grammar_lazy_patterns(
372 model.vocab_ptr(),
373 grammar_str.as_ptr(),
374 grammar_root.as_ptr(),
375 trigger_pattern_ptrs.as_mut_ptr(),
376 trigger_pattern_ptrs.len(),
377 trigger_tokens.as_ptr().cast(),
378 trigger_tokens.len(),
379 )
380 };
381
382 if sampler.is_null() {
383 Err(GrammarError::NullGrammar)
384 } else {
385 Ok(Self { sampler })
386 }
387 }
388
389 fn sanitize_grammar_strings(
390 grammar_str: &str,
391 grammar_root: &str,
392 ) -> Result<(CString, CString), GrammarError> {
393 if !grammar_str.contains(grammar_root) {
394 return Err(GrammarError::RootNotFound);
395 }
396
397 if grammar_str.contains('\0') || grammar_root.contains('\0') {
398 return Err(GrammarError::GrammarNullBytes);
399 }
400
401 Ok((
402 CString::new(grammar_str).unwrap(),
403 CString::new(grammar_root).unwrap(),
404 ))
405 }
406
407 fn sanitize_trigger_words(
408 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
409 ) -> Result<Vec<CString>, GrammarError> {
410 let trigger_words: Vec<_> = trigger_words.into_iter().collect();
411 if trigger_words
412 .iter()
413 .any(|word| word.as_ref().contains(&b'\0'))
414 {
415 return Err(GrammarError::TriggerWordNullBytes);
416 }
417 Ok(trigger_words
418 .into_iter()
419 .map(|word| CString::new(word.as_ref()).unwrap())
420 .collect())
421 }
422
423 fn sanitize_trigger_patterns(
424 trigger_patterns: &[String],
425 ) -> Result<Vec<CString>, GrammarError> {
426 let mut patterns = Vec::with_capacity(trigger_patterns.len());
427 for pattern in trigger_patterns {
428 if pattern.contains('\0') {
429 return Err(GrammarError::GrammarNullBytes);
430 }
431 patterns.push(CString::new(pattern.as_str()).unwrap());
432 }
433 Ok(patterns)
434 }
435
436 /// DRY sampler, designed by p-e-w, as described in:
437 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
438 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
439 ///
440 /// # Panics
441 /// If any string in ``seq_breakers`` contains null bytes.
442 #[allow(missing_docs)]
443 #[must_use]
444 pub fn dry(
445 model: &LlamaModel,
446 multiplier: f32,
447 base: f32,
448 allowed_length: i32,
449 penalty_last_n: i32,
450 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
451 ) -> Self {
452 let seq_breakers: Vec<CString> = seq_breakers
453 .into_iter()
454 .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
455 .collect();
456 let mut seq_breaker_pointers: Vec<*const c_char> =
457 seq_breakers.iter().map(|s| s.as_ptr()).collect();
458
459 let sampler = unsafe {
460 llama_cpp_sys_2::llama_sampler_init_dry(
461 model.vocab_ptr(),
462 model
463 .n_ctx_train()
464 .try_into()
465 .expect("n_ctx_train exceeds i32::MAX"),
466 multiplier,
467 base,
468 allowed_length,
469 penalty_last_n,
470 seq_breaker_pointers.as_mut_ptr(),
471 seq_breaker_pointers.len(),
472 )
473 };
474 Self { sampler }
475 }
476
477 /// Penalizes tokens for being present in the context.
478 ///
479 /// Parameters:
480 /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
481 /// - ``penalty_repeat``: 1.0 = disabled
482 /// - ``penalty_freq``: 0.0 = disabled
483 /// - ``penalty_present``: 0.0 = disabled
484 #[allow(clippy::too_many_arguments)]
485 #[must_use]
486 pub fn penalties(
487 penalty_last_n: i32,
488 penalty_repeat: f32,
489 penalty_freq: f32,
490 penalty_present: f32,
491 ) -> Self {
492 let sampler = unsafe {
493 llama_cpp_sys_2::llama_sampler_init_penalties(
494 penalty_last_n,
495 penalty_repeat,
496 penalty_freq,
497 penalty_present,
498 )
499 };
500 Self { sampler }
501 }
502
503 /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
504 ///
505 /// # Parameters:
506 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
507 /// - ``seed``: Seed to initialize random generation with.
508 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
509 /// generated text. A higher value corresponds to more surprising or less predictable text,
510 /// while a lower value corresponds to less surprising or more predictable text.
511 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
512 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
513 /// updated more quickly, while a smaller learning rate will result in slower updates.
514 /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
515 /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
516 /// In the paper, they use `m = 100`, but you can experiment with different values to see how
517 /// it affects the performance of the algorithm.
518 #[must_use]
519 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
520 let sampler =
521 unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) };
522 Self { sampler }
523 }
524
525 /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
526 ///
527 /// # Parameters:
528 /// - ``seed``: Seed to initialize random generation with.
529 /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
530 /// generated text. A higher value corresponds to more surprising or less predictable text,
531 /// while a lower value corresponds to less surprising or more predictable text.
532 /// - ``eta``: The learning rate used to update `mu` based on the error between the target and
533 /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
534 /// updated more quickly, while a smaller learning rate will result in slower updates.
535 #[must_use]
536 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
537 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
538 Self { sampler }
539 }
540
541 /// Selects a token at random based on each token's probabilities
542 #[must_use]
543 pub fn dist(seed: u32) -> Self {
544 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) };
545 Self { sampler }
546 }
547
548 /// Selects the most likely token
549 ///
550 /// # Example:
551 /// ```rust
552 /// use llama_cpp_2::token::{
553 /// LlamaToken,
554 /// data::LlamaTokenData,
555 /// data_array::LlamaTokenDataArray
556 /// };
557 /// use llama_cpp_2::sampling::LlamaSampler;
558 ///
559 /// let mut data_array = LlamaTokenDataArray::new(vec![
560 /// LlamaTokenData::new(LlamaToken(0), 0., 0.),
561 /// LlamaTokenData::new(LlamaToken(1), 1., 0.),
562 /// ], false);
563 ///
564 /// data_array.apply_sampler(&mut LlamaSampler::greedy());
565 ///
566 /// assert_eq!(data_array.data.len(), 2);
567 /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1)));
568 /// ```
569 #[must_use]
570 pub fn greedy() -> Self {
571 let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
572 Self { sampler }
573 }
574
575 /// Creates a sampler that applies bias values to specific tokens during sampling.
576 ///
577 /// # Parameters
578 /// - ``n_vocab``: [`LlamaModel::n_vocab`]
579 /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
580 ///
581 /// # Example
582 /// ```rust
583 /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
584 /// use llama_cpp_2::sampling::LlamaSampler;
585 ///
586 /// let biases = vec![
587 /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
588 /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
589 /// ];
590 ///
591 /// // Assuming vocab_size of 32000
592 /// let sampler = LlamaSampler::logit_bias(32000, &biases);
593 /// ```
594 #[must_use]
595 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
596 let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
597
598 let sampler = unsafe {
599 llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data)
600 };
601
602 Self { sampler }
603 }
604}
605
606impl Drop for LlamaSampler {
607 fn drop(&mut self) {
608 unsafe {
609 llama_cpp_sys_2::llama_sampler_free(self.sampler);
610 }
611 }
612}