1pub mod advanced;
7pub mod chain;
8pub mod grammar;
9
10use std::sync::Arc;
11
12use serde::{Deserialize, Serialize};
13
14use grammar::{apply_grammar_mask, Grammar, GrammarState};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SamplerConfig {
19 pub temperature: f32,
21 pub top_k: usize,
23 pub top_p: f32,
25 pub min_p: f32,
27 pub repetition_penalty: f32,
29 pub repetition_penalty_window: usize,
31 pub seed: Option<u64>,
33 pub mirostat: u8,
35 pub mirostat_tau: f32,
38 pub mirostat_eta: f32,
40
41 #[serde(skip)]
45 pub grammar: Option<Arc<Grammar>>,
46
47 #[serde(skip)]
50 #[allow(clippy::type_complexity)]
51 pub token_vocab: Option<Arc<Vec<(u32, Vec<u8>)>>>,
52
53 #[serde(default)]
61 pub logit_bias: std::collections::HashMap<u32, f32>,
62
63 #[serde(default)]
70 pub banned_tokens: Vec<u32>,
71
72 #[serde(default)]
78 pub dry_multiplier: f32,
79
80 #[serde(default = "dry_base_default")]
84 pub dry_base: f32,
85
86 #[serde(default = "dry_allowed_length_default")]
88 pub dry_allowed_length: usize,
89
90 #[serde(default)]
95 pub xtc_threshold: f32,
96
97 #[serde(default = "xtc_probability_default")]
99 pub xtc_probability: f32,
100
101 #[serde(default = "typical_p_default")]
106 pub typical_p: f32,
107
108 #[serde(default)]
112 pub top_a: f32,
113
114 #[serde(default)]
118 pub eta_cutoff: f32,
119
120 #[serde(default)]
122 pub epsilon_cutoff: f32,
123}
124
125fn dry_base_default() -> f32 {
127 1.75
128}
129fn dry_allowed_length_default() -> usize {
130 2
131}
132fn xtc_probability_default() -> f32 {
133 0.5
134}
135fn typical_p_default() -> f32 {
136 1.0
137}
138
139impl Default for SamplerConfig {
140 fn default() -> Self {
141 Self {
142 temperature: 0.7,
143 top_k: 40,
144 top_p: 0.9,
145 min_p: 0.0,
146 repetition_penalty: 1.1,
147 repetition_penalty_window: 64,
148 seed: None,
149 mirostat: 0,
150 mirostat_tau: 5.0,
151 mirostat_eta: 0.1,
152 grammar: None,
153 token_vocab: None,
154 logit_bias: std::collections::HashMap::new(),
155 banned_tokens: Vec::new(),
156 dry_multiplier: 0.0,
158 dry_base: 1.75,
159 dry_allowed_length: 2,
160 xtc_threshold: 0.0,
161 xtc_probability: 0.5,
162 typical_p: 1.0,
163 top_a: 0.0,
164 eta_cutoff: 0.0,
165 epsilon_cutoff: 0.0,
166 }
167 }
168}
169
170impl SamplerConfig {
171 pub fn greedy() -> Self {
173 Self {
174 temperature: 0.0,
175 top_k: 1,
176 top_p: 1.0,
177 min_p: 0.0,
178 repetition_penalty: 1.0,
179 repetition_penalty_window: 0,
180 seed: None,
181 mirostat: 0,
182 mirostat_tau: 5.0,
183 mirostat_eta: 0.1,
184 grammar: None,
185 token_vocab: None,
186 logit_bias: std::collections::HashMap::new(),
187 banned_tokens: Vec::new(),
188 dry_multiplier: 0.0,
189 dry_base: 1.75,
190 dry_allowed_length: 2,
191 xtc_threshold: 0.0,
192 xtc_probability: 0.5,
193 typical_p: 1.0,
194 top_a: 0.0,
195 eta_cutoff: 0.0,
196 epsilon_cutoff: 0.0,
197 }
198 }
199
200 pub fn mirostat_v2(tau: f32, eta: f32) -> Self {
202 Self {
203 temperature: 1.0,
204 mirostat: 2,
205 mirostat_tau: tau,
206 mirostat_eta: eta,
207 top_k: 0,
208 top_p: 1.0,
209 min_p: 0.0,
210 repetition_penalty: 1.0,
211 repetition_penalty_window: 0,
212 seed: None,
213 grammar: None,
214 token_vocab: None,
215 logit_bias: std::collections::HashMap::new(),
216 banned_tokens: Vec::new(),
217 dry_multiplier: 0.0,
218 dry_base: 1.75,
219 dry_allowed_length: 2,
220 xtc_threshold: 0.0,
221 xtc_probability: 0.5,
222 typical_p: 1.0,
223 top_a: 0.0,
224 eta_cutoff: 0.0,
225 epsilon_cutoff: 0.0,
226 }
227 }
228}
229
230pub struct Sampler {
232 config: SamplerConfig,
233 rng: Xorshift64,
234 mirostat_mu: f32,
237 grammar_state: Option<GrammarState>,
239}
240
241impl Sampler {
242 pub fn new(config: SamplerConfig) -> Self {
244 let seed = config.seed.unwrap_or_else(|| {
245 let mut s = 0x517cc1b727220a95u64;
248 s ^= (&s as *const u64 as u64).wrapping_mul(0x9e3779b97f4a7c15);
250 s ^ s.wrapping_shr(33)
251 });
252 let mirostat_mu = 2.0 * config.mirostat_tau;
253 let grammar_state = config.grammar.as_ref().map(|g| g.initial_state());
254 Self {
255 config,
256 rng: Xorshift64::new(seed),
257 mirostat_mu,
258 grammar_state,
259 }
260 }
261
262 pub fn sample(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
264 let token = if self.config.mirostat == 2 {
265 self.sample_mirostat_v2(logits, recent_tokens)
266 } else {
267 sample_with_rng(
268 logits,
269 &self.config,
270 recent_tokens,
271 &mut self.rng,
272 self.grammar_state.as_ref(),
273 )
274 };
275
276 if let Some(state) = &mut self.grammar_state {
279 if let Some(vocab) = &self.config.token_vocab {
280 if let Ok(idx) = vocab.binary_search_by_key(&token, |&(id, _)| id) {
281 let bytes = vocab[idx].1.clone();
282 let _ = state.advance(&bytes);
285 }
286 }
287 }
288
289 token
290 }
291
292 pub fn reset_grammar(&mut self) {
294 self.grammar_state = self.config.grammar.as_ref().map(|g| g.initial_state());
295 }
296
297 pub fn grammar_complete(&self) -> bool {
299 self.grammar_state
300 .as_ref()
301 .is_none_or(GrammarState::is_complete)
302 }
303
304 fn sample_mirostat_v2(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
310 if logits.is_empty() {
311 return 0;
312 }
313
314 let mut processed = logits.to_vec();
315
316 apply_logit_bias_and_banned_tokens(&mut processed, &self.config);
319
320 apply_repetition_penalty(&mut processed, &self.config, recent_tokens);
322
323 if let (Some(state), Some(vocab)) = (&self.grammar_state, &self.config.token_vocab) {
327 apply_grammar_mask(&mut processed, state, vocab.as_ref());
328 }
329
330 if self.config.temperature > 0.0 && self.config.temperature != 1.0 {
332 let inv_temp = 1.0 / self.config.temperature;
333 for val in &mut processed {
334 *val *= inv_temp;
335 }
336 }
337
338 let mut candidates: Vec<(u32, f32)> = processed
340 .iter()
341 .enumerate()
342 .map(|(i, &v)| (i as u32, v))
343 .collect();
344 candidates
345 .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
346
347 softmax_candidates(&mut candidates);
349
350 let mu = self.mirostat_mu;
354 candidates.retain(|&(_, p)| {
355 if p <= 0.0 {
356 return false;
357 }
358 let surprise = -p.log2();
359 surprise <= mu
360 });
361
362 if candidates.is_empty() {
364 let token = argmax(&processed);
365 let top_prob = softmax_single_max(&processed);
367 let surprise = if top_prob > 0.0 {
368 -top_prob.log2()
369 } else {
370 self.config.mirostat_tau
371 };
372 self.mirostat_mu =
373 mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
374 return token;
375 }
376
377 let total: f32 = candidates.iter().map(|(_, p)| p).sum();
379 if total > 0.0 && total != 1.0 {
380 for (_, p) in &mut candidates {
381 *p /= total;
382 }
383 }
384
385 let r = self.rng.next_f32();
387 let mut cumulative = 0.0f32;
388 let mut selected_idx = candidates[0].0;
389 let mut selected_prob = candidates[0].1 * total; for &(idx, prob) in &candidates {
391 cumulative += prob;
392 if r < cumulative {
393 selected_idx = idx;
394 selected_prob = prob * total;
395 break;
396 }
397 }
398
399 let surprise = if selected_prob > 0.0 {
401 -selected_prob.log2()
402 } else {
403 self.config.mirostat_tau
404 };
405 self.mirostat_mu = mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
406
407 selected_idx
408 }
409
410 pub fn config(&self) -> &SamplerConfig {
412 &self.config
413 }
414
415 pub fn rng_state(&self) -> u64 {
417 self.rng.state_value()
418 }
419
420 pub fn mirostat_mu_value(&self) -> f32 {
422 self.mirostat_mu
423 }
424
425 pub fn restore_rng_state(&mut self, state: u64, mu: f32) {
427 self.rng = Xorshift64::from_state_value(state);
428 self.mirostat_mu = mu;
429 }
430}
431
432pub fn sample(logits: &[f32], config: &SamplerConfig, recent_tokens: &[u32]) -> u32 {
446 if logits.is_empty() {
447 return 0;
448 }
449
450 let seed = config.seed.unwrap_or(0xDEADBEEF_CAFEBABE);
453 let mut rng = Xorshift64::new(seed);
454 sample_with_rng(logits, config, recent_tokens, &mut rng, None)
455}
456
457fn sample_with_rng(
459 logits: &[f32],
460 config: &SamplerConfig,
461 recent_tokens: &[u32],
462 rng: &mut Xorshift64,
463 grammar_state: Option<&GrammarState>,
464) -> u32 {
465 if logits.is_empty() {
466 return 0;
467 }
468
469 let mut processed = logits.to_vec();
470
471 apply_logit_bias_and_banned_tokens(&mut processed, config);
475
476 apply_repetition_penalty(&mut processed, config, recent_tokens);
478
479 if let (Some(state), Some(vocab)) = (grammar_state, &config.token_vocab) {
482 apply_grammar_mask(&mut processed, state, vocab.as_ref());
483 }
484
485 if config.temperature <= 0.0 || config.top_k == 1 {
487 return argmax(&processed);
488 }
489
490 if config.temperature != 1.0 {
492 let inv_temp = 1.0 / config.temperature;
493 for val in &mut processed {
494 *val *= inv_temp;
495 }
496 }
497
498 let mut candidates: Vec<(u32, f32)> = processed
500 .iter()
501 .enumerate()
502 .map(|(i, &v)| (i as u32, v))
503 .collect();
504 candidates.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
505
506 if config.top_k > 0 && config.top_k < candidates.len() {
508 candidates.truncate(config.top_k);
509 }
510
511 softmax_candidates(&mut candidates);
513
514 if config.min_p > 0.0 && !candidates.is_empty() {
516 let max_prob = candidates[0].1; let threshold = config.min_p * max_prob;
518 candidates.retain(|&(_, p)| p >= threshold);
519 }
520
521 if config.top_p < 1.0 && !candidates.is_empty() {
523 let mut cumulative = 0.0f32;
524 let mut cutoff = candidates.len();
525 for (i, &(_, prob)) in candidates.iter().enumerate() {
526 cumulative += prob;
527 if cumulative >= config.top_p {
528 cutoff = i + 1;
529 break;
530 }
531 }
532 candidates.truncate(cutoff);
533 }
534
535 let total: f32 = candidates.iter().map(|(_, p)| p).sum();
537 if total > 0.0 && total != 1.0 {
538 for (_, p) in &mut candidates {
539 *p /= total;
540 }
541 }
542
543 if candidates.is_empty() {
545 return argmax(&processed);
546 }
547 if candidates.len() == 1 {
548 return candidates[0].0;
549 }
550
551 let r = rng.next_f32();
552 let mut cumulative = 0.0f32;
553 for &(idx, prob) in &candidates {
554 cumulative += prob;
555 if r < cumulative {
556 return idx;
557 }
558 }
559
560 candidates.last().map(|&(idx, _)| idx).unwrap_or(0)
562}
563
564fn apply_logit_bias_and_banned_tokens(processed: &mut [f32], config: &SamplerConfig) {
574 for &token in &config.banned_tokens {
576 let idx = token as usize;
577 if idx < processed.len() {
578 processed[idx] = f32::NEG_INFINITY;
579 }
580 }
581
582 for (&token, &bias) in &config.logit_bias {
584 let idx = token as usize;
585 if idx < processed.len() {
586 if processed[idx].is_finite() {
589 processed[idx] += bias;
590 }
591 }
592 }
593}
594
595fn apply_repetition_penalty(processed: &mut [f32], config: &SamplerConfig, recent_tokens: &[u32]) {
597 if config.repetition_penalty == 1.0 || recent_tokens.is_empty() {
598 return;
599 }
600
601 let window_start = recent_tokens
602 .len()
603 .saturating_sub(config.repetition_penalty_window);
604 for &token in &recent_tokens[window_start..] {
605 let idx = token as usize;
606 if idx < processed.len() {
607 if processed[idx] > 0.0 {
608 processed[idx] /= config.repetition_penalty;
609 } else {
610 processed[idx] *= config.repetition_penalty;
611 }
612 }
613 }
614}
615
616fn softmax_candidates(candidates: &mut [(u32, f32)]) {
618 if candidates.is_empty() {
619 return;
620 }
621
622 let max_logit = candidates
623 .iter()
624 .map(|(_, v)| *v)
625 .fold(f32::NEG_INFINITY, f32::max);
626
627 let mut sum = 0.0f32;
628 for (_, logit) in candidates.iter_mut() {
629 *logit = (*logit - max_logit).exp();
630 sum += *logit;
631 }
632
633 if sum > 0.0 {
634 for (_, prob) in candidates.iter_mut() {
635 *prob /= sum;
636 }
637 }
638}
639
640fn softmax_single_max(logits: &[f32]) -> f32 {
642 let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
643 let sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
644 if sum > 0.0 {
645 1.0 / sum
646 } else {
647 0.0
648 }
649}
650
651fn argmax(values: &[f32]) -> u32 {
653 let mut max_idx = 0u32;
654 let mut max_val = f32::NEG_INFINITY;
655 for (i, &v) in values.iter().enumerate() {
656 if v > max_val {
657 max_val = v;
658 max_idx = i as u32;
659 }
660 }
661 max_idx
662}
663
664struct Xorshift64 {
666 state: u64,
667}
668
669impl Xorshift64 {
670 fn new(seed: u64) -> Self {
671 Self {
673 state: if seed == 0 { 0x517cc1b727220a95 } else { seed },
674 }
675 }
676
677 fn next_u64(&mut self) -> u64 {
678 let mut x = self.state;
679 x ^= x << 13;
680 x ^= x >> 7;
681 x ^= x << 17;
682 self.state = x;
683 x
684 }
685
686 fn next_f32(&mut self) -> f32 {
688 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
689 }
690
691 pub(crate) fn state_value(&self) -> u64 {
693 self.state
694 }
695
696 pub(crate) fn from_state_value(state: u64) -> Self {
698 Self {
699 state: if state == 0 { 1 } else { state },
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_greedy_sampling() {
710 let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
711 let config = SamplerConfig::greedy();
712 let token = sample(&logits, &config, &[]);
713 assert_eq!(token, 3); }
715
716 #[test]
717 fn test_empty_logits() {
718 let logits: Vec<f32> = vec![];
719 let config = SamplerConfig::greedy();
720 let token = sample(&logits, &config, &[]);
721 assert_eq!(token, 0);
722 }
723
724 #[test]
725 fn test_temperature_zero_is_greedy() {
726 let logits = vec![1.0, 5.0, 3.0, 2.0];
727 let config = SamplerConfig {
728 temperature: 0.0,
729 ..SamplerConfig::default()
730 };
731 let token = sample(&logits, &config, &[]);
732 assert_eq!(token, 1); }
734
735 #[test]
736 fn test_top_k_1_is_greedy() {
737 let logits = vec![1.0, 5.0, 3.0, 2.0];
738 let config = SamplerConfig {
739 temperature: 1.0,
740 top_k: 1,
741 ..SamplerConfig::default()
742 };
743 let token = sample(&logits, &config, &[]);
744 assert_eq!(token, 1);
745 }
746
747 #[test]
748 fn test_seeded_determinism() {
749 let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
750 let config = SamplerConfig {
751 temperature: 1.0,
752 top_k: 0,
753 top_p: 1.0,
754 min_p: 0.0,
755 seed: Some(42),
756 ..SamplerConfig::default()
757 };
758
759 let mut sampler1 = Sampler::new(config.clone());
760 let mut sampler2 = Sampler::new(config);
761
762 for _ in 0..10 {
764 let t1 = sampler1.sample(&logits, &[]);
765 let t2 = sampler2.sample(&logits, &[]);
766 assert_eq!(t1, t2, "seeded samplers should produce identical results");
767 }
768 }
769
770 #[test]
771 fn test_top_p_filters_low_prob() {
772 let logits = vec![100.0, 0.0, 0.0, 0.0, 0.0];
774 let config = SamplerConfig {
775 temperature: 1.0,
776 top_k: 0,
777 top_p: 0.5,
778 min_p: 0.0,
779 seed: Some(123),
780 ..SamplerConfig::default()
781 };
782
783 let token = sample(&logits, &config, &[]);
785 assert_eq!(token, 0);
786 }
787
788 #[test]
789 fn test_repetition_penalty() {
790 let logits = vec![1.0, 5.0, 4.9, 1.0];
792 let config = SamplerConfig {
793 temperature: 0.0, repetition_penalty: 100.0, repetition_penalty_window: 64,
796 ..SamplerConfig::greedy()
797 };
798
799 let token_no_penalty = sample(&logits, &SamplerConfig::greedy(), &[]);
801 assert_eq!(token_no_penalty, 1);
802
803 let token_with_penalty = sample(&logits, &config, &[1]);
805 assert_eq!(token_with_penalty, 2);
806 }
807
808 #[test]
809 fn test_sampling_distribution() {
810 let logits = vec![2.0, 2.0, 2.0, 2.0]; let config = SamplerConfig {
813 temperature: 1.0,
814 top_k: 0,
815 top_p: 1.0,
816 min_p: 0.0,
817 seed: Some(999),
818 ..SamplerConfig::default()
819 };
820
821 let mut sampler = Sampler::new(config);
822 let mut counts = [0u32; 4];
823 for _ in 0..1000 {
824 let t = sampler.sample(&logits, &[]);
825 counts[t as usize] += 1;
826 }
827
828 for (i, &count) in counts.iter().enumerate() {
831 assert!(
832 count > 100 && count < 400,
833 "token {i} got {count} hits (expected ~250 for uniform distribution)"
834 );
835 }
836 }
837
838 #[test]
839 fn test_min_p_filtering() {
840 let logits = vec![10.0, -10.0, -10.0, -10.0];
842 let config = SamplerConfig {
843 temperature: 1.0,
844 top_k: 0,
845 top_p: 1.0,
846 min_p: 0.1, seed: Some(42),
848 ..SamplerConfig::default()
849 };
850
851 let mut sampler = Sampler::new(config);
853 for _ in 0..100 {
854 assert_eq!(sampler.sample(&logits, &[]), 0);
855 }
856 }
857
858 #[test]
859 fn test_xorshift_range() {
860 let mut rng = Xorshift64::new(12345);
861 for _ in 0..10000 {
862 let v = rng.next_f32();
863 assert!((0.0..1.0).contains(&v), "RNG produced {v} outside [0, 1)");
864 }
865 }
866
867 #[test]
868 fn test_mirostat_v2_basic() {
869 let logits = vec![3.0, 2.0, 1.0, 0.5, 0.1, -1.0, -2.0, -5.0];
871 let config = SamplerConfig {
872 seed: Some(42),
873 ..SamplerConfig::mirostat_v2(5.0, 0.1)
874 };
875 let mut sampler = Sampler::new(config);
876
877 for _ in 0..50 {
878 let token = sampler.sample(&logits, &[]);
879 assert!((token as usize) < logits.len());
880 }
881 }
882
883 #[test]
884 fn test_mirostat_v2_adapts_mu() {
885 let logits = vec![5.0, 0.0, 0.0, 0.0];
886 let config = SamplerConfig {
887 seed: Some(123),
888 ..SamplerConfig::mirostat_v2(3.0, 0.1)
889 };
890 let mut sampler = Sampler::new(config);
891 let initial_mu = sampler.mirostat_mu;
892
893 sampler.sample(&logits, &[]);
895 assert!(
896 (sampler.mirostat_mu - initial_mu).abs() > 1e-6,
897 "mu should adapt after sampling"
898 );
899 }
900
901 #[test]
902 fn test_mirostat_v2_low_tau_prefers_top() {
903 let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
905 let config = SamplerConfig {
906 seed: Some(42),
907 ..SamplerConfig::mirostat_v2(0.5, 0.1) };
909 let mut sampler = Sampler::new(config);
910
911 let mut top_count = 0;
912 for _ in 0..100 {
913 if sampler.sample(&logits, &[]) == 0 {
914 top_count += 1;
915 }
916 }
917 assert!(
919 top_count > 90,
920 "low tau should strongly prefer top token, got {top_count}/100"
921 );
922 }
923
924 #[test]
925 fn test_mirostat_v2_deterministic_with_seed() {
926 let logits = vec![2.0, 1.5, 1.0, 0.5];
927 let config = SamplerConfig {
928 seed: Some(777),
929 ..SamplerConfig::mirostat_v2(5.0, 0.1)
930 };
931
932 let mut sampler1 = Sampler::new(config.clone());
933 let mut sampler2 = Sampler::new(config);
934
935 for _ in 0..20 {
936 assert_eq!(
937 sampler1.sample(&logits, &[]),
938 sampler2.sample(&logits, &[]),
939 "same seed should produce same sequence"
940 );
941 }
942 }
943
944 #[test]
945 fn test_softmax_candidates_basic() {
946 let mut candidates = vec![(0, 0.0f32), (1, 0.0), (2, 0.0)];
947 softmax_candidates(&mut candidates);
948 for &(_, p) in &candidates {
950 assert!((p - 1.0 / 3.0).abs() < 0.01, "expected ~0.333, got {p}");
951 }
952 }
953
954 #[test]
957 fn banned_tokens_never_sampled() {
958 let vocab_size = 5usize;
960 let logits: Vec<f32> = (0..vocab_size).map(|i| i as f32).collect();
961
962 let mut banned = Vec::new();
963 for i in 0u32..vocab_size as u32 {
964 if i != 3 {
965 banned.push(i);
966 }
967 }
968 let config = SamplerConfig {
969 temperature: 1.0,
970 top_k: 0,
971 top_p: 1.0,
972 min_p: 0.0,
973 seed: Some(42),
974 banned_tokens: banned,
975 ..SamplerConfig::default()
976 };
977 let mut sampler = Sampler::new(config);
978 for _ in 0..50 {
979 let tok = sampler.sample(&logits, &[]);
980 assert_eq!(
981 tok, 3,
982 "only token 3 should ever be sampled when all others are banned"
983 );
984 }
985 }
986
987 #[test]
988 fn positive_bias_increases_token_probability() {
989 let logits = vec![10.0f32, -20.0, -20.0, -20.0];
992 let mut bias = std::collections::HashMap::new();
993 bias.insert(1u32, 100.0f32); let config = SamplerConfig {
996 temperature: 1.0,
997 top_k: 0,
998 top_p: 1.0,
999 min_p: 0.0,
1000 seed: Some(7),
1001 logit_bias: bias,
1002 ..SamplerConfig::default()
1003 };
1004 let mut sampler = Sampler::new(config);
1005 let tok = sampler.sample(&logits, &[]);
1007 assert_eq!(tok, 1, "large positive bias should make token 1 dominate");
1008 }
1009
1010 #[test]
1011 fn negative_bias_decreases() {
1012 let logits = vec![100.0f32, 1.0, 0.5, 0.1];
1015 let mut bias = std::collections::HashMap::new();
1016 bias.insert(0u32, -200.0f32); let config = SamplerConfig {
1019 temperature: 0.0, logit_bias: bias,
1021 ..SamplerConfig::greedy()
1022 };
1023 let tok = sample(&logits, &config, &[]);
1024 assert_eq!(
1025 tok, 1,
1026 "after large negative bias on token 0, token 1 should win"
1027 );
1028 }
1029
1030 #[test]
1031 fn logit_bias_empty_config_no_op() {
1032 let logits = vec![1.0f32, 2.0, 3.0, 0.5];
1034 let config_empty = SamplerConfig {
1035 temperature: 0.0,
1036 logit_bias: std::collections::HashMap::new(),
1037 banned_tokens: Vec::new(),
1038 ..SamplerConfig::greedy()
1039 };
1040 let tok = sample(&logits, &config_empty, &[]);
1041 assert_eq!(tok, 2, "empty logit_bias / banned_tokens should be a no-op");
1043 }
1044
1045 #[test]
1048 fn test_grammar_constrained_yes_no() {
1049 let g = Grammar::parse(r#"root ::= "yes" | "no""#).unwrap();
1050 let state = g.initial_state();
1051 assert!(state.allows_token(b"yes"));
1052 assert!(state.allows_token(b"no"));
1053 assert!(!state.allows_token(b"maybe"));
1054 }
1055
1056 #[test]
1057 fn test_grammar_sampler_masks_logits() {
1058 let vocab: Vec<(u32, Vec<u8>)> = vec![
1060 (0, b"maybe".to_vec()),
1061 (1, b"yes".to_vec()),
1062 (2, b"no".to_vec()),
1063 ];
1064 let g = Arc::new(Grammar::parse(r#"root ::= "yes" | "no""#).unwrap());
1065 let config = SamplerConfig {
1066 temperature: 0.0, grammar: Some(g),
1068 token_vocab: Some(Arc::new(vocab)),
1069 ..SamplerConfig::default()
1070 };
1071
1072 let logits = vec![100.0f32, 1.0, 1.0];
1074 let mut sampler = Sampler::new(config);
1075 let tok = sampler.sample(&logits, &[]);
1076 assert!(tok == 1 || tok == 2, "expected yes(1) or no(2), got {tok}");
1078 }
1079
1080 #[test]
1081 fn test_grammar_state_advances_through_sequence() {
1082 let vocab: Vec<(u32, Vec<u8>)> =
1083 vec![(0, b"a".to_vec()), (1, b"b".to_vec()), (2, b"c".to_vec())];
1084 let g = Arc::new(Grammar::parse(r#"root ::= "a" "b""#).unwrap());
1085 let config = SamplerConfig {
1086 temperature: 0.0,
1087 grammar: Some(g),
1088 token_vocab: Some(Arc::new(vocab)),
1089 ..SamplerConfig::default()
1090 };
1091
1092 let logits = vec![1.0f32, 0.5, 0.5];
1094 let mut sampler = Sampler::new(config);
1095
1096 let tok1 = sampler.sample(&logits, &[]);
1098 assert_eq!(tok1, 0, "first token must be 'a' (id=0)");
1099
1100 let tok2 = sampler.sample(&logits, &[0]);
1102 assert_eq!(tok2, 1, "second token must be 'b' (id=1)");
1103
1104 assert!(
1105 sampler.grammar_complete(),
1106 "grammar should be complete after 'a' + 'b'"
1107 );
1108 }
1109
1110 #[test]
1111 fn test_grammar_parse_roundtrip() {
1112 let g = Grammar::parse("root ::= [a-z]+ \":\" [0-9]+").unwrap();
1113 assert!(!g.rules.is_empty());
1114 assert_eq!(g.root, "root");
1115 }
1116
1117 #[test]
1118 fn test_grammar_stuck_state_masks_all() {
1119 let g = Arc::new(Grammar::parse(r#"root ::= "x""#).unwrap());
1121 let mut state = g.initial_state();
1122 let result = state.advance(b"y");
1123 assert!(result.is_err(), "advancing with wrong bytes should error");
1124 }
1125}