oxillama_runtime/sampling/
chain.rs1use std::collections::HashSet;
26
27pub trait SamplerStage: Send + Sync {
33 fn apply(&self, logits: &mut Vec<f32>, recent_tokens: &[u32]);
35
36 fn name(&self) -> &'static str;
38}
39
40pub struct SamplerChain {
42 stages: Vec<Box<dyn SamplerStage>>,
43 seed: u64,
45}
46
47impl Default for SamplerChain {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SamplerChain {
54 pub fn new() -> Self {
56 Self {
57 stages: Vec::new(),
58 seed: 0xDEAD_BEEF_CAFE_BABE,
59 }
60 }
61
62 pub fn with_seed(mut self, seed: u64) -> Self {
64 self.seed = seed;
65 self
66 }
67
68 pub fn push(mut self, stage: impl SamplerStage + 'static) -> Self {
70 self.stages.push(Box::new(stage));
71 self
72 }
73
74 pub fn sample(&self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
78 if logits.is_empty() {
79 return 0;
80 }
81
82 let mut processed = logits.to_vec();
83
84 for stage in &self.stages {
85 stage.apply(&mut processed, recent_tokens);
86 }
87
88 select_token(&processed, self.seed)
90 }
91
92 pub fn len(&self) -> usize {
94 self.stages.len()
95 }
96
97 pub fn is_empty(&self) -> bool {
99 self.stages.is_empty()
100 }
101
102 pub fn stage_names(&self) -> Vec<&'static str> {
104 self.stages.iter().map(|s| s.name()).collect()
105 }
106
107 pub fn from_config(config: &super::SamplerConfig) -> Self {
118 use super::advanced::{DryStage, EtaStage, TopAStage, TypicalPStage, XtcStage};
119
120 let mut chain = Self::new();
121
122 if let Some(seed) = config.seed {
123 chain = chain.with_seed(seed);
124 }
125
126 if !config.logit_bias.is_empty() || !config.banned_tokens.is_empty() {
128 chain = chain.push(LogitBias::new(
129 config.logit_bias.clone(),
130 config.banned_tokens.clone(),
131 ));
132 }
133
134 if config.repetition_penalty != 1.0 {
135 chain = chain.push(RepetitionPenalty::new(
136 config.repetition_penalty,
137 config.repetition_penalty_window,
138 ));
139 }
140
141 if config.dry_multiplier != 0.0 {
144 chain = chain.push(DryStage::new(
145 config.dry_multiplier,
146 config.dry_base,
147 config.dry_allowed_length,
148 Vec::new(), ));
150 }
151
152 if config.xtc_threshold < 1.0 && config.xtc_probability > 0.0 {
153 let seed = config.seed.unwrap_or(0xDEAD_BEEF_CAFE_BABE);
154 chain = chain.push(XtcStage::new(
155 config.xtc_threshold,
156 config.xtc_probability,
157 seed,
158 ));
159 }
160
161 if config.typical_p < 1.0 {
162 chain = chain.push(TypicalPStage::new(config.typical_p));
163 }
164
165 if config.top_a != 0.0 {
166 chain = chain.push(TopAStage::new(config.top_a));
167 }
168
169 if config.eta_cutoff != 0.0 || config.epsilon_cutoff != 0.0 {
170 chain = chain.push(EtaStage::new(config.eta_cutoff, config.epsilon_cutoff));
171 }
172 if config.temperature <= 0.0 {
175 chain = chain.push(GreedySelect);
177 return chain;
178 }
179
180 if config.temperature != 1.0 {
181 chain = chain.push(TemperatureScale::new(config.temperature));
182 }
183
184 if config.top_k > 0 {
185 chain = chain.push(TopK::new(config.top_k));
186 }
187
188 if config.min_p > 0.0 {
189 chain = chain.push(MinP::new(config.min_p));
190 }
191
192 if config.top_p < 1.0 {
193 chain = chain.push(TopP::new(config.top_p));
194 }
195
196 chain
197 }
198}
199
200pub struct RepetitionPenalty {
204 penalty: f32,
205 window: usize,
206}
207
208impl RepetitionPenalty {
209 pub fn new(penalty: f32, window: usize) -> Self {
214 Self { penalty, window }
215 }
216}
217
218impl SamplerStage for RepetitionPenalty {
219 fn apply(&self, logits: &mut Vec<f32>, recent_tokens: &[u32]) {
220 if self.penalty == 1.0 || recent_tokens.is_empty() {
221 return;
222 }
223 let start = recent_tokens.len().saturating_sub(self.window);
224 for &token in &recent_tokens[start..] {
225 let idx = token as usize;
226 if idx < logits.len() {
227 if logits[idx] > 0.0 {
228 logits[idx] /= self.penalty;
229 } else {
230 logits[idx] *= self.penalty;
231 }
232 }
233 }
234 }
235
236 fn name(&self) -> &'static str {
237 "repetition_penalty"
238 }
239}
240
241pub struct TemperatureScale {
243 temperature: f32,
244}
245
246impl TemperatureScale {
247 pub fn new(temperature: f32) -> Self {
249 Self { temperature }
250 }
251}
252
253impl SamplerStage for TemperatureScale {
254 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
255 if self.temperature <= 0.0 || self.temperature == 1.0 {
256 return;
257 }
258 let inv = 1.0 / self.temperature;
259 for v in logits.iter_mut() {
260 *v *= inv;
261 }
262 }
263
264 fn name(&self) -> &'static str {
265 "temperature"
266 }
267}
268
269pub struct TopK {
271 k: usize,
272}
273
274impl TopK {
275 pub fn new(k: usize) -> Self {
277 Self { k }
278 }
279}
280
281impl SamplerStage for TopK {
282 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
283 if self.k == 0 || self.k >= logits.len() {
284 return;
285 }
286 let mut sorted: Vec<f32> = logits.clone();
288 sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
289 let threshold = sorted[self.k - 1];
290 let mut kept = 0usize;
292 for v in logits.iter_mut() {
293 if *v >= threshold && kept < self.k {
294 kept += 1;
295 } else if *v < threshold {
296 *v = f32::NEG_INFINITY;
297 }
298 }
299 }
300
301 fn name(&self) -> &'static str {
302 "top_k"
303 }
304}
305
306pub struct TopP {
308 p: f32,
309}
310
311impl TopP {
312 pub fn new(p: f32) -> Self {
314 Self { p }
315 }
316}
317
318impl SamplerStage for TopP {
319 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
320 if self.p >= 1.0 {
321 return;
322 }
323 let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
325 let probs: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
326 let sum: f32 = probs.iter().sum();
327 if sum <= 0.0 {
328 return;
329 }
330 let probs: Vec<f32> = probs.iter().map(|&p| p / sum).collect();
331
332 let mut indices: Vec<usize> = (0..probs.len()).collect();
334 indices.sort_unstable_by(|&a, &b| {
335 probs[b]
336 .partial_cmp(&probs[a])
337 .unwrap_or(std::cmp::Ordering::Equal)
338 });
339
340 let mut cumulative = 0.0f32;
342 let mut cutoff_idx = indices.len();
343 for (i, &idx) in indices.iter().enumerate() {
344 cumulative += probs[idx];
345 if cumulative >= self.p {
346 cutoff_idx = i + 1;
347 break;
348 }
349 }
350
351 let kept: HashSet<usize> = indices[..cutoff_idx].iter().copied().collect();
353 for (i, v) in logits.iter_mut().enumerate() {
354 if !kept.contains(&i) {
355 *v = f32::NEG_INFINITY;
356 }
357 }
358 }
359
360 fn name(&self) -> &'static str {
361 "top_p"
362 }
363}
364
365pub struct MinP {
367 min_p: f32,
368}
369
370impl MinP {
371 pub fn new(min_p: f32) -> Self {
373 Self { min_p }
374 }
375}
376
377impl SamplerStage for MinP {
378 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
379 if self.min_p <= 0.0 {
380 return;
381 }
382 let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
383 let probs: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
384 let sum: f32 = probs.iter().sum();
385 if sum <= 0.0 {
386 return;
387 }
388 let max_prob = probs.iter().fold(0.0f32, |a, &b| a.max(b)) / sum;
389 let threshold = self.min_p * max_prob;
390 for (i, v) in logits.iter_mut().enumerate() {
391 if probs[i] / sum < threshold {
392 *v = f32::NEG_INFINITY;
393 }
394 }
395 }
396
397 fn name(&self) -> &'static str {
398 "min_p"
399 }
400}
401
402pub struct LogitBias {
412 biases: std::collections::HashMap<u32, f32>,
414 banned: Vec<u32>,
416}
417
418impl LogitBias {
419 pub fn new(biases: std::collections::HashMap<u32, f32>, banned: Vec<u32>) -> Self {
424 Self { biases, banned }
425 }
426
427 pub fn banned_only(banned: Vec<u32>) -> Self {
429 Self {
430 biases: std::collections::HashMap::new(),
431 banned,
432 }
433 }
434
435 pub fn biases_only(biases: std::collections::HashMap<u32, f32>) -> Self {
437 Self {
438 biases,
439 banned: Vec::new(),
440 }
441 }
442}
443
444impl SamplerStage for LogitBias {
445 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
446 for &token in &self.banned {
448 let idx = token as usize;
449 if idx < logits.len() {
450 logits[idx] = f32::NEG_INFINITY;
451 }
452 }
453 for (&token, &bias) in &self.biases {
455 let idx = token as usize;
456 if idx < logits.len() && logits[idx].is_finite() {
457 logits[idx] += bias;
458 }
459 }
460 }
461
462 fn name(&self) -> &'static str {
463 "logit_bias"
464 }
465}
466
467pub struct GreedySelect;
470
471impl SamplerStage for GreedySelect {
472 fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
473 if logits.is_empty() {
474 return;
475 }
476 let max_idx = logits
477 .iter()
478 .enumerate()
479 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
480 .map(|(i, _)| i)
481 .unwrap_or(0);
482 for (i, v) in logits.iter_mut().enumerate() {
483 if i != max_idx {
484 *v = f32::NEG_INFINITY;
485 }
486 }
487 }
488
489 fn name(&self) -> &'static str {
490 "greedy"
491 }
492}
493
494fn select_token(logits: &[f32], seed: u64) -> u32 {
498 if logits.is_empty() {
499 return 0;
500 }
501
502 let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
504 let exps: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
505 let sum: f32 = exps.iter().sum();
506
507 if sum <= 0.0 {
508 return logits
510 .iter()
511 .enumerate()
512 .find(|(_, &v)| v > f32::NEG_INFINITY)
513 .map(|(i, _)| i as u32)
514 .unwrap_or(0);
515 }
516
517 let mut survivor_count = 0usize;
519 let mut survivor_idx = 0u32;
520 for (i, &e) in exps.iter().enumerate() {
521 if e > 0.0 {
522 survivor_count += 1;
523 survivor_idx = i as u32;
524 if survivor_count > 1 {
525 break;
526 }
527 }
528 }
529 if survivor_count == 1 {
530 return survivor_idx;
531 }
532
533 let mut state = if seed == 0 {
535 0x517c_c1b7_2722_0a95_u64
536 } else {
537 seed
538 };
539 state ^= state << 13;
540 state ^= state >> 7;
541 state ^= state << 17;
542 let r = (state >> 40) as f32 / (1u64 << 24) as f32;
543
544 let mut cumulative = 0.0f32;
545 for (i, &e) in exps.iter().enumerate() {
546 cumulative += e / sum;
547 if r < cumulative {
548 return i as u32;
549 }
550 }
551
552 (logits.len() - 1) as u32
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::sampling::SamplerConfig;
559
560 #[test]
561 fn test_empty_chain_selects_token() {
562 let chain = SamplerChain::new().with_seed(42);
563 let logits = vec![1.0, 2.0, 3.0];
564 let token = chain.sample(&logits, &[]);
565 assert!((token as usize) < logits.len());
566 }
567
568 #[test]
569 fn test_greedy_chain() {
570 let chain = SamplerChain::new().push(GreedySelect);
571 let logits = vec![1.0, 5.0, 3.0, 0.5];
572 let token = chain.sample(&logits, &[]);
573 assert_eq!(token, 1); }
575
576 #[test]
577 fn test_temperature_affects_distribution() {
578 let chain_cold = SamplerChain::new()
580 .with_seed(42)
581 .push(TemperatureScale::new(0.01));
582
583 let logits = vec![3.0, 2.0, 1.0, 0.5];
584 let token = chain_cold.sample(&logits, &[]);
585 assert_eq!(token, 0);
586 }
587
588 #[test]
589 fn test_top_k_limits_candidates() {
590 let chain = SamplerChain::new().push(TopK::new(1)).with_seed(42);
591 let logits = vec![1.0, 5.0, 3.0];
592 let token = chain.sample(&logits, &[]);
593 assert_eq!(token, 1); }
595
596 #[test]
597 fn test_repetition_penalty_reduces_repeated() {
598 let chain = SamplerChain::new()
599 .push(RepetitionPenalty::new(100.0, 64))
600 .push(GreedySelect);
601 let logits = vec![1.0, 5.0, 4.9, 1.0];
602 let token = chain.sample(&logits, &[1]);
604 assert_eq!(token, 2);
605 }
606
607 #[test]
608 fn test_chain_from_config_greedy() {
609 let config = SamplerConfig::greedy();
610 let chain = SamplerChain::from_config(&config);
611 let logits = vec![1.0, 5.0, 3.0];
612 assert_eq!(chain.sample(&logits, &[]), 1);
613 }
614
615 #[test]
616 fn test_chain_from_config_default() {
617 let config = SamplerConfig::default();
618 let chain = SamplerChain::from_config(&config);
619 assert!(!chain.is_empty());
620 let names = chain.stage_names();
621 assert!(names.contains(&"repetition_penalty"));
622 assert!(names.contains(&"temperature"));
623 }
624
625 #[test]
626 fn test_stage_names() {
627 let chain = SamplerChain::new()
628 .push(RepetitionPenalty::new(1.1, 64))
629 .push(TemperatureScale::new(0.8))
630 .push(TopK::new(40))
631 .push(TopP::new(0.9))
632 .push(MinP::new(0.05));
633 let names = chain.stage_names();
634 assert_eq!(
635 names,
636 vec![
637 "repetition_penalty",
638 "temperature",
639 "top_k",
640 "top_p",
641 "min_p"
642 ]
643 );
644 }
645
646 #[test]
647 fn test_empty_logits() {
648 let chain = SamplerChain::new().push(GreedySelect);
649 assert_eq!(chain.sample(&[], &[]), 0);
650 }
651
652 #[test]
653 fn test_min_p_filters_low_prob() {
654 let chain = SamplerChain::new().push(MinP::new(0.1)).push(GreedySelect);
655 let logits = vec![10.0, -10.0, -10.0, -10.0];
657 let token = chain.sample(&logits, &[]);
658 assert_eq!(token, 0);
659 }
660
661 #[test]
662 fn test_top_p_nucleus() {
663 let chain = SamplerChain::new().push(TopP::new(0.5)).with_seed(42);
664 let logits = vec![100.0, 0.0, 0.0, 0.0];
666 let token = chain.sample(&logits, &[]);
667 assert_eq!(token, 0);
668 }
669
670 #[test]
671 fn test_chain_len_and_is_empty() {
672 let chain = SamplerChain::new();
673 assert!(chain.is_empty());
674 assert_eq!(chain.len(), 0);
675
676 let chain = chain.push(GreedySelect);
677 assert!(!chain.is_empty());
678 assert_eq!(chain.len(), 1);
679 }
680
681 #[test]
684 fn test_logit_bias_bans_token() {
685 let chain = SamplerChain::new()
686 .push(LogitBias::banned_only(vec![1]))
687 .push(GreedySelect);
688 let logits = vec![1.0f32, 5.0, 3.0];
690 let tok = chain.sample(&logits, &[]);
691 assert_eq!(
692 tok, 2,
693 "banned token 1 should never win; token 2 (3.0) should"
694 );
695 }
696
697 #[test]
698 fn test_logit_bias_boosts_token() {
699 let mut biases = std::collections::HashMap::new();
700 biases.insert(2u32, 100.0f32);
701 let chain = SamplerChain::new()
702 .push(LogitBias::biases_only(biases))
703 .push(GreedySelect);
704 let logits = vec![10.0f32, 10.0, 0.0]; let tok = chain.sample(&logits, &[]);
706 assert_eq!(tok, 2, "large positive bias should make token 2 win");
707 }
708
709 #[test]
710 fn test_logit_bias_ban_wins_over_positive_bias() {
711 let mut biases = std::collections::HashMap::new();
713 biases.insert(0u32, 999.0f32); let chain = SamplerChain::new()
715 .push(LogitBias::new(biases, vec![0])) .push(GreedySelect);
717 let logits = vec![10.0f32, 1.0, 1.0];
718 let tok = chain.sample(&logits, &[]);
719 assert_ne!(tok, 0, "ban must override positive bias");
721 }
722
723 #[test]
724 fn test_from_config_includes_logit_bias_stage() {
725 let mut biases = std::collections::HashMap::new();
726 biases.insert(0u32, -100.0f32);
727 let config = SamplerConfig {
728 temperature: 0.0,
729 logit_bias: biases,
730 ..SamplerConfig::greedy()
731 };
732 let chain = SamplerChain::from_config(&config);
733 let names = chain.stage_names();
734 assert!(
735 names.contains(&"logit_bias"),
736 "from_config should add logit_bias stage when bias map is non-empty"
737 );
738 }
739
740 #[test]
741 fn test_from_config_no_logit_bias_stage_when_empty() {
742 let config = SamplerConfig::greedy();
743 let chain = SamplerChain::from_config(&config);
744 let names = chain.stage_names();
745 assert!(
746 !names.contains(&"logit_bias"),
747 "from_config should NOT add logit_bias stage when both bias map and banned list are empty"
748 );
749 }
750}