Skip to main content

oxibonsai_runtime/
token_healing.rs

1//! Token healing for partial-token prompt repair.
2//!
3//! When a prompt ends in the middle of a token boundary, the model is biased
4//! toward completing that token rather than exploring alternatives. Token healing
5//! backs up `lookback` tokens, regenerates from the valid prefix, and splices
6//! the result — producing more natural continuations.
7//!
8//! ## Algorithm
9//!
10//! 1. Strip the last `lookback` tokens from the prompt to form a *prefix*.
11//! 2. Call a user-supplied `get_logits` closure on the prefix.
12//! 3. Select `t*` = argmax of the returned logit vector.
13//! 4. If `t*` equals the original next token, no change is needed.
14//! 5. Otherwise replace those `lookback` tokens with `[t*]` — the healed sequence.
15//!
16//! ## Example
17//!
18//! ```rust
19//! use oxibonsai_runtime::token_healing::{TokenHealer, TokenHealingConfig};
20//!
21//! let healer = TokenHealer::new(TokenHealingConfig::default());
22//! let tokens = vec![10u32, 20, 99]; // 99 might be a mid-word continuation
23//!
24//! let result = healer.heal(&tokens, 128, |prefix| {
25//!     // Mock: always prefer token 42 as the next token
26//!     let mut logits = vec![0.0f32; 128];
27//!     logits[42] = 10.0;
28//!     logits
29//! });
30//!
31//! // token 42 != 99, so healing changed the sequence
32//! assert!(result.was_healed());
33//! assert_eq!(result.healed_tokens.last().copied(), Some(42));
34//! ```
35
36// ─────────────────────────────────────────────────────────────────────────────
37// Config
38// ─────────────────────────────────────────────────────────────────────────────
39
40/// Configuration for the token healing pass.
41#[derive(Debug, Clone)]
42pub struct TokenHealingConfig {
43    /// Number of tokens to back up and re-score.
44    ///
45    /// A value of `1` (the default) is sufficient for the vast majority of
46    /// tokenisation schemes. Larger values provide wider context but are slower.
47    pub lookback: usize,
48
49    /// Minimum probability that a healed token must have to be accepted.
50    ///
51    /// If the best candidate falls below `min_prob`, healing is skipped and
52    /// the original sequence is returned unchanged.
53    pub min_prob: f32,
54
55    /// Master switch. When `false` the healer is a no-op.
56    pub enabled: bool,
57}
58
59impl Default for TokenHealingConfig {
60    fn default() -> Self {
61        Self {
62            lookback: 1,
63            min_prob: 0.0,
64            enabled: true,
65        }
66    }
67}
68
69// ─────────────────────────────────────────────────────────────────────────────
70// HealingResult
71// ─────────────────────────────────────────────────────────────────────────────
72
73/// Result returned by [`TokenHealer::heal`].
74#[derive(Debug, Clone)]
75pub struct HealingResult {
76    /// The token sequence supplied to [`TokenHealer::heal`] (before any change).
77    pub original_tokens: Vec<u32>,
78    /// The token sequence after healing.  Equal to `original_tokens` when unchanged.
79    pub healed_tokens: Vec<u32>,
80    /// How many trailing tokens were backed up and re-scored.
81    pub tokens_healed: usize,
82    /// `true` iff the healed sequence differs from the original.
83    pub changed: bool,
84}
85
86impl HealingResult {
87    /// Construct a result that records no change.
88    pub fn unchanged(tokens: Vec<u32>) -> Self {
89        Self {
90            healed_tokens: tokens.clone(),
91            original_tokens: tokens,
92            tokens_healed: 0,
93            changed: false,
94        }
95    }
96
97    /// Returns `true` when the healer actually changed the sequence.
98    pub fn was_healed(&self) -> bool {
99        self.changed
100    }
101}
102
103// ─────────────────────────────────────────────────────────────────────────────
104// TokenHealer
105// ─────────────────────────────────────────────────────────────────────────────
106
107/// Backs up `lookback` tokens and re-scores from the prefix using the
108/// caller-supplied logit function.
109pub struct TokenHealer {
110    config: TokenHealingConfig,
111}
112
113impl TokenHealer {
114    /// Create a new healer with the supplied configuration.
115    pub fn new(config: TokenHealingConfig) -> Self {
116        Self { config }
117    }
118
119    /// Convenience constructor — use all defaults but override `lookback`.
120    pub fn with_lookback(lookback: usize) -> Self {
121        Self::new(TokenHealingConfig {
122            lookback,
123            ..TokenHealingConfig::default()
124        })
125    }
126
127    /// Apply token healing to `tokens`.
128    ///
129    /// `get_logits` receives a prefix slice and returns raw (unnormalized) logits
130    /// over the vocabulary.  The closure is called at most once.
131    ///
132    /// Returns a [`HealingResult`] describing what (if anything) changed.
133    pub fn heal<F>(&self, tokens: &[u32], vocab_size: usize, mut get_logits: F) -> HealingResult
134    where
135        F: FnMut(&[u32]) -> Vec<f32>,
136    {
137        // Short-circuit: disabled or not enough tokens to back up.
138        if !self.config.enabled || tokens.len() <= self.config.lookback {
139            return HealingResult::unchanged(tokens.to_vec());
140        }
141
142        let split = tokens.len() - self.config.lookback;
143        let prefix = &tokens[..split];
144        let logits = get_logits(prefix);
145
146        if logits.is_empty() || logits.len() < vocab_size {
147            // Cannot score — return unchanged rather than panicking.
148            return HealingResult::unchanged(tokens.to_vec());
149        }
150
151        // Find the highest-scoring token.
152        let best_token = argmax_f32(&logits) as u32;
153
154        // Check min_prob gate.
155        let prob = Self::token_prob(&logits, best_token);
156        if prob < self.config.min_prob {
157            return HealingResult::unchanged(tokens.to_vec());
158        }
159
160        // If best token already matches what was there, no change needed.
161        if best_token == tokens[split] {
162            return HealingResult {
163                original_tokens: tokens.to_vec(),
164                healed_tokens: tokens.to_vec(),
165                tokens_healed: self.config.lookback,
166                changed: false,
167            };
168        }
169
170        // Build the healed sequence: prefix + [best_token]
171        let mut healed = prefix.to_vec();
172        healed.push(best_token);
173
174        HealingResult {
175            original_tokens: tokens.to_vec(),
176            healed_tokens: healed,
177            tokens_healed: self.config.lookback,
178            changed: true,
179        }
180    }
181
182    /// Heuristic: returns `true` when `token_text` looks like a continuation
183    /// of `prev_token_text` (i.e., no leading whitespace and `prev_token_text`
184    /// ends mid-word).
185    ///
186    /// This is a lightweight signal used to decide whether healing is semantically
187    /// meaningful.  It does not affect the heal algorithm itself.
188    pub fn is_continuation_token(prev_token_text: &str, token_text: &str) -> bool {
189        if token_text.is_empty() || prev_token_text.is_empty() {
190            return false;
191        }
192        // The next token is a continuation if it does NOT start with whitespace.
193        let next_starts_clean = !token_text.starts_with(' ');
194        // The previous token ends mid-word (last char is alphanumeric).
195        let prev_ends_mid_word = prev_token_text
196            .chars()
197            .next_back()
198            .map(|c| c.is_alphanumeric())
199            .unwrap_or(false);
200        prev_ends_mid_word && next_starts_clean
201    }
202
203    /// Compute the probability of `token_id` under the softmax of `logits`.
204    ///
205    /// Returns `0.0` when `token_id` is out of range or `logits` is empty.
206    pub fn token_prob(logits: &[f32], token_id: u32) -> f32 {
207        let idx = token_id as usize;
208        if logits.is_empty() || idx >= logits.len() {
209            return 0.0;
210        }
211        // Numerically stable softmax.
212        let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
213        let exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
214        let sum: f32 = exps.iter().sum();
215        if sum == 0.0 {
216            return 0.0;
217        }
218        exps[idx] / sum
219    }
220}
221
222// ─────────────────────────────────────────────────────────────────────────────
223// HealingDecoder
224// ─────────────────────────────────────────────────────────────────────────────
225
226/// Combines token healing with a simple token-by-token generation loop.
227///
228/// Healing is applied once to the prompt; then `max_tokens` additional tokens
229/// are drawn using the `sample` closure.
230pub struct HealingDecoder {
231    /// The inner healer driving the healing step.
232    pub healer: TokenHealer,
233}
234
235impl HealingDecoder {
236    /// Create a new decoder with the supplied healing configuration.
237    pub fn new(config: TokenHealingConfig) -> Self {
238        Self {
239            healer: TokenHealer::new(config),
240        }
241    }
242
243    /// Apply token healing to `prompt_tokens`, then generate up to `max_tokens`
244    /// additional tokens.
245    ///
246    /// # Parameters
247    ///
248    /// - `get_logits` — called with the current token sequence; returns logits.
249    /// - `sample`     — called with the raw logits; returns the next token id.
250    ///
251    /// # Returns
252    ///
253    /// A pair `(HealingResult, generated_tokens)`.
254    pub fn generate<F, G>(
255        &self,
256        prompt_tokens: Vec<u32>,
257        vocab_size: usize,
258        max_tokens: usize,
259        mut get_logits: F,
260        mut sample: G,
261    ) -> (HealingResult, Vec<u32>)
262    where
263        F: FnMut(&[u32]) -> Vec<f32>,
264        G: FnMut(Vec<f32>) -> u32,
265    {
266        // Phase 1: heal the prompt.
267        let healing = self
268            .healer
269            .heal(&prompt_tokens, vocab_size, &mut get_logits);
270        let healed_prompt = healing.healed_tokens.clone();
271
272        // Phase 2: generate up to max_tokens from the (possibly healed) prompt.
273        let mut context = healed_prompt.clone();
274        let mut generated = Vec::with_capacity(max_tokens);
275
276        for _ in 0..max_tokens {
277            let logits = get_logits(&context);
278            if logits.is_empty() {
279                break;
280            }
281            let next_token = sample(logits);
282            context.push(next_token);
283            generated.push(next_token);
284        }
285
286        (healing, generated)
287    }
288}
289
290// ─────────────────────────────────────────────────────────────────────────────
291// Internal helpers
292// ─────────────────────────────────────────────────────────────────────────────
293
294/// Return the index of the maximum value in `values`.
295/// Returns `0` for empty slices (safe default).
296fn argmax_f32(values: &[f32]) -> usize {
297    values
298        .iter()
299        .enumerate()
300        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
301        .map(|(i, _)| i)
302        .unwrap_or(0)
303}
304
305// ─────────────────────────────────────────────────────────────────────────────
306// Tests
307// ─────────────────────────────────────────────────────────────────────────────
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    // Helper: build a logit vector where `winner` has a high score.
314    fn logits_prefer(vocab_size: usize, winner: usize) -> Vec<f32> {
315        let mut v = vec![0.0f32; vocab_size];
316        v[winner] = 100.0;
317        v
318    }
319
320    #[test]
321    fn test_token_healing_disabled_returns_unchanged() {
322        let config = TokenHealingConfig {
323            enabled: false,
324            ..TokenHealingConfig::default()
325        };
326        let healer = TokenHealer::new(config);
327        let tokens = vec![1u32, 2, 3, 4];
328        let result = healer.heal(&tokens, 10, |_| logits_prefer(10, 7));
329        assert!(!result.changed);
330        assert_eq!(result.healed_tokens, tokens);
331        assert_eq!(result.original_tokens, tokens);
332    }
333
334    #[test]
335    fn test_token_healing_empty_input_unchanged() {
336        let healer = TokenHealer::new(TokenHealingConfig::default());
337        let result = healer.heal(&[], 10, |_| logits_prefer(10, 0));
338        assert!(!result.changed);
339        assert!(result.healed_tokens.is_empty());
340    }
341
342    #[test]
343    fn test_token_healing_lookback_1_no_change_when_correct() {
344        // The best logit token IS the last token in the sequence → no change.
345        let healer = TokenHealer::new(TokenHealingConfig::default());
346        let tokens = vec![10u32, 20, 5]; // last token = 5
347        let result = healer.heal(&tokens, 30, |_| logits_prefer(30, 5));
348        assert!(
349            !result.changed,
350            "no change expected when prediction matches"
351        );
352        assert_eq!(result.healed_tokens, tokens);
353        assert_eq!(result.tokens_healed, 1);
354    }
355
356    #[test]
357    fn test_token_healing_lookback_1_changes_wrong_token() {
358        // Best logit token (7) differs from last token (99) → healing fires.
359        let healer = TokenHealer::new(TokenHealingConfig::default());
360        let tokens = vec![10u32, 20, 99];
361        let result = healer.heal(&tokens, 128, |_| logits_prefer(128, 7));
362        assert!(result.changed);
363        assert!(result.was_healed());
364        // Healed sequence = prefix [10, 20] + [7]
365        assert_eq!(result.healed_tokens, vec![10u32, 20, 7]);
366        assert_eq!(result.original_tokens, tokens);
367        assert_eq!(result.tokens_healed, 1);
368    }
369
370    #[test]
371    fn test_token_prob_correct() {
372        // With one dominant logit the probability of that token should be ≈ 1.
373        let mut logits = vec![0.0f32; 10];
374        logits[3] = 100.0;
375        let p = TokenHealer::token_prob(&logits, 3);
376        assert!(
377            (p - 1.0).abs() < 1e-5,
378            "dominant token should have prob ≈ 1"
379        );
380
381        // Uniform logits → all tokens should have prob ≈ 1/n.
382        let uniform = vec![0.0f32; 4];
383        let p_uniform = TokenHealer::token_prob(&uniform, 2);
384        assert!(
385            (p_uniform - 0.25).abs() < 1e-5,
386            "uniform prob should be 0.25"
387        );
388    }
389
390    #[test]
391    fn test_healing_result_unchanged() {
392        let tokens = vec![1u32, 2, 3];
393        let result = HealingResult::unchanged(tokens.clone());
394        assert!(!result.changed);
395        assert!(!result.was_healed());
396        assert_eq!(result.original_tokens, tokens);
397        assert_eq!(result.healed_tokens, tokens);
398        assert_eq!(result.tokens_healed, 0);
399    }
400
401    #[test]
402    fn test_healing_decoder_runs() {
403        let decoder = HealingDecoder::new(TokenHealingConfig::default());
404        let prompt = vec![1u32, 2, 3]; // last token = 3; best = 9 → healing fires
405        let vocab_size = 20;
406        let max_tokens = 5;
407
408        let call_count = std::cell::Cell::new(0usize);
409        let get_logits = |_prefix: &[u32]| {
410            call_count.set(call_count.get() + 1);
411            logits_prefer(vocab_size, 9)
412        };
413        // sample always returns token 1
414        let sample = |_logits: Vec<f32>| 1u32;
415
416        let (healing, generated) =
417            decoder.generate(prompt, vocab_size, max_tokens, get_logits, sample);
418        // Healing should have fired (best=9, last was 3).
419        assert!(healing.changed);
420        // Exactly max_tokens tokens generated.
421        assert_eq!(generated.len(), max_tokens);
422        // All generated tokens are 1 (from our mock sampler).
423        assert!(generated.iter().all(|&t| t == 1));
424    }
425
426    #[test]
427    fn test_is_continuation_token() {
428        // "ing" follows "call" — mid-word continuation.
429        assert!(
430            TokenHealer::is_continuation_token("call", "ing"),
431            "\"calling\" split should be a continuation"
432        );
433        // " the" after "call" — new word, NOT a continuation.
434        assert!(
435            !TokenHealer::is_continuation_token("call", " the"),
436            "space-prefixed token is not a continuation"
437        );
438        // Empty inputs → not a continuation.
439        assert!(!TokenHealer::is_continuation_token("", "ing"));
440        assert!(!TokenHealer::is_continuation_token("call", ""));
441        // Punctuation ending the previous token → not mid-word.
442        assert!(
443            !TokenHealer::is_continuation_token("call.", "ing"),
444            "period-ended token is not mid-word"
445        );
446    }
447}