const PHI_INV: f32 = 0.618033988;
pub struct GoldenRatioConverter {
dephasing_rates: Vec<f32>,
pub vocab_size: usize,
}
impl GoldenRatioConverter {
pub fn new(token_counts: &[usize], total_tokens: usize, vocab_size: usize) -> Self {
let total = total_tokens.max(1) as f32;
let max_surprise = (vocab_size.max(2) as f32).ln();
let dephasing_rates: Vec<f32> = (0..vocab_size)
.map(|token| {
let count = token_counts.get(token).copied().unwrap_or(0);
let freq = (count as f32) / total;
let freq_clamped = freq.max(1.0 / (total + 1.0));
let surprise = -freq_clamped.ln();
let surprise_norm = (surprise / max_surprise).clamp(0.0, 1.0);
1.0 - surprise_norm.powf(PHI_INV)
})
.collect();
Self {
dephasing_rates,
vocab_size,
}
}
#[inline]
pub fn dephasing_rate(&self, token: usize) -> f32 {
self.dephasing_rates.get(token).copied().unwrap_or(0.0)
}
}
pub fn compute_token_counts(tokens: &[usize], vocab_size: usize) -> Vec<usize> {
let mut counts = vec![0usize; vocab_size];
for &t in tokens {
if t < vocab_size {
counts[t] += 1;
}
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn common_tokens_higher_dephasing() {
let counts = vec![100, 50, 10, 1];
let conv = GoldenRatioConverter::new(&counts, 161, 4);
assert!(
conv.dephasing_rate(0) > conv.dephasing_rate(3),
"common token should have higher dephasing: ε₀={} ε₃={}",
conv.dephasing_rate(0),
conv.dephasing_rate(3)
);
for i in 0..3 {
assert!(
conv.dephasing_rate(i) >= conv.dephasing_rate(i + 1) - 0.01,
"dephasing should decrease with rarity"
);
}
}
#[test]
fn rare_token_near_zero_dephasing() {
let counts = vec![100, 1];
let conv = GoldenRatioConverter::new(&counts, 101, 2);
let rare_eps = conv.dephasing_rate(1);
assert!(rare_eps < 0.5, "rare token should have low dephasing: ε={rare_eps}");
}
#[test]
fn dephasing_bounded_zero_one() {
let counts = vec![1000, 500, 100, 10, 1, 0];
let conv = GoldenRatioConverter::new(&counts, 1611, 6);
for i in 0..6 {
let eps = conv.dephasing_rate(i);
assert!(eps >= 0.0 && eps <= 1.0, "token {i}: ε={eps} out of [0,1]");
}
}
#[test]
fn compute_counts_correct() {
let tokens = vec![0, 1, 1, 2, 2, 2, 3, 3, 3, 3];
let counts = compute_token_counts(&tokens, 5);
assert_eq!(counts, vec![1, 2, 3, 4, 0]);
}
}