1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
//! N-gram cache for zero-cost speculative decoding draft generation.
//!
//! Maintains a frequency-based cache of token patterns observed during
//! generation. When a trigram pattern (a, b) → c has been seen before,
//! it can predict c as the likely next token after seeing (a, b).
use std::collections::HashMap;
/// Token-level n-gram cache for speculative draft generation.
///
/// Records bigram and trigram patterns from generated text and
/// predicts likely next tokens based on observed frequencies.
pub struct NgramCache {
/// Bigram: single token → (next_token, count) sorted by count desc
bigrams: HashMap<u32, Vec<(u32, u32)>>,
/// Trigram: (token_a, token_b) → (next_token, count) sorted by count desc
trigrams: HashMap<(u32, u32), Vec<(u32, u32)>>,
/// Maximum entries per n-gram key (prevents unbounded growth)
max_entries_per_key: usize,
}
impl NgramCache {
/// Create a new empty n-gram cache.
pub fn new() -> Self {
Self {
bigrams: HashMap::new(),
trigrams: HashMap::new(),
max_entries_per_key: 8,
}
}
/// Record a sequence of tokens into the cache.
///
/// Updates both bigram and trigram frequency tables.
pub fn record(&mut self, tokens: &[u32]) {
// Record bigrams
for window in tokens.windows(2) {
self.record_bigram(window[0], window[1]);
}
// Record trigrams
for window in tokens.windows(3) {
self.record_trigram(window[0], window[1], window[2]);
}
}
/// Record a single bigram observation.
fn record_bigram(&mut self, a: u32, next: u32) {
let entries = self.bigrams.entry(a).or_default();
if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
entry.1 += 1;
} else if entries.len() < self.max_entries_per_key {
entries.push((next, 1));
}
// Keep sorted by count descending for fast top-1 lookup
entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
}
/// Record a single trigram observation.
fn record_trigram(&mut self, a: u32, b: u32, next: u32) {
let entries = self.trigrams.entry((a, b)).or_default();
if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
entry.1 += 1;
} else if entries.len() < self.max_entries_per_key {
entries.push((next, 1));
}
entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
}
/// Predict the most likely next token given the context.
///
/// Tries trigram first (higher accuracy), falls back to bigram.
/// Returns `None` if no matching pattern is found.
pub fn predict_one(&self, context: &[u32]) -> Option<u32> {
// Try trigram: use last 2 tokens
if context.len() >= 2 {
let a = context[context.len() - 2];
let b = context[context.len() - 1];
if let Some(entries) = self.trigrams.get(&(a, b)) {
if let Some(&(next, _count)) = entries.first() {
return Some(next);
}
}
}
// Fallback: bigram using last token
if let Some(&last) = context.last() {
if let Some(entries) = self.bigrams.get(&last) {
if let Some(&(next, _count)) = entries.first() {
return Some(next);
}
}
}
None
}
/// Predict up to `lookahead` tokens by chaining predictions.
///
/// Each predicted token is appended to the context for the next prediction.
/// Stops early if no prediction is available.
pub fn draft(&self, context: &[u32], lookahead: usize) -> Vec<u32> {
let mut draft = Vec::with_capacity(lookahead);
let mut ctx: Vec<u32> = context.to_vec();
for _ in 0..lookahead {
match self.predict_one(&ctx) {
Some(token) => {
draft.push(token);
ctx.push(token);
}
None => break,
}
}
draft
}
/// Number of unique trigram keys stored.
pub fn trigram_count(&self) -> usize {
self.trigrams.len()
}
/// Number of unique bigram keys stored.
pub fn bigram_count(&self) -> usize {
self.bigrams.len()
}
/// Returns true if the cache has no entries.
pub fn is_empty(&self) -> bool {
self.bigrams.is_empty() && self.trigrams.is_empty()
}
}
impl Default for NgramCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_cache_returns_no_prediction() {
let cache = NgramCache::new();
assert_eq!(cache.predict_one(&[1, 2, 3]), None);
assert!(cache.is_empty());
}
#[test]
fn bigram_prediction() {
let mut cache = NgramCache::new();
cache.record(&[10, 20, 30]);
// Bigram: 10→20, 20→30
assert_eq!(cache.predict_one(&[10]), Some(20));
assert_eq!(cache.predict_one(&[20]), Some(30));
}
#[test]
fn trigram_preferred_over_bigram() {
let mut cache = NgramCache::new();
cache.record(&[10, 20, 30]);
cache.record(&[10, 20, 40]); // second trigram (10,20)→40
cache.record(&[10, 20, 40]); // now (10,20)→40 has count=2 > (10,20)→30 count=1
// Trigram (10,20) predicts 40 (higher count)
assert_eq!(cache.predict_one(&[10, 20]), Some(40));
}
#[test]
fn draft_chains_predictions() {
let mut cache = NgramCache::new();
// Record a repeating pattern: 1, 2, 3, 1, 2, 3, 1, 2, 3
cache.record(&[1, 2, 3, 1, 2, 3, 1, 2, 3]);
let draft = cache.draft(&[1, 2], 4);
// Should predict: 3, 1, 2, 3 (repeating pattern)
assert_eq!(draft, vec![3, 1, 2, 3]);
}
#[test]
fn draft_stops_on_no_prediction() {
let mut cache = NgramCache::new();
cache.record(&[1, 2, 3]);
// Context [99] has no match
let draft = cache.draft(&[99], 4);
assert!(draft.is_empty());
}
#[test]
fn frequency_tracking() {
let mut cache = NgramCache::new();
cache.record(&[1, 2, 3]);
cache.record(&[1, 2, 3]);
cache.record(&[1, 2, 3]);
cache.record(&[1, 2, 5]);
// (1,2)→3 has count 3, (1,2)→5 has count 1
assert_eq!(cache.predict_one(&[1, 2]), Some(3));
}
}