Skip to main content

oxibonsai_runtime/
sampling_advanced.rs

1//! Advanced sampling algorithms for text generation.
2//!
3//! This module extends the basic [`crate::sampling`] module with state-of-the-art
4//! sampling strategies used in modern LLM inference:
5//!
6//! - **[`MirostatV1Sampler`]** — feedback-controlled perplexity targeting (Baktash et al. 2020)
7//! - **[`MirostatV2Sampler`]** — simplified, more stable mirostat variant
8//! - **[`TypicalSampler`]** — locally typical sampling (Meister et al. 2023)
9//! - **[`EtaSampler`]** — entropy-adaptive cutoff sampling
10//! - **[`MinPSampler`]** — probabilistic nucleus based on min fraction of top token
11//! - **[`SamplerChain`]** — composable sampling pipeline with named presets
12//! - **[`LcgRng`]** — deterministic LCG pseudo-random number generator (no external deps)
13//!
14//! ## Helper functions
15//!
16//! Module-level helpers: [`softmax_inplace`], [`log_softmax`], [`entropy`],
17//! [`perplexity`], [`top_k_indices`], [`apply_temperature`], [`apply_repetition_penalty`].
18
19// ─────────────────────────────────────────────────────────────────────────────
20// LCG RNG
21// ─────────────────────────────────────────────────────────────────────────────
22
23/// Linear Congruential Generator — deterministic pseudo-random number generator.
24///
25/// Uses the multiplier and increment from Knuth's MMIX:
26/// `state = state * 6364136223846793005 + 1442695040888963407`
27///
28/// No external crate dependencies; suitable for reproducible sampling.
29#[derive(Debug, Clone)]
30pub struct LcgRng {
31    state: u64,
32}
33
34impl LcgRng {
35    /// Create a new LCG seeded with `seed`. Identical seeds produce identical streams.
36    pub fn new(seed: u64) -> Self {
37        // Mix the seed so that seed=0 doesn't get stuck near zero.
38        let state = seed
39            .wrapping_add(1442695040888963407)
40            .wrapping_mul(6364136223846793005);
41        Self { state }
42    }
43
44    /// Advance the generator and return the next raw 64-bit value.
45    pub fn next_u64(&mut self) -> u64 {
46        self.state = self
47            .state
48            .wrapping_mul(6364136223846793005)
49            .wrapping_add(1442695040888963407);
50        self.state
51    }
52
53    /// Return a sample in `[0.0, 1.0)`.
54    pub fn next_f32(&mut self) -> f32 {
55        // Use the top 24 bits for f32 mantissa precision.
56        let bits = (self.next_u64() >> 40) as u32;
57        bits as f32 / (1u32 << 24) as f32
58    }
59
60    /// Return a sample in `0..n` (exclusive). Panics if `n == 0`.
61    pub fn next_usize_below(&mut self, n: usize) -> usize {
62        assert!(n > 0, "n must be greater than zero");
63        (self.next_u64() % n as u64) as usize
64    }
65}
66
67// ─────────────────────────────────────────────────────────────────────────────
68// Helper functions
69// ─────────────────────────────────────────────────────────────────────────────
70
71/// Apply softmax in-place, subtracting the max for numerical stability.
72pub fn softmax_inplace(logits: &mut [f32]) {
73    if logits.is_empty() {
74        return;
75    }
76    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
77    let mut sum = 0.0_f32;
78    for v in logits.iter_mut() {
79        *v = (*v - max).exp();
80        sum += *v;
81    }
82    if sum > 0.0 {
83        for v in logits.iter_mut() {
84            *v /= sum;
85        }
86    }
87}
88
89/// Compute log-softmax for a slice of logits (numerically stable).
90pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
91    if logits.is_empty() {
92        return Vec::new();
93    }
94    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
95    let log_sum_exp = logits.iter().map(|&v| (v - max).exp()).sum::<f32>().ln() + max;
96    logits.iter().map(|&v| v - log_sum_exp).collect()
97}
98
99/// Compute the Shannon entropy (in nats) of a probability distribution.
100///
101/// Assumes `probs` sums to 1. Skips zero entries to avoid `ln(0) = -inf`.
102pub fn entropy(probs: &[f32]) -> f32 {
103    probs
104        .iter()
105        .filter(|&&p| p > 0.0)
106        .map(|&p| -p * p.ln())
107        .sum()
108}
109
110/// Compute perplexity from a slice of log-probabilities (natural log).
111///
112/// `perplexity = exp(mean(-log_prob))`
113pub fn perplexity(log_probs: &[f32]) -> f32 {
114    if log_probs.is_empty() {
115        return 1.0;
116    }
117    let mean_neg_log: f32 = log_probs.iter().map(|&lp| -lp).sum::<f32>() / log_probs.len() as f32;
118    mean_neg_log.exp()
119}
120
121/// Return the indices of the top-k highest logit values, sorted descending.
122pub fn top_k_indices(logits: &[f32], k: usize) -> Vec<usize> {
123    if k == 0 || logits.is_empty() {
124        return Vec::new();
125    }
126    let k = k.min(logits.len());
127    let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
128    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
129    indexed.truncate(k);
130    indexed.into_iter().map(|(i, _)| i).collect()
131}
132
133/// Divide all logits by `temp`. If `temp <= 0`, this is a no-op (caller should handle greedy).
134pub fn apply_temperature(logits: &mut [f32], temp: f32) {
135    if temp > 0.0 {
136        for v in logits.iter_mut() {
137            *v /= temp;
138        }
139    }
140}
141
142/// Apply repetition penalty to logits for previously-seen token ids.
143///
144/// Tokens with positive logits are divided by `penalty`; negative logits are multiplied.
145/// `penalty` should be > 1.0 to discourage repetition.
146pub fn apply_repetition_penalty(logits: &mut [f32], token_ids: &[u32], penalty: f32) {
147    if penalty == 1.0 || token_ids.is_empty() {
148        return;
149    }
150    for &id in token_ids {
151        let idx = id as usize;
152        if idx < logits.len() {
153            if logits[idx] >= 0.0 {
154                logits[idx] /= penalty;
155            } else {
156                logits[idx] *= penalty;
157            }
158        }
159    }
160}
161
162// ─────────────────────────────────────────────────────────────────────────────
163// Weighted categorical draw from a probability slice
164// ─────────────────────────────────────────────────────────────────────────────
165
166/// Draw an index from `probs` (must sum to 1) using the given RNG.
167/// Falls back to index 0 if no threshold is crossed (floating-point edge case).
168fn categorical_sample(probs: &[(usize, f32)], rng: &mut LcgRng) -> usize {
169    let u = rng.next_f32();
170    let mut cumsum = 0.0_f32;
171    for &(idx, p) in probs {
172        cumsum += p;
173        if u < cumsum {
174            return idx;
175        }
176    }
177    // Fallback — return highest-probability token.
178    probs.first().map(|&(i, _)| i).unwrap_or(0)
179}
180
181// ─────────────────────────────────────────────────────────────────────────────
182// Mirostat v1
183// ─────────────────────────────────────────────────────────────────────────────
184
185/// Mirostat v1 sampling — maintains target perplexity via feedback control.
186///
187/// Reference: Baktash et al., "Mirostat: A Neural Text Decoding Algorithm that
188/// Directly Controls Perplexity" (2020), <https://arxiv.org/abs/2007.14966>.
189///
190/// The algorithm:
191/// 1. Truncates the vocabulary to the top-`m` tokens.
192/// 2. Estimates the cross-entropy of the chosen token.
193/// 3. Updates `mu` (current estimate of target surprise) via `eta`.
194#[derive(Debug, Clone)]
195pub struct MirostatV1Sampler {
196    /// Target surprise level (bits). Default: `5.0`.
197    pub tau: f32,
198    /// Learning rate for the feedback loop. Default: `0.1`.
199    pub eta: f32,
200    /// Number of top candidates to consider. Typically `vocab_size / 2`.
201    pub m: usize,
202    /// Running estimate of the surprise level (initialised to `2 * tau`).
203    mu: f32,
204}
205
206impl MirostatV1Sampler {
207    /// Create a new v1 sampler.
208    pub fn new(tau: f32, eta: f32, m: usize) -> Self {
209        Self {
210            tau,
211            eta,
212            m,
213            mu: 2.0 * tau,
214        }
215    }
216
217    /// Sample a token index from raw logits, updating internal state.
218    pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
219        if logits.is_empty() {
220            return 0;
221        }
222
223        // Collect (index, logit) and sort descending.
224        let mut candidates: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
225        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
226
227        // Truncate to top-m.
228        let m = self.m.min(candidates.len()).max(1);
229        candidates.truncate(m);
230
231        // Softmax over the truncated set.
232        let max_v = candidates[0].1;
233        let mut sum = 0.0_f32;
234        for (_, v) in candidates.iter_mut() {
235            *v = (*v - max_v).exp();
236            sum += *v;
237        }
238        if sum > 0.0 {
239            for (_, v) in candidates.iter_mut() {
240                *v /= sum;
241            }
242        }
243
244        // Filter to tokens whose estimated surprise <= mu.
245        // Surprise of token i: -log2(p_i).
246        let filtered: Vec<(usize, f32)> = candidates
247            .iter()
248            .cloned()
249            .filter(|&(_, p)| p > 0.0 && (-p.log2()) <= self.mu)
250            .collect();
251
252        let pool = if filtered.is_empty() {
253            &candidates
254        } else {
255            &filtered
256        };
257
258        // Re-normalise the pool.
259        let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
260        let normalised: Vec<(usize, f32)> = if pool_sum > 0.0 {
261            pool.iter().map(|&(i, p)| (i, p / pool_sum)).collect()
262        } else {
263            pool.to_vec()
264        };
265
266        // Sample.
267        let chosen = categorical_sample(&normalised, rng);
268
269        // Compute observed surprise and update mu.
270        if let Some(&(_, p)) = normalised.iter().find(|&&(i, _)| i == chosen) {
271            if p > 0.0 {
272                let surprise = -p.log2();
273                self.mu -= self.eta * (surprise - self.tau);
274            }
275        }
276
277        chosen
278    }
279
280    /// Reset the internal state to the initial value.
281    pub fn reset(&mut self) {
282        self.mu = 2.0 * self.tau;
283    }
284}
285
286// ─────────────────────────────────────────────────────────────────────────────
287// Mirostat v2
288// ─────────────────────────────────────────────────────────────────────────────
289
290/// Mirostat v2 sampling — simpler and more stable than v1.
291///
292/// Rather than pre-truncating to top-m, v2 dynamically computes a probability
293/// threshold from `mu`, discards tokens below it, then samples from the rest.
294#[derive(Debug, Clone)]
295pub struct MirostatV2Sampler {
296    /// Target surprise level (bits). Default: `5.0`.
297    pub tau: f32,
298    /// Learning rate for the feedback loop. Default: `0.1`.
299    pub eta: f32,
300    /// Running surprise estimate (initialised to `2 * tau`).
301    mu: f32,
302}
303
304impl MirostatV2Sampler {
305    /// Create a new v2 sampler.
306    pub fn new(tau: f32, eta: f32) -> Self {
307        Self {
308            tau,
309            eta,
310            mu: 2.0 * tau,
311        }
312    }
313
314    /// Sample a token index from raw logits, updating internal state.
315    pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
316        if logits.is_empty() {
317            return 0;
318        }
319
320        // Full softmax.
321        let mut probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
322        {
323            let max_v = probs
324                .iter()
325                .map(|(_, v)| *v)
326                .fold(f32::NEG_INFINITY, f32::max);
327            let mut sum = 0.0_f32;
328            for (_, v) in probs.iter_mut() {
329                *v = (*v - max_v).exp();
330                sum += *v;
331            }
332            if sum > 0.0 {
333                for (_, v) in probs.iter_mut() {
334                    *v /= sum;
335                }
336            }
337        }
338
339        // The threshold probability corresponding to self.mu bits of surprise:
340        // p_threshold = 2^{-mu}
341        let threshold = (-self.mu * std::f32::consts::LN_2).exp();
342
343        let mut pool: Vec<(usize, f32)> = probs
344            .iter()
345            .cloned()
346            .filter(|&(_, p)| p >= threshold)
347            .collect();
348
349        if pool.is_empty() {
350            // Fallback: keep top-1 token.
351            probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
352            pool.push(probs[0]);
353        }
354
355        // Re-normalise pool.
356        let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
357        if pool_sum > 0.0 {
358            for (_, p) in pool.iter_mut() {
359                *p /= pool_sum;
360            }
361        }
362
363        let chosen = categorical_sample(&pool, rng);
364
365        // Update mu from observed surprise.
366        if let Some(&(_, p)) = pool.iter().find(|&&(i, _)| i == chosen) {
367            if p > 0.0 {
368                let surprise = -p.log2();
369                self.mu -= self.eta * (surprise - self.tau);
370            }
371        }
372
373        chosen
374    }
375
376    /// Reset the internal state to the initial value.
377    pub fn reset(&mut self) {
378        self.mu = 2.0 * self.tau;
379    }
380
381    /// Current mu value (for diagnostics / tests).
382    pub fn mu(&self) -> f32 {
383        self.mu
384    }
385}
386
387// ─────────────────────────────────────────────────────────────────────────────
388// Locally Typical Sampling
389// ─────────────────────────────────────────────────────────────────────────────
390
391/// Locally Typical sampling (Meister et al., "Locally Typical Sampling", 2023).
392///
393/// Keeps the smallest set of tokens whose information content is closest to the
394/// conditional entropy of the distribution, summing to at least `p` probability mass.
395#[derive(Debug, Clone)]
396pub struct TypicalSampler {
397    /// Cumulative probability mass to retain. Default: `0.9`.
398    pub p: f32,
399    /// Minimum number of candidates to keep regardless of `p`. Default: `1`.
400    pub min_keep: usize,
401}
402
403impl TypicalSampler {
404    /// Create a new typical sampler.
405    pub fn new(p: f32, min_keep: usize) -> Self {
406        Self {
407            p: p.clamp(0.0, 1.0),
408            min_keep: min_keep.max(1),
409        }
410    }
411
412    /// Sample a token index from raw logits.
413    pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
414        if logits.is_empty() {
415            return 0;
416        }
417
418        // Compute log-softmax → log-probs and probs.
419        let log_probs = log_softmax(logits);
420        let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
421
422        // Conditional entropy H = -sum_i p_i * log(p_i).
423        let h = entropy(&probs);
424
425        // Compute |log(p_i) - H| for each token — how "typical" it is.
426        let mut candidates: Vec<(usize, f32, f32)> = log_probs
427            .iter()
428            .cloned()
429            .zip(probs.iter().cloned())
430            .enumerate()
431            .map(|(i, (lp, p))| {
432                let typicality = (-lp - h).abs();
433                (i, p, typicality)
434            })
435            .collect();
436
437        // Sort ascending by typicality (most typical first).
438        candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
439
440        // Keep tokens until we accumulate >= p probability mass.
441        let mut cumsum = 0.0_f32;
442        let mut keep = 0;
443        for (k, &(_, p, _)) in candidates.iter().enumerate() {
444            cumsum += p;
445            keep = k + 1;
446            if cumsum >= self.p && keep >= self.min_keep {
447                break;
448            }
449        }
450        keep = keep.max(self.min_keep).min(candidates.len());
451        candidates.truncate(keep);
452
453        // Re-normalise and sample.
454        let total: f32 = candidates.iter().map(|(_, p, _)| p).sum();
455        let normalised: Vec<(usize, f32)> = candidates
456            .iter()
457            .map(|&(i, p, _)| (i, if total > 0.0 { p / total } else { p }))
458            .collect();
459
460        categorical_sample(&normalised, rng)
461    }
462}
463
464// ─────────────────────────────────────────────────────────────────────────────
465// Eta Sampling
466// ─────────────────────────────────────────────────────────────────────────────
467
468/// Eta sampling — adaptively selects a probability cutoff based on distribution entropy.
469///
470/// The cutoff is `max(epsilon, sqrt(exp(-H(p))) * delta)` where `H` is the entropy.
471/// Tokens below the cutoff are discarded.
472#[derive(Debug, Clone)]
473pub struct EtaSampler {
474    /// Minimum token probability (floor). Default: `0.0009`.
475    pub epsilon: f32,
476    /// Entropy scaling factor for adaptive threshold. Default: `0.07`.
477    pub delta: f32,
478}
479
480impl EtaSampler {
481    /// Create a new eta sampler.
482    pub fn new(epsilon: f32, delta: f32) -> Self {
483        Self { epsilon, delta }
484    }
485
486    /// Sample a token index from raw logits.
487    pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
488        if logits.is_empty() {
489            return 0;
490        }
491
492        let mut probs: Vec<f32> = logits.to_vec();
493        softmax_inplace(&mut probs);
494
495        // Adaptive threshold.
496        let h = entropy(&probs);
497        let eta_threshold = (self.epsilon).max((-h).exp().sqrt() * self.delta);
498
499        let mut candidates: Vec<(usize, f32)> = probs
500            .iter()
501            .cloned()
502            .enumerate()
503            .filter(|&(_, p)| p >= eta_threshold)
504            .collect();
505
506        if candidates.is_empty() {
507            // Fallback: take argmax.
508            let best = probs
509                .iter()
510                .cloned()
511                .enumerate()
512                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
513                .map(|(i, _)| i)
514                .unwrap_or(0);
515            return best;
516        }
517
518        // Re-normalise.
519        let total: f32 = candidates.iter().map(|(_, p)| p).sum();
520        if total > 0.0 {
521            for (_, p) in candidates.iter_mut() {
522                *p /= total;
523            }
524        }
525
526        categorical_sample(&candidates, rng)
527    }
528}
529
530// ─────────────────────────────────────────────────────────────────────────────
531// Min-P Sampling
532// ─────────────────────────────────────────────────────────────────────────────
533
534/// Min-P sampling — probabilistic nucleus based on a minimum fraction of the top-token probability.
535///
536/// Keeps all tokens `i` where `p_i >= min_p * max(p)`.
537#[derive(Debug, Clone)]
538pub struct MinPSampler {
539    /// Minimum fraction of the maximum probability. Default: `0.05`.
540    pub min_p: f32,
541    /// Minimum candidates to keep regardless of the threshold. Default: `1`.
542    pub min_keep: usize,
543}
544
545impl MinPSampler {
546    /// Create a new Min-P sampler.
547    pub fn new(min_p: f32, min_keep: usize) -> Self {
548        Self {
549            min_p: min_p.clamp(0.0, 1.0),
550            min_keep: min_keep.max(1),
551        }
552    }
553
554    /// Sample a token index from raw logits.
555    pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
556        if logits.is_empty() {
557            return 0;
558        }
559
560        let mut probs: Vec<f32> = logits.to_vec();
561        softmax_inplace(&mut probs);
562
563        let max_p = probs.iter().cloned().fold(0.0_f32, f32::max);
564        let threshold = self.min_p * max_p;
565
566        let mut candidates: Vec<(usize, f32)> = probs
567            .iter()
568            .cloned()
569            .enumerate()
570            .filter(|&(_, p)| p >= threshold)
571            .collect();
572
573        // Ensure min_keep.
574        if candidates.len() < self.min_keep {
575            // Sort all probs descending and take top min_keep.
576            let mut all: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
577            all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
578            candidates = all.into_iter().take(self.min_keep).collect();
579        }
580
581        // Re-normalise.
582        let total: f32 = candidates.iter().map(|(_, p)| p).sum();
583        if total > 0.0 {
584            for (_, p) in candidates.iter_mut() {
585                *p /= total;
586            }
587        }
588
589        categorical_sample(&candidates, rng)
590    }
591}
592
593// ─────────────────────────────────────────────────────────────────────────────
594// Sampler Chain
595// ─────────────────────────────────────────────────────────────────────────────
596
597/// A single step in a [`SamplerChain`] pipeline.
598///
599/// Steps are applied in order to the logit vector before final sampling.
600#[derive(Debug, Clone)]
601pub enum SamplerStep {
602    /// Divide logits by temperature. Values near 0 produce near-greedy output.
603    Temperature(f32),
604    /// Penalise previously-seen tokens to reduce repetition.
605    RepetitionPenalty {
606        /// Penalty multiplier (>1.0 discourages repetition).
607        penalty: f32,
608        /// Number of recent tokens to consider (window).
609        last_n: usize,
610        /// The recent token ids to penalise.
611        tokens: Vec<u32>,
612    },
613    /// Keep only the top-k highest-logit candidates.
614    TopK(usize),
615    /// Nucleus (top-p) filtering.
616    TopP(f32),
617    /// Min-P filtering (min fraction of top token probability).
618    MinP(f32),
619    /// Locally typical sampling with probability mass `p`.
620    Typical(f32),
621    /// Mirostat v2 with given tau and eta.
622    Mirostat2 {
623        /// Target surprise (bits).
624        tau: f32,
625        /// Learning rate.
626        eta: f32,
627    },
628    /// Always pick the argmax (no randomness).
629    Greedy,
630}
631
632/// Composable sampling pipeline.
633///
634/// Steps are applied sequentially to the logit vector. The first `Greedy` or
635/// `Mirostat2` step that yields a token terminates the pipeline. All other steps
636/// modify the logit/probability vector in place.
637///
638/// # Example
639/// ```rust
640/// use oxibonsai_runtime::sampling_advanced::{SamplerChain, SamplerStep};
641///
642/// let mut chain = SamplerChain::default_chat(42);
643/// let mut logits = vec![1.0_f32, 5.0, 2.0, 3.0];
644/// let token = chain.sample(&mut logits);
645/// assert!(token < 4);
646/// ```
647#[derive(Debug, Clone)]
648pub struct SamplerChain {
649    steps: Vec<SamplerStep>,
650    rng: LcgRng,
651    /// Persistent Mirostat v2 state (one per chain).
652    mirostat2: Option<MirostatV2Sampler>,
653}
654
655impl SamplerChain {
656    /// Create an empty chain with the given RNG seed.
657    pub fn new(seed: u64) -> Self {
658        Self {
659            steps: Vec::new(),
660            rng: LcgRng::new(seed),
661            mirostat2: None,
662        }
663    }
664
665    /// Append a step to the chain (builder pattern).
666    #[allow(clippy::should_implement_trait)]
667    pub fn add(mut self, step: SamplerStep) -> Self {
668        // If Mirostat2 step is added, initialise persistent state.
669        if let SamplerStep::Mirostat2 { tau, eta } = step {
670            self.mirostat2 = Some(MirostatV2Sampler::new(tau, eta));
671        }
672        self.steps.push(step);
673        self
674    }
675
676    /// Sample from the given logits, applying all steps in order.
677    ///
678    /// `logits` is consumed/mutated during processing.
679    pub fn sample(&mut self, logits: &mut Vec<f32>) -> usize {
680        if logits.is_empty() {
681            return 0;
682        }
683
684        for step in &self.steps {
685            match step {
686                SamplerStep::Temperature(temp) => {
687                    if *temp < 1e-6 {
688                        // Treat as greedy immediately.
689                        return argmax_slice(logits);
690                    }
691                    apply_temperature(logits, *temp);
692                }
693
694                SamplerStep::RepetitionPenalty {
695                    penalty,
696                    last_n,
697                    tokens,
698                } => {
699                    let window = if *last_n == 0 {
700                        tokens.as_slice()
701                    } else {
702                        let start = tokens.len().saturating_sub(*last_n);
703                        &tokens[start..]
704                    };
705                    apply_repetition_penalty(logits, window, *penalty);
706                }
707
708                SamplerStep::TopK(k) => {
709                    if *k > 0 && *k < logits.len() {
710                        let indices = top_k_indices(logits, *k);
711                        let mut mask = vec![f32::NEG_INFINITY; logits.len()];
712                        for i in indices {
713                            mask[i] = logits[i];
714                        }
715                        *logits = mask;
716                    }
717                }
718
719                SamplerStep::TopP(p) => {
720                    if *p < 1.0 {
721                        apply_top_p(logits, *p, &mut self.rng);
722                        // top_p returns early — but we continue to let sampling happen below.
723                    }
724                }
725
726                SamplerStep::MinP(min_p) => {
727                    let sampler = MinPSampler::new(*min_p, 1);
728                    return sampler.sample(logits, &mut self.rng);
729                }
730
731                SamplerStep::Typical(p) => {
732                    let sampler = TypicalSampler::new(*p, 1);
733                    return sampler.sample(logits, &mut self.rng);
734                }
735
736                SamplerStep::Mirostat2 { .. } => {
737                    // Use persistent state stored in self.mirostat2.
738                    if let Some(ref mut ms) = self.mirostat2 {
739                        return ms.sample(logits, &mut self.rng);
740                    }
741                }
742
743                SamplerStep::Greedy => {
744                    return argmax_slice(logits);
745                }
746            }
747        }
748
749        // Default: softmax then weighted sample.
750        softmax_inplace(logits);
751        let probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
752        categorical_sample(&probs, &mut self.rng)
753    }
754
755    // ── Presets ──────────────────────────────────────────────────────────────
756
757    /// Greedy decoding — always picks the token with the highest logit.
758    pub fn greedy() -> Self {
759        Self::new(0).add(SamplerStep::Greedy)
760    }
761
762    /// Default chat preset: temperature(0.7) → top_p(0.9) → min_p(0.05).
763    pub fn default_chat(seed: u64) -> Self {
764        Self::new(seed)
765            .add(SamplerStep::Temperature(0.7))
766            .add(SamplerStep::TopP(0.9))
767            .add(SamplerStep::MinP(0.05))
768    }
769
770    /// Creative preset: temperature(1.0) → mirostat_v2(tau=5.0, eta=0.1).
771    pub fn creative(seed: u64) -> Self {
772        Self::new(seed)
773            .add(SamplerStep::Temperature(1.0))
774            .add(SamplerStep::Mirostat2 { tau: 5.0, eta: 0.1 })
775    }
776
777    /// Precise preset: temperature(0.3) → top_k(40) → top_p(0.9).
778    pub fn precise(seed: u64) -> Self {
779        Self::new(seed)
780            .add(SamplerStep::Temperature(0.3))
781            .add(SamplerStep::TopK(40))
782            .add(SamplerStep::TopP(0.9))
783    }
784}
785
786// ─────────────────────────────────────────────────────────────────────────────
787// Internal helpers
788// ─────────────────────────────────────────────────────────────────────────────
789
790/// Return the index of the maximum element (ties broken by lowest index).
791fn argmax_slice(values: &[f32]) -> usize {
792    values
793        .iter()
794        .cloned()
795        .enumerate()
796        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
797        .map(|(i, _)| i)
798        .unwrap_or(0)
799}
800
801/// Apply top-p (nucleus) filtering to a logit vector in-place.
802///
803/// Tokens outside the nucleus are set to `NEG_INFINITY` so they are excluded
804/// by a subsequent softmax + sample step.
805fn apply_top_p(logits: &mut [f32], p: f32, _rng: &mut LcgRng) {
806    // Compute softmax probabilities.
807    let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
808    let mut probs: Vec<(usize, f32)> = logits
809        .iter()
810        .enumerate()
811        .map(|(i, &v)| (i, (v - max_v).exp()))
812        .collect();
813    let total: f32 = probs.iter().map(|(_, v)| v).sum();
814    if total > 0.0 {
815        for (_, v) in probs.iter_mut() {
816            *v /= total;
817        }
818    }
819
820    // Sort descending by probability.
821    probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
822
823    // Find nucleus boundary.
824    let mut cumsum = 0.0_f32;
825    let mut nucleus_end = 0;
826    for (k, &(_, prob)) in probs.iter().enumerate() {
827        cumsum += prob;
828        nucleus_end = k;
829        if cumsum >= p {
830            break;
831        }
832    }
833
834    // Collect nucleus indices.
835    let nucleus_indices: std::collections::HashSet<usize> =
836        probs[..=nucleus_end].iter().map(|&(i, _)| i).collect();
837
838    // Mask out non-nucleus tokens.
839    for (i, v) in logits.iter_mut().enumerate() {
840        if !nucleus_indices.contains(&i) {
841            *v = f32::NEG_INFINITY;
842        }
843    }
844}
845
846// ─────────────────────────────────────────────────────────────────────────────
847// Unit tests (module-internal)
848// ─────────────────────────────────────────────────────────────────────────────
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    #[test]
855    fn lcg_rng_produces_values() {
856        let mut rng = LcgRng::new(1);
857        let v = rng.next_f32();
858        assert!((0.0..1.0).contains(&v), "f32 out of range: {v}");
859    }
860
861    #[test]
862    fn softmax_sums_to_one() {
863        let mut logits = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
864        softmax_inplace(&mut logits);
865        let sum: f32 = logits.iter().sum();
866        assert!((sum - 1.0).abs() < 1e-5, "sum={sum}");
867    }
868
869    #[test]
870    fn mirostat_v2_returns_valid_index() {
871        let logits = vec![1.0_f32, 5.0, 2.0, 3.0];
872        let mut sampler = MirostatV2Sampler::new(5.0, 0.1);
873        let mut rng = LcgRng::new(99);
874        let idx = sampler.sample(&logits, &mut rng);
875        assert!(idx < logits.len());
876    }
877
878    #[test]
879    fn sampler_chain_greedy_preset() {
880        let mut chain = SamplerChain::greedy();
881        let mut logits = vec![0.1_f32, 5.0, 0.2, 0.3];
882        let tok = chain.sample(&mut logits);
883        assert_eq!(tok, 1); // index of 5.0
884    }
885}