1use rlx_ir::Philox4x32;
50
51pub type Logits<'a> = &'a mut [f32];
53
54#[derive(Debug, Default, Clone)]
56pub struct SamplerState {
57 pub mirostat_mu: f32,
59}
60
61impl SamplerState {
62 pub fn new() -> Self {
63 Self {
64 mirostat_mu: f32::NAN,
65 }
66 }
67}
68
69pub trait Sampler: std::fmt::Debug + Send + Sync {
72 fn apply(
76 &self,
77 logits: Logits<'_>,
78 history: &[u32],
79 state: &mut SamplerState,
80 rng: &mut Philox4x32,
81 );
82
83 fn name(&self) -> &'static str {
85 std::any::type_name::<Self>()
86 }
87}
88
89#[derive(Debug)]
92pub struct SamplerChain {
93 pub steps: Vec<Box<dyn Sampler>>,
94}
95
96impl SamplerChain {
97 pub fn new() -> Self {
98 Self { steps: Vec::new() }
99 }
100
101 pub fn builder() -> SamplerChainBuilder {
102 SamplerChainBuilder::default()
103 }
104
105 pub fn sample(
107 &self,
108 logits: Logits<'_>,
109 history: &[u32],
110 state: &mut SamplerState,
111 rng: &mut Philox4x32,
112 ) -> u32 {
113 for step in &self.steps {
114 step.apply(logits, history, state, rng);
115 }
116 sample_from_logits(logits, rng)
117 }
118}
119
120impl Default for SamplerChain {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126#[derive(Debug, Default)]
127pub struct SamplerChainBuilder {
128 steps: Vec<Box<dyn Sampler>>,
129}
130
131impl SamplerChainBuilder {
132 pub fn push<S: Sampler + 'static>(mut self, s: S) -> Self {
133 self.steps.push(Box::new(s));
134 self
135 }
136
137 pub fn push_boxed(mut self, s: Box<dyn Sampler>) -> Self {
138 self.steps.push(s);
139 self
140 }
141
142 pub fn build(self) -> SamplerChain {
143 SamplerChain { steps: self.steps }
144 }
145}
146
147pub fn softmax_inplace(logits: &mut [f32]) {
151 let mut maxv = f32::NEG_INFINITY;
152 for &x in logits.iter() {
153 if x > maxv {
154 maxv = x;
155 }
156 }
157 if !maxv.is_finite() {
158 let inv = 1.0 / logits.len() as f32;
159 for x in logits.iter_mut() {
160 *x = inv;
161 }
162 return;
163 }
164 let mut s = 0.0f32;
165 for x in logits.iter_mut() {
166 let v = (*x - maxv).exp();
167 *x = v;
168 s += v;
169 }
170 let inv = if s > 0.0 { 1.0 / s } else { 0.0 };
171 for x in logits.iter_mut() {
172 *x *= inv;
173 }
174}
175
176pub fn sample_from_probs(probs: &[f32], rng: &mut Philox4x32) -> u32 {
179 let r = rng.next_f32();
180 let mut acc = 0.0f32;
181 for (i, &p) in probs.iter().enumerate() {
182 acc += p;
183 if r <= acc {
184 return i as u32;
185 }
186 }
187 (probs.len() - 1) as u32
188}
189
190pub fn sample_from_logits(logits: &mut [f32], rng: &mut Philox4x32) -> u32 {
192 softmax_inplace(logits);
193 sample_from_probs(logits, rng)
194}
195
196fn sorted_desc(logits: &[f32]) -> Vec<(usize, f32)> {
198 let mut v: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
199 v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
200 v
201}
202
203#[derive(Debug, Clone, Copy)]
206pub struct Temperature {
207 pub t: f32,
208}
209
210impl Sampler for Temperature {
211 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
212 let t = self.t.max(1e-6);
213 for x in logits.iter_mut() {
214 *x /= t;
215 }
216 }
217}
218
219#[derive(Debug, Clone, Copy)]
226pub struct DynamicTemperature {
227 pub min: f32,
228 pub max: f32,
229 pub exponent: f32,
230}
231
232impl Sampler for DynamicTemperature {
233 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
234 let v = logits.len();
235 if v == 0 {
236 return;
237 }
238 let mut tmp: Vec<f32> = logits.to_vec();
239 softmax_inplace(&mut tmp);
240 let mut h = 0.0f32;
242 for &p in tmp.iter() {
243 if p > 0.0 {
244 h -= p * p.ln();
245 }
246 }
247 let hmax = (v as f32).ln().max(1e-6);
248 let norm = (h / hmax).clamp(0.0, 1.0);
249 let t = self.min + (self.max - self.min) * norm.powf(self.exponent);
250 let t = t.max(1e-6);
251 for x in logits.iter_mut() {
252 *x /= t;
253 }
254 }
255}
256
257#[derive(Debug, Clone, Copy)]
260pub struct TopK {
261 pub k: usize,
262}
263
264impl Sampler for TopK {
265 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
266 let v = logits.len();
267 if self.k == 0 || self.k >= v {
268 return;
269 }
270 let sorted = sorted_desc(logits);
271 let cutoff = sorted[self.k - 1].1;
272 for x in logits.iter_mut() {
273 if *x < cutoff {
274 *x = f32::NEG_INFINITY;
275 }
276 }
277 }
278}
279
280#[derive(Debug, Clone, Copy)]
283pub struct TopP {
284 pub p: f32,
285 pub min_keep: usize,
288}
289
290impl Sampler for TopP {
291 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
292 if self.p >= 1.0 {
293 return;
294 }
295 let v = logits.len();
296 if v == 0 {
297 return;
298 }
299 let mut probs: Vec<f32> = logits.to_vec();
300 softmax_inplace(&mut probs);
301 let sorted = sorted_desc(&probs);
302 let mut keep = vec![false; v];
303 let mut cum = 0.0f32;
304 for (rank, (idx, p)) in sorted.iter().enumerate() {
305 keep[*idx] = true;
306 cum += *p;
307 if cum >= self.p && rank + 1 >= self.min_keep {
308 break;
309 }
310 }
311 for (i, x) in logits.iter_mut().enumerate() {
312 if !keep[i] {
313 *x = f32::NEG_INFINITY;
314 }
315 }
316 }
317}
318
319#[derive(Debug, Clone, Copy)]
326pub struct TopNSigma {
327 pub n: f32,
328}
329
330impl Sampler for TopNSigma {
331 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
332 let v = logits.len();
333 if v == 0 || !self.n.is_finite() || self.n <= 0.0 {
334 return;
335 }
336 let mut maxv = f32::NEG_INFINITY;
337 let mut count = 0usize;
338 let mut sum = 0.0f32;
339 for &x in logits.iter() {
340 if x.is_finite() {
341 if x > maxv {
342 maxv = x;
343 }
344 sum += x;
345 count += 1;
346 }
347 }
348 if count == 0 || !maxv.is_finite() {
349 return;
350 }
351 let mean = sum / count as f32;
352 let mut var = 0.0f32;
353 for &x in logits.iter() {
354 if x.is_finite() {
355 let d = x - mean;
356 var += d * d;
357 }
358 }
359 let sigma = (var / count as f32).sqrt();
360 let cutoff = maxv - self.n * sigma;
361 for x in logits.iter_mut() {
362 if *x < cutoff {
363 *x = f32::NEG_INFINITY;
364 }
365 }
366 }
367}
368
369#[derive(Debug, Clone, Copy)]
376pub struct TypicalP {
377 pub p: f32,
378 pub min_keep: usize,
379}
380
381impl Sampler for TypicalP {
382 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
383 if self.p >= 1.0 {
384 return;
385 }
386 let v = logits.len();
387 if v == 0 {
388 return;
389 }
390 let mut probs: Vec<f32> = logits.to_vec();
391 softmax_inplace(&mut probs);
392 let mut h = 0.0f32;
393 for &p in probs.iter() {
394 if p > 0.0 {
395 h -= p * p.ln();
396 }
397 }
398 let mut scored: Vec<(usize, f32, f32)> = probs
400 .iter()
401 .enumerate()
402 .map(|(i, &p)| {
403 let neg_log = if p > 0.0 { -p.ln() } else { f32::INFINITY };
404 let dev = (neg_log - h).abs();
405 (i, p, dev)
406 })
407 .collect();
408 scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
409 let mut keep = vec![false; v];
410 let mut cum = 0.0f32;
411 for (rank, (idx, p, _)) in scored.iter().enumerate() {
412 keep[*idx] = true;
413 cum += *p;
414 if cum >= self.p && rank + 1 >= self.min_keep {
415 break;
416 }
417 }
418 for (i, x) in logits.iter_mut().enumerate() {
419 if !keep[i] {
420 *x = f32::NEG_INFINITY;
421 }
422 }
423 }
424}
425
426#[derive(Debug, Clone, Copy)]
433pub struct MirostatV1 {
434 pub tau: f32,
435 pub eta: f32,
436 pub m: usize,
438}
439
440impl Default for MirostatV1 {
441 fn default() -> Self {
442 Self {
443 tau: 5.0,
444 eta: 0.1,
445 m: 100,
446 }
447 }
448}
449
450impl Sampler for MirostatV1 {
451 fn apply(
452 &self,
453 logits: Logits<'_>,
454 _h: &[u32],
455 state: &mut SamplerState,
456 rng: &mut Philox4x32,
457 ) {
458 let v = logits.len();
459 if v == 0 {
460 return;
461 }
462 if !state.mirostat_mu.is_finite() {
463 state.mirostat_mu = 2.0 * self.tau;
464 }
465 let mu = state.mirostat_mu.max(1e-6);
466 let mut probs = logits.to_vec();
468 softmax_inplace(&mut probs);
469 let sorted = sorted_desc(&probs);
470 let m = self.m.min(sorted.len()).max(2);
472 let mut num = 0.0f32;
473 let mut den = 0.0f32;
474 for i in 0..(m - 1) {
475 let t = ((i + 2) as f32 / (i + 1) as f32).ln();
476 let b = (sorted[i].1 / sorted[i + 1].1).ln().max(1e-9);
477 num += t * b;
478 den += t * t;
479 }
480 let s_hat = if den > 0.0 { num / den } else { 1.0 };
481 let eps = (s_hat - 1.0).abs().max(1e-3);
483 let k_real = ((eps * (2.0f32.powf(mu))) / (1.0 - (v as f32).powf(-eps)))
484 .powf(1.0 / s_hat)
485 .clamp(1.0, v as f32);
486 let k = k_real as usize;
487 if k < sorted.len() {
488 let cutoff = sorted[k - 1].1;
489 for (i, p) in probs.iter_mut().enumerate() {
490 if *p < cutoff {
491 *p = 0.0;
492 }
493 let _ = i;
494 }
495 let s: f32 = probs.iter().sum();
496 if s > 0.0 {
497 for p in probs.iter_mut() {
498 *p /= s;
499 }
500 }
501 }
502 let tok = sample_from_probs(&probs, rng) as usize;
505 let surprise = if probs[tok] > 0.0 {
506 -probs[tok].ln() / 2.0f32.ln()
507 } else {
508 mu
509 };
510 state.mirostat_mu = (mu - self.eta * (surprise - self.tau)).max(0.0);
511 for (i, x) in logits.iter_mut().enumerate() {
512 *x = if i == tok {
513 f32::INFINITY
514 } else {
515 f32::NEG_INFINITY
516 };
517 }
518 }
519}
520
521#[derive(Debug, Clone, Copy)]
527pub struct MirostatV2 {
528 pub tau: f32,
529 pub eta: f32,
530}
531
532impl Default for MirostatV2 {
533 fn default() -> Self {
534 Self { tau: 5.0, eta: 0.1 }
535 }
536}
537
538impl Sampler for MirostatV2 {
539 fn apply(
540 &self,
541 logits: Logits<'_>,
542 _h: &[u32],
543 state: &mut SamplerState,
544 rng: &mut Philox4x32,
545 ) {
546 let v = logits.len();
547 if v == 0 {
548 return;
549 }
550 if !state.mirostat_mu.is_finite() {
551 state.mirostat_mu = 2.0 * self.tau;
552 }
553 let mu = state.mirostat_mu;
554 let mut probs = logits.to_vec();
555 softmax_inplace(&mut probs);
556 let mut sorted = sorted_desc(&probs);
558 let ln2 = 2.0f32.ln();
559 let mut keep_n = 0usize;
560 for (i, (_, p)) in sorted.iter().enumerate() {
561 let s = if *p > 0.0 {
562 -p.ln() / ln2
563 } else {
564 f32::INFINITY
565 };
566 if s > mu {
567 break;
568 }
569 keep_n = i + 1;
570 }
571 if keep_n == 0 {
572 keep_n = 1;
573 }
574 let kept: std::collections::HashSet<usize> =
575 sorted.drain(..keep_n).map(|(i, _)| i).collect();
576 for (i, p) in probs.iter_mut().enumerate() {
577 if !kept.contains(&i) {
578 *p = 0.0;
579 }
580 }
581 let s: f32 = probs.iter().sum();
582 if s > 0.0 {
583 for p in probs.iter_mut() {
584 *p /= s;
585 }
586 }
587 let tok = sample_from_probs(&probs, rng) as usize;
588 let surprise = if probs[tok] > 0.0 {
589 -probs[tok].ln() / ln2
590 } else {
591 mu
592 };
593 state.mirostat_mu = (mu - self.eta * (surprise - self.tau)).max(0.0);
594 for (i, x) in logits.iter_mut().enumerate() {
595 *x = if i == tok {
596 f32::INFINITY
597 } else {
598 f32::NEG_INFINITY
599 };
600 }
601 }
602}
603
604#[derive(Debug, Clone, Copy)]
611pub struct Xtc {
612 pub threshold: f32,
613 pub prob: f32,
614 pub min_keep: usize,
616}
617
618impl Sampler for Xtc {
619 fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, rng: &mut Philox4x32) {
620 if self.prob <= 0.0 {
621 return;
622 }
623 if rng.next_f32() > self.prob {
624 return;
625 }
626 let v = logits.len();
627 if v == 0 {
628 return;
629 }
630 let mut probs = logits.to_vec();
631 softmax_inplace(&mut probs);
632 let sorted = sorted_desc(&probs);
633 let n_above = sorted.iter().filter(|(_, p)| *p > self.threshold).count();
635 if n_above < 2 {
636 return; }
638 let to_kill = n_above.saturating_sub(self.min_keep.max(1));
641 for (idx, _) in sorted.iter().take(to_kill) {
642 logits[*idx] = f32::NEG_INFINITY;
643 }
644 }
645}
646
647#[derive(Debug, Clone)]
655pub struct Dry {
656 pub multiplier: f32,
657 pub base: f32,
658 pub allowed_length: usize,
659 pub max_ngram: usize,
660 pub sequence_breakers: Vec<u32>,
662}
663
664impl Default for Dry {
665 fn default() -> Self {
666 Self {
667 multiplier: 0.8,
668 base: 1.75,
669 allowed_length: 2,
670 max_ngram: 32,
671 sequence_breakers: Vec::new(),
672 }
673 }
674}
675
676impl Sampler for Dry {
677 fn apply(
678 &self,
679 logits: Logits<'_>,
680 history: &[u32],
681 _s: &mut SamplerState,
682 _r: &mut Philox4x32,
683 ) {
684 if self.multiplier <= 0.0 || history.is_empty() {
685 return;
686 }
687 let n = history.len();
688 let max_ngram = self.max_ngram.min(n);
689 let breakers: std::collections::HashSet<u32> =
690 self.sequence_breakers.iter().copied().collect();
691 let mut longest: std::collections::HashMap<u32, usize> = std::collections::HashMap::new();
696 for i in 0..n.saturating_sub(1) {
697 if breakers.contains(&history[i]) {
698 continue;
699 }
700 let mut l = 0usize;
703 while l < max_ngram && i >= l && n > l && history[i - l] == history[n - 1 - l] {
704 l += 1;
705 }
706 if l >= self.allowed_length && i + 1 < n {
707 let next = history[i + 1];
708 let cur = longest.entry(next).or_insert(0);
709 if l > *cur {
710 *cur = l;
711 }
712 }
713 }
714 for (tok, l) in longest {
715 let pen = self.multiplier * self.base.powi((l - self.allowed_length) as i32);
716 let idx = tok as usize;
717 if idx < logits.len() {
718 logits[idx] -= pen;
719 }
720 }
721 }
722}
723
724#[derive(Debug, Clone, Copy)]
727pub struct RepetitionPenalty {
728 pub penalty: f32,
729 pub frequency: f32,
730 pub presence: f32,
731 pub last_n: usize,
733}
734
735impl Default for RepetitionPenalty {
736 fn default() -> Self {
737 Self {
738 penalty: 1.0,
739 frequency: 0.0,
740 presence: 0.0,
741 last_n: 64,
742 }
743 }
744}
745
746impl Sampler for RepetitionPenalty {
747 fn apply(
748 &self,
749 logits: Logits<'_>,
750 history: &[u32],
751 _s: &mut SamplerState,
752 _r: &mut Philox4x32,
753 ) {
754 if history.is_empty() {
755 return;
756 }
757 let start = history.len().saturating_sub(self.last_n);
758 let window = &history[start..];
759 let mut counts: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
760 for &t in window {
761 *counts.entry(t).or_insert(0) += 1;
762 }
763 for (tok, c) in counts {
764 let idx = tok as usize;
765 if idx >= logits.len() {
766 continue;
767 }
768 logits[idx] -= self.presence + self.frequency * c as f32;
770 if (self.penalty - 1.0).abs() > 1e-6 {
773 if logits[idx] > 0.0 {
774 logits[idx] /= self.penalty;
775 } else {
776 logits[idx] *= self.penalty;
777 }
778 }
779 }
780 }
781}
782
783#[cfg(test)]
786mod tests {
787 use super::*;
788
789 fn rng() -> Philox4x32 {
790 Philox4x32::new(0xDEAD_BEEF)
791 }
792
793 #[test]
794 fn temperature_zero_is_greedy_after_chain() {
795 let chain = SamplerChain::builder()
796 .push(Temperature { t: 1e-6 })
797 .build();
798 let mut state = SamplerState::new();
799 let mut r = rng();
800 let mut logits = vec![1.0, 5.0, 2.0, 3.0];
801 let tok = chain.sample(&mut logits, &[], &mut state, &mut r);
802 assert_eq!(tok, 1);
803 }
804
805 #[test]
806 fn top_k_masks_below_kth() {
807 let mut logits = vec![1.0, 5.0, 2.0, 3.0];
808 let mut s = SamplerState::new();
809 let mut r = rng();
810 TopK { k: 2 }.apply(&mut logits, &[], &mut s, &mut r);
811 assert_eq!(logits[1], 5.0);
812 assert_eq!(logits[3], 3.0);
813 assert!(logits[0].is_infinite() && logits[0] < 0.0);
814 assert!(logits[2].is_infinite() && logits[2] < 0.0);
815 }
816
817 #[test]
818 fn top_p_keeps_nucleus() {
819 let mut logits = vec![0.0f32; 4];
820 logits[0] = 10.0;
821 logits[1] = 5.0;
822 let mut s = SamplerState::new();
823 let mut r = rng();
824 TopP {
825 p: 0.5,
826 min_keep: 1,
827 }
828 .apply(&mut logits, &[], &mut s, &mut r);
829 assert!(logits[0].is_finite());
830 assert!(logits[2].is_infinite() && logits[2] < 0.0);
832 assert!(logits[3].is_infinite() && logits[3] < 0.0);
833 }
834
835 #[test]
836 fn top_n_sigma_keeps_top_logits() {
837 let mut logits = vec![0.0f32; 32];
839 logits[0] = 10.0;
840 logits[1] = 9.5;
841 let mut s = SamplerState::new();
842 let mut r = rng();
843 TopNSigma { n: 1.0 }.apply(&mut logits, &[], &mut s, &mut r);
844 assert!(logits[0].is_finite());
846 assert!(logits[5].is_infinite() && logits[5] < 0.0);
847 }
848
849 #[test]
850 fn dynamic_temperature_scales_with_entropy() {
851 let mut logits = vec![1.0f32; 16];
853 let before = logits.clone();
854 let mut s = SamplerState::new();
855 let mut r = rng();
856 DynamicTemperature {
857 min: 0.5,
858 max: 2.0,
859 exponent: 1.0,
860 }
861 .apply(&mut logits, &[], &mut s, &mut r);
862 assert!((logits[0] - before[0] / 2.0).abs() < 1e-5);
863 }
864
865 #[test]
866 fn typical_p_keeps_typical_token() {
867 let mut logits = vec![5.0, 4.0, 0.0, -10.0];
868 let mut s = SamplerState::new();
869 let mut r = rng();
870 TypicalP {
871 p: 0.5,
872 min_keep: 1,
873 }
874 .apply(&mut logits, &[], &mut s, &mut r);
875 assert!(logits.iter().any(|x| x.is_finite()));
877 }
878
879 #[test]
880 fn mirostat_v2_keeps_at_least_one() {
881 let mut logits = vec![1.0, 2.0, 3.0, 4.0];
882 let mut s = SamplerState::new();
883 let mut r = rng();
884 MirostatV2 { tau: 5.0, eta: 0.1 }.apply(&mut logits, &[], &mut s, &mut r);
885 let n_inf = logits
887 .iter()
888 .filter(|x| x.is_infinite() && **x > 0.0)
889 .count();
890 assert_eq!(n_inf, 1);
891 }
892
893 #[test]
894 fn xtc_disabled_when_prob_zero() {
895 let mut logits = vec![10.0, 5.0, 1.0];
896 let before = logits.clone();
897 let mut s = SamplerState::new();
898 let mut r = rng();
899 Xtc {
900 threshold: 0.5,
901 prob: 0.0,
902 min_keep: 1,
903 }
904 .apply(&mut logits, &[], &mut s, &mut r);
905 assert_eq!(logits, before);
906 }
907
908 #[test]
909 fn dry_penalises_repeat_continuation() {
910 let history = vec![0u32, 1, 0, 1, 0];
912 let mut logits = vec![0.0, 0.0];
913 let mut s = SamplerState::new();
914 let mut r = rng();
915 Dry {
916 multiplier: 1.0,
917 base: 2.0,
918 allowed_length: 2,
919 max_ngram: 8,
920 sequence_breakers: vec![],
921 }
922 .apply(&mut logits, &history, &mut s, &mut r);
923 assert!(logits[1] < 0.0, "B should be penalised; got {}", logits[1]);
924 }
925
926 #[test]
927 fn repetition_penalty_lowers_repeated_token() {
928 let history = vec![0u32; 8];
929 let mut logits = vec![1.0, 1.0];
930 let mut s = SamplerState::new();
931 let mut r = rng();
932 RepetitionPenalty {
933 penalty: 2.0,
934 frequency: 0.0,
935 presence: 0.0,
936 last_n: 64,
937 }
938 .apply(&mut logits, &history, &mut s, &mut r);
939 assert!(logits[0] < logits[1]);
940 }
941}