#[derive(Debug, Clone)]
pub struct TokenHealingConfig {
pub lookback: usize,
pub min_prob: f32,
pub enabled: bool,
}
impl Default for TokenHealingConfig {
fn default() -> Self {
Self {
lookback: 1,
min_prob: 0.0,
enabled: true,
}
}
}
#[derive(Debug, Clone)]
pub struct HealingResult {
pub original_tokens: Vec<u32>,
pub healed_tokens: Vec<u32>,
pub tokens_healed: usize,
pub changed: bool,
}
impl HealingResult {
pub fn unchanged(tokens: Vec<u32>) -> Self {
Self {
healed_tokens: tokens.clone(),
original_tokens: tokens,
tokens_healed: 0,
changed: false,
}
}
pub fn was_healed(&self) -> bool {
self.changed
}
}
pub struct TokenHealer {
config: TokenHealingConfig,
}
impl TokenHealer {
pub fn new(config: TokenHealingConfig) -> Self {
Self { config }
}
pub fn with_lookback(lookback: usize) -> Self {
Self::new(TokenHealingConfig {
lookback,
..TokenHealingConfig::default()
})
}
pub fn heal<F>(&self, tokens: &[u32], vocab_size: usize, mut get_logits: F) -> HealingResult
where
F: FnMut(&[u32]) -> Vec<f32>,
{
if !self.config.enabled || tokens.len() <= self.config.lookback {
return HealingResult::unchanged(tokens.to_vec());
}
let split = tokens.len() - self.config.lookback;
let prefix = &tokens[..split];
let logits = get_logits(prefix);
if logits.is_empty() || logits.len() < vocab_size {
return HealingResult::unchanged(tokens.to_vec());
}
let best_token = argmax_f32(&logits) as u32;
let prob = Self::token_prob(&logits, best_token);
if prob < self.config.min_prob {
return HealingResult::unchanged(tokens.to_vec());
}
if best_token == tokens[split] {
return HealingResult {
original_tokens: tokens.to_vec(),
healed_tokens: tokens.to_vec(),
tokens_healed: self.config.lookback,
changed: false,
};
}
let mut healed = prefix.to_vec();
healed.push(best_token);
HealingResult {
original_tokens: tokens.to_vec(),
healed_tokens: healed,
tokens_healed: self.config.lookback,
changed: true,
}
}
pub fn is_continuation_token(prev_token_text: &str, token_text: &str) -> bool {
if token_text.is_empty() || prev_token_text.is_empty() {
return false;
}
let next_starts_clean = !token_text.starts_with(' ');
let prev_ends_mid_word = prev_token_text
.chars()
.next_back()
.map(|c| c.is_alphanumeric())
.unwrap_or(false);
prev_ends_mid_word && next_starts_clean
}
pub fn token_prob(logits: &[f32], token_id: u32) -> f32 {
let idx = token_id as usize;
if logits.is_empty() || idx >= logits.len() {
return 0.0;
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum == 0.0 {
return 0.0;
}
exps[idx] / sum
}
}
pub struct HealingDecoder {
pub healer: TokenHealer,
}
impl HealingDecoder {
pub fn new(config: TokenHealingConfig) -> Self {
Self {
healer: TokenHealer::new(config),
}
}
pub fn generate<F, G>(
&self,
prompt_tokens: Vec<u32>,
vocab_size: usize,
max_tokens: usize,
mut get_logits: F,
mut sample: G,
) -> (HealingResult, Vec<u32>)
where
F: FnMut(&[u32]) -> Vec<f32>,
G: FnMut(Vec<f32>) -> u32,
{
let healing = self
.healer
.heal(&prompt_tokens, vocab_size, &mut get_logits);
let healed_prompt = healing.healed_tokens.clone();
let mut context = healed_prompt.clone();
let mut generated = Vec::with_capacity(max_tokens);
for _ in 0..max_tokens {
let logits = get_logits(&context);
if logits.is_empty() {
break;
}
let next_token = sample(logits);
context.push(next_token);
generated.push(next_token);
}
(healing, generated)
}
}
fn argmax_f32(values: &[f32]) -> usize {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
fn logits_prefer(vocab_size: usize, winner: usize) -> Vec<f32> {
let mut v = vec![0.0f32; vocab_size];
v[winner] = 100.0;
v
}
#[test]
fn test_token_healing_disabled_returns_unchanged() {
let config = TokenHealingConfig {
enabled: false,
..TokenHealingConfig::default()
};
let healer = TokenHealer::new(config);
let tokens = vec![1u32, 2, 3, 4];
let result = healer.heal(&tokens, 10, |_| logits_prefer(10, 7));
assert!(!result.changed);
assert_eq!(result.healed_tokens, tokens);
assert_eq!(result.original_tokens, tokens);
}
#[test]
fn test_token_healing_empty_input_unchanged() {
let healer = TokenHealer::new(TokenHealingConfig::default());
let result = healer.heal(&[], 10, |_| logits_prefer(10, 0));
assert!(!result.changed);
assert!(result.healed_tokens.is_empty());
}
#[test]
fn test_token_healing_lookback_1_no_change_when_correct() {
let healer = TokenHealer::new(TokenHealingConfig::default());
let tokens = vec![10u32, 20, 5]; let result = healer.heal(&tokens, 30, |_| logits_prefer(30, 5));
assert!(
!result.changed,
"no change expected when prediction matches"
);
assert_eq!(result.healed_tokens, tokens);
assert_eq!(result.tokens_healed, 1);
}
#[test]
fn test_token_healing_lookback_1_changes_wrong_token() {
let healer = TokenHealer::new(TokenHealingConfig::default());
let tokens = vec![10u32, 20, 99];
let result = healer.heal(&tokens, 128, |_| logits_prefer(128, 7));
assert!(result.changed);
assert!(result.was_healed());
assert_eq!(result.healed_tokens, vec![10u32, 20, 7]);
assert_eq!(result.original_tokens, tokens);
assert_eq!(result.tokens_healed, 1);
}
#[test]
fn test_token_prob_correct() {
let mut logits = vec![0.0f32; 10];
logits[3] = 100.0;
let p = TokenHealer::token_prob(&logits, 3);
assert!(
(p - 1.0).abs() < 1e-5,
"dominant token should have prob ≈ 1"
);
let uniform = vec![0.0f32; 4];
let p_uniform = TokenHealer::token_prob(&uniform, 2);
assert!(
(p_uniform - 0.25).abs() < 1e-5,
"uniform prob should be 0.25"
);
}
#[test]
fn test_healing_result_unchanged() {
let tokens = vec![1u32, 2, 3];
let result = HealingResult::unchanged(tokens.clone());
assert!(!result.changed);
assert!(!result.was_healed());
assert_eq!(result.original_tokens, tokens);
assert_eq!(result.healed_tokens, tokens);
assert_eq!(result.tokens_healed, 0);
}
#[test]
fn test_healing_decoder_runs() {
let decoder = HealingDecoder::new(TokenHealingConfig::default());
let prompt = vec![1u32, 2, 3]; let vocab_size = 20;
let max_tokens = 5;
let call_count = std::cell::Cell::new(0usize);
let get_logits = |_prefix: &[u32]| {
call_count.set(call_count.get() + 1);
logits_prefer(vocab_size, 9)
};
let sample = |_logits: Vec<f32>| 1u32;
let (healing, generated) =
decoder.generate(prompt, vocab_size, max_tokens, get_logits, sample);
assert!(healing.changed);
assert_eq!(generated.len(), max_tokens);
assert!(generated.iter().all(|&t| t == 1));
}
#[test]
fn test_is_continuation_token() {
assert!(
TokenHealer::is_continuation_token("call", "ing"),
"\"calling\" split should be a continuation"
);
assert!(
!TokenHealer::is_continuation_token("call", " the"),
"space-prefixed token is not a continuation"
);
assert!(!TokenHealer::is_continuation_token("", "ing"));
assert!(!TokenHealer::is_continuation_token("call", ""));
assert!(
!TokenHealer::is_continuation_token("call.", "ing"),
"period-ended token is not mid-word"
);
}
}