Skip to main content

oxillama_runtime/sampling/
chain.rs

1//! Composable sampler chain — trait-based pipeline for token selection.
2//!
3//! Each [`SamplerStage`] transforms a logit vector in-place. Stages are
4//! composed into a [`SamplerChain`] that runs them in order before final
5//! token selection.
6//!
7//! Built-in stages: [`RepetitionPenalty`], `GrammarMask`,
8//! [`TemperatureScale`], [`TopK`], [`TopP`], [`MinP`].
9//!
10//! # Example
11//!
12//! ```ignore
13//! use oxillama_runtime::sampling::chain::*;
14//!
15//! let chain = SamplerChain::new()
16//!     .push(RepetitionPenalty::new(1.1, 64))
17//!     .push(TemperatureScale::new(0.8))
18//!     .push(TopK::new(40))
19//!     .push(TopP::new(0.9));
20//!
21//! let logits = vec![1.0, 2.0, 3.0, 0.5];
22//! let token = chain.sample(&logits, &recent_tokens);
23//! ```
24
25use std::collections::HashSet;
26
27/// A single stage in the sampling pipeline.
28///
29/// Each stage receives the full logit vector (mutable) and the recent token
30/// history, and transforms the logits in place (e.g., applying penalties,
31/// masking, temperature scaling).
32pub trait SamplerStage: Send + Sync {
33    /// Apply this stage to the logit vector in place.
34    fn apply(&self, logits: &mut Vec<f32>, recent_tokens: &[u32]);
35
36    /// Human-readable name for logging / debugging.
37    fn name(&self) -> &'static str;
38}
39
40/// A composable pipeline of [`SamplerStage`]s followed by a final selection step.
41pub struct SamplerChain {
42    stages: Vec<Box<dyn SamplerStage>>,
43    /// Seed for the final selection RNG.
44    seed: u64,
45}
46
47impl Default for SamplerChain {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SamplerChain {
54    /// Create an empty chain (no stages, default seed).
55    pub fn new() -> Self {
56        Self {
57            stages: Vec::new(),
58            seed: 0xDEAD_BEEF_CAFE_BABE,
59        }
60    }
61
62    /// Set the RNG seed for the final selection step.
63    pub fn with_seed(mut self, seed: u64) -> Self {
64        self.seed = seed;
65        self
66    }
67
68    /// Append a stage to the pipeline. Returns self for chaining.
69    pub fn push(mut self, stage: impl SamplerStage + 'static) -> Self {
70        self.stages.push(Box::new(stage));
71        self
72    }
73
74    /// Run all stages on a copy of the logits and select a token.
75    ///
76    /// The original logit slice is not modified.
77    pub fn sample(&self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
78        if logits.is_empty() {
79            return 0;
80        }
81
82        let mut processed = logits.to_vec();
83
84        for stage in &self.stages {
85            stage.apply(&mut processed, recent_tokens);
86        }
87
88        // Final selection: softmax + weighted random
89        select_token(&processed, self.seed)
90    }
91
92    /// Return the number of stages in the chain.
93    pub fn len(&self) -> usize {
94        self.stages.len()
95    }
96
97    /// Check if the chain is empty.
98    pub fn is_empty(&self) -> bool {
99        self.stages.is_empty()
100    }
101
102    /// List the names of all stages in order.
103    pub fn stage_names(&self) -> Vec<&'static str> {
104        self.stages.iter().map(|s| s.name()).collect()
105    }
106
107    /// Build a chain from a `SamplerConfig`, replicating the standard pipeline.
108    ///
109    /// Pipeline order:
110    /// logit-bias → repetition penalty → DRY → XTC → TypicalP → TopA → Eta
111    ///   → temperature → top-K → min-P → top-P.
112    ///
113    /// Logit-bias must come first so that bans and boosts are visible to all
114    /// downstream filtering stages. The five advanced stages are inserted after
115    /// repetition penalty (they work on logit-scale values) but before temperature
116    /// scaling (so they see the pre-temperature distribution shape).
117    pub fn from_config(config: &super::SamplerConfig) -> Self {
118        use super::advanced::{DryStage, EtaStage, TopAStage, TypicalPStage, XtcStage};
119
120        let mut chain = Self::new();
121
122        if let Some(seed) = config.seed {
123            chain = chain.with_seed(seed);
124        }
125
126        // Insert logit-bias / banned-tokens stage first (before everything else).
127        if !config.logit_bias.is_empty() || !config.banned_tokens.is_empty() {
128            chain = chain.push(LogitBias::new(
129                config.logit_bias.clone(),
130                config.banned_tokens.clone(),
131            ));
132        }
133
134        if config.repetition_penalty != 1.0 {
135            chain = chain.push(RepetitionPenalty::new(
136                config.repetition_penalty,
137                config.repetition_penalty_window,
138            ));
139        }
140
141        // ── Advanced stages (Track B, v0.1.7) ────────────────────────────────
142        // Order: DRY → XTC → TypicalP → TopA → Eta
143        if config.dry_multiplier != 0.0 {
144            chain = chain.push(DryStage::new(
145                config.dry_multiplier,
146                config.dry_base,
147                config.dry_allowed_length,
148                Vec::new(), // sequence_breakers — not yet in SamplerConfig; extend later
149            ));
150        }
151
152        if config.xtc_threshold < 1.0 && config.xtc_probability > 0.0 {
153            let seed = config.seed.unwrap_or(0xDEAD_BEEF_CAFE_BABE);
154            chain = chain.push(XtcStage::new(
155                config.xtc_threshold,
156                config.xtc_probability,
157                seed,
158            ));
159        }
160
161        if config.typical_p < 1.0 {
162            chain = chain.push(TypicalPStage::new(config.typical_p));
163        }
164
165        if config.top_a != 0.0 {
166            chain = chain.push(TopAStage::new(config.top_a));
167        }
168
169        if config.eta_cutoff != 0.0 || config.epsilon_cutoff != 0.0 {
170            chain = chain.push(EtaStage::new(config.eta_cutoff, config.epsilon_cutoff));
171        }
172        // ─────────────────────────────────────────────────────────────────────
173
174        if config.temperature <= 0.0 {
175            // Greedy: just push the greedy selector
176            chain = chain.push(GreedySelect);
177            return chain;
178        }
179
180        if config.temperature != 1.0 {
181            chain = chain.push(TemperatureScale::new(config.temperature));
182        }
183
184        if config.top_k > 0 {
185            chain = chain.push(TopK::new(config.top_k));
186        }
187
188        if config.min_p > 0.0 {
189            chain = chain.push(MinP::new(config.min_p));
190        }
191
192        if config.top_p < 1.0 {
193            chain = chain.push(TopP::new(config.top_p));
194        }
195
196        chain
197    }
198}
199
200// ── Built-in stages ──────────────────────────────────────────────────────────
201
202/// Repetition penalty stage — penalizes recently generated tokens.
203pub struct RepetitionPenalty {
204    penalty: f32,
205    window: usize,
206}
207
208impl RepetitionPenalty {
209    /// Create a new repetition penalty stage.
210    ///
211    /// `penalty` of 1.0 = no effect. Values > 1.0 penalize repetition.
212    /// `window` is the number of recent tokens to consider.
213    pub fn new(penalty: f32, window: usize) -> Self {
214        Self { penalty, window }
215    }
216}
217
218impl SamplerStage for RepetitionPenalty {
219    fn apply(&self, logits: &mut Vec<f32>, recent_tokens: &[u32]) {
220        if self.penalty == 1.0 || recent_tokens.is_empty() {
221            return;
222        }
223        let start = recent_tokens.len().saturating_sub(self.window);
224        for &token in &recent_tokens[start..] {
225            let idx = token as usize;
226            if idx < logits.len() {
227                if logits[idx] > 0.0 {
228                    logits[idx] /= self.penalty;
229                } else {
230                    logits[idx] *= self.penalty;
231                }
232            }
233        }
234    }
235
236    fn name(&self) -> &'static str {
237        "repetition_penalty"
238    }
239}
240
241/// Temperature scaling stage.
242pub struct TemperatureScale {
243    temperature: f32,
244}
245
246impl TemperatureScale {
247    /// Create a new temperature scaling stage.
248    pub fn new(temperature: f32) -> Self {
249        Self { temperature }
250    }
251}
252
253impl SamplerStage for TemperatureScale {
254    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
255        if self.temperature <= 0.0 || self.temperature == 1.0 {
256            return;
257        }
258        let inv = 1.0 / self.temperature;
259        for v in logits.iter_mut() {
260            *v *= inv;
261        }
262    }
263
264    fn name(&self) -> &'static str {
265        "temperature"
266    }
267}
268
269/// Top-K filtering stage — keeps only the K highest logits; sets rest to -inf.
270pub struct TopK {
271    k: usize,
272}
273
274impl TopK {
275    /// Create a new top-K stage with the given `k`.
276    pub fn new(k: usize) -> Self {
277        Self { k }
278    }
279}
280
281impl SamplerStage for TopK {
282    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
283        if self.k == 0 || self.k >= logits.len() {
284            return;
285        }
286        // Find the k-th largest value
287        let mut sorted: Vec<f32> = logits.clone();
288        sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
289        let threshold = sorted[self.k - 1];
290        // Keep tokens at or above threshold, up to k
291        let mut kept = 0usize;
292        for v in logits.iter_mut() {
293            if *v >= threshold && kept < self.k {
294                kept += 1;
295            } else if *v < threshold {
296                *v = f32::NEG_INFINITY;
297            }
298        }
299    }
300
301    fn name(&self) -> &'static str {
302        "top_k"
303    }
304}
305
306/// Top-P (nucleus) filtering stage — keeps smallest set with cumulative prob >= p.
307pub struct TopP {
308    p: f32,
309}
310
311impl TopP {
312    /// Create a new top-P (nucleus) stage with the given probability threshold.
313    pub fn new(p: f32) -> Self {
314        Self { p }
315    }
316}
317
318impl SamplerStage for TopP {
319    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
320        if self.p >= 1.0 {
321            return;
322        }
323        // Softmax first
324        let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
325        let probs: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
326        let sum: f32 = probs.iter().sum();
327        if sum <= 0.0 {
328            return;
329        }
330        let probs: Vec<f32> = probs.iter().map(|&p| p / sum).collect();
331
332        // Sort indices by probability descending
333        let mut indices: Vec<usize> = (0..probs.len()).collect();
334        indices.sort_unstable_by(|&a, &b| {
335            probs[b]
336                .partial_cmp(&probs[a])
337                .unwrap_or(std::cmp::Ordering::Equal)
338        });
339
340        // Find cutoff
341        let mut cumulative = 0.0f32;
342        let mut cutoff_idx = indices.len();
343        for (i, &idx) in indices.iter().enumerate() {
344            cumulative += probs[idx];
345            if cumulative >= self.p {
346                cutoff_idx = i + 1;
347                break;
348            }
349        }
350
351        // Mask everything beyond cutoff
352        let kept: HashSet<usize> = indices[..cutoff_idx].iter().copied().collect();
353        for (i, v) in logits.iter_mut().enumerate() {
354            if !kept.contains(&i) {
355                *v = f32::NEG_INFINITY;
356            }
357        }
358    }
359
360    fn name(&self) -> &'static str {
361        "top_p"
362    }
363}
364
365/// Min-P filtering stage — removes tokens with prob < min_p * max_prob.
366pub struct MinP {
367    min_p: f32,
368}
369
370impl MinP {
371    /// Create a new min-P stage with the given minimum probability ratio.
372    pub fn new(min_p: f32) -> Self {
373        Self { min_p }
374    }
375}
376
377impl SamplerStage for MinP {
378    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
379        if self.min_p <= 0.0 {
380            return;
381        }
382        let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
383        let probs: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
384        let sum: f32 = probs.iter().sum();
385        if sum <= 0.0 {
386            return;
387        }
388        let max_prob = probs.iter().fold(0.0f32, |a, &b| a.max(b)) / sum;
389        let threshold = self.min_p * max_prob;
390        for (i, v) in logits.iter_mut().enumerate() {
391            if probs[i] / sum < threshold {
392                *v = f32::NEG_INFINITY;
393            }
394        }
395    }
396
397    fn name(&self) -> &'static str {
398        "min_p"
399    }
400}
401
402/// Logit-bias stage — applies per-token additive biases and hard bans.
403///
404/// This stage must be positioned **before** temperature scaling, repetition
405/// penalty, and any filtering stages so that bans and biases influence all
406/// downstream steps uniformly.
407///
408/// Processing order (matches `sampling::mod::apply_logit_bias_and_banned_tokens`):
409/// 1. Banned tokens → `f32::NEG_INFINITY` (hard ban, cannot be overridden by bias).
410/// 2. Logit biases are added to surviving logits.
411pub struct LogitBias {
412    /// Per-token additive biases.
413    biases: std::collections::HashMap<u32, f32>,
414    /// Tokens that must never be sampled.
415    banned: Vec<u32>,
416}
417
418impl LogitBias {
419    /// Create a new logit-bias stage.
420    ///
421    /// `biases` maps token IDs to additive values (positive = boost,
422    /// negative = suppress).  `banned` is the list of tokens to hard-ban.
423    pub fn new(biases: std::collections::HashMap<u32, f32>, banned: Vec<u32>) -> Self {
424        Self { biases, banned }
425    }
426
427    /// Create a stage with only hard-banned tokens and no biases.
428    pub fn banned_only(banned: Vec<u32>) -> Self {
429        Self {
430            biases: std::collections::HashMap::new(),
431            banned,
432        }
433    }
434
435    /// Create a stage with only biases and no bans.
436    pub fn biases_only(biases: std::collections::HashMap<u32, f32>) -> Self {
437        Self {
438            biases,
439            banned: Vec::new(),
440        }
441    }
442}
443
444impl SamplerStage for LogitBias {
445    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
446        // Step 1: hard ban.
447        for &token in &self.banned {
448            let idx = token as usize;
449            if idx < logits.len() {
450                logits[idx] = f32::NEG_INFINITY;
451            }
452        }
453        // Step 2: additive bias (skip already-banned slots).
454        for (&token, &bias) in &self.biases {
455            let idx = token as usize;
456            if idx < logits.len() && logits[idx].is_finite() {
457                logits[idx] += bias;
458            }
459        }
460    }
461
462    fn name(&self) -> &'static str {
463        "logit_bias"
464    }
465}
466
467/// Greedy selection stage — sets all logits except the max to -inf.
468/// Use this as the final stage for deterministic (argmax) output.
469pub struct GreedySelect;
470
471impl SamplerStage for GreedySelect {
472    fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
473        if logits.is_empty() {
474            return;
475        }
476        let max_idx = logits
477            .iter()
478            .enumerate()
479            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
480            .map(|(i, _)| i)
481            .unwrap_or(0);
482        for (i, v) in logits.iter_mut().enumerate() {
483            if i != max_idx {
484                *v = f32::NEG_INFINITY;
485            }
486        }
487    }
488
489    fn name(&self) -> &'static str {
490        "greedy"
491    }
492}
493
494// ── Internal helpers ─────────────────────────────────────────────────────────
495
496/// Final token selection: softmax + weighted random using xorshift64.
497fn select_token(logits: &[f32], seed: u64) -> u32 {
498    if logits.is_empty() {
499        return 0;
500    }
501
502    // Softmax
503    let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
504    let exps: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
505    let sum: f32 = exps.iter().sum();
506
507    if sum <= 0.0 {
508        // All -inf — fallback to first non-NEG_INFINITY, or 0
509        return logits
510            .iter()
511            .enumerate()
512            .find(|(_, &v)| v > f32::NEG_INFINITY)
513            .map(|(i, _)| i as u32)
514            .unwrap_or(0);
515    }
516
517    // Check if only one token survived (common after greedy/aggressive filtering)
518    let mut survivor_count = 0usize;
519    let mut survivor_idx = 0u32;
520    for (i, &e) in exps.iter().enumerate() {
521        if e > 0.0 {
522            survivor_count += 1;
523            survivor_idx = i as u32;
524            if survivor_count > 1 {
525                break;
526            }
527        }
528    }
529    if survivor_count == 1 {
530        return survivor_idx;
531    }
532
533    // Weighted random selection via xorshift64
534    let mut state = if seed == 0 {
535        0x517c_c1b7_2722_0a95_u64
536    } else {
537        seed
538    };
539    state ^= state << 13;
540    state ^= state >> 7;
541    state ^= state << 17;
542    let r = (state >> 40) as f32 / (1u64 << 24) as f32;
543
544    let mut cumulative = 0.0f32;
545    for (i, &e) in exps.iter().enumerate() {
546        cumulative += e / sum;
547        if r < cumulative {
548            return i as u32;
549        }
550    }
551
552    (logits.len() - 1) as u32
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::sampling::SamplerConfig;
559
560    #[test]
561    fn test_empty_chain_selects_token() {
562        let chain = SamplerChain::new().with_seed(42);
563        let logits = vec![1.0, 2.0, 3.0];
564        let token = chain.sample(&logits, &[]);
565        assert!((token as usize) < logits.len());
566    }
567
568    #[test]
569    fn test_greedy_chain() {
570        let chain = SamplerChain::new().push(GreedySelect);
571        let logits = vec![1.0, 5.0, 3.0, 0.5];
572        let token = chain.sample(&logits, &[]);
573        assert_eq!(token, 1); // argmax
574    }
575
576    #[test]
577    fn test_temperature_affects_distribution() {
578        // Very cold temperature should always pick top
579        let chain_cold = SamplerChain::new()
580            .with_seed(42)
581            .push(TemperatureScale::new(0.01));
582
583        let logits = vec![3.0, 2.0, 1.0, 0.5];
584        let token = chain_cold.sample(&logits, &[]);
585        assert_eq!(token, 0);
586    }
587
588    #[test]
589    fn test_top_k_limits_candidates() {
590        let chain = SamplerChain::new().push(TopK::new(1)).with_seed(42);
591        let logits = vec![1.0, 5.0, 3.0];
592        let token = chain.sample(&logits, &[]);
593        assert_eq!(token, 1); // only top-1 survives
594    }
595
596    #[test]
597    fn test_repetition_penalty_reduces_repeated() {
598        let chain = SamplerChain::new()
599            .push(RepetitionPenalty::new(100.0, 64))
600            .push(GreedySelect);
601        let logits = vec![1.0, 5.0, 4.9, 1.0];
602        // Without penalty, token 1 wins. With penalty on token 1, token 2 should win.
603        let token = chain.sample(&logits, &[1]);
604        assert_eq!(token, 2);
605    }
606
607    #[test]
608    fn test_chain_from_config_greedy() {
609        let config = SamplerConfig::greedy();
610        let chain = SamplerChain::from_config(&config);
611        let logits = vec![1.0, 5.0, 3.0];
612        assert_eq!(chain.sample(&logits, &[]), 1);
613    }
614
615    #[test]
616    fn test_chain_from_config_default() {
617        let config = SamplerConfig::default();
618        let chain = SamplerChain::from_config(&config);
619        assert!(!chain.is_empty());
620        let names = chain.stage_names();
621        assert!(names.contains(&"repetition_penalty"));
622        assert!(names.contains(&"temperature"));
623    }
624
625    #[test]
626    fn test_stage_names() {
627        let chain = SamplerChain::new()
628            .push(RepetitionPenalty::new(1.1, 64))
629            .push(TemperatureScale::new(0.8))
630            .push(TopK::new(40))
631            .push(TopP::new(0.9))
632            .push(MinP::new(0.05));
633        let names = chain.stage_names();
634        assert_eq!(
635            names,
636            vec![
637                "repetition_penalty",
638                "temperature",
639                "top_k",
640                "top_p",
641                "min_p"
642            ]
643        );
644    }
645
646    #[test]
647    fn test_empty_logits() {
648        let chain = SamplerChain::new().push(GreedySelect);
649        assert_eq!(chain.sample(&[], &[]), 0);
650    }
651
652    #[test]
653    fn test_min_p_filters_low_prob() {
654        let chain = SamplerChain::new().push(MinP::new(0.1)).push(GreedySelect);
655        // One dominant token
656        let logits = vec![10.0, -10.0, -10.0, -10.0];
657        let token = chain.sample(&logits, &[]);
658        assert_eq!(token, 0);
659    }
660
661    #[test]
662    fn test_top_p_nucleus() {
663        let chain = SamplerChain::new().push(TopP::new(0.5)).with_seed(42);
664        // One very dominant token
665        let logits = vec![100.0, 0.0, 0.0, 0.0];
666        let token = chain.sample(&logits, &[]);
667        assert_eq!(token, 0);
668    }
669
670    #[test]
671    fn test_chain_len_and_is_empty() {
672        let chain = SamplerChain::new();
673        assert!(chain.is_empty());
674        assert_eq!(chain.len(), 0);
675
676        let chain = chain.push(GreedySelect);
677        assert!(!chain.is_empty());
678        assert_eq!(chain.len(), 1);
679    }
680
681    // ── LogitBias stage tests ─────────────────────────────────────────────────
682
683    #[test]
684    fn test_logit_bias_bans_token() {
685        let chain = SamplerChain::new()
686            .push(LogitBias::banned_only(vec![1]))
687            .push(GreedySelect);
688        // Token 1 would normally win (logit 5.0) but is banned.
689        let logits = vec![1.0f32, 5.0, 3.0];
690        let tok = chain.sample(&logits, &[]);
691        assert_eq!(
692            tok, 2,
693            "banned token 1 should never win; token 2 (3.0) should"
694        );
695    }
696
697    #[test]
698    fn test_logit_bias_boosts_token() {
699        let mut biases = std::collections::HashMap::new();
700        biases.insert(2u32, 100.0f32);
701        let chain = SamplerChain::new()
702            .push(LogitBias::biases_only(biases))
703            .push(GreedySelect);
704        let logits = vec![10.0f32, 10.0, 0.0]; // token 2 has lowest logit before bias
705        let tok = chain.sample(&logits, &[]);
706        assert_eq!(tok, 2, "large positive bias should make token 2 win");
707    }
708
709    #[test]
710    fn test_logit_bias_ban_wins_over_positive_bias() {
711        // A banned token should stay at -inf even if it also has a positive bias.
712        let mut biases = std::collections::HashMap::new();
713        biases.insert(0u32, 999.0f32); // very large positive bias on token 0
714        let chain = SamplerChain::new()
715            .push(LogitBias::new(biases, vec![0])) // but also banned
716            .push(GreedySelect);
717        let logits = vec![10.0f32, 1.0, 1.0];
718        let tok = chain.sample(&logits, &[]);
719        // Token 0 is banned — the positive bias must NOT override the ban.
720        assert_ne!(tok, 0, "ban must override positive bias");
721    }
722
723    #[test]
724    fn test_from_config_includes_logit_bias_stage() {
725        let mut biases = std::collections::HashMap::new();
726        biases.insert(0u32, -100.0f32);
727        let config = SamplerConfig {
728            temperature: 0.0,
729            logit_bias: biases,
730            ..SamplerConfig::greedy()
731        };
732        let chain = SamplerChain::from_config(&config);
733        let names = chain.stage_names();
734        assert!(
735            names.contains(&"logit_bias"),
736            "from_config should add logit_bias stage when bias map is non-empty"
737        );
738    }
739
740    #[test]
741    fn test_from_config_no_logit_bias_stage_when_empty() {
742        let config = SamplerConfig::greedy();
743        let chain = SamplerChain::from_config(&config);
744        let names = chain.stage_names();
745        assert!(
746            !names.contains(&"logit_bias"),
747            "from_config should NOT add logit_bias stage when both bias map and banned list are empty"
748        );
749    }
750}