oxibonsai_runtime/token_healing.rs
1//! Token healing for partial-token prompt repair.
2//!
3//! When a prompt ends in the middle of a token boundary, the model is biased
4//! toward completing that token rather than exploring alternatives. Token healing
5//! backs up `lookback` tokens, regenerates from the valid prefix, and splices
6//! the result — producing more natural continuations.
7//!
8//! ## Algorithm
9//!
10//! 1. Strip the last `lookback` tokens from the prompt to form a *prefix*.
11//! 2. Call a user-supplied `get_logits` closure on the prefix.
12//! 3. Select `t*` = argmax of the returned logit vector.
13//! 4. If `t*` equals the original next token, no change is needed.
14//! 5. Otherwise replace those `lookback` tokens with `[t*]` — the healed sequence.
15//!
16//! ## Example
17//!
18//! ```rust
19//! use oxibonsai_runtime::token_healing::{TokenHealer, TokenHealingConfig};
20//!
21//! let healer = TokenHealer::new(TokenHealingConfig::default());
22//! let tokens = vec![10u32, 20, 99]; // 99 might be a mid-word continuation
23//!
24//! let result = healer.heal(&tokens, 128, |prefix| {
25//! // Mock: always prefer token 42 as the next token
26//! let mut logits = vec![0.0f32; 128];
27//! logits[42] = 10.0;
28//! logits
29//! });
30//!
31//! // token 42 != 99, so healing changed the sequence
32//! assert!(result.was_healed());
33//! assert_eq!(result.healed_tokens.last().copied(), Some(42));
34//! ```
35
36// ─────────────────────────────────────────────────────────────────────────────
37// Config
38// ─────────────────────────────────────────────────────────────────────────────
39
40/// Configuration for the token healing pass.
41#[derive(Debug, Clone)]
42pub struct TokenHealingConfig {
43 /// Number of tokens to back up and re-score.
44 ///
45 /// A value of `1` (the default) is sufficient for the vast majority of
46 /// tokenisation schemes. Larger values provide wider context but are slower.
47 pub lookback: usize,
48
49 /// Minimum probability that a healed token must have to be accepted.
50 ///
51 /// If the best candidate falls below `min_prob`, healing is skipped and
52 /// the original sequence is returned unchanged.
53 pub min_prob: f32,
54
55 /// Master switch. When `false` the healer is a no-op.
56 pub enabled: bool,
57}
58
59impl Default for TokenHealingConfig {
60 fn default() -> Self {
61 Self {
62 lookback: 1,
63 min_prob: 0.0,
64 enabled: true,
65 }
66 }
67}
68
69// ─────────────────────────────────────────────────────────────────────────────
70// HealingResult
71// ─────────────────────────────────────────────────────────────────────────────
72
73/// Result returned by [`TokenHealer::heal`].
74#[derive(Debug, Clone)]
75pub struct HealingResult {
76 /// The token sequence supplied to [`TokenHealer::heal`] (before any change).
77 pub original_tokens: Vec<u32>,
78 /// The token sequence after healing. Equal to `original_tokens` when unchanged.
79 pub healed_tokens: Vec<u32>,
80 /// How many trailing tokens were backed up and re-scored.
81 pub tokens_healed: usize,
82 /// `true` iff the healed sequence differs from the original.
83 pub changed: bool,
84}
85
86impl HealingResult {
87 /// Construct a result that records no change.
88 pub fn unchanged(tokens: Vec<u32>) -> Self {
89 Self {
90 healed_tokens: tokens.clone(),
91 original_tokens: tokens,
92 tokens_healed: 0,
93 changed: false,
94 }
95 }
96
97 /// Returns `true` when the healer actually changed the sequence.
98 pub fn was_healed(&self) -> bool {
99 self.changed
100 }
101}
102
103// ─────────────────────────────────────────────────────────────────────────────
104// TokenHealer
105// ─────────────────────────────────────────────────────────────────────────────
106
107/// Backs up `lookback` tokens and re-scores from the prefix using the
108/// caller-supplied logit function.
109pub struct TokenHealer {
110 config: TokenHealingConfig,
111}
112
113impl TokenHealer {
114 /// Create a new healer with the supplied configuration.
115 pub fn new(config: TokenHealingConfig) -> Self {
116 Self { config }
117 }
118
119 /// Convenience constructor — use all defaults but override `lookback`.
120 pub fn with_lookback(lookback: usize) -> Self {
121 Self::new(TokenHealingConfig {
122 lookback,
123 ..TokenHealingConfig::default()
124 })
125 }
126
127 /// Apply token healing to `tokens`.
128 ///
129 /// `get_logits` receives a prefix slice and returns raw (unnormalized) logits
130 /// over the vocabulary. The closure is called at most once.
131 ///
132 /// Returns a [`HealingResult`] describing what (if anything) changed.
133 pub fn heal<F>(&self, tokens: &[u32], vocab_size: usize, mut get_logits: F) -> HealingResult
134 where
135 F: FnMut(&[u32]) -> Vec<f32>,
136 {
137 // Short-circuit: disabled or not enough tokens to back up.
138 if !self.config.enabled || tokens.len() <= self.config.lookback {
139 return HealingResult::unchanged(tokens.to_vec());
140 }
141
142 let split = tokens.len() - self.config.lookback;
143 let prefix = &tokens[..split];
144 let logits = get_logits(prefix);
145
146 if logits.is_empty() || logits.len() < vocab_size {
147 // Cannot score — return unchanged rather than panicking.
148 return HealingResult::unchanged(tokens.to_vec());
149 }
150
151 // Find the highest-scoring token.
152 let best_token = argmax_f32(&logits) as u32;
153
154 // Check min_prob gate.
155 let prob = Self::token_prob(&logits, best_token);
156 if prob < self.config.min_prob {
157 return HealingResult::unchanged(tokens.to_vec());
158 }
159
160 // If best token already matches what was there, no change needed.
161 if best_token == tokens[split] {
162 return HealingResult {
163 original_tokens: tokens.to_vec(),
164 healed_tokens: tokens.to_vec(),
165 tokens_healed: self.config.lookback,
166 changed: false,
167 };
168 }
169
170 // Build the healed sequence: prefix + [best_token]
171 let mut healed = prefix.to_vec();
172 healed.push(best_token);
173
174 HealingResult {
175 original_tokens: tokens.to_vec(),
176 healed_tokens: healed,
177 tokens_healed: self.config.lookback,
178 changed: true,
179 }
180 }
181
182 /// Heuristic: returns `true` when `token_text` looks like a continuation
183 /// of `prev_token_text` (i.e., no leading whitespace and `prev_token_text`
184 /// ends mid-word).
185 ///
186 /// This is a lightweight signal used to decide whether healing is semantically
187 /// meaningful. It does not affect the heal algorithm itself.
188 pub fn is_continuation_token(prev_token_text: &str, token_text: &str) -> bool {
189 if token_text.is_empty() || prev_token_text.is_empty() {
190 return false;
191 }
192 // The next token is a continuation if it does NOT start with whitespace.
193 let next_starts_clean = !token_text.starts_with(' ');
194 // The previous token ends mid-word (last char is alphanumeric).
195 let prev_ends_mid_word = prev_token_text
196 .chars()
197 .next_back()
198 .map(|c| c.is_alphanumeric())
199 .unwrap_or(false);
200 prev_ends_mid_word && next_starts_clean
201 }
202
203 /// Compute the probability of `token_id` under the softmax of `logits`.
204 ///
205 /// Returns `0.0` when `token_id` is out of range or `logits` is empty.
206 pub fn token_prob(logits: &[f32], token_id: u32) -> f32 {
207 let idx = token_id as usize;
208 if logits.is_empty() || idx >= logits.len() {
209 return 0.0;
210 }
211 // Numerically stable softmax.
212 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
213 let exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
214 let sum: f32 = exps.iter().sum();
215 if sum == 0.0 {
216 return 0.0;
217 }
218 exps[idx] / sum
219 }
220}
221
222// ─────────────────────────────────────────────────────────────────────────────
223// HealingDecoder
224// ─────────────────────────────────────────────────────────────────────────────
225
226/// Combines token healing with a simple token-by-token generation loop.
227///
228/// Healing is applied once to the prompt; then `max_tokens` additional tokens
229/// are drawn using the `sample` closure.
230pub struct HealingDecoder {
231 /// The inner healer driving the healing step.
232 pub healer: TokenHealer,
233}
234
235impl HealingDecoder {
236 /// Create a new decoder with the supplied healing configuration.
237 pub fn new(config: TokenHealingConfig) -> Self {
238 Self {
239 healer: TokenHealer::new(config),
240 }
241 }
242
243 /// Apply token healing to `prompt_tokens`, then generate up to `max_tokens`
244 /// additional tokens.
245 ///
246 /// # Parameters
247 ///
248 /// - `get_logits` — called with the current token sequence; returns logits.
249 /// - `sample` — called with the raw logits; returns the next token id.
250 ///
251 /// # Returns
252 ///
253 /// A pair `(HealingResult, generated_tokens)`.
254 pub fn generate<F, G>(
255 &self,
256 prompt_tokens: Vec<u32>,
257 vocab_size: usize,
258 max_tokens: usize,
259 mut get_logits: F,
260 mut sample: G,
261 ) -> (HealingResult, Vec<u32>)
262 where
263 F: FnMut(&[u32]) -> Vec<f32>,
264 G: FnMut(Vec<f32>) -> u32,
265 {
266 // Phase 1: heal the prompt.
267 let healing = self
268 .healer
269 .heal(&prompt_tokens, vocab_size, &mut get_logits);
270 let healed_prompt = healing.healed_tokens.clone();
271
272 // Phase 2: generate up to max_tokens from the (possibly healed) prompt.
273 let mut context = healed_prompt.clone();
274 let mut generated = Vec::with_capacity(max_tokens);
275
276 for _ in 0..max_tokens {
277 let logits = get_logits(&context);
278 if logits.is_empty() {
279 break;
280 }
281 let next_token = sample(logits);
282 context.push(next_token);
283 generated.push(next_token);
284 }
285
286 (healing, generated)
287 }
288}
289
290// ─────────────────────────────────────────────────────────────────────────────
291// Internal helpers
292// ─────────────────────────────────────────────────────────────────────────────
293
294/// Return the index of the maximum value in `values`.
295/// Returns `0` for empty slices (safe default).
296fn argmax_f32(values: &[f32]) -> usize {
297 values
298 .iter()
299 .enumerate()
300 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
301 .map(|(i, _)| i)
302 .unwrap_or(0)
303}
304
305// ─────────────────────────────────────────────────────────────────────────────
306// Tests
307// ─────────────────────────────────────────────────────────────────────────────
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 // Helper: build a logit vector where `winner` has a high score.
314 fn logits_prefer(vocab_size: usize, winner: usize) -> Vec<f32> {
315 let mut v = vec![0.0f32; vocab_size];
316 v[winner] = 100.0;
317 v
318 }
319
320 #[test]
321 fn test_token_healing_disabled_returns_unchanged() {
322 let config = TokenHealingConfig {
323 enabled: false,
324 ..TokenHealingConfig::default()
325 };
326 let healer = TokenHealer::new(config);
327 let tokens = vec![1u32, 2, 3, 4];
328 let result = healer.heal(&tokens, 10, |_| logits_prefer(10, 7));
329 assert!(!result.changed);
330 assert_eq!(result.healed_tokens, tokens);
331 assert_eq!(result.original_tokens, tokens);
332 }
333
334 #[test]
335 fn test_token_healing_empty_input_unchanged() {
336 let healer = TokenHealer::new(TokenHealingConfig::default());
337 let result = healer.heal(&[], 10, |_| logits_prefer(10, 0));
338 assert!(!result.changed);
339 assert!(result.healed_tokens.is_empty());
340 }
341
342 #[test]
343 fn test_token_healing_lookback_1_no_change_when_correct() {
344 // The best logit token IS the last token in the sequence → no change.
345 let healer = TokenHealer::new(TokenHealingConfig::default());
346 let tokens = vec![10u32, 20, 5]; // last token = 5
347 let result = healer.heal(&tokens, 30, |_| logits_prefer(30, 5));
348 assert!(
349 !result.changed,
350 "no change expected when prediction matches"
351 );
352 assert_eq!(result.healed_tokens, tokens);
353 assert_eq!(result.tokens_healed, 1);
354 }
355
356 #[test]
357 fn test_token_healing_lookback_1_changes_wrong_token() {
358 // Best logit token (7) differs from last token (99) → healing fires.
359 let healer = TokenHealer::new(TokenHealingConfig::default());
360 let tokens = vec![10u32, 20, 99];
361 let result = healer.heal(&tokens, 128, |_| logits_prefer(128, 7));
362 assert!(result.changed);
363 assert!(result.was_healed());
364 // Healed sequence = prefix [10, 20] + [7]
365 assert_eq!(result.healed_tokens, vec![10u32, 20, 7]);
366 assert_eq!(result.original_tokens, tokens);
367 assert_eq!(result.tokens_healed, 1);
368 }
369
370 #[test]
371 fn test_token_prob_correct() {
372 // With one dominant logit the probability of that token should be ≈ 1.
373 let mut logits = vec![0.0f32; 10];
374 logits[3] = 100.0;
375 let p = TokenHealer::token_prob(&logits, 3);
376 assert!(
377 (p - 1.0).abs() < 1e-5,
378 "dominant token should have prob ≈ 1"
379 );
380
381 // Uniform logits → all tokens should have prob ≈ 1/n.
382 let uniform = vec![0.0f32; 4];
383 let p_uniform = TokenHealer::token_prob(&uniform, 2);
384 assert!(
385 (p_uniform - 0.25).abs() < 1e-5,
386 "uniform prob should be 0.25"
387 );
388 }
389
390 #[test]
391 fn test_healing_result_unchanged() {
392 let tokens = vec![1u32, 2, 3];
393 let result = HealingResult::unchanged(tokens.clone());
394 assert!(!result.changed);
395 assert!(!result.was_healed());
396 assert_eq!(result.original_tokens, tokens);
397 assert_eq!(result.healed_tokens, tokens);
398 assert_eq!(result.tokens_healed, 0);
399 }
400
401 #[test]
402 fn test_healing_decoder_runs() {
403 let decoder = HealingDecoder::new(TokenHealingConfig::default());
404 let prompt = vec![1u32, 2, 3]; // last token = 3; best = 9 → healing fires
405 let vocab_size = 20;
406 let max_tokens = 5;
407
408 let call_count = std::cell::Cell::new(0usize);
409 let get_logits = |_prefix: &[u32]| {
410 call_count.set(call_count.get() + 1);
411 logits_prefer(vocab_size, 9)
412 };
413 // sample always returns token 1
414 let sample = |_logits: Vec<f32>| 1u32;
415
416 let (healing, generated) =
417 decoder.generate(prompt, vocab_size, max_tokens, get_logits, sample);
418 // Healing should have fired (best=9, last was 3).
419 assert!(healing.changed);
420 // Exactly max_tokens tokens generated.
421 assert_eq!(generated.len(), max_tokens);
422 // All generated tokens are 1 (from our mock sampler).
423 assert!(generated.iter().all(|&t| t == 1));
424 }
425
426 #[test]
427 fn test_is_continuation_token() {
428 // "ing" follows "call" — mid-word continuation.
429 assert!(
430 TokenHealer::is_continuation_token("call", "ing"),
431 "\"calling\" split should be a continuation"
432 );
433 // " the" after "call" — new word, NOT a continuation.
434 assert!(
435 !TokenHealer::is_continuation_token("call", " the"),
436 "space-prefixed token is not a continuation"
437 );
438 // Empty inputs → not a continuation.
439 assert!(!TokenHealer::is_continuation_token("", "ing"));
440 assert!(!TokenHealer::is_continuation_token("call", ""));
441 // Punctuation ending the previous token → not mid-word.
442 assert!(
443 !TokenHealer::is_continuation_token("call.", "ing"),
444 "period-ended token is not mid-word"
445 );
446 }
447}