1use std::borrow::Borrow;
2use std::ffi::{CString, c_char};
3use std::fmt::{Debug, Formatter};
4
5use crate::context::LlamaContext;
6use crate::ffi_error_reader::read_and_free_cpp_error;
7use crate::model::LlamaModel;
8use crate::token::LlamaToken;
9use crate::token::data_array::LlamaTokenDataArray;
10use crate::token::logit_bias::LlamaLogitBias;
11use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError};
12
13fn check_sampler_accept_status(
14 status: llama_cpp_bindings_sys::llama_rs_sampler_accept_status,
15 error_ptr: *mut c_char,
16) -> Result<(), SamplerAcceptError> {
17 match status {
18 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK => Ok(()),
19 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_ERROR_STRING_ALLOCATION_FAILED => {
20 Err(SamplerAcceptError::NotEnoughMemory)
21 }
22 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION => {
23 let message = unsafe { read_and_free_cpp_error(error_ptr) };
24 Err(SamplerAcceptError::GrammarStateCorrupted { message })
25 }
26 other => unreachable!("llama_rs_sampler_accept returned unrecognized status {other}"),
27 }
28}
29
30fn checked_u32_as_i32(value: u32) -> Result<i32, GrammarError> {
31 i32::try_from(value).map_err(|convert_error| {
32 GrammarError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
33 })
34}
35
36fn checked_usize_as_i32_sampling(value: usize) -> Result<i32, SamplingError> {
37 i32::try_from(value).map_err(|convert_error| {
38 SamplingError::IntegerOverflow(format!("value exceeds i32::MAX: {convert_error}"))
39 })
40}
41
42pub struct LlamaSampler {
43 pub sampler: *mut llama_cpp_bindings_sys::llama_sampler,
44}
45
46impl Debug for LlamaSampler {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("LlamaSamplerChain").finish()
49 }
50}
51
52impl LlamaSampler {
53 pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result<LlamaToken, SampleError> {
57 let mut token: i32 = -1;
58 let mut error_ptr: *mut c_char = std::ptr::null_mut();
59
60 let status = unsafe {
61 llama_cpp_bindings_sys::llama_rs_sampler_sample(
62 self.sampler,
63 ctx.context.as_ptr(),
64 idx,
65 &raw mut token,
66 &raw mut error_ptr,
67 )
68 };
69
70 match status {
71 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_OK => Ok(LlamaToken(token)),
72 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_ERROR_STRING_ALLOCATION_FAILED => {
73 Err(SampleError::NotEnoughMemory)
74 }
75 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_SAMPLE_VENDORED_THREW_CXX_EXCEPTION => {
76 let message = unsafe { read_and_free_cpp_error(error_ptr) };
77 Err(SampleError::Reported { message })
78 }
79 other => unreachable!("llama_rs_sampler_sample returned unrecognized status {other}"),
80 }
81 }
82
83 pub fn apply(&self, data_array: &mut LlamaTokenDataArray) {
84 data_array.apply_sampler(self);
85 }
86
87 pub fn accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
90 self.try_accept(token)
91 }
92
93 pub fn accept_many(
96 &mut self,
97 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
98 ) -> Result<(), SamplerAcceptError> {
99 for token in tokens {
100 self.try_accept(*token.borrow())?;
101 }
102
103 Ok(())
104 }
105
106 pub fn with_tokens(
109 mut self,
110 tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>,
111 ) -> Result<Self, SamplerAcceptError> {
112 self.accept_many(tokens)?;
113
114 Ok(self)
115 }
116
117 pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> {
120 let mut error_ptr: *mut c_char = std::ptr::null_mut();
121
122 let status = unsafe {
123 llama_cpp_bindings_sys::llama_rs_sampler_accept(
124 self.sampler,
125 token.0,
126 &raw mut error_ptr,
127 )
128 };
129
130 check_sampler_accept_status(status, error_ptr)
131 }
132
133 pub fn reset(&mut self) {
134 unsafe {
135 llama_cpp_bindings_sys::llama_sampler_reset(self.sampler);
136 }
137 }
138
139 #[must_use]
140 pub fn get_seed(&self) -> u32 {
141 unsafe { llama_cpp_bindings_sys::llama_sampler_get_seed(self.sampler) }
142 }
143
144 #[must_use]
145 pub fn chain(samplers: impl IntoIterator<Item = Self>, no_perf: bool) -> Self {
146 unsafe {
147 let chain = llama_cpp_bindings_sys::llama_sampler_chain_init(
148 llama_cpp_bindings_sys::llama_sampler_chain_params { no_perf },
149 );
150
151 for sampler in samplers {
152 llama_cpp_bindings_sys::llama_sampler_chain_add(chain, sampler.sampler);
153 std::mem::forget(sampler);
154 }
155
156 Self { sampler: chain }
157 }
158 }
159
160 #[must_use]
161 pub fn chain_simple(samplers: impl IntoIterator<Item = Self>) -> Self {
162 Self::chain(samplers, false)
163 }
164
165 #[must_use]
166 pub fn temp(t: f32) -> Self {
167 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp(t) };
168 Self { sampler }
169 }
170
171 #[must_use]
172 pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self {
173 let sampler =
174 unsafe { llama_cpp_bindings_sys::llama_sampler_init_temp_ext(t, delta, exponent) };
175 Self { sampler }
176 }
177
178 #[must_use]
179 pub fn top_k(k: i32) -> Self {
180 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_k(k) };
181 Self { sampler }
182 }
183
184 #[must_use]
185 pub fn top_n_sigma(n: f32) -> Self {
186 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_n_sigma(n) };
187 Self { sampler }
188 }
189
190 #[must_use]
191 pub fn typical(p: f32, min_keep: usize) -> Self {
192 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_typical(p, min_keep) };
193 Self { sampler }
194 }
195
196 #[must_use]
197 pub fn top_p(p: f32, min_keep: usize) -> Self {
198 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_top_p(p, min_keep) };
199 Self { sampler }
200 }
201
202 #[must_use]
203 pub fn min_p(p: f32, min_keep: usize) -> Self {
204 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_min_p(p, min_keep) };
205 Self { sampler }
206 }
207
208 #[must_use]
209 pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self {
210 let sampler =
211 unsafe { llama_cpp_bindings_sys::llama_sampler_init_xtc(p, t, min_keep, seed) };
212 Self { sampler }
213 }
214
215 pub fn grammar(
218 model: &LlamaModel,
219 grammar_str: &str,
220 grammar_root: &str,
221 ) -> Result<Self, GrammarError> {
222 let (grammar_str, grammar_root) =
223 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
224 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
225 let mut error_ptr: *mut c_char = std::ptr::null_mut();
226
227 let status = unsafe {
228 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar(
229 model.vocab_ptr(),
230 grammar_str.as_ptr(),
231 grammar_root.as_ptr(),
232 &raw mut sampler,
233 &raw mut error_ptr,
234 )
235 };
236
237 match status {
238 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_OK => {
239 Ok(Self { sampler })
240 }
241 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_RETURNED_NULL => {
242 Err(GrammarError::GrammarMalformed)
243 }
244 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_ERROR_STRING_ALLOCATION_FAILED => {
245 Err(GrammarError::NotEnoughMemory)
246 }
247 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_VENDORED_THREW_CXX_EXCEPTION => {
248 let message = unsafe { read_and_free_cpp_error(error_ptr) };
249 Err(GrammarError::Reported { message })
250 }
251 other => unreachable!(
252 "llama_rs_sampler_init_grammar returned unrecognized status {other}"
253 ),
254 }
255 }
256
257 pub fn grammar_lazy(
260 model: &LlamaModel,
261 grammar_str: &str,
262 grammar_root: &str,
263 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
264 trigger_tokens: &[LlamaToken],
265 ) -> Result<Self, GrammarError> {
266 let (grammar_str, grammar_root) =
267 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
268 let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
269 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
270 let mut error_ptr: *mut c_char = std::ptr::null_mut();
271
272 let mut trigger_word_ptrs: Vec<*const c_char> =
273 trigger_words.iter().map(|cs| cs.as_ptr()).collect();
274
275 let status = unsafe {
276 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy(
277 model.vocab_ptr(),
278 grammar_str.as_ptr(),
279 grammar_root.as_ptr(),
280 trigger_word_ptrs.as_mut_ptr(),
281 trigger_word_ptrs.len(),
282 trigger_tokens.as_ptr().cast(),
283 trigger_tokens.len(),
284 &raw mut sampler,
285 &raw mut error_ptr,
286 )
287 };
288
289 match status {
290 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_OK => {
291 Ok(Self { sampler })
292 }
293 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_RETURNED_NULL => {
294 Err(GrammarError::LazyGrammarMalformed)
295 }
296 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_ERROR_STRING_ALLOCATION_FAILED => {
297 Err(GrammarError::NotEnoughMemory)
298 }
299 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_VENDORED_THREW_CXX_EXCEPTION => {
300 let message = unsafe { read_and_free_cpp_error(error_ptr) };
301 Err(GrammarError::Reported { message })
302 }
303 other => unreachable!(
304 "llama_rs_sampler_init_grammar_lazy returned unrecognized status {other}"
305 ),
306 }
307 }
308
309 pub fn grammar_lazy_patterns(
312 model: &LlamaModel,
313 grammar_str: &str,
314 grammar_root: &str,
315 trigger_patterns: &[String],
316 trigger_tokens: &[LlamaToken],
317 ) -> Result<Self, GrammarError> {
318 let (grammar_str, grammar_root) =
319 Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
320 let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?;
321 let mut sampler: *mut llama_cpp_bindings_sys::llama_sampler = std::ptr::null_mut();
322 let mut error_ptr: *mut c_char = std::ptr::null_mut();
323
324 let mut trigger_pattern_ptrs: Vec<*const c_char> =
325 trigger_patterns.iter().map(|cs| cs.as_ptr()).collect();
326
327 let status = unsafe {
328 llama_cpp_bindings_sys::llama_rs_sampler_init_grammar_lazy_patterns(
329 model.vocab_ptr(),
330 grammar_str.as_ptr(),
331 grammar_root.as_ptr(),
332 trigger_pattern_ptrs.as_mut_ptr(),
333 trigger_pattern_ptrs.len(),
334 trigger_tokens.as_ptr().cast(),
335 trigger_tokens.len(),
336 &raw mut sampler,
337 &raw mut error_ptr,
338 )
339 };
340
341 match status {
342 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_OK => {
343 Ok(Self { sampler })
344 }
345 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_RETURNED_NULL => {
346 Err(GrammarError::LazyPatternsGrammarMalformed)
347 }
348 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_ERROR_STRING_ALLOCATION_FAILED => {
349 Err(GrammarError::NotEnoughMemory)
350 }
351 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_INVALID_TRIGGER_PATTERN => {
352 let message = unsafe { read_and_free_cpp_error(error_ptr) };
353 Err(GrammarError::InvalidTriggerPattern { message })
354 }
355 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_INIT_GRAMMAR_LAZY_PATTERNS_VENDORED_THREW_CXX_EXCEPTION => {
356 let message = unsafe { read_and_free_cpp_error(error_ptr) };
357 Err(GrammarError::Reported { message })
358 }
359 other => unreachable!(
360 "llama_rs_sampler_init_grammar_lazy_patterns returned unrecognized status {other}"
361 ),
362 }
363 }
364
365 pub fn llguidance(
369 model: &LlamaModel,
370 grammar_kind: &str,
371 grammar_data: &str,
372 ) -> Result<Self, GrammarError> {
373 crate::llguidance_sampler::create_llg_sampler(model, grammar_kind, grammar_data)
374 }
375
376 fn sanitize_grammar_strings(
377 grammar_str: &str,
378 grammar_root: &str,
379 ) -> Result<(CString, CString), GrammarError> {
380 if !grammar_str.contains(grammar_root) {
381 return Err(GrammarError::RootNotFound);
382 }
383
384 let grammar = CString::new(grammar_str).map_err(GrammarError::GrammarNullBytes)?;
385 let root = CString::new(grammar_root).map_err(GrammarError::GrammarNullBytes)?;
386
387 Ok((grammar, root))
388 }
389
390 fn sanitize_trigger_words(
391 trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
392 ) -> Result<Vec<CString>, GrammarError> {
393 trigger_words
394 .into_iter()
395 .map(|word| CString::new(word.as_ref()).map_err(GrammarError::TriggerWordNullBytes))
396 .collect()
397 }
398
399 fn sanitize_trigger_patterns(
400 trigger_patterns: &[String],
401 ) -> Result<Vec<CString>, GrammarError> {
402 trigger_patterns
403 .iter()
404 .map(|pattern| CString::new(pattern.as_str()).map_err(GrammarError::GrammarNullBytes))
405 .collect()
406 }
407
408 pub fn dry(
411 model: &LlamaModel,
412 multiplier: f32,
413 base: f32,
414 allowed_length: i32,
415 penalty_last_n: i32,
416 seq_breakers: impl IntoIterator<Item = impl AsRef<[u8]>>,
417 ) -> Result<Self, GrammarError> {
418 let seq_breakers: Vec<CString> = seq_breakers
419 .into_iter()
420 .map(|seq_breaker| CString::new(seq_breaker.as_ref()))
421 .collect::<Result<Vec<_>, _>>()?;
422 let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers
423 .iter()
424 .map(|seq_breaker| seq_breaker.as_ptr())
425 .collect();
426
427 let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| {
428 GrammarError::IntegerOverflow(format!(
429 "n_ctx_train does not fit into u32: {convert_error}"
430 ))
431 })?;
432 let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?;
433 let sampler = unsafe {
434 llama_cpp_bindings_sys::llama_sampler_init_dry(
435 model.vocab_ptr(),
436 n_ctx_train,
437 multiplier,
438 base,
439 allowed_length,
440 penalty_last_n,
441 seq_breaker_pointers.as_mut_ptr(),
442 seq_breaker_pointers.len(),
443 )
444 };
445
446 Ok(Self { sampler })
447 }
448
449 #[must_use]
450 pub fn penalties(
451 penalty_last_n: i32,
452 penalty_repeat: f32,
453 penalty_freq: f32,
454 penalty_present: f32,
455 ) -> Self {
456 let sampler = unsafe {
457 llama_cpp_bindings_sys::llama_sampler_init_penalties(
458 penalty_last_n,
459 penalty_repeat,
460 penalty_freq,
461 penalty_present,
462 )
463 };
464 Self { sampler }
465 }
466
467 #[must_use]
468 pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
469 let sampler = unsafe {
470 llama_cpp_bindings_sys::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
471 };
472 Self { sampler }
473 }
474
475 #[must_use]
476 pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
477 let sampler =
478 unsafe { llama_cpp_bindings_sys::llama_sampler_init_mirostat_v2(seed, tau, eta) };
479 Self { sampler }
480 }
481
482 #[must_use]
483 pub fn dist(seed: u32) -> Self {
484 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dist(seed) };
485 Self { sampler }
486 }
487
488 #[must_use]
489 pub fn greedy() -> Self {
490 let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_greedy() };
491 Self { sampler }
492 }
493
494 pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Result<Self, SamplingError> {
498 let bias_count = checked_usize_as_i32_sampling(biases.len())?;
499 let data = biases
500 .as_ptr()
501 .cast::<llama_cpp_bindings_sys::llama_logit_bias>();
502
503 let sampler = unsafe {
504 llama_cpp_bindings_sys::llama_sampler_init_logit_bias(n_vocab, bias_count, data)
505 };
506
507 Ok(Self { sampler })
508 }
509}
510
511impl Drop for LlamaSampler {
512 fn drop(&mut self) {
513 unsafe {
514 llama_cpp_bindings_sys::llama_sampler_free(self.sampler);
515 }
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use std::ffi::CString;
522 use std::mem::Discriminant;
523
524 use super::LlamaSampler;
525 use crate::GrammarError;
526
527 fn nul_error() -> std::ffi::NulError {
528 CString::new(b"a\0b".to_vec()).unwrap_err()
529 }
530
531 fn root_not_found_disc() -> Discriminant<GrammarError> {
532 std::mem::discriminant(&GrammarError::RootNotFound)
533 }
534
535 fn grammar_null_bytes_disc() -> Discriminant<GrammarError> {
536 std::mem::discriminant(&GrammarError::GrammarNullBytes(nul_error()))
537 }
538
539 fn trigger_word_null_bytes_disc() -> Discriminant<GrammarError> {
540 std::mem::discriminant(&GrammarError::TriggerWordNullBytes(nul_error()))
541 }
542
543 #[test]
544 fn sanitize_grammar_strings_valid() {
545 let result = LlamaSampler::sanitize_grammar_strings("root ::= \"hello\"", "root");
546
547 assert!(result.is_ok());
548 }
549
550 #[test]
551 fn sanitize_grammar_strings_root_not_found() {
552 let err = LlamaSampler::sanitize_grammar_strings("expr ::= \"hello\"", "root").unwrap_err();
553
554 assert_eq!(std::mem::discriminant(&err), root_not_found_disc());
555 }
556
557 #[test]
558 fn sanitize_grammar_strings_null_byte_in_grammar() {
559 let err = LlamaSampler::sanitize_grammar_strings("root ::= \"\0\"", "root").unwrap_err();
560
561 assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
562 }
563
564 #[test]
565 fn sanitize_grammar_strings_null_byte_in_root() {
566 let err =
567 LlamaSampler::sanitize_grammar_strings("ro\0ot ::= \"hello\"", "ro\0ot").unwrap_err();
568
569 assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
570 }
571
572 #[test]
573 fn sanitize_trigger_words_valid() {
574 let words: Vec<&[u8]> = vec![b"hello", b"world"];
575 let result = LlamaSampler::sanitize_trigger_words(words);
576
577 assert!(result.is_ok());
578 assert_eq!(result.expect("valid trigger words").len(), 2);
579 }
580
581 #[test]
582 fn sanitize_trigger_words_empty_list() {
583 let words: Vec<&[u8]> = vec![];
584 let result = LlamaSampler::sanitize_trigger_words(words);
585
586 assert!(result.is_ok());
587 assert!(result.expect("valid trigger words").is_empty());
588 }
589
590 #[test]
591 fn sanitize_trigger_words_null_byte() {
592 let words: Vec<&[u8]> = vec![b"hel\0lo"];
593 let err = LlamaSampler::sanitize_trigger_words(words).unwrap_err();
594
595 assert_eq!(std::mem::discriminant(&err), trigger_word_null_bytes_disc());
596 }
597
598 #[test]
599 fn sanitize_trigger_patterns_valid() {
600 let patterns = vec!["^hello$".to_string(), "world.*".to_string()];
601 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
602
603 assert!(result.is_ok());
604 assert_eq!(result.expect("valid trigger patterns").len(), 2);
605 }
606
607 #[test]
608 fn sanitize_trigger_patterns_empty_list() {
609 let patterns: Vec<String> = vec![];
610 let result = LlamaSampler::sanitize_trigger_patterns(&patterns);
611
612 assert!(result.is_ok());
613 assert!(result.expect("valid trigger patterns").is_empty());
614 }
615
616 #[test]
617 fn sanitize_trigger_patterns_null_byte() {
618 let patterns = vec!["hel\0lo".to_string()];
619 let err = LlamaSampler::sanitize_trigger_patterns(&patterns).unwrap_err();
620
621 assert_eq!(std::mem::discriminant(&err), grammar_null_bytes_disc());
622 }
623
624 #[test]
625 fn apply_modifies_data_array() {
626 use crate::token::LlamaToken;
627 use crate::token::data::LlamaTokenData;
628 use crate::token::data_array::LlamaTokenDataArray;
629
630 let sampler = LlamaSampler::greedy();
631 let mut data_array = LlamaTokenDataArray::new(
632 vec![
633 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
634 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
635 ],
636 false,
637 );
638
639 sampler.apply(&mut data_array);
640
641 assert_eq!(data_array.selected_token(), Some(LlamaToken::new(1)));
642 }
643
644 #[test]
645 fn accept_succeeds() {
646 let mut sampler = LlamaSampler::chain_simple([
647 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
648 LlamaSampler::greedy(),
649 ]);
650
651 sampler
652 .accept(crate::token::LlamaToken::new(1))
653 .expect("test: accept should succeed");
654 }
655
656 #[test]
657 fn try_accept_succeeds_on_penalties_sampler() {
658 let mut sampler = LlamaSampler::chain_simple([
659 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
660 LlamaSampler::greedy(),
661 ]);
662
663 let result = sampler.try_accept(crate::token::LlamaToken::new(42));
664
665 assert!(result.is_ok());
666 }
667
668 #[test]
669 fn accept_many_multiple_tokens() {
670 use crate::token::LlamaToken;
671
672 let mut sampler = LlamaSampler::chain_simple([
673 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
674 LlamaSampler::greedy(),
675 ]);
676
677 sampler
678 .accept_many([LlamaToken::new(1), LlamaToken::new(2), LlamaToken::new(3)])
679 .expect("test: accept_many should succeed");
680 }
681
682 #[test]
683 fn with_tokens_builder_pattern() {
684 use crate::token::LlamaToken;
685
686 let _sampler = LlamaSampler::chain_simple([
687 LlamaSampler::penalties(64, 1.1, 0.0, 0.0),
688 LlamaSampler::greedy(),
689 ])
690 .with_tokens([LlamaToken::new(10), LlamaToken::new(20)])
691 .expect("test: with_tokens should succeed");
692 }
693
694 #[test]
695 fn all_sampler_constructors() {
696 use crate::token::LlamaToken;
697 use crate::token::logit_bias::LlamaLogitBias;
698
699 let _temp = LlamaSampler::temp(0.8);
700 let _temp_ext = LlamaSampler::temp_ext(0.8, 0.1, 1.0);
701 let _top_k = LlamaSampler::top_k(40);
702 let _top_n_sigma = LlamaSampler::top_n_sigma(2.0);
703 let _top_p = LlamaSampler::top_p(0.9, 1);
704 let _min_p = LlamaSampler::min_p(0.05, 1);
705 let _typical = LlamaSampler::typical(0.9, 1);
706 let _xtc = LlamaSampler::xtc(0.1, 0.5, 1, 42);
707 let _dist = LlamaSampler::dist(42);
708 let _mirostat = LlamaSampler::mirostat(32000, 42, 5.0, 0.1, 100);
709 let _mirostat_v2 = LlamaSampler::mirostat_v2(42, 5.0, 0.1);
710 let biases = vec![LlamaLogitBias::new(LlamaToken::new(0), -100.0)];
711 let _logit_bias = LlamaSampler::logit_bias(32000, &biases);
712 let _chain = LlamaSampler::chain([LlamaSampler::greedy()], true);
713 }
714
715 #[test]
716 fn reset_and_get_seed() {
717 let mut sampler = LlamaSampler::dist(42);
718 sampler.reset();
719 let _seed = sampler.get_seed();
720 }
721
722 #[test]
723 fn debug_formatting() {
724 let sampler = LlamaSampler::greedy();
725 let debug_output = format!("{sampler:?}");
726 assert!(debug_output.contains("LlamaSampler"));
727 }
728
729 #[test]
730 fn checked_u32_as_i32_overflow() {
731 let result = super::checked_u32_as_i32(u32::MAX);
732 assert!(result.is_err());
733 }
734
735 #[test]
736 fn checked_usize_as_i32_sampling_overflow() {
737 let result = super::checked_usize_as_i32_sampling(usize::MAX);
738 assert!(result.is_err());
739 }
740
741 #[test]
742 fn check_sampler_accept_status_ok() {
743 let result = super::check_sampler_accept_status(
744 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_OK,
745 std::ptr::null_mut(),
746 );
747
748 assert!(result.is_ok());
749 }
750
751 #[test]
752 fn check_sampler_accept_status_exception_maps_to_typed_variant() {
753 let err = super::check_sampler_accept_status(
754 llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_ACCEPT_VENDORED_THREW_CXX_EXCEPTION,
755 std::ptr::null_mut(),
756 )
757 .unwrap_err();
758 let grammar_state_corrupted_disc =
759 std::mem::discriminant(&crate::SamplerAcceptError::GrammarStateCorrupted {
760 message: String::new(),
761 });
762
763 assert_eq!(std::mem::discriminant(&err), grammar_state_corrupted_disc);
764 }
765}