use super::chain::SamplerStage;
fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
let sum: f32 = probs.iter().sum();
if sum > 0.0 {
for p in &mut probs {
*p /= sum;
}
}
probs
}
fn entropy(probs: &[f32]) -> f32 {
probs
.iter()
.copied()
.filter(|&p| p > 0.0)
.map(|p| -p * p.ln())
.sum()
}
pub struct DryStage {
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
pub sequence_breakers: Vec<u32>,
}
impl DryStage {
pub fn new(
multiplier: f32,
base: f32,
allowed_length: usize,
sequence_breakers: Vec<u32>,
) -> Self {
Self {
multiplier,
base,
allowed_length,
sequence_breakers,
}
}
}
impl SamplerStage for DryStage {
fn apply(&self, logits: &mut Vec<f32>, recent_tokens: &[u32]) {
if self.multiplier == 0.0 || logits.is_empty() || recent_tokens.is_empty() {
return;
}
let breaker_set: std::collections::HashSet<u32> =
self.sequence_breakers.iter().copied().collect();
let hist_len = recent_tokens.len();
for (t_idx, logit) in logits.iter_mut().enumerate() {
if !logit.is_finite() {
continue;
}
let t = t_idx as u32;
if breaker_set.contains(&t) {
continue;
}
let mut best_match = 0usize;
for pos in 0..hist_len {
if recent_tokens[pos] != t {
continue;
}
let mut match_len = 1usize;
let max_back = pos.min(hist_len - 1); for k in 1..=max_back {
let hist_token = recent_tokens[pos - k];
let ctx_token = recent_tokens[hist_len - k];
if breaker_set.contains(&hist_token) || breaker_set.contains(&ctx_token) {
break;
}
if hist_token == ctx_token {
match_len += 1;
} else {
break;
}
}
if match_len > best_match {
best_match = match_len;
}
}
if best_match >= self.allowed_length {
let excess = (best_match - self.allowed_length) as f32;
let penalty = self.multiplier * self.base.powf(excess);
*logit -= penalty;
}
}
}
fn name(&self) -> &'static str {
"dry"
}
}
pub struct XtcStage {
pub threshold: f32,
pub probability: f32,
pub seed: u64,
}
impl XtcStage {
pub fn new(threshold: f32, probability: f32, seed: u64) -> Self {
Self {
threshold,
probability,
seed,
}
}
}
impl SamplerStage for XtcStage {
fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
if self.threshold >= 1.0 || self.probability == 0.0 || logits.is_empty() {
return;
}
let probs = softmax(logits);
let mut indices: Vec<usize> = (0..probs.len()).collect();
indices.sort_unstable_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumulative = 0.0f32;
let mut top_end = 0usize; for &idx in &indices {
cumulative += probs[idx];
top_end += 1;
if cumulative >= self.threshold {
break;
}
}
if top_end < 2 {
return;
}
let rand_val = xorshift64_f32(self.seed);
if rand_val >= self.probability {
return;
}
let best_idx = indices[0];
for &idx in &indices[1..top_end] {
logits[idx] = f32::NEG_INFINITY;
}
if !logits[best_idx].is_finite() {
logits[best_idx] = 1.0;
}
}
fn name(&self) -> &'static str {
"xtc"
}
}
pub struct TypicalPStage {
pub p: f32,
}
impl TypicalPStage {
pub fn new(p: f32) -> Self {
Self { p }
}
}
impl SamplerStage for TypicalPStage {
fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
if self.p >= 1.0 || logits.is_empty() {
return;
}
let probs = softmax(logits);
let h = entropy(&probs);
let deviations: Vec<f32> = probs
.iter()
.copied()
.map(|p| {
if p > 0.0 {
(p.ln() + h).abs()
} else {
f32::INFINITY
}
})
.collect();
let mut indices: Vec<usize> = (0..probs.len()).collect();
indices.sort_unstable_by(|&a, &b| {
deviations[a]
.partial_cmp(&deviations[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumulative = 0.0f32;
let mut cutoff = 0usize; for &idx in &indices {
cumulative += probs[idx];
cutoff += 1;
if cumulative >= self.p {
break;
}
}
let cutoff = cutoff.max(1);
let kept: std::collections::HashSet<usize> = indices[..cutoff].iter().copied().collect();
for (i, v) in logits.iter_mut().enumerate() {
if !kept.contains(&i) {
*v = f32::NEG_INFINITY;
}
}
ensure_at_least_one_finite(logits);
}
fn name(&self) -> &'static str {
"typical_p"
}
}
pub struct TopAStage {
pub a: f32,
}
impl TopAStage {
pub fn new(a: f32) -> Self {
Self { a }
}
}
impl SamplerStage for TopAStage {
fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
if self.a == 0.0 || logits.is_empty() {
return;
}
let probs = softmax(logits);
let max_prob = probs.iter().copied().fold(0.0f32, f32::max);
if max_prob <= 0.0 {
return;
}
let threshold = self.a * max_prob * max_prob;
for (i, v) in logits.iter_mut().enumerate() {
if probs[i] < threshold {
*v = f32::NEG_INFINITY;
}
}
ensure_at_least_one_finite(logits);
}
fn name(&self) -> &'static str {
"top_a"
}
}
pub struct EtaStage {
pub eta: f32,
pub epsilon: f32,
}
impl EtaStage {
pub fn new(eta: f32, epsilon: f32) -> Self {
Self { eta, epsilon }
}
}
impl SamplerStage for EtaStage {
fn apply(&self, logits: &mut Vec<f32>, _recent_tokens: &[u32]) {
if self.eta == 0.0 && self.epsilon == 0.0 || logits.is_empty() {
return;
}
let probs = softmax(logits);
let h = entropy(&probs);
let s = h.exp();
let dynamic = if s > 0.0 { self.eta / s } else { self.eta };
let cutoff = self.epsilon.max(dynamic);
if cutoff <= 0.0 {
return;
}
for (i, v) in logits.iter_mut().enumerate() {
if probs[i] < cutoff {
*v = f32::NEG_INFINITY;
}
}
ensure_at_least_one_finite(logits);
}
fn name(&self) -> &'static str {
"eta"
}
}
fn ensure_at_least_one_finite(logits: &mut [f32]) {
if logits.iter().any(|v| v.is_finite()) {
return; }
if !logits.is_empty() {
logits[0] = 0.0;
}
}
fn xorshift64_f32(seed: u64) -> f32 {
let seed = if seed == 0 {
0x517c_c1b7_2722_0a95
} else {
seed
};
let mut x = seed;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
(x >> 40) as f32 / (1u64 << 24) as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dry_disabled_passthrough() {
let stage = DryStage::new(0.0, 1.75, 2, vec![]);
let original = vec![1.0f32, 2.0, 3.0, 0.5];
let mut logits = original.clone();
stage.apply(&mut logits, &[0, 1, 0, 1]);
assert_eq!(logits, original, "DRY(multiplier=0) must be a no-op");
}
#[test]
fn xtc_disabled_passthrough() {
let stage = XtcStage::new(1.0, 0.5, 42); let original = vec![1.0f32, 2.0, 3.0, 0.5];
let mut logits = original.clone();
stage.apply(&mut logits, &[]);
assert_eq!(logits, original, "XTC(threshold>=1.0) must be a no-op");
let stage2 = XtcStage::new(0.5, 0.0, 42); let mut logits2 = original.clone();
stage2.apply(&mut logits2, &[]);
assert_eq!(logits2, original, "XTC(probability=0) must be a no-op");
}
#[test]
fn typical_p_disabled_passthrough() {
let stage = TypicalPStage::new(1.0);
let original = vec![1.0f32, 2.0, 3.0, 0.5];
let mut logits = original.clone();
stage.apply(&mut logits, &[]);
assert_eq!(logits, original, "TypicalP(p=1.0) must be a no-op");
}
#[test]
fn top_a_disabled_passthrough() {
let stage = TopAStage::new(0.0);
let original = vec![1.0f32, 2.0, 3.0, 0.5];
let mut logits = original.clone();
stage.apply(&mut logits, &[]);
assert_eq!(logits, original, "TopA(a=0.0) must be a no-op");
}
#[test]
fn eta_disabled_passthrough() {
let stage = EtaStage::new(0.0, 0.0);
let original = vec![1.0f32, 2.0, 3.0, 0.5];
let mut logits = original.clone();
stage.apply(&mut logits, &[]);
assert_eq!(logits, original, "Eta(eta=0, epsilon=0) must be a no-op");
}
#[test]
fn dry_active_penalises_repeated_tokens() {
let stage = DryStage::new(2.0, 1.75, 2, vec![]);
let mut logits = vec![0.0f32, 0.0, 5.0, 5.0]; let recent = vec![0u32, 1, 2, 0, 1]; let original_c = logits[2];
let original_d = logits[3];
stage.apply(&mut logits, &recent);
assert!(
logits[2] < original_c,
"token C should be penalised by DRY; was {original_c}, now {}",
logits[2]
);
assert!(
(logits[3] - original_d).abs() < 1e-6,
"token D should NOT be penalised by DRY; was {original_d}, now {}",
logits[3]
);
}
#[test]
fn xtc_active_excludes_top_tokens() {
let stage = XtcStage::new(0.7, 1.0, 1);
let mut logits = vec![3.0f32, 2.0, 1.0, 0.0];
stage.apply(&mut logits, &[]);
assert!(
logits[0].is_finite(),
"XTC must preserve the top-1 token (token 0)"
);
assert_eq!(
logits[1],
f32::NEG_INFINITY,
"XTC should exclude token 1 (second-best in top set)"
);
}
#[test]
fn typical_p_active_reduces_distribution() {
let stage = TypicalPStage::new(0.3);
let mut logits = vec![0.0f32; 8]; stage.apply(&mut logits, &[]);
let finite_count = logits.iter().filter(|&&v| v.is_finite()).count();
assert!(
finite_count < 8,
"TypicalP(p=0.3) should reduce the number of active tokens; got {finite_count} finite"
);
assert!(
finite_count >= 1,
"TypicalP must preserve at least 1 token; got {finite_count}"
);
}
#[test]
fn top_a_active_keeps_only_near_max() {
let stage = TopAStage::new(1.0);
let mut logits = vec![10.0f32, -10.0, -10.0, -10.0];
stage.apply(&mut logits, &[]);
assert!(logits[0].is_finite(), "dominant token must survive TopA");
for (i, &v) in logits[1..].iter().enumerate() {
assert_eq!(
v,
f32::NEG_INFINITY,
"token {} should be excluded by TopA",
i + 1
);
}
}
#[test]
fn eta_active_cuts_low_prob_tokens() {
let stage = EtaStage::new(0.1, 0.0);
let mut logits = vec![10.0f32, -10.0, -10.0, -10.0];
stage.apply(&mut logits, &[]);
assert!(
logits[0].is_finite(),
"dominant token must survive Eta cutoff"
);
let any_cut = logits[1..].contains(&f32::NEG_INFINITY);
assert!(any_cut, "Eta should cut at least some low-prob tokens");
}
#[test]
fn all_stages_handle_empty_logits() {
let mut logits: Vec<f32> = Vec::new();
DryStage::new(1.0, 1.75, 2, vec![]).apply(&mut logits, &[]);
XtcStage::new(0.5, 1.0, 42).apply(&mut logits, &[]);
TypicalPStage::new(0.5).apply(&mut logits, &[]);
TopAStage::new(1.0).apply(&mut logits, &[]);
EtaStage::new(0.1, 0.0).apply(&mut logits, &[]);
assert!(logits.is_empty()); }
#[test]
fn stages_never_empty_distribution() {
let mut logits = vec![5.0f32, 4.9, 4.8, 4.7];
TopAStage::new(100.0).apply(&mut logits, &[]);
let finite = logits.iter().filter(|&&v| v.is_finite()).count();
assert!(
finite >= 1,
"TopA with extreme a must still preserve at least 1 token"
);
let mut logits2 = vec![5.0f32, 4.9, 4.8, 4.7];
TypicalPStage::new(0.0).apply(&mut logits2, &[]);
let finite2 = logits2.iter().filter(|&&v| v.is_finite()).count();
assert!(
finite2 >= 1,
"TypicalP with p=0 must still preserve at least 1 token"
);
}
}