oxibonsai_runtime/sampling_advanced.rs
1//! Advanced sampling algorithms for text generation.
2//!
3//! This module extends the basic [`crate::sampling`] module with state-of-the-art
4//! sampling strategies used in modern LLM inference:
5//!
6//! - **[`MirostatV1Sampler`]** — feedback-controlled perplexity targeting (Baktash et al. 2020)
7//! - **[`MirostatV2Sampler`]** — simplified, more stable mirostat variant
8//! - **[`TypicalSampler`]** — locally typical sampling (Meister et al. 2023)
9//! - **[`EtaSampler`]** — entropy-adaptive cutoff sampling
10//! - **[`MinPSampler`]** — probabilistic nucleus based on min fraction of top token
11//! - **[`SamplerChain`]** — composable sampling pipeline with named presets
12//! - **[`LcgRng`]** — deterministic LCG pseudo-random number generator (no external deps)
13//!
14//! ## Helper functions
15//!
16//! Module-level helpers: [`softmax_inplace`], [`log_softmax`], [`entropy`],
17//! [`perplexity`], [`top_k_indices`], [`apply_temperature`], [`apply_repetition_penalty`].
18
19// ─────────────────────────────────────────────────────────────────────────────
20// LCG RNG
21// ─────────────────────────────────────────────────────────────────────────────
22
23/// Linear Congruential Generator — deterministic pseudo-random number generator.
24///
25/// Uses the multiplier and increment from Knuth's MMIX:
26/// `state = state * 6364136223846793005 + 1442695040888963407`
27///
28/// No external crate dependencies; suitable for reproducible sampling.
29#[derive(Debug, Clone)]
30pub struct LcgRng {
31 state: u64,
32}
33
34impl LcgRng {
35 /// Create a new LCG seeded with `seed`. Identical seeds produce identical streams.
36 pub fn new(seed: u64) -> Self {
37 // Mix the seed so that seed=0 doesn't get stuck near zero.
38 let state = seed
39 .wrapping_add(1442695040888963407)
40 .wrapping_mul(6364136223846793005);
41 Self { state }
42 }
43
44 /// Advance the generator and return the next raw 64-bit value.
45 pub fn next_u64(&mut self) -> u64 {
46 self.state = self
47 .state
48 .wrapping_mul(6364136223846793005)
49 .wrapping_add(1442695040888963407);
50 self.state
51 }
52
53 /// Return a sample in `[0.0, 1.0)`.
54 pub fn next_f32(&mut self) -> f32 {
55 // Use the top 24 bits for f32 mantissa precision.
56 let bits = (self.next_u64() >> 40) as u32;
57 bits as f32 / (1u32 << 24) as f32
58 }
59
60 /// Return a sample in `0..n` (exclusive). Panics if `n == 0`.
61 pub fn next_usize_below(&mut self, n: usize) -> usize {
62 assert!(n > 0, "n must be greater than zero");
63 (self.next_u64() % n as u64) as usize
64 }
65}
66
67// ─────────────────────────────────────────────────────────────────────────────
68// Helper functions
69// ─────────────────────────────────────────────────────────────────────────────
70
71/// Apply softmax in-place, subtracting the max for numerical stability.
72pub fn softmax_inplace(logits: &mut [f32]) {
73 if logits.is_empty() {
74 return;
75 }
76 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
77 let mut sum = 0.0_f32;
78 for v in logits.iter_mut() {
79 *v = (*v - max).exp();
80 sum += *v;
81 }
82 if sum > 0.0 {
83 for v in logits.iter_mut() {
84 *v /= sum;
85 }
86 }
87}
88
89/// Compute log-softmax for a slice of logits (numerically stable).
90pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
91 if logits.is_empty() {
92 return Vec::new();
93 }
94 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
95 let log_sum_exp = logits.iter().map(|&v| (v - max).exp()).sum::<f32>().ln() + max;
96 logits.iter().map(|&v| v - log_sum_exp).collect()
97}
98
99/// Compute the Shannon entropy (in nats) of a probability distribution.
100///
101/// Assumes `probs` sums to 1. Skips zero entries to avoid `ln(0) = -inf`.
102pub fn entropy(probs: &[f32]) -> f32 {
103 probs
104 .iter()
105 .filter(|&&p| p > 0.0)
106 .map(|&p| -p * p.ln())
107 .sum()
108}
109
110/// Compute perplexity from a slice of log-probabilities (natural log).
111///
112/// `perplexity = exp(mean(-log_prob))`
113pub fn perplexity(log_probs: &[f32]) -> f32 {
114 if log_probs.is_empty() {
115 return 1.0;
116 }
117 let mean_neg_log: f32 = log_probs.iter().map(|&lp| -lp).sum::<f32>() / log_probs.len() as f32;
118 mean_neg_log.exp()
119}
120
121/// Return the indices of the top-k highest logit values, sorted descending.
122pub fn top_k_indices(logits: &[f32], k: usize) -> Vec<usize> {
123 if k == 0 || logits.is_empty() {
124 return Vec::new();
125 }
126 let k = k.min(logits.len());
127 let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
128 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
129 indexed.truncate(k);
130 indexed.into_iter().map(|(i, _)| i).collect()
131}
132
133/// Divide all logits by `temp`. If `temp <= 0`, this is a no-op (caller should handle greedy).
134pub fn apply_temperature(logits: &mut [f32], temp: f32) {
135 if temp > 0.0 {
136 for v in logits.iter_mut() {
137 *v /= temp;
138 }
139 }
140}
141
142/// Apply repetition penalty to logits for previously-seen token ids.
143///
144/// Tokens with positive logits are divided by `penalty`; negative logits are multiplied.
145/// `penalty` should be > 1.0 to discourage repetition.
146pub fn apply_repetition_penalty(logits: &mut [f32], token_ids: &[u32], penalty: f32) {
147 if penalty == 1.0 || token_ids.is_empty() {
148 return;
149 }
150 for &id in token_ids {
151 let idx = id as usize;
152 if idx < logits.len() {
153 if logits[idx] >= 0.0 {
154 logits[idx] /= penalty;
155 } else {
156 logits[idx] *= penalty;
157 }
158 }
159 }
160}
161
162// ─────────────────────────────────────────────────────────────────────────────
163// Weighted categorical draw from a probability slice
164// ─────────────────────────────────────────────────────────────────────────────
165
166/// Draw an index from `probs` (must sum to 1) using the given RNG.
167/// Falls back to index 0 if no threshold is crossed (floating-point edge case).
168fn categorical_sample(probs: &[(usize, f32)], rng: &mut LcgRng) -> usize {
169 let u = rng.next_f32();
170 let mut cumsum = 0.0_f32;
171 for &(idx, p) in probs {
172 cumsum += p;
173 if u < cumsum {
174 return idx;
175 }
176 }
177 // Fallback — return highest-probability token.
178 probs.first().map(|&(i, _)| i).unwrap_or(0)
179}
180
181// ─────────────────────────────────────────────────────────────────────────────
182// Mirostat v1
183// ─────────────────────────────────────────────────────────────────────────────
184
185/// Mirostat v1 sampling — maintains target perplexity via feedback control.
186///
187/// Reference: Baktash et al., "Mirostat: A Neural Text Decoding Algorithm that
188/// Directly Controls Perplexity" (2020), <https://arxiv.org/abs/2007.14966>.
189///
190/// The algorithm:
191/// 1. Truncates the vocabulary to the top-`m` tokens.
192/// 2. Estimates the cross-entropy of the chosen token.
193/// 3. Updates `mu` (current estimate of target surprise) via `eta`.
194#[derive(Debug, Clone)]
195pub struct MirostatV1Sampler {
196 /// Target surprise level (bits). Default: `5.0`.
197 pub tau: f32,
198 /// Learning rate for the feedback loop. Default: `0.1`.
199 pub eta: f32,
200 /// Number of top candidates to consider. Typically `vocab_size / 2`.
201 pub m: usize,
202 /// Running estimate of the surprise level (initialised to `2 * tau`).
203 mu: f32,
204}
205
206impl MirostatV1Sampler {
207 /// Create a new v1 sampler.
208 pub fn new(tau: f32, eta: f32, m: usize) -> Self {
209 Self {
210 tau,
211 eta,
212 m,
213 mu: 2.0 * tau,
214 }
215 }
216
217 /// Sample a token index from raw logits, updating internal state.
218 pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
219 if logits.is_empty() {
220 return 0;
221 }
222
223 // Collect (index, logit) and sort descending.
224 let mut candidates: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
225 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
226
227 // Truncate to top-m.
228 let m = self.m.min(candidates.len()).max(1);
229 candidates.truncate(m);
230
231 // Softmax over the truncated set.
232 let max_v = candidates[0].1;
233 let mut sum = 0.0_f32;
234 for (_, v) in candidates.iter_mut() {
235 *v = (*v - max_v).exp();
236 sum += *v;
237 }
238 if sum > 0.0 {
239 for (_, v) in candidates.iter_mut() {
240 *v /= sum;
241 }
242 }
243
244 // Filter to tokens whose estimated surprise <= mu.
245 // Surprise of token i: -log2(p_i).
246 let filtered: Vec<(usize, f32)> = candidates
247 .iter()
248 .cloned()
249 .filter(|&(_, p)| p > 0.0 && (-p.log2()) <= self.mu)
250 .collect();
251
252 let pool = if filtered.is_empty() {
253 &candidates
254 } else {
255 &filtered
256 };
257
258 // Re-normalise the pool.
259 let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
260 let normalised: Vec<(usize, f32)> = if pool_sum > 0.0 {
261 pool.iter().map(|&(i, p)| (i, p / pool_sum)).collect()
262 } else {
263 pool.to_vec()
264 };
265
266 // Sample.
267 let chosen = categorical_sample(&normalised, rng);
268
269 // Compute observed surprise and update mu.
270 if let Some(&(_, p)) = normalised.iter().find(|&&(i, _)| i == chosen) {
271 if p > 0.0 {
272 let surprise = -p.log2();
273 self.mu -= self.eta * (surprise - self.tau);
274 }
275 }
276
277 chosen
278 }
279
280 /// Reset the internal state to the initial value.
281 pub fn reset(&mut self) {
282 self.mu = 2.0 * self.tau;
283 }
284}
285
286// ─────────────────────────────────────────────────────────────────────────────
287// Mirostat v2
288// ─────────────────────────────────────────────────────────────────────────────
289
290/// Mirostat v2 sampling — simpler and more stable than v1.
291///
292/// Rather than pre-truncating to top-m, v2 dynamically computes a probability
293/// threshold from `mu`, discards tokens below it, then samples from the rest.
294#[derive(Debug, Clone)]
295pub struct MirostatV2Sampler {
296 /// Target surprise level (bits). Default: `5.0`.
297 pub tau: f32,
298 /// Learning rate for the feedback loop. Default: `0.1`.
299 pub eta: f32,
300 /// Running surprise estimate (initialised to `2 * tau`).
301 mu: f32,
302}
303
304impl MirostatV2Sampler {
305 /// Create a new v2 sampler.
306 pub fn new(tau: f32, eta: f32) -> Self {
307 Self {
308 tau,
309 eta,
310 mu: 2.0 * tau,
311 }
312 }
313
314 /// Sample a token index from raw logits, updating internal state.
315 pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
316 if logits.is_empty() {
317 return 0;
318 }
319
320 // Full softmax.
321 let mut probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
322 {
323 let max_v = probs
324 .iter()
325 .map(|(_, v)| *v)
326 .fold(f32::NEG_INFINITY, f32::max);
327 let mut sum = 0.0_f32;
328 for (_, v) in probs.iter_mut() {
329 *v = (*v - max_v).exp();
330 sum += *v;
331 }
332 if sum > 0.0 {
333 for (_, v) in probs.iter_mut() {
334 *v /= sum;
335 }
336 }
337 }
338
339 // The threshold probability corresponding to self.mu bits of surprise:
340 // p_threshold = 2^{-mu}
341 let threshold = (-self.mu * std::f32::consts::LN_2).exp();
342
343 let mut pool: Vec<(usize, f32)> = probs
344 .iter()
345 .cloned()
346 .filter(|&(_, p)| p >= threshold)
347 .collect();
348
349 if pool.is_empty() {
350 // Fallback: keep top-1 token.
351 probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
352 pool.push(probs[0]);
353 }
354
355 // Re-normalise pool.
356 let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
357 if pool_sum > 0.0 {
358 for (_, p) in pool.iter_mut() {
359 *p /= pool_sum;
360 }
361 }
362
363 let chosen = categorical_sample(&pool, rng);
364
365 // Update mu from observed surprise.
366 if let Some(&(_, p)) = pool.iter().find(|&&(i, _)| i == chosen) {
367 if p > 0.0 {
368 let surprise = -p.log2();
369 self.mu -= self.eta * (surprise - self.tau);
370 }
371 }
372
373 chosen
374 }
375
376 /// Reset the internal state to the initial value.
377 pub fn reset(&mut self) {
378 self.mu = 2.0 * self.tau;
379 }
380
381 /// Current mu value (for diagnostics / tests).
382 pub fn mu(&self) -> f32 {
383 self.mu
384 }
385}
386
387// ─────────────────────────────────────────────────────────────────────────────
388// Locally Typical Sampling
389// ─────────────────────────────────────────────────────────────────────────────
390
391/// Locally Typical sampling (Meister et al., "Locally Typical Sampling", 2023).
392///
393/// Keeps the smallest set of tokens whose information content is closest to the
394/// conditional entropy of the distribution, summing to at least `p` probability mass.
395#[derive(Debug, Clone)]
396pub struct TypicalSampler {
397 /// Cumulative probability mass to retain. Default: `0.9`.
398 pub p: f32,
399 /// Minimum number of candidates to keep regardless of `p`. Default: `1`.
400 pub min_keep: usize,
401}
402
403impl TypicalSampler {
404 /// Create a new typical sampler.
405 pub fn new(p: f32, min_keep: usize) -> Self {
406 Self {
407 p: p.clamp(0.0, 1.0),
408 min_keep: min_keep.max(1),
409 }
410 }
411
412 /// Sample a token index from raw logits.
413 pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
414 if logits.is_empty() {
415 return 0;
416 }
417
418 // Compute log-softmax → log-probs and probs.
419 let log_probs = log_softmax(logits);
420 let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
421
422 // Conditional entropy H = -sum_i p_i * log(p_i).
423 let h = entropy(&probs);
424
425 // Compute |log(p_i) - H| for each token — how "typical" it is.
426 let mut candidates: Vec<(usize, f32, f32)> = log_probs
427 .iter()
428 .cloned()
429 .zip(probs.iter().cloned())
430 .enumerate()
431 .map(|(i, (lp, p))| {
432 let typicality = (-lp - h).abs();
433 (i, p, typicality)
434 })
435 .collect();
436
437 // Sort ascending by typicality (most typical first).
438 candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
439
440 // Keep tokens until we accumulate >= p probability mass.
441 let mut cumsum = 0.0_f32;
442 let mut keep = 0;
443 for (k, &(_, p, _)) in candidates.iter().enumerate() {
444 cumsum += p;
445 keep = k + 1;
446 if cumsum >= self.p && keep >= self.min_keep {
447 break;
448 }
449 }
450 keep = keep.max(self.min_keep).min(candidates.len());
451 candidates.truncate(keep);
452
453 // Re-normalise and sample.
454 let total: f32 = candidates.iter().map(|(_, p, _)| p).sum();
455 let normalised: Vec<(usize, f32)> = candidates
456 .iter()
457 .map(|&(i, p, _)| (i, if total > 0.0 { p / total } else { p }))
458 .collect();
459
460 categorical_sample(&normalised, rng)
461 }
462}
463
464// ─────────────────────────────────────────────────────────────────────────────
465// Eta Sampling
466// ─────────────────────────────────────────────────────────────────────────────
467
468/// Eta sampling — adaptively selects a probability cutoff based on distribution entropy.
469///
470/// The cutoff is `max(epsilon, sqrt(exp(-H(p))) * delta)` where `H` is the entropy.
471/// Tokens below the cutoff are discarded.
472#[derive(Debug, Clone)]
473pub struct EtaSampler {
474 /// Minimum token probability (floor). Default: `0.0009`.
475 pub epsilon: f32,
476 /// Entropy scaling factor for adaptive threshold. Default: `0.07`.
477 pub delta: f32,
478}
479
480impl EtaSampler {
481 /// Create a new eta sampler.
482 pub fn new(epsilon: f32, delta: f32) -> Self {
483 Self { epsilon, delta }
484 }
485
486 /// Sample a token index from raw logits.
487 pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
488 if logits.is_empty() {
489 return 0;
490 }
491
492 let mut probs: Vec<f32> = logits.to_vec();
493 softmax_inplace(&mut probs);
494
495 // Adaptive threshold.
496 let h = entropy(&probs);
497 let eta_threshold = (self.epsilon).max((-h).exp().sqrt() * self.delta);
498
499 let mut candidates: Vec<(usize, f32)> = probs
500 .iter()
501 .cloned()
502 .enumerate()
503 .filter(|&(_, p)| p >= eta_threshold)
504 .collect();
505
506 if candidates.is_empty() {
507 // Fallback: take argmax.
508 let best = probs
509 .iter()
510 .cloned()
511 .enumerate()
512 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
513 .map(|(i, _)| i)
514 .unwrap_or(0);
515 return best;
516 }
517
518 // Re-normalise.
519 let total: f32 = candidates.iter().map(|(_, p)| p).sum();
520 if total > 0.0 {
521 for (_, p) in candidates.iter_mut() {
522 *p /= total;
523 }
524 }
525
526 categorical_sample(&candidates, rng)
527 }
528}
529
530// ─────────────────────────────────────────────────────────────────────────────
531// Min-P Sampling
532// ─────────────────────────────────────────────────────────────────────────────
533
534/// Min-P sampling — probabilistic nucleus based on a minimum fraction of the top-token probability.
535///
536/// Keeps all tokens `i` where `p_i >= min_p * max(p)`.
537#[derive(Debug, Clone)]
538pub struct MinPSampler {
539 /// Minimum fraction of the maximum probability. Default: `0.05`.
540 pub min_p: f32,
541 /// Minimum candidates to keep regardless of the threshold. Default: `1`.
542 pub min_keep: usize,
543}
544
545impl MinPSampler {
546 /// Create a new Min-P sampler.
547 pub fn new(min_p: f32, min_keep: usize) -> Self {
548 Self {
549 min_p: min_p.clamp(0.0, 1.0),
550 min_keep: min_keep.max(1),
551 }
552 }
553
554 /// Sample a token index from raw logits.
555 pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
556 if logits.is_empty() {
557 return 0;
558 }
559
560 let mut probs: Vec<f32> = logits.to_vec();
561 softmax_inplace(&mut probs);
562
563 let max_p = probs.iter().cloned().fold(0.0_f32, f32::max);
564 let threshold = self.min_p * max_p;
565
566 let mut candidates: Vec<(usize, f32)> = probs
567 .iter()
568 .cloned()
569 .enumerate()
570 .filter(|&(_, p)| p >= threshold)
571 .collect();
572
573 // Ensure min_keep.
574 if candidates.len() < self.min_keep {
575 // Sort all probs descending and take top min_keep.
576 let mut all: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
577 all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
578 candidates = all.into_iter().take(self.min_keep).collect();
579 }
580
581 // Re-normalise.
582 let total: f32 = candidates.iter().map(|(_, p)| p).sum();
583 if total > 0.0 {
584 for (_, p) in candidates.iter_mut() {
585 *p /= total;
586 }
587 }
588
589 categorical_sample(&candidates, rng)
590 }
591}
592
593// ─────────────────────────────────────────────────────────────────────────────
594// Sampler Chain
595// ─────────────────────────────────────────────────────────────────────────────
596
597/// A single step in a [`SamplerChain`] pipeline.
598///
599/// Steps are applied in order to the logit vector before final sampling.
600#[derive(Debug, Clone)]
601pub enum SamplerStep {
602 /// Divide logits by temperature. Values near 0 produce near-greedy output.
603 Temperature(f32),
604 /// Penalise previously-seen tokens to reduce repetition.
605 RepetitionPenalty {
606 /// Penalty multiplier (>1.0 discourages repetition).
607 penalty: f32,
608 /// Number of recent tokens to consider (window).
609 last_n: usize,
610 /// The recent token ids to penalise.
611 tokens: Vec<u32>,
612 },
613 /// Keep only the top-k highest-logit candidates.
614 TopK(usize),
615 /// Nucleus (top-p) filtering.
616 TopP(f32),
617 /// Min-P filtering (min fraction of top token probability).
618 MinP(f32),
619 /// Locally typical sampling with probability mass `p`.
620 Typical(f32),
621 /// Mirostat v2 with given tau and eta.
622 Mirostat2 {
623 /// Target surprise (bits).
624 tau: f32,
625 /// Learning rate.
626 eta: f32,
627 },
628 /// Always pick the argmax (no randomness).
629 Greedy,
630}
631
632/// Composable sampling pipeline.
633///
634/// Steps are applied sequentially to the logit vector. The first `Greedy` or
635/// `Mirostat2` step that yields a token terminates the pipeline. All other steps
636/// modify the logit/probability vector in place.
637///
638/// # Example
639/// ```rust
640/// use oxibonsai_runtime::sampling_advanced::{SamplerChain, SamplerStep};
641///
642/// let mut chain = SamplerChain::default_chat(42);
643/// let mut logits = vec![1.0_f32, 5.0, 2.0, 3.0];
644/// let token = chain.sample(&mut logits);
645/// assert!(token < 4);
646/// ```
647#[derive(Debug, Clone)]
648pub struct SamplerChain {
649 steps: Vec<SamplerStep>,
650 rng: LcgRng,
651 /// Persistent Mirostat v2 state (one per chain).
652 mirostat2: Option<MirostatV2Sampler>,
653}
654
655impl SamplerChain {
656 /// Create an empty chain with the given RNG seed.
657 pub fn new(seed: u64) -> Self {
658 Self {
659 steps: Vec::new(),
660 rng: LcgRng::new(seed),
661 mirostat2: None,
662 }
663 }
664
665 /// Append a step to the chain (builder pattern).
666 #[allow(clippy::should_implement_trait)]
667 pub fn add(mut self, step: SamplerStep) -> Self {
668 // If Mirostat2 step is added, initialise persistent state.
669 if let SamplerStep::Mirostat2 { tau, eta } = step {
670 self.mirostat2 = Some(MirostatV2Sampler::new(tau, eta));
671 }
672 self.steps.push(step);
673 self
674 }
675
676 /// Sample from the given logits, applying all steps in order.
677 ///
678 /// `logits` is consumed/mutated during processing.
679 pub fn sample(&mut self, logits: &mut Vec<f32>) -> usize {
680 if logits.is_empty() {
681 return 0;
682 }
683
684 for step in &self.steps {
685 match step {
686 SamplerStep::Temperature(temp) => {
687 if *temp < 1e-6 {
688 // Treat as greedy immediately.
689 return argmax_slice(logits);
690 }
691 apply_temperature(logits, *temp);
692 }
693
694 SamplerStep::RepetitionPenalty {
695 penalty,
696 last_n,
697 tokens,
698 } => {
699 let window = if *last_n == 0 {
700 tokens.as_slice()
701 } else {
702 let start = tokens.len().saturating_sub(*last_n);
703 &tokens[start..]
704 };
705 apply_repetition_penalty(logits, window, *penalty);
706 }
707
708 SamplerStep::TopK(k) => {
709 if *k > 0 && *k < logits.len() {
710 let indices = top_k_indices(logits, *k);
711 let mut mask = vec![f32::NEG_INFINITY; logits.len()];
712 for i in indices {
713 mask[i] = logits[i];
714 }
715 *logits = mask;
716 }
717 }
718
719 SamplerStep::TopP(p) => {
720 if *p < 1.0 {
721 apply_top_p(logits, *p, &mut self.rng);
722 // top_p returns early — but we continue to let sampling happen below.
723 }
724 }
725
726 SamplerStep::MinP(min_p) => {
727 let sampler = MinPSampler::new(*min_p, 1);
728 return sampler.sample(logits, &mut self.rng);
729 }
730
731 SamplerStep::Typical(p) => {
732 let sampler = TypicalSampler::new(*p, 1);
733 return sampler.sample(logits, &mut self.rng);
734 }
735
736 SamplerStep::Mirostat2 { .. } => {
737 // Use persistent state stored in self.mirostat2.
738 if let Some(ref mut ms) = self.mirostat2 {
739 return ms.sample(logits, &mut self.rng);
740 }
741 }
742
743 SamplerStep::Greedy => {
744 return argmax_slice(logits);
745 }
746 }
747 }
748
749 // Default: softmax then weighted sample.
750 softmax_inplace(logits);
751 let probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
752 categorical_sample(&probs, &mut self.rng)
753 }
754
755 // ── Presets ──────────────────────────────────────────────────────────────
756
757 /// Greedy decoding — always picks the token with the highest logit.
758 pub fn greedy() -> Self {
759 Self::new(0).add(SamplerStep::Greedy)
760 }
761
762 /// Default chat preset: temperature(0.7) → top_p(0.9) → min_p(0.05).
763 pub fn default_chat(seed: u64) -> Self {
764 Self::new(seed)
765 .add(SamplerStep::Temperature(0.7))
766 .add(SamplerStep::TopP(0.9))
767 .add(SamplerStep::MinP(0.05))
768 }
769
770 /// Creative preset: temperature(1.0) → mirostat_v2(tau=5.0, eta=0.1).
771 pub fn creative(seed: u64) -> Self {
772 Self::new(seed)
773 .add(SamplerStep::Temperature(1.0))
774 .add(SamplerStep::Mirostat2 { tau: 5.0, eta: 0.1 })
775 }
776
777 /// Precise preset: temperature(0.3) → top_k(40) → top_p(0.9).
778 pub fn precise(seed: u64) -> Self {
779 Self::new(seed)
780 .add(SamplerStep::Temperature(0.3))
781 .add(SamplerStep::TopK(40))
782 .add(SamplerStep::TopP(0.9))
783 }
784}
785
786// ─────────────────────────────────────────────────────────────────────────────
787// Internal helpers
788// ─────────────────────────────────────────────────────────────────────────────
789
790/// Return the index of the maximum element (ties broken by lowest index).
791fn argmax_slice(values: &[f32]) -> usize {
792 values
793 .iter()
794 .cloned()
795 .enumerate()
796 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
797 .map(|(i, _)| i)
798 .unwrap_or(0)
799}
800
801/// Apply top-p (nucleus) filtering to a logit vector in-place.
802///
803/// Tokens outside the nucleus are set to `NEG_INFINITY` so they are excluded
804/// by a subsequent softmax + sample step.
805fn apply_top_p(logits: &mut [f32], p: f32, _rng: &mut LcgRng) {
806 // Compute softmax probabilities.
807 let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
808 let mut probs: Vec<(usize, f32)> = logits
809 .iter()
810 .enumerate()
811 .map(|(i, &v)| (i, (v - max_v).exp()))
812 .collect();
813 let total: f32 = probs.iter().map(|(_, v)| v).sum();
814 if total > 0.0 {
815 for (_, v) in probs.iter_mut() {
816 *v /= total;
817 }
818 }
819
820 // Sort descending by probability.
821 probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
822
823 // Find nucleus boundary.
824 let mut cumsum = 0.0_f32;
825 let mut nucleus_end = 0;
826 for (k, &(_, prob)) in probs.iter().enumerate() {
827 cumsum += prob;
828 nucleus_end = k;
829 if cumsum >= p {
830 break;
831 }
832 }
833
834 // Collect nucleus indices.
835 let nucleus_indices: std::collections::HashSet<usize> =
836 probs[..=nucleus_end].iter().map(|&(i, _)| i).collect();
837
838 // Mask out non-nucleus tokens.
839 for (i, v) in logits.iter_mut().enumerate() {
840 if !nucleus_indices.contains(&i) {
841 *v = f32::NEG_INFINITY;
842 }
843 }
844}
845
846// ─────────────────────────────────────────────────────────────────────────────
847// Unit tests (module-internal)
848// ─────────────────────────────────────────────────────────────────────────────
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
855 fn lcg_rng_produces_values() {
856 let mut rng = LcgRng::new(1);
857 let v = rng.next_f32();
858 assert!((0.0..1.0).contains(&v), "f32 out of range: {v}");
859 }
860
861 #[test]
862 fn softmax_sums_to_one() {
863 let mut logits = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
864 softmax_inplace(&mut logits);
865 let sum: f32 = logits.iter().sum();
866 assert!((sum - 1.0).abs() < 1e-5, "sum={sum}");
867 }
868
869 #[test]
870 fn mirostat_v2_returns_valid_index() {
871 let logits = vec![1.0_f32, 5.0, 2.0, 3.0];
872 let mut sampler = MirostatV2Sampler::new(5.0, 0.1);
873 let mut rng = LcgRng::new(99);
874 let idx = sampler.sample(&logits, &mut rng);
875 assert!(idx < logits.len());
876 }
877
878 #[test]
879 fn sampler_chain_greedy_preset() {
880 let mut chain = SamplerChain::greedy();
881 let mut logits = vec![0.1_f32, 5.0, 0.2, 0.3];
882 let tok = chain.sample(&mut logits);
883 assert_eq!(tok, 1); // index of 5.0
884 }
885}