Skip to main content

oxillama_runtime/sampling/
mod.rs

1//! Sampling strategies for next-token selection.
2//!
3//! Supports greedy, top-k, top-p (nucleus), min-p, temperature scaling,
4//! repetition penalty, Mirostat v2, and GBNF grammar-constrained sampling.
5
6pub mod advanced;
7pub mod chain;
8pub mod grammar;
9
10use std::sync::Arc;
11
12use serde::{Deserialize, Serialize};
13
14use grammar::{apply_grammar_mask, Grammar, GrammarState};
15
16/// Configuration for the sampling strategy.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SamplerConfig {
19    /// Temperature for logit scaling (1.0 = no scaling, 0.0 = greedy).
20    pub temperature: f32,
21    /// Top-K: only consider the K most likely tokens (0 = disabled).
22    pub top_k: usize,
23    /// Top-P (nucleus): only consider tokens with cumulative probability <= p.
24    pub top_p: f32,
25    /// Min-P: minimum probability threshold relative to the top token.
26    pub min_p: f32,
27    /// Repetition penalty factor (1.0 = no penalty).
28    pub repetition_penalty: f32,
29    /// Number of recent tokens to consider for repetition penalty.
30    pub repetition_penalty_window: usize,
31    /// Random seed for reproducible sampling (None = random).
32    pub seed: Option<u64>,
33    /// Mirostat mode: 0 = disabled, 2 = Mirostat v2.
34    pub mirostat: u8,
35    /// Mirostat target surprise (tau). Controls coherence vs diversity.
36    /// Lower = more coherent, higher = more diverse. Default: 5.0.
37    pub mirostat_tau: f32,
38    /// Mirostat learning rate (eta). How fast the algorithm adapts. Default: 0.1.
39    pub mirostat_eta: f32,
40
41    /// Optional GBNF grammar for constrained sampling.
42    /// Logits for tokens that cannot advance the grammar are set to -∞.
43    /// Skipped during serialization (not representable as JSON directly).
44    #[serde(skip)]
45    pub grammar: Option<Arc<Grammar>>,
46
47    /// Pre-computed vocabulary `(token_id, byte_repr)` table used for grammar masking.
48    /// Must be set when `grammar` is `Some`. Build via `TokenizerBridge::vocab_bytes()`.
49    #[serde(skip)]
50    #[allow(clippy::type_complexity)]
51    pub token_vocab: Option<Arc<Vec<(u32, Vec<u8>)>>>,
52
53    /// Per-token logit biases applied before top-k/top-p.
54    ///
55    /// Positive values increase a token's probability; negative values decrease it.
56    /// For example, `logit_bias[token_id] = 5.0` strongly encourages that token,
57    /// while `-100.0` effectively bans it (use `banned_tokens` for strict banning).
58    ///
59    /// Applied as: `logits[token_id] += bias` before the greedy / sampling steps.
60    #[serde(default)]
61    pub logit_bias: std::collections::HashMap<u32, f32>,
62
63    /// Tokens that must never be generated.
64    ///
65    /// Their logits are set to `f32::NEG_INFINITY` before any other sampling
66    /// step, including top-k/p filtering. This is a hard constraint — unlike
67    /// a large negative `logit_bias`, a banned token will never be selected
68    /// even if it is the only remaining candidate.
69    #[serde(default)]
70    pub banned_tokens: Vec<u32>,
71
72    // ── Advanced sampler stages (v0.1.7 Track B) ─────────────────────────────
73    /// DRY penalty multiplier (0.0 = disabled).
74    ///
75    /// Penalises tokens that would continue an n-gram already present in the
76    /// recent context. Higher values apply stronger penalties.
77    #[serde(default)]
78    pub dry_multiplier: f32,
79
80    /// DRY exponential base for match-length amplification (default = 1.75).
81    ///
82    /// Longer n-gram matches receive penalty `dry_multiplier * dry_base^(match_len - dry_allowed_length)`.
83    #[serde(default = "dry_base_default")]
84    pub dry_base: f32,
85
86    /// Minimum match length (in tokens) before DRY applies any penalty (default = 2).
87    #[serde(default = "dry_allowed_length_default")]
88    pub dry_allowed_length: usize,
89
90    /// XTC cumulative-probability threshold (0.0 = disabled; use ≥ 1.0 to disable).
91    ///
92    /// The "top set" is defined as the smallest set of tokens whose cumulative
93    /// probability exceeds this threshold.
94    #[serde(default)]
95    pub xtc_threshold: f32,
96
97    /// XTC exclusion probability — how often the top-set exclusion fires (default = 0.5).
98    #[serde(default = "xtc_probability_default")]
99    pub xtc_probability: f32,
100
101    /// Locally-typical sampling budget (1.0 = disabled / passthrough).
102    ///
103    /// Keeps only tokens whose information content is closest to the distribution
104    /// entropy until cumulative probability ≥ p.
105    #[serde(default = "typical_p_default")]
106    pub typical_p: f32,
107
108    /// Top-A adaptive threshold multiplier (0.0 = disabled).
109    ///
110    /// Keeps tokens with `prob >= top_a * max_prob²`.
111    #[serde(default)]
112    pub top_a: f32,
113
114    /// Eta-cutoff entropy-adaptive threshold (0.0 = disabled).
115    ///
116    /// Dynamic floor = `max(epsilon_cutoff, eta_cutoff / perplexity)`.
117    #[serde(default)]
118    pub eta_cutoff: f32,
119
120    /// Epsilon hard-floor probability used together with `eta_cutoff` (0.0 = no floor).
121    #[serde(default)]
122    pub epsilon_cutoff: f32,
123}
124
125// Default-value helpers for serde.
126fn dry_base_default() -> f32 {
127    1.75
128}
129fn dry_allowed_length_default() -> usize {
130    2
131}
132fn xtc_probability_default() -> f32 {
133    0.5
134}
135fn typical_p_default() -> f32 {
136    1.0
137}
138
139impl Default for SamplerConfig {
140    fn default() -> Self {
141        Self {
142            temperature: 0.7,
143            top_k: 40,
144            top_p: 0.9,
145            min_p: 0.0,
146            repetition_penalty: 1.1,
147            repetition_penalty_window: 64,
148            seed: None,
149            mirostat: 0,
150            mirostat_tau: 5.0,
151            mirostat_eta: 0.1,
152            grammar: None,
153            token_vocab: None,
154            logit_bias: std::collections::HashMap::new(),
155            banned_tokens: Vec::new(),
156            // Advanced stages (disabled by default)
157            dry_multiplier: 0.0,
158            dry_base: 1.75,
159            dry_allowed_length: 2,
160            xtc_threshold: 0.0,
161            xtc_probability: 0.5,
162            typical_p: 1.0,
163            top_a: 0.0,
164            eta_cutoff: 0.0,
165            epsilon_cutoff: 0.0,
166        }
167    }
168}
169
170impl SamplerConfig {
171    /// Create a greedy sampling config (always pick the most likely token).
172    pub fn greedy() -> Self {
173        Self {
174            temperature: 0.0,
175            top_k: 1,
176            top_p: 1.0,
177            min_p: 0.0,
178            repetition_penalty: 1.0,
179            repetition_penalty_window: 0,
180            seed: None,
181            mirostat: 0,
182            mirostat_tau: 5.0,
183            mirostat_eta: 0.1,
184            grammar: None,
185            token_vocab: None,
186            logit_bias: std::collections::HashMap::new(),
187            banned_tokens: Vec::new(),
188            dry_multiplier: 0.0,
189            dry_base: 1.75,
190            dry_allowed_length: 2,
191            xtc_threshold: 0.0,
192            xtc_probability: 0.5,
193            typical_p: 1.0,
194            top_a: 0.0,
195            eta_cutoff: 0.0,
196            epsilon_cutoff: 0.0,
197        }
198    }
199
200    /// Create a Mirostat v2 config with the given target surprise.
201    pub fn mirostat_v2(tau: f32, eta: f32) -> Self {
202        Self {
203            temperature: 1.0,
204            mirostat: 2,
205            mirostat_tau: tau,
206            mirostat_eta: eta,
207            top_k: 0,
208            top_p: 1.0,
209            min_p: 0.0,
210            repetition_penalty: 1.0,
211            repetition_penalty_window: 0,
212            seed: None,
213            grammar: None,
214            token_vocab: None,
215            logit_bias: std::collections::HashMap::new(),
216            banned_tokens: Vec::new(),
217            dry_multiplier: 0.0,
218            dry_base: 1.75,
219            dry_allowed_length: 2,
220            xtc_threshold: 0.0,
221            xtc_probability: 0.5,
222            typical_p: 1.0,
223            top_a: 0.0,
224            eta_cutoff: 0.0,
225            epsilon_cutoff: 0.0,
226        }
227    }
228}
229
230/// Stateful sampler that maintains PRNG state across calls.
231pub struct Sampler {
232    config: SamplerConfig,
233    rng: Xorshift64,
234    /// Mirostat v2 running estimate of surprise (mu).
235    /// Initialized to 2 * tau, updated after each sample.
236    mirostat_mu: f32,
237    /// Current grammar parse state (None when no grammar is configured).
238    grammar_state: Option<GrammarState>,
239}
240
241impl Sampler {
242    /// Create a new sampler with the given config.
243    pub fn new(config: SamplerConfig) -> Self {
244        let seed = config.seed.unwrap_or_else(|| {
245            // Use a time-based seed when no explicit seed is provided.
246            // This is deterministic enough for inference; not for crypto.
247            let mut s = 0x517cc1b727220a95u64;
248            // Mix in some bits from the stack address for entropy
249            s ^= (&s as *const u64 as u64).wrapping_mul(0x9e3779b97f4a7c15);
250            s ^ s.wrapping_shr(33)
251        });
252        let mirostat_mu = 2.0 * config.mirostat_tau;
253        let grammar_state = config.grammar.as_ref().map(|g| g.initial_state());
254        Self {
255            config,
256            rng: Xorshift64::new(seed),
257            mirostat_mu,
258            grammar_state,
259        }
260    }
261
262    /// Sample a token ID from logits.
263    pub fn sample(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
264        let token = if self.config.mirostat == 2 {
265            self.sample_mirostat_v2(logits, recent_tokens)
266        } else {
267            sample_with_rng(
268                logits,
269                &self.config,
270                recent_tokens,
271                &mut self.rng,
272                self.grammar_state.as_ref(),
273            )
274        };
275
276        // Advance grammar state after token selection.
277        // We look up the token bytes in the pre-built vocab table (binary search by id).
278        if let Some(state) = &mut self.grammar_state {
279            if let Some(vocab) = &self.config.token_vocab {
280                if let Ok(idx) = vocab.binary_search_by_key(&token, |&(id, _)| id) {
281                    let bytes = vocab[idx].1.clone();
282                    // Silently ignore advance errors — the mask will catch a stuck state
283                    // on the next step and -inf all invalid tokens.
284                    let _ = state.advance(&bytes);
285                }
286            }
287        }
288
289        token
290    }
291
292    /// Reset the grammar state to the beginning (use for a new generation).
293    pub fn reset_grammar(&mut self) {
294        self.grammar_state = self.config.grammar.as_ref().map(|g| g.initial_state());
295    }
296
297    /// Returns true when the grammar (if any) is in a valid accepting state.
298    pub fn grammar_complete(&self) -> bool {
299        self.grammar_state
300            .as_ref()
301            .is_none_or(GrammarState::is_complete)
302    }
303
304    /// Mirostat v2 sampling.
305    ///
306    /// Adaptively controls the "surprise" of generated tokens to maintain
307    /// a target perplexity level (tau). This produces more coherent text
308    /// than fixed top-k/top-p by dynamically adjusting the token pool.
309    fn sample_mirostat_v2(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
310        if logits.is_empty() {
311            return 0;
312        }
313
314        let mut processed = logits.to_vec();
315
316        // Step 0: Apply logit bias and banned tokens — same order as
317        // sample_with_rng so both code paths behave identically.
318        apply_logit_bias_and_banned_tokens(&mut processed, &self.config);
319
320        // Step 1: Apply repetition penalty
321        apply_repetition_penalty(&mut processed, &self.config, recent_tokens);
322
323        // Step 2: Apply grammar mask — BEFORE temperature and sorting.
324        // Grammar masking must happen before any filtering so the constraint
325        // is respected even in the greedy case.
326        if let (Some(state), Some(vocab)) = (&self.grammar_state, &self.config.token_vocab) {
327            apply_grammar_mask(&mut processed, state, vocab.as_ref());
328        }
329
330        // Step 3: Apply temperature
331        if self.config.temperature > 0.0 && self.config.temperature != 1.0 {
332            let inv_temp = 1.0 / self.config.temperature;
333            for val in &mut processed {
334                *val *= inv_temp;
335            }
336        }
337
338        // Build sorted candidates with probabilities
339        let mut candidates: Vec<(u32, f32)> = processed
340            .iter()
341            .enumerate()
342            .map(|(i, &v)| (i as u32, v))
343            .collect();
344        candidates
345            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
346
347        // Softmax to get probabilities
348        softmax_candidates(&mut candidates);
349
350        // Mirostat v2: filter tokens by surprise threshold
351        // surprise(token) = -log2(prob)
352        // Keep tokens where surprise <= mu
353        let mu = self.mirostat_mu;
354        candidates.retain(|&(_, p)| {
355            if p <= 0.0 {
356                return false;
357            }
358            let surprise = -p.log2();
359            surprise <= mu
360        });
361
362        // Fallback: if all tokens filtered, keep the top one
363        if candidates.is_empty() {
364            let token = argmax(&processed);
365            // Still update mu
366            let top_prob = softmax_single_max(&processed);
367            let surprise = if top_prob > 0.0 {
368                -top_prob.log2()
369            } else {
370                self.config.mirostat_tau
371            };
372            self.mirostat_mu =
373                mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
374            return token;
375        }
376
377        // Re-normalize
378        let total: f32 = candidates.iter().map(|(_, p)| p).sum();
379        if total > 0.0 && total != 1.0 {
380            for (_, p) in &mut candidates {
381                *p /= total;
382            }
383        }
384
385        // Sample from filtered candidates
386        let r = self.rng.next_f32();
387        let mut cumulative = 0.0f32;
388        let mut selected_idx = candidates[0].0;
389        let mut selected_prob = candidates[0].1 * total; // original probability
390        for &(idx, prob) in &candidates {
391            cumulative += prob;
392            if r < cumulative {
393                selected_idx = idx;
394                selected_prob = prob * total;
395                break;
396            }
397        }
398
399        // Update mu: mu' = mu - eta * (surprise - tau)
400        let surprise = if selected_prob > 0.0 {
401            -selected_prob.log2()
402        } else {
403            self.config.mirostat_tau
404        };
405        self.mirostat_mu = mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
406
407        selected_idx
408    }
409
410    /// Get a reference to the config.
411    pub fn config(&self) -> &SamplerConfig {
412        &self.config
413    }
414
415    /// Return the raw RNG state for snapshot/resume.
416    pub fn rng_state(&self) -> u64 {
417        self.rng.state_value()
418    }
419
420    /// Return the current mirostat mu value for snapshot/resume.
421    pub fn mirostat_mu_value(&self) -> f32 {
422        self.mirostat_mu
423    }
424
425    /// Restore the RNG state and mirostat mu (for resume).
426    pub fn restore_rng_state(&mut self, state: u64, mu: f32) {
427        self.rng = Xorshift64::from_state_value(state);
428        self.mirostat_mu = mu;
429    }
430}
431
432/// Sample a token ID from logits using the given configuration.
433///
434/// This is the stateless variant. Grammar state (if any in config) is ignored
435/// because there is no place to persist it between calls. Use [`Sampler`] for
436/// grammar-constrained generation.
437///
438/// # Arguments
439/// * `logits` - Raw logits from the model (length = vocab_size).
440/// * `config` - Sampling configuration.
441/// * `recent_tokens` - Recent token history for repetition penalty.
442///
443/// # Returns
444/// The selected token ID.
445pub fn sample(logits: &[f32], config: &SamplerConfig, recent_tokens: &[u32]) -> u32 {
446    if logits.is_empty() {
447        return 0;
448    }
449
450    // For stateless API, create a one-shot RNG. Grammar state is not threaded
451    // here — callers needing grammar must use `Sampler`.
452    let seed = config.seed.unwrap_or(0xDEADBEEF_CAFEBABE);
453    let mut rng = Xorshift64::new(seed);
454    sample_with_rng(logits, config, recent_tokens, &mut rng, None)
455}
456
457/// Core sampling implementation with explicit RNG and optional grammar state.
458fn sample_with_rng(
459    logits: &[f32],
460    config: &SamplerConfig,
461    recent_tokens: &[u32],
462    rng: &mut Xorshift64,
463    grammar_state: Option<&GrammarState>,
464) -> u32 {
465    if logits.is_empty() {
466        return 0;
467    }
468
469    let mut processed = logits.to_vec();
470
471    // Step 0: Apply logit bias and banned tokens FIRST — before any other
472    // transformation so that bans are absolute and biases influence all
473    // downstream filtering steps (top-k, top-p, grammar masking, etc.).
474    apply_logit_bias_and_banned_tokens(&mut processed, config);
475
476    // Step 1: Apply repetition penalty
477    apply_repetition_penalty(&mut processed, config, recent_tokens);
478
479    // Step 2: Apply grammar mask — BEFORE the greedy shortcut.
480    // This ensures grammar constraints are enforced even at temperature=0.
481    if let (Some(state), Some(vocab)) = (grammar_state, &config.token_vocab) {
482        apply_grammar_mask(&mut processed, state, vocab.as_ref());
483    }
484
485    // Step 3: Greedy shortcut (after grammar mask)
486    if config.temperature <= 0.0 || config.top_k == 1 {
487        return argmax(&processed);
488    }
489
490    // Step 4: Temperature scaling
491    if config.temperature != 1.0 {
492        let inv_temp = 1.0 / config.temperature;
493        for val in &mut processed {
494            *val *= inv_temp;
495        }
496    }
497
498    // Step 5: Build sorted (index, logit) candidates
499    let mut candidates: Vec<(u32, f32)> = processed
500        .iter()
501        .enumerate()
502        .map(|(i, &v)| (i as u32, v))
503        .collect();
504    candidates.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
505
506    // Step 6: Top-K filtering
507    if config.top_k > 0 && config.top_k < candidates.len() {
508        candidates.truncate(config.top_k);
509    }
510
511    // Step 7: Softmax over remaining candidates
512    softmax_candidates(&mut candidates);
513
514    // Step 8: Min-P filtering (remove tokens with prob < min_p * max_prob)
515    if config.min_p > 0.0 && !candidates.is_empty() {
516        let max_prob = candidates[0].1; // already sorted descending by probability
517        let threshold = config.min_p * max_prob;
518        candidates.retain(|&(_, p)| p >= threshold);
519    }
520
521    // Step 9: Top-P (nucleus) filtering
522    if config.top_p < 1.0 && !candidates.is_empty() {
523        let mut cumulative = 0.0f32;
524        let mut cutoff = candidates.len();
525        for (i, &(_, prob)) in candidates.iter().enumerate() {
526            cumulative += prob;
527            if cumulative >= config.top_p {
528                cutoff = i + 1;
529                break;
530            }
531        }
532        candidates.truncate(cutoff);
533    }
534
535    // Step 10: Re-normalize after filtering
536    let total: f32 = candidates.iter().map(|(_, p)| p).sum();
537    if total > 0.0 && total != 1.0 {
538        for (_, p) in &mut candidates {
539            *p /= total;
540        }
541    }
542
543    // Step 11: Weighted random selection
544    if candidates.is_empty() {
545        return argmax(&processed);
546    }
547    if candidates.len() == 1 {
548        return candidates[0].0;
549    }
550
551    let r = rng.next_f32();
552    let mut cumulative = 0.0f32;
553    for &(idx, prob) in &candidates {
554        cumulative += prob;
555        if r < cumulative {
556            return idx;
557        }
558    }
559
560    // Fallback: return last candidate (rounding issues)
561    candidates.last().map(|&(idx, _)| idx).unwrap_or(0)
562}
563
564/// Apply logit bias and banned-token masking to logits in-place.
565///
566/// Processing order:
567/// 1. Banned tokens are set to `f32::NEG_INFINITY` unconditionally.
568/// 2. Logit biases are added to the surviving logits.
569///
570/// Both operations are applied before repetition penalty, grammar masking,
571/// and temperature / top-k / top-p filtering, so they influence all
572/// downstream steps.
573fn apply_logit_bias_and_banned_tokens(processed: &mut [f32], config: &SamplerConfig) {
574    // Step A: hard-ban tokens.
575    for &token in &config.banned_tokens {
576        let idx = token as usize;
577        if idx < processed.len() {
578            processed[idx] = f32::NEG_INFINITY;
579        }
580    }
581
582    // Step B: additive bias.
583    for (&token, &bias) in &config.logit_bias {
584        let idx = token as usize;
585        if idx < processed.len() {
586            // Do not modify already-banned tokens — a banned token must
587            // remain at -inf even if a positive bias is also specified.
588            if processed[idx].is_finite() {
589                processed[idx] += bias;
590            }
591        }
592    }
593}
594
595/// Apply repetition penalty to logits in-place.
596fn apply_repetition_penalty(processed: &mut [f32], config: &SamplerConfig, recent_tokens: &[u32]) {
597    if config.repetition_penalty == 1.0 || recent_tokens.is_empty() {
598        return;
599    }
600
601    let window_start = recent_tokens
602        .len()
603        .saturating_sub(config.repetition_penalty_window);
604    for &token in &recent_tokens[window_start..] {
605        let idx = token as usize;
606        if idx < processed.len() {
607            if processed[idx] > 0.0 {
608                processed[idx] /= config.repetition_penalty;
609            } else {
610                processed[idx] *= config.repetition_penalty;
611            }
612        }
613    }
614}
615
616/// Compute softmax over candidates in-place (replaces logits with probabilities).
617fn softmax_candidates(candidates: &mut [(u32, f32)]) {
618    if candidates.is_empty() {
619        return;
620    }
621
622    let max_logit = candidates
623        .iter()
624        .map(|(_, v)| *v)
625        .fold(f32::NEG_INFINITY, f32::max);
626
627    let mut sum = 0.0f32;
628    for (_, logit) in candidates.iter_mut() {
629        *logit = (*logit - max_logit).exp();
630        sum += *logit;
631    }
632
633    if sum > 0.0 {
634        for (_, prob) in candidates.iter_mut() {
635            *prob /= sum;
636        }
637    }
638}
639
640/// Compute the softmax probability of the maximum logit (for fallback).
641fn softmax_single_max(logits: &[f32]) -> f32 {
642    let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
643    let sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
644    if sum > 0.0 {
645        1.0 / sum
646    } else {
647        0.0
648    }
649}
650
651/// Return the index of the maximum value.
652fn argmax(values: &[f32]) -> u32 {
653    let mut max_idx = 0u32;
654    let mut max_val = f32::NEG_INFINITY;
655    for (i, &v) in values.iter().enumerate() {
656        if v > max_val {
657            max_val = v;
658            max_idx = i as u32;
659        }
660    }
661    max_idx
662}
663
664/// Simple xorshift64 PRNG — fast, small, seedable, no dependencies.
665struct Xorshift64 {
666    state: u64,
667}
668
669impl Xorshift64 {
670    fn new(seed: u64) -> Self {
671        // Ensure non-zero state
672        Self {
673            state: if seed == 0 { 0x517cc1b727220a95 } else { seed },
674        }
675    }
676
677    fn next_u64(&mut self) -> u64 {
678        let mut x = self.state;
679        x ^= x << 13;
680        x ^= x >> 7;
681        x ^= x << 17;
682        self.state = x;
683        x
684    }
685
686    /// Generate a uniform f32 in [0, 1).
687    fn next_f32(&mut self) -> f32 {
688        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
689    }
690
691    /// Return the raw internal state for snapshot/resume.
692    pub(crate) fn state_value(&self) -> u64 {
693        self.state
694    }
695
696    /// Reconstruct from a raw state value (for resume).
697    pub(crate) fn from_state_value(state: u64) -> Self {
698        Self {
699            state: if state == 0 { 1 } else { state },
700        }
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    #[test]
709    fn test_greedy_sampling() {
710        let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
711        let config = SamplerConfig::greedy();
712        let token = sample(&logits, &config, &[]);
713        assert_eq!(token, 3); // index of 0.8
714    }
715
716    #[test]
717    fn test_empty_logits() {
718        let logits: Vec<f32> = vec![];
719        let config = SamplerConfig::greedy();
720        let token = sample(&logits, &config, &[]);
721        assert_eq!(token, 0);
722    }
723
724    #[test]
725    fn test_temperature_zero_is_greedy() {
726        let logits = vec![1.0, 5.0, 3.0, 2.0];
727        let config = SamplerConfig {
728            temperature: 0.0,
729            ..SamplerConfig::default()
730        };
731        let token = sample(&logits, &config, &[]);
732        assert_eq!(token, 1); // argmax
733    }
734
735    #[test]
736    fn test_top_k_1_is_greedy() {
737        let logits = vec![1.0, 5.0, 3.0, 2.0];
738        let config = SamplerConfig {
739            temperature: 1.0,
740            top_k: 1,
741            ..SamplerConfig::default()
742        };
743        let token = sample(&logits, &config, &[]);
744        assert_eq!(token, 1);
745    }
746
747    #[test]
748    fn test_seeded_determinism() {
749        let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
750        let config = SamplerConfig {
751            temperature: 1.0,
752            top_k: 0,
753            top_p: 1.0,
754            min_p: 0.0,
755            seed: Some(42),
756            ..SamplerConfig::default()
757        };
758
759        let mut sampler1 = Sampler::new(config.clone());
760        let mut sampler2 = Sampler::new(config);
761
762        // Same seed should produce same sequence
763        for _ in 0..10 {
764            let t1 = sampler1.sample(&logits, &[]);
765            let t2 = sampler2.sample(&logits, &[]);
766            assert_eq!(t1, t2, "seeded samplers should produce identical results");
767        }
768    }
769
770    #[test]
771    fn test_top_p_filters_low_prob() {
772        // One token has overwhelming probability
773        let logits = vec![100.0, 0.0, 0.0, 0.0, 0.0];
774        let config = SamplerConfig {
775            temperature: 1.0,
776            top_k: 0,
777            top_p: 0.5,
778            min_p: 0.0,
779            seed: Some(123),
780            ..SamplerConfig::default()
781        };
782
783        // With top_p=0.5, only the dominant token should remain
784        let token = sample(&logits, &config, &[]);
785        assert_eq!(token, 0);
786    }
787
788    #[test]
789    fn test_repetition_penalty() {
790        // Token 1 has highest logit but is in recent history
791        let logits = vec![1.0, 5.0, 4.9, 1.0];
792        let config = SamplerConfig {
793            temperature: 0.0,          // greedy after penalty
794            repetition_penalty: 100.0, // severe penalty
795            repetition_penalty_window: 64,
796            ..SamplerConfig::greedy()
797        };
798
799        // Without penalty, token 1 wins
800        let token_no_penalty = sample(&logits, &SamplerConfig::greedy(), &[]);
801        assert_eq!(token_no_penalty, 1);
802
803        // With penalty on token 1, token 2 (4.9) should win
804        let token_with_penalty = sample(&logits, &config, &[1]);
805        assert_eq!(token_with_penalty, 2);
806    }
807
808    #[test]
809    fn test_sampling_distribution() {
810        // Verify that with temperature sampling, we don't always pick argmax
811        let logits = vec![2.0, 2.0, 2.0, 2.0]; // equal logits
812        let config = SamplerConfig {
813            temperature: 1.0,
814            top_k: 0,
815            top_p: 1.0,
816            min_p: 0.0,
817            seed: Some(999),
818            ..SamplerConfig::default()
819        };
820
821        let mut sampler = Sampler::new(config);
822        let mut counts = [0u32; 4];
823        for _ in 0..1000 {
824            let t = sampler.sample(&logits, &[]);
825            counts[t as usize] += 1;
826        }
827
828        // With equal logits, each token should get ~250 hits.
829        // Allow generous margin (100-400).
830        for (i, &count) in counts.iter().enumerate() {
831            assert!(
832                count > 100 && count < 400,
833                "token {i} got {count} hits (expected ~250 for uniform distribution)"
834            );
835        }
836    }
837
838    #[test]
839    fn test_min_p_filtering() {
840        // One very likely token and several very unlikely ones
841        let logits = vec![10.0, -10.0, -10.0, -10.0];
842        let config = SamplerConfig {
843            temperature: 1.0,
844            top_k: 0,
845            top_p: 1.0,
846            min_p: 0.1, // require at least 10% of max prob
847            seed: Some(42),
848            ..SamplerConfig::default()
849        };
850
851        // The dominant token should always win after min_p filtering
852        let mut sampler = Sampler::new(config);
853        for _ in 0..100 {
854            assert_eq!(sampler.sample(&logits, &[]), 0);
855        }
856    }
857
858    #[test]
859    fn test_xorshift_range() {
860        let mut rng = Xorshift64::new(12345);
861        for _ in 0..10000 {
862            let v = rng.next_f32();
863            assert!((0.0..1.0).contains(&v), "RNG produced {v} outside [0, 1)");
864        }
865    }
866
867    #[test]
868    fn test_mirostat_v2_basic() {
869        // Mirostat v2 should produce valid tokens
870        let logits = vec![3.0, 2.0, 1.0, 0.5, 0.1, -1.0, -2.0, -5.0];
871        let config = SamplerConfig {
872            seed: Some(42),
873            ..SamplerConfig::mirostat_v2(5.0, 0.1)
874        };
875        let mut sampler = Sampler::new(config);
876
877        for _ in 0..50 {
878            let token = sampler.sample(&logits, &[]);
879            assert!((token as usize) < logits.len());
880        }
881    }
882
883    #[test]
884    fn test_mirostat_v2_adapts_mu() {
885        let logits = vec![5.0, 0.0, 0.0, 0.0];
886        let config = SamplerConfig {
887            seed: Some(123),
888            ..SamplerConfig::mirostat_v2(3.0, 0.1)
889        };
890        let mut sampler = Sampler::new(config);
891        let initial_mu = sampler.mirostat_mu;
892
893        // After sampling, mu should change
894        sampler.sample(&logits, &[]);
895        assert!(
896            (sampler.mirostat_mu - initial_mu).abs() > 1e-6,
897            "mu should adapt after sampling"
898        );
899    }
900
901    #[test]
902    fn test_mirostat_v2_low_tau_prefers_top() {
903        // Very low tau = very low target surprise = prefer high-probability tokens
904        let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
905        let config = SamplerConfig {
906            seed: Some(42),
907            ..SamplerConfig::mirostat_v2(0.5, 0.1) // very low tau
908        };
909        let mut sampler = Sampler::new(config);
910
911        let mut top_count = 0;
912        for _ in 0..100 {
913            if sampler.sample(&logits, &[]) == 0 {
914                top_count += 1;
915            }
916        }
917        // With tau=0.5, should almost always pick the top token
918        assert!(
919            top_count > 90,
920            "low tau should strongly prefer top token, got {top_count}/100"
921        );
922    }
923
924    #[test]
925    fn test_mirostat_v2_deterministic_with_seed() {
926        let logits = vec![2.0, 1.5, 1.0, 0.5];
927        let config = SamplerConfig {
928            seed: Some(777),
929            ..SamplerConfig::mirostat_v2(5.0, 0.1)
930        };
931
932        let mut sampler1 = Sampler::new(config.clone());
933        let mut sampler2 = Sampler::new(config);
934
935        for _ in 0..20 {
936            assert_eq!(
937                sampler1.sample(&logits, &[]),
938                sampler2.sample(&logits, &[]),
939                "same seed should produce same sequence"
940            );
941        }
942    }
943
944    #[test]
945    fn test_softmax_candidates_basic() {
946        let mut candidates = vec![(0, 0.0f32), (1, 0.0), (2, 0.0)];
947        softmax_candidates(&mut candidates);
948        // Equal logits → equal probabilities
949        for &(_, p) in &candidates {
950            assert!((p - 1.0 / 3.0).abs() < 0.01, "expected ~0.333, got {p}");
951        }
952    }
953
954    // ── Logit-bias / banned-tokens tests ──────────────────────────────────────
955
956    #[test]
957    fn banned_tokens_never_sampled() {
958        // Only token 3 is allowed; all others are banned.
959        let vocab_size = 5usize;
960        let logits: Vec<f32> = (0..vocab_size).map(|i| i as f32).collect();
961
962        let mut banned = Vec::new();
963        for i in 0u32..vocab_size as u32 {
964            if i != 3 {
965                banned.push(i);
966            }
967        }
968        let config = SamplerConfig {
969            temperature: 1.0,
970            top_k: 0,
971            top_p: 1.0,
972            min_p: 0.0,
973            seed: Some(42),
974            banned_tokens: banned,
975            ..SamplerConfig::default()
976        };
977        let mut sampler = Sampler::new(config);
978        for _ in 0..50 {
979            let tok = sampler.sample(&logits, &[]);
980            assert_eq!(
981                tok, 3,
982                "only token 3 should ever be sampled when all others are banned"
983            );
984        }
985    }
986
987    #[test]
988    fn positive_bias_increases_token_probability() {
989        // Token 1 starts with a very low logit; add a large positive bias.
990        // After bias, token 1 should dominate and be selected nearly always.
991        let logits = vec![10.0f32, -20.0, -20.0, -20.0];
992        let mut bias = std::collections::HashMap::new();
993        bias.insert(1u32, 100.0f32); // huge positive bias on token 1
994
995        let config = SamplerConfig {
996            temperature: 1.0,
997            top_k: 0,
998            top_p: 1.0,
999            min_p: 0.0,
1000            seed: Some(7),
1001            logit_bias: bias,
1002            ..SamplerConfig::default()
1003        };
1004        let mut sampler = Sampler::new(config);
1005        // With a +100 bias, token 1's effective logit = 80, far above token 0's 10.
1006        let tok = sampler.sample(&logits, &[]);
1007        assert_eq!(tok, 1, "large positive bias should make token 1 dominate");
1008    }
1009
1010    #[test]
1011    fn negative_bias_decreases() {
1012        // Token 0 has the highest raw logit; apply a strongly negative bias.
1013        // Token 1 should win after bias.
1014        let logits = vec![100.0f32, 1.0, 0.5, 0.1];
1015        let mut bias = std::collections::HashMap::new();
1016        bias.insert(0u32, -200.0f32); // strong negative on the top token
1017
1018        let config = SamplerConfig {
1019            temperature: 0.0, // greedy — picks strictly by highest logit after bias
1020            logit_bias: bias,
1021            ..SamplerConfig::greedy()
1022        };
1023        let tok = sample(&logits, &config, &[]);
1024        assert_eq!(
1025            tok, 1,
1026            "after large negative bias on token 0, token 1 should win"
1027        );
1028    }
1029
1030    #[test]
1031    fn logit_bias_empty_config_no_op() {
1032        // Empty logit_bias and empty banned_tokens must not change sampling behaviour.
1033        let logits = vec![1.0f32, 2.0, 3.0, 0.5];
1034        let config_empty = SamplerConfig {
1035            temperature: 0.0,
1036            logit_bias: std::collections::HashMap::new(),
1037            banned_tokens: Vec::new(),
1038            ..SamplerConfig::greedy()
1039        };
1040        let tok = sample(&logits, &config_empty, &[]);
1041        // Greedy with no bias should still pick index 2 (value 3.0).
1042        assert_eq!(tok, 2, "empty logit_bias / banned_tokens should be a no-op");
1043    }
1044
1045    // ── Grammar-constrained sampling tests ────────────────────────────────────
1046
1047    #[test]
1048    fn test_grammar_constrained_yes_no() {
1049        let g = Grammar::parse(r#"root ::= "yes" | "no""#).unwrap();
1050        let state = g.initial_state();
1051        assert!(state.allows_token(b"yes"));
1052        assert!(state.allows_token(b"no"));
1053        assert!(!state.allows_token(b"maybe"));
1054    }
1055
1056    #[test]
1057    fn test_grammar_sampler_masks_logits() {
1058        // Vocab: 0="maybe", 1="yes", 2="no"
1059        let vocab: Vec<(u32, Vec<u8>)> = vec![
1060            (0, b"maybe".to_vec()),
1061            (1, b"yes".to_vec()),
1062            (2, b"no".to_vec()),
1063        ];
1064        let g = Arc::new(Grammar::parse(r#"root ::= "yes" | "no""#).unwrap());
1065        let config = SamplerConfig {
1066            temperature: 0.0, // greedy — must pick grammar-compliant token
1067            grammar: Some(g),
1068            token_vocab: Some(Arc::new(vocab)),
1069            ..SamplerConfig::default()
1070        };
1071
1072        // Give "maybe" the highest logit — grammar must mask it away
1073        let logits = vec![100.0f32, 1.0, 1.0];
1074        let mut sampler = Sampler::new(config);
1075        let tok = sampler.sample(&logits, &[]);
1076        // After masking, only "yes"(1) or "no"(2) remain
1077        assert!(tok == 1 || tok == 2, "expected yes(1) or no(2), got {tok}");
1078    }
1079
1080    #[test]
1081    fn test_grammar_state_advances_through_sequence() {
1082        let vocab: Vec<(u32, Vec<u8>)> =
1083            vec![(0, b"a".to_vec()), (1, b"b".to_vec()), (2, b"c".to_vec())];
1084        let g = Arc::new(Grammar::parse(r#"root ::= "a" "b""#).unwrap());
1085        let config = SamplerConfig {
1086            temperature: 0.0,
1087            grammar: Some(g),
1088            token_vocab: Some(Arc::new(vocab)),
1089            ..SamplerConfig::default()
1090        };
1091
1092        // Equal logits — grammar drives selection
1093        let logits = vec![1.0f32, 0.5, 0.5];
1094        let mut sampler = Sampler::new(config);
1095
1096        // First step: only "a" is valid
1097        let tok1 = sampler.sample(&logits, &[]);
1098        assert_eq!(tok1, 0, "first token must be 'a' (id=0)");
1099
1100        // Second step: only "b" is valid
1101        let tok2 = sampler.sample(&logits, &[0]);
1102        assert_eq!(tok2, 1, "second token must be 'b' (id=1)");
1103
1104        assert!(
1105            sampler.grammar_complete(),
1106            "grammar should be complete after 'a' + 'b'"
1107        );
1108    }
1109
1110    #[test]
1111    fn test_grammar_parse_roundtrip() {
1112        let g = Grammar::parse("root ::= [a-z]+ \":\" [0-9]+").unwrap();
1113        assert!(!g.rules.is_empty());
1114        assert_eq!(g.root, "root");
1115    }
1116
1117    #[test]
1118    fn test_grammar_stuck_state_masks_all() {
1119        // A grammar that requires "x" — advancing with "y" must produce an error
1120        let g = Arc::new(Grammar::parse(r#"root ::= "x""#).unwrap());
1121        let mut state = g.initial_state();
1122        let result = state.advance(b"y");
1123        assert!(result.is_err(), "advancing with wrong bytes should error");
1124    }
1125}