Skip to main content

rlx_runtime/
samplers.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Advanced token samplers — the long-tail of llama.cpp samplers ported to
17//! a backend-agnostic `Sampler` chain.
18//!
19//! Why backend-agnostic: all samplers operate on a logits row (`&mut [f32]`)
20//! plus a token history slice; nothing about the kernel cares which device
21//! produced the logits. This file owns the algorithms; backend kernels just
22//! emit logits and (for paths like Metal/CUDA) may call into a thin shim
23//! that runs the chain on the CPU after a single device→host copy.
24//!
25//! Available samplers
26//! ------------------
27//!
28//! - [`Temperature`] — uniform `logits[i] /= T` scaling (T → 0 ⇒ greedy).
29//! - [`DynamicTemperature`] — entropy-aware temperature: T scales between
30//!   `min` and `max` based on `entropy(p) / log(vocab)`.
31//! - [`TopK`] — keep the K largest logits; mask the rest to `-inf`.
32//! - [`TopP`] — nucleus sampling: smallest set whose cumulative prob ≥ p.
33//! - [`TopNSigma`] — keep tokens whose logit ≥ `max - n * σ(logits)`.
34//! - [`TypicalP`] — locally-typical sampling (Meister et al. 2022): keep
35//!   tokens whose `|−log p − H(p)|` is smallest until cumulative prob ≥ p.
36//! - [`MirostatV1`] — target-surprise mode-V1 (μ/τ control loop).
37//! - [`MirostatV2`] — simpler logit-based variant: keep tokens with surprise ≤ μ.
38//! - [`Xtc`] — eXclude Top Choices: with probability `prob`, drop top tokens
39//!   above `threshold` so only the lower-rank tokens survive (Q1-style sampling).
40//! - [`Dry`] — DRY (Don't Repeat Yourself) repetition penalty. Penalises
41//!   tokens whose continuation matches an earlier n-gram in the history.
42//! - [`RepetitionPenalty`] — classic frequency/presence penalty (kept here
43//!   so the chain is self-contained).
44//!
45//! Composition: a `SamplerChain` runs samplers in order, each mutating the
46//! logits in place, then the final `Sampler::sample` draws one token. All
47//! samplers are deterministic given the same RNG state.
48
49use rlx_ir::Philox4x32;
50
51/// One row of vocabulary-shaped logits (`[vocab]`).
52pub type Logits<'a> = &'a mut [f32];
53
54/// State that some samplers update each step (e.g. Mirostat μ).
55#[derive(Debug, Default, Clone)]
56pub struct SamplerState {
57    /// Mirostat estimator. NaN until the first apply.
58    pub mirostat_mu: f32,
59}
60
61impl SamplerState {
62    pub fn new() -> Self {
63        Self {
64            mirostat_mu: f32::NAN,
65        }
66    }
67}
68
69/// A logit transform + sample step. Samplers mutate the logits row in place;
70/// `sample` is called at the end of the chain.
71pub trait Sampler: std::fmt::Debug + Send + Sync {
72    /// Apply this sampler's transformation to `logits`. `history` is the
73    /// already-emitted token sequence (newest last). Some samplers (DRY,
74    /// repetition penalty) need it; pure logit transforms ignore it.
75    fn apply(
76        &self,
77        logits: Logits<'_>,
78        history: &[u32],
79        state: &mut SamplerState,
80        rng: &mut Philox4x32,
81    );
82
83    /// Optional name for tracing.
84    fn name(&self) -> &'static str {
85        std::any::type_name::<Self>()
86    }
87}
88
89/// A linear chain of samplers terminated by a `softmax + multinomial`
90/// draw. Build one with [`SamplerChain::builder`] or directly from a Vec.
91#[derive(Debug)]
92pub struct SamplerChain {
93    pub steps: Vec<Box<dyn Sampler>>,
94}
95
96impl SamplerChain {
97    pub fn new() -> Self {
98        Self { steps: Vec::new() }
99    }
100
101    pub fn builder() -> SamplerChainBuilder {
102        SamplerChainBuilder::default()
103    }
104
105    /// Run every sampler in order over `logits`, then draw one token.
106    pub fn sample(
107        &self,
108        logits: Logits<'_>,
109        history: &[u32],
110        state: &mut SamplerState,
111        rng: &mut Philox4x32,
112    ) -> u32 {
113        for step in &self.steps {
114            step.apply(logits, history, state, rng);
115        }
116        sample_from_logits(logits, rng)
117    }
118}
119
120impl Default for SamplerChain {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126#[derive(Debug, Default)]
127pub struct SamplerChainBuilder {
128    steps: Vec<Box<dyn Sampler>>,
129}
130
131impl SamplerChainBuilder {
132    pub fn push<S: Sampler + 'static>(mut self, s: S) -> Self {
133        self.steps.push(Box::new(s));
134        self
135    }
136
137    pub fn push_boxed(mut self, s: Box<dyn Sampler>) -> Self {
138        self.steps.push(s);
139        self
140    }
141
142    pub fn build(self) -> SamplerChain {
143        SamplerChain { steps: self.steps }
144    }
145}
146
147// ─── shared utilities ─────────────────────────────────────────────────
148
149/// Stable softmax in place. Tokens at `-inf` get probability 0.
150pub fn softmax_inplace(logits: &mut [f32]) {
151    let mut maxv = f32::NEG_INFINITY;
152    for &x in logits.iter() {
153        if x > maxv {
154            maxv = x;
155        }
156    }
157    if !maxv.is_finite() {
158        let inv = 1.0 / logits.len() as f32;
159        for x in logits.iter_mut() {
160            *x = inv;
161        }
162        return;
163    }
164    let mut s = 0.0f32;
165    for x in logits.iter_mut() {
166        let v = (*x - maxv).exp();
167        *x = v;
168        s += v;
169    }
170    let inv = if s > 0.0 { 1.0 / s } else { 0.0 };
171    for x in logits.iter_mut() {
172        *x *= inv;
173    }
174}
175
176/// Multinomial draw from a probability row (assumed already softmax-ed).
177/// Robust against floating-point drift in the tail.
178pub fn sample_from_probs(probs: &[f32], rng: &mut Philox4x32) -> u32 {
179    let r = rng.next_f32();
180    let mut acc = 0.0f32;
181    for (i, &p) in probs.iter().enumerate() {
182        acc += p;
183        if r <= acc {
184            return i as u32;
185        }
186    }
187    (probs.len() - 1) as u32
188}
189
190/// Apply softmax to logits in place then multinomial-sample one index.
191pub fn sample_from_logits(logits: &mut [f32], rng: &mut Philox4x32) -> u32 {
192    softmax_inplace(logits);
193    sample_from_probs(logits, rng)
194}
195
196/// Sorted-descending index/value pairs. Used by every top-something sampler.
197fn sorted_desc(logits: &[f32]) -> Vec<(usize, f32)> {
198    let mut v: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
199    v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
200    v
201}
202
203// ─── Temperature ─────────────────────────────────────────────────────
204
205#[derive(Debug, Clone, Copy)]
206pub struct Temperature {
207    pub t: f32,
208}
209
210impl Sampler for Temperature {
211    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
212        let t = self.t.max(1e-6);
213        for x in logits.iter_mut() {
214            *x /= t;
215        }
216    }
217}
218
219// ─── DynamicTemperature ──────────────────────────────────────────────
220//
221// Scale temperature between [min, max] based on the entropy of the
222// current logits' softmax: T = min + (max - min) * (H / Hmax)^exponent.
223// `exponent=1.0` matches the original llama.cpp `dynatemp` behavior.
224
225#[derive(Debug, Clone, Copy)]
226pub struct DynamicTemperature {
227    pub min: f32,
228    pub max: f32,
229    pub exponent: f32,
230}
231
232impl Sampler for DynamicTemperature {
233    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
234        let v = logits.len();
235        if v == 0 {
236            return;
237        }
238        let mut tmp: Vec<f32> = logits.to_vec();
239        softmax_inplace(&mut tmp);
240        // Shannon entropy in nats.
241        let mut h = 0.0f32;
242        for &p in tmp.iter() {
243            if p > 0.0 {
244                h -= p * p.ln();
245            }
246        }
247        let hmax = (v as f32).ln().max(1e-6);
248        let norm = (h / hmax).clamp(0.0, 1.0);
249        let t = self.min + (self.max - self.min) * norm.powf(self.exponent);
250        let t = t.max(1e-6);
251        for x in logits.iter_mut() {
252            *x /= t;
253        }
254    }
255}
256
257// ─── TopK ────────────────────────────────────────────────────────────
258
259#[derive(Debug, Clone, Copy)]
260pub struct TopK {
261    pub k: usize,
262}
263
264impl Sampler for TopK {
265    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
266        let v = logits.len();
267        if self.k == 0 || self.k >= v {
268            return;
269        }
270        let sorted = sorted_desc(logits);
271        let cutoff = sorted[self.k - 1].1;
272        for x in logits.iter_mut() {
273            if *x < cutoff {
274                *x = f32::NEG_INFINITY;
275            }
276        }
277    }
278}
279
280// ─── TopP (nucleus) ──────────────────────────────────────────────────
281
282#[derive(Debug, Clone, Copy)]
283pub struct TopP {
284    pub p: f32,
285    /// Always keep at least this many tokens. Avoids degenerate single-token
286    /// nucleus when one logit dominates.
287    pub min_keep: usize,
288}
289
290impl Sampler for TopP {
291    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
292        if self.p >= 1.0 {
293            return;
294        }
295        let v = logits.len();
296        if v == 0 {
297            return;
298        }
299        let mut probs: Vec<f32> = logits.to_vec();
300        softmax_inplace(&mut probs);
301        let sorted = sorted_desc(&probs);
302        let mut keep = vec![false; v];
303        let mut cum = 0.0f32;
304        for (rank, (idx, p)) in sorted.iter().enumerate() {
305            keep[*idx] = true;
306            cum += *p;
307            if cum >= self.p && rank + 1 >= self.min_keep {
308                break;
309            }
310        }
311        for (i, x) in logits.iter_mut().enumerate() {
312            if !keep[i] {
313                *x = f32::NEG_INFINITY;
314            }
315        }
316    }
317}
318
319// ─── TopNSigma ───────────────────────────────────────────────────────
320//
321// Keep tokens whose logit ≥ max_logit − n × σ(logits). Works directly on
322// the logit space, so no softmax is required to mask. (Ref: Hewitt et al.
323// "Top-n-σ" 2024.)
324
325#[derive(Debug, Clone, Copy)]
326pub struct TopNSigma {
327    pub n: f32,
328}
329
330impl Sampler for TopNSigma {
331    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
332        let v = logits.len();
333        if v == 0 || !self.n.is_finite() || self.n <= 0.0 {
334            return;
335        }
336        let mut maxv = f32::NEG_INFINITY;
337        let mut count = 0usize;
338        let mut sum = 0.0f32;
339        for &x in logits.iter() {
340            if x.is_finite() {
341                if x > maxv {
342                    maxv = x;
343                }
344                sum += x;
345                count += 1;
346            }
347        }
348        if count == 0 || !maxv.is_finite() {
349            return;
350        }
351        let mean = sum / count as f32;
352        let mut var = 0.0f32;
353        for &x in logits.iter() {
354            if x.is_finite() {
355                let d = x - mean;
356                var += d * d;
357            }
358        }
359        let sigma = (var / count as f32).sqrt();
360        let cutoff = maxv - self.n * sigma;
361        for x in logits.iter_mut() {
362            if *x < cutoff {
363                *x = f32::NEG_INFINITY;
364            }
365        }
366    }
367}
368
369// ─── TypicalP (locally typical sampling) ────────────────────────────
370//
371// Meister et al. 2022: rank tokens by deviation of their surprisal
372// (−log p) from the distribution's entropy H, ascending. Keep the
373// smallest set whose cumulative prob ≥ p.
374
375#[derive(Debug, Clone, Copy)]
376pub struct TypicalP {
377    pub p: f32,
378    pub min_keep: usize,
379}
380
381impl Sampler for TypicalP {
382    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, _r: &mut Philox4x32) {
383        if self.p >= 1.0 {
384            return;
385        }
386        let v = logits.len();
387        if v == 0 {
388            return;
389        }
390        let mut probs: Vec<f32> = logits.to_vec();
391        softmax_inplace(&mut probs);
392        let mut h = 0.0f32;
393        for &p in probs.iter() {
394            if p > 0.0 {
395                h -= p * p.ln();
396            }
397        }
398        // Score each token by |−log p − H|.
399        let mut scored: Vec<(usize, f32, f32)> = probs
400            .iter()
401            .enumerate()
402            .map(|(i, &p)| {
403                let neg_log = if p > 0.0 { -p.ln() } else { f32::INFINITY };
404                let dev = (neg_log - h).abs();
405                (i, p, dev)
406            })
407            .collect();
408        scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
409        let mut keep = vec![false; v];
410        let mut cum = 0.0f32;
411        for (rank, (idx, p, _)) in scored.iter().enumerate() {
412            keep[*idx] = true;
413            cum += *p;
414            if cum >= self.p && rank + 1 >= self.min_keep {
415                break;
416            }
417        }
418        for (i, x) in logits.iter_mut().enumerate() {
419            if !keep[i] {
420                *x = f32::NEG_INFINITY;
421            }
422        }
423    }
424}
425
426// ─── Mirostat v1 ─────────────────────────────────────────────────────
427//
428// Original (NeurIPS 2021) paper: maintain a running μ that is the target
429// surprise. Top-k is set so the average surprise of the truncated head
430// matches μ, then μ is updated by the error against τ.
431
432#[derive(Debug, Clone, Copy)]
433pub struct MirostatV1 {
434    pub tau: f32,
435    pub eta: f32,
436    /// Pareto‐tail estimator window. Original paper uses 100.
437    pub m: usize,
438}
439
440impl Default for MirostatV1 {
441    fn default() -> Self {
442        Self {
443            tau: 5.0,
444            eta: 0.1,
445            m: 100,
446        }
447    }
448}
449
450impl Sampler for MirostatV1 {
451    fn apply(
452        &self,
453        logits: Logits<'_>,
454        _h: &[u32],
455        state: &mut SamplerState,
456        rng: &mut Philox4x32,
457    ) {
458        let v = logits.len();
459        if v == 0 {
460            return;
461        }
462        if !state.mirostat_mu.is_finite() {
463            state.mirostat_mu = 2.0 * self.tau;
464        }
465        let mu = state.mirostat_mu.max(1e-6);
466        // Softmax to get probabilities and sort descending.
467        let mut probs = logits.to_vec();
468        softmax_inplace(&mut probs);
469        let sorted = sorted_desc(&probs);
470        // Pareto-tail s-hat from top-m. Surprise[i] = −log p[i].
471        let m = self.m.min(sorted.len()).max(2);
472        let mut num = 0.0f32;
473        let mut den = 0.0f32;
474        for i in 0..(m - 1) {
475            let t = ((i + 2) as f32 / (i + 1) as f32).ln();
476            let b = (sorted[i].1 / sorted[i + 1].1).ln().max(1e-9);
477            num += t * b;
478            den += t * t;
479        }
480        let s_hat = if den > 0.0 { num / den } else { 1.0 };
481        // k = ((eps * 2^μ) / (1 − N^(−eps))) ^ (1 / s_hat)
482        let eps = (s_hat - 1.0).abs().max(1e-3);
483        let k_real = ((eps * (2.0f32.powf(mu))) / (1.0 - (v as f32).powf(-eps)))
484            .powf(1.0 / s_hat)
485            .clamp(1.0, v as f32);
486        let k = k_real as usize;
487        if k < sorted.len() {
488            let cutoff = sorted[k - 1].1;
489            for (i, p) in probs.iter_mut().enumerate() {
490                if *p < cutoff {
491                    *p = 0.0;
492                }
493                let _ = i;
494            }
495            let s: f32 = probs.iter().sum();
496            if s > 0.0 {
497                for p in probs.iter_mut() {
498                    *p /= s;
499                }
500            }
501        }
502        // Draw the token now; convert back into logits by setting one-hot
503        // (downstream samplers softmax+sample again, which is OK since one-hot).
504        let tok = sample_from_probs(&probs, rng) as usize;
505        let surprise = if probs[tok] > 0.0 {
506            -probs[tok].ln() / 2.0f32.ln()
507        } else {
508            mu
509        };
510        state.mirostat_mu = (mu - self.eta * (surprise - self.tau)).max(0.0);
511        for (i, x) in logits.iter_mut().enumerate() {
512            *x = if i == tok {
513                f32::INFINITY
514            } else {
515                f32::NEG_INFINITY
516            };
517        }
518    }
519}
520
521// ─── Mirostat v2 ─────────────────────────────────────────────────────
522//
523// Simpler variant (commonly shipped by llama.cpp): keep tokens whose
524// surprise (in bits) is ≤ μ; sample, then update μ ← μ − η(s − τ).
525
526#[derive(Debug, Clone, Copy)]
527pub struct MirostatV2 {
528    pub tau: f32,
529    pub eta: f32,
530}
531
532impl Default for MirostatV2 {
533    fn default() -> Self {
534        Self { tau: 5.0, eta: 0.1 }
535    }
536}
537
538impl Sampler for MirostatV2 {
539    fn apply(
540        &self,
541        logits: Logits<'_>,
542        _h: &[u32],
543        state: &mut SamplerState,
544        rng: &mut Philox4x32,
545    ) {
546        let v = logits.len();
547        if v == 0 {
548            return;
549        }
550        if !state.mirostat_mu.is_finite() {
551            state.mirostat_mu = 2.0 * self.tau;
552        }
553        let mu = state.mirostat_mu;
554        let mut probs = logits.to_vec();
555        softmax_inplace(&mut probs);
556        // Sort descending; keep tokens with surprise ≤ μ.
557        let mut sorted = sorted_desc(&probs);
558        let ln2 = 2.0f32.ln();
559        let mut keep_n = 0usize;
560        for (i, (_, p)) in sorted.iter().enumerate() {
561            let s = if *p > 0.0 {
562                -p.ln() / ln2
563            } else {
564                f32::INFINITY
565            };
566            if s > mu {
567                break;
568            }
569            keep_n = i + 1;
570        }
571        if keep_n == 0 {
572            keep_n = 1;
573        }
574        let kept: std::collections::HashSet<usize> =
575            sorted.drain(..keep_n).map(|(i, _)| i).collect();
576        for (i, p) in probs.iter_mut().enumerate() {
577            if !kept.contains(&i) {
578                *p = 0.0;
579            }
580        }
581        let s: f32 = probs.iter().sum();
582        if s > 0.0 {
583            for p in probs.iter_mut() {
584                *p /= s;
585            }
586        }
587        let tok = sample_from_probs(&probs, rng) as usize;
588        let surprise = if probs[tok] > 0.0 {
589            -probs[tok].ln() / ln2
590        } else {
591            mu
592        };
593        state.mirostat_mu = (mu - self.eta * (surprise - self.tau)).max(0.0);
594        for (i, x) in logits.iter_mut().enumerate() {
595            *x = if i == tok {
596                f32::INFINITY
597            } else {
598                f32::NEG_INFINITY
599            };
600        }
601    }
602}
603
604// ─── XTC (eXclude Top Choices) ───────────────────────────────────────
605//
606// Drops top tokens whose prob > `threshold`, with probability `prob`, as
607// long as at least one such "high-confidence" token would remain. Forces
608// the model into long-tail tokens to break repetition / mode collapse.
609
610#[derive(Debug, Clone, Copy)]
611pub struct Xtc {
612    pub threshold: f32,
613    pub prob: f32,
614    /// Keep this many top tokens minimum even after exclusion.
615    pub min_keep: usize,
616}
617
618impl Sampler for Xtc {
619    fn apply(&self, logits: Logits<'_>, _h: &[u32], _s: &mut SamplerState, rng: &mut Philox4x32) {
620        if self.prob <= 0.0 {
621            return;
622        }
623        if rng.next_f32() > self.prob {
624            return;
625        }
626        let v = logits.len();
627        if v == 0 {
628            return;
629        }
630        let mut probs = logits.to_vec();
631        softmax_inplace(&mut probs);
632        let sorted = sorted_desc(&probs);
633        // Count tokens above threshold.
634        let n_above = sorted.iter().filter(|(_, p)| *p > self.threshold).count();
635        if n_above < 2 {
636            return; // not enough to exclude
637        }
638        // Exclude all but the lowest-ranked one above threshold so
639        // exactly one survives.
640        let to_kill = n_above.saturating_sub(self.min_keep.max(1));
641        for (idx, _) in sorted.iter().take(to_kill) {
642            logits[*idx] = f32::NEG_INFINITY;
643        }
644    }
645}
646
647// ─── DRY (Don't Repeat Yourself) ─────────────────────────────────────
648//
649// For each suffix of `history` of length n ≥ `allowed_length`, if the
650// next token would extend the suffix into a known n-gram match, scale
651// its logit down by `multiplier * base^(n − allowed_length)`. Caps at
652// `max_ngram` to bound work.
653
654#[derive(Debug, Clone)]
655pub struct Dry {
656    pub multiplier: f32,
657    pub base: f32,
658    pub allowed_length: usize,
659    pub max_ngram: usize,
660    /// Tokens that break a repetition run (e.g. newlines for prose, EOS).
661    pub sequence_breakers: Vec<u32>,
662}
663
664impl Default for Dry {
665    fn default() -> Self {
666        Self {
667            multiplier: 0.8,
668            base: 1.75,
669            allowed_length: 2,
670            max_ngram: 32,
671            sequence_breakers: Vec::new(),
672        }
673    }
674}
675
676impl Sampler for Dry {
677    fn apply(
678        &self,
679        logits: Logits<'_>,
680        history: &[u32],
681        _s: &mut SamplerState,
682        _r: &mut Philox4x32,
683    ) {
684        if self.multiplier <= 0.0 || history.is_empty() {
685            return;
686        }
687        let n = history.len();
688        let max_ngram = self.max_ngram.min(n);
689        let breakers: std::collections::HashSet<u32> =
690            self.sequence_breakers.iter().copied().collect();
691        // For each candidate next token t (only tokens that actually appear
692        // in history can match — bail early on the rest), find the longest
693        // suffix of history that matches a prefix ending at some earlier i.
694        // Track the longest match per token.
695        let mut longest: std::collections::HashMap<u32, usize> = std::collections::HashMap::new();
696        for i in 0..n.saturating_sub(1) {
697            if breakers.contains(&history[i]) {
698                continue;
699            }
700            // Length of common tail starting at history[..=i] (matching i with
701            // n-1, i-1 with n-2, …).
702            let mut l = 0usize;
703            while l < max_ngram && i >= l && n > l && history[i - l] == history[n - 1 - l] {
704                l += 1;
705            }
706            if l >= self.allowed_length && i + 1 < n {
707                let next = history[i + 1];
708                let cur = longest.entry(next).or_insert(0);
709                if l > *cur {
710                    *cur = l;
711                }
712            }
713        }
714        for (tok, l) in longest {
715            let pen = self.multiplier * self.base.powi((l - self.allowed_length) as i32);
716            let idx = tok as usize;
717            if idx < logits.len() {
718                logits[idx] -= pen;
719            }
720        }
721    }
722}
723
724// ─── Repetition penalty ──────────────────────────────────────────────
725
726#[derive(Debug, Clone, Copy)]
727pub struct RepetitionPenalty {
728    pub penalty: f32,
729    pub frequency: f32,
730    pub presence: f32,
731    /// Last N tokens of history to consider.
732    pub last_n: usize,
733}
734
735impl Default for RepetitionPenalty {
736    fn default() -> Self {
737        Self {
738            penalty: 1.0,
739            frequency: 0.0,
740            presence: 0.0,
741            last_n: 64,
742        }
743    }
744}
745
746impl Sampler for RepetitionPenalty {
747    fn apply(
748        &self,
749        logits: Logits<'_>,
750        history: &[u32],
751        _s: &mut SamplerState,
752        _r: &mut Philox4x32,
753    ) {
754        if history.is_empty() {
755            return;
756        }
757        let start = history.len().saturating_sub(self.last_n);
758        let window = &history[start..];
759        let mut counts: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
760        for &t in window {
761            *counts.entry(t).or_insert(0) += 1;
762        }
763        for (tok, c) in counts {
764            let idx = tok as usize;
765            if idx >= logits.len() {
766                continue;
767            }
768            // OpenAI-style: subtract presence + frequency * count.
769            logits[idx] -= self.presence + self.frequency * c as f32;
770            // Anthropic/llama.cpp-style: divide (for positive logits) or
771            // multiply (negative) by `penalty`.
772            if (self.penalty - 1.0).abs() > 1e-6 {
773                if logits[idx] > 0.0 {
774                    logits[idx] /= self.penalty;
775                } else {
776                    logits[idx] *= self.penalty;
777                }
778            }
779        }
780    }
781}
782
783// ─── tests ───────────────────────────────────────────────────────────
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788
789    fn rng() -> Philox4x32 {
790        Philox4x32::new(0xDEAD_BEEF)
791    }
792
793    #[test]
794    fn temperature_zero_is_greedy_after_chain() {
795        let chain = SamplerChain::builder()
796            .push(Temperature { t: 1e-6 })
797            .build();
798        let mut state = SamplerState::new();
799        let mut r = rng();
800        let mut logits = vec![1.0, 5.0, 2.0, 3.0];
801        let tok = chain.sample(&mut logits, &[], &mut state, &mut r);
802        assert_eq!(tok, 1);
803    }
804
805    #[test]
806    fn top_k_masks_below_kth() {
807        let mut logits = vec![1.0, 5.0, 2.0, 3.0];
808        let mut s = SamplerState::new();
809        let mut r = rng();
810        TopK { k: 2 }.apply(&mut logits, &[], &mut s, &mut r);
811        assert_eq!(logits[1], 5.0);
812        assert_eq!(logits[3], 3.0);
813        assert!(logits[0].is_infinite() && logits[0] < 0.0);
814        assert!(logits[2].is_infinite() && logits[2] < 0.0);
815    }
816
817    #[test]
818    fn top_p_keeps_nucleus() {
819        let mut logits = vec![0.0f32; 4];
820        logits[0] = 10.0;
821        logits[1] = 5.0;
822        let mut s = SamplerState::new();
823        let mut r = rng();
824        TopP {
825            p: 0.5,
826            min_keep: 1,
827        }
828        .apply(&mut logits, &[], &mut s, &mut r);
829        assert!(logits[0].is_finite());
830        // Lowest-probability tokens should be masked.
831        assert!(logits[2].is_infinite() && logits[2] < 0.0);
832        assert!(logits[3].is_infinite() && logits[3] < 0.0);
833    }
834
835    #[test]
836    fn top_n_sigma_keeps_top_logits() {
837        // One peak, rest noise.
838        let mut logits = vec![0.0f32; 32];
839        logits[0] = 10.0;
840        logits[1] = 9.5;
841        let mut s = SamplerState::new();
842        let mut r = rng();
843        TopNSigma { n: 1.0 }.apply(&mut logits, &[], &mut s, &mut r);
844        // Peak survives, the long flat tail is killed.
845        assert!(logits[0].is_finite());
846        assert!(logits[5].is_infinite() && logits[5] < 0.0);
847    }
848
849    #[test]
850    fn dynamic_temperature_scales_with_entropy() {
851        // Uniform input → maximum entropy → T = max.
852        let mut logits = vec![1.0f32; 16];
853        let before = logits.clone();
854        let mut s = SamplerState::new();
855        let mut r = rng();
856        DynamicTemperature {
857            min: 0.5,
858            max: 2.0,
859            exponent: 1.0,
860        }
861        .apply(&mut logits, &[], &mut s, &mut r);
862        assert!((logits[0] - before[0] / 2.0).abs() < 1e-5);
863    }
864
865    #[test]
866    fn typical_p_keeps_typical_token() {
867        let mut logits = vec![5.0, 4.0, 0.0, -10.0];
868        let mut s = SamplerState::new();
869        let mut r = rng();
870        TypicalP {
871            p: 0.5,
872            min_keep: 1,
873        }
874        .apply(&mut logits, &[], &mut s, &mut r);
875        // At least one token survives.
876        assert!(logits.iter().any(|x| x.is_finite()));
877    }
878
879    #[test]
880    fn mirostat_v2_keeps_at_least_one() {
881        let mut logits = vec![1.0, 2.0, 3.0, 4.0];
882        let mut s = SamplerState::new();
883        let mut r = rng();
884        MirostatV2 { tau: 5.0, eta: 0.1 }.apply(&mut logits, &[], &mut s, &mut r);
885        // Result is a one-hot encoding of the drawn token.
886        let n_inf = logits
887            .iter()
888            .filter(|x| x.is_infinite() && **x > 0.0)
889            .count();
890        assert_eq!(n_inf, 1);
891    }
892
893    #[test]
894    fn xtc_disabled_when_prob_zero() {
895        let mut logits = vec![10.0, 5.0, 1.0];
896        let before = logits.clone();
897        let mut s = SamplerState::new();
898        let mut r = rng();
899        Xtc {
900            threshold: 0.5,
901            prob: 0.0,
902            min_keep: 1,
903        }
904        .apply(&mut logits, &[], &mut s, &mut r);
905        assert_eq!(logits, before);
906    }
907
908    #[test]
909    fn dry_penalises_repeat_continuation() {
910        // History "A B A B A" → continuing the "B" pattern after "A" should be penalised.
911        let history = vec![0u32, 1, 0, 1, 0];
912        let mut logits = vec![0.0, 0.0];
913        let mut s = SamplerState::new();
914        let mut r = rng();
915        Dry {
916            multiplier: 1.0,
917            base: 2.0,
918            allowed_length: 2,
919            max_ngram: 8,
920            sequence_breakers: vec![],
921        }
922        .apply(&mut logits, &history, &mut s, &mut r);
923        assert!(logits[1] < 0.0, "B should be penalised; got {}", logits[1]);
924    }
925
926    #[test]
927    fn repetition_penalty_lowers_repeated_token() {
928        let history = vec![0u32; 8];
929        let mut logits = vec![1.0, 1.0];
930        let mut s = SamplerState::new();
931        let mut r = rng();
932        RepetitionPenalty {
933            penalty: 2.0,
934            frequency: 0.0,
935            presence: 0.0,
936            last_n: 64,
937        }
938        .apply(&mut logits, &history, &mut s, &mut r);
939        assert!(logits[0] < logits[1]);
940    }
941}