Skip to main content

oxibonsai_runtime/
ngram_cache.rs

1//! N-gram cache for zero-cost speculative decoding draft generation.
2//!
3//! Maintains a frequency-based cache of token patterns observed during
4//! generation. When a trigram pattern (a, b) → c has been seen before,
5//! it can predict c as the likely next token after seeing (a, b).
6
7use std::collections::HashMap;
8
9/// Token-level n-gram cache for speculative draft generation.
10///
11/// Records bigram and trigram patterns from generated text and
12/// predicts likely next tokens based on observed frequencies.
13pub struct NgramCache {
14    /// Bigram: single token → (next_token, count) sorted by count desc
15    bigrams: HashMap<u32, Vec<(u32, u32)>>,
16    /// Trigram: (token_a, token_b) → (next_token, count) sorted by count desc
17    trigrams: HashMap<(u32, u32), Vec<(u32, u32)>>,
18    /// Maximum entries per n-gram key (prevents unbounded growth)
19    max_entries_per_key: usize,
20}
21
22impl NgramCache {
23    /// Create a new empty n-gram cache.
24    pub fn new() -> Self {
25        Self {
26            bigrams: HashMap::new(),
27            trigrams: HashMap::new(),
28            max_entries_per_key: 8,
29        }
30    }
31
32    /// Record a sequence of tokens into the cache.
33    ///
34    /// Updates both bigram and trigram frequency tables.
35    pub fn record(&mut self, tokens: &[u32]) {
36        // Record bigrams
37        for window in tokens.windows(2) {
38            self.record_bigram(window[0], window[1]);
39        }
40        // Record trigrams
41        for window in tokens.windows(3) {
42            self.record_trigram(window[0], window[1], window[2]);
43        }
44    }
45
46    /// Record a single bigram observation.
47    fn record_bigram(&mut self, a: u32, next: u32) {
48        let entries = self.bigrams.entry(a).or_default();
49        if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
50            entry.1 += 1;
51        } else if entries.len() < self.max_entries_per_key {
52            entries.push((next, 1));
53        }
54        // Keep sorted by count descending for fast top-1 lookup
55        entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
56    }
57
58    /// Record a single trigram observation.
59    fn record_trigram(&mut self, a: u32, b: u32, next: u32) {
60        let entries = self.trigrams.entry((a, b)).or_default();
61        if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
62            entry.1 += 1;
63        } else if entries.len() < self.max_entries_per_key {
64            entries.push((next, 1));
65        }
66        entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
67    }
68
69    /// Predict the most likely next token given the context.
70    ///
71    /// Tries trigram first (higher accuracy), falls back to bigram.
72    /// Returns `None` if no matching pattern is found.
73    pub fn predict_one(&self, context: &[u32]) -> Option<u32> {
74        // Try trigram: use last 2 tokens
75        if context.len() >= 2 {
76            let a = context[context.len() - 2];
77            let b = context[context.len() - 1];
78            if let Some(entries) = self.trigrams.get(&(a, b)) {
79                if let Some(&(next, _count)) = entries.first() {
80                    return Some(next);
81                }
82            }
83        }
84
85        // Fallback: bigram using last token
86        if let Some(&last) = context.last() {
87            if let Some(entries) = self.bigrams.get(&last) {
88                if let Some(&(next, _count)) = entries.first() {
89                    return Some(next);
90                }
91            }
92        }
93
94        None
95    }
96
97    /// Predict up to `lookahead` tokens by chaining predictions.
98    ///
99    /// Each predicted token is appended to the context for the next prediction.
100    /// Stops early if no prediction is available.
101    pub fn draft(&self, context: &[u32], lookahead: usize) -> Vec<u32> {
102        let mut draft = Vec::with_capacity(lookahead);
103        let mut ctx: Vec<u32> = context.to_vec();
104
105        for _ in 0..lookahead {
106            match self.predict_one(&ctx) {
107                Some(token) => {
108                    draft.push(token);
109                    ctx.push(token);
110                }
111                None => break,
112            }
113        }
114
115        draft
116    }
117
118    /// Number of unique trigram keys stored.
119    pub fn trigram_count(&self) -> usize {
120        self.trigrams.len()
121    }
122
123    /// Number of unique bigram keys stored.
124    pub fn bigram_count(&self) -> usize {
125        self.bigrams.len()
126    }
127
128    /// Returns true if the cache has no entries.
129    pub fn is_empty(&self) -> bool {
130        self.bigrams.is_empty() && self.trigrams.is_empty()
131    }
132}
133
134impl Default for NgramCache {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn empty_cache_returns_no_prediction() {
146        let cache = NgramCache::new();
147        assert_eq!(cache.predict_one(&[1, 2, 3]), None);
148        assert!(cache.is_empty());
149    }
150
151    #[test]
152    fn bigram_prediction() {
153        let mut cache = NgramCache::new();
154        cache.record(&[10, 20, 30]);
155        // Bigram: 10→20, 20→30
156        assert_eq!(cache.predict_one(&[10]), Some(20));
157        assert_eq!(cache.predict_one(&[20]), Some(30));
158    }
159
160    #[test]
161    fn trigram_preferred_over_bigram() {
162        let mut cache = NgramCache::new();
163        cache.record(&[10, 20, 30]);
164        cache.record(&[10, 20, 40]); // second trigram (10,20)→40
165        cache.record(&[10, 20, 40]); // now (10,20)→40 has count=2 > (10,20)→30 count=1
166                                     // Trigram (10,20) predicts 40 (higher count)
167        assert_eq!(cache.predict_one(&[10, 20]), Some(40));
168    }
169
170    #[test]
171    fn draft_chains_predictions() {
172        let mut cache = NgramCache::new();
173        // Record a repeating pattern: 1, 2, 3, 1, 2, 3, 1, 2, 3
174        cache.record(&[1, 2, 3, 1, 2, 3, 1, 2, 3]);
175
176        let draft = cache.draft(&[1, 2], 4);
177        // Should predict: 3, 1, 2, 3 (repeating pattern)
178        assert_eq!(draft, vec![3, 1, 2, 3]);
179    }
180
181    #[test]
182    fn draft_stops_on_no_prediction() {
183        let mut cache = NgramCache::new();
184        cache.record(&[1, 2, 3]);
185
186        // Context [99] has no match
187        let draft = cache.draft(&[99], 4);
188        assert!(draft.is_empty());
189    }
190
191    #[test]
192    fn frequency_tracking() {
193        let mut cache = NgramCache::new();
194        cache.record(&[1, 2, 3]);
195        cache.record(&[1, 2, 3]);
196        cache.record(&[1, 2, 3]);
197        cache.record(&[1, 2, 5]);
198
199        // (1,2)→3 has count 3, (1,2)→5 has count 1
200        assert_eq!(cache.predict_one(&[1, 2]), Some(3));
201    }
202}