oxibonsai_runtime/
ngram_cache.rs1use std::collections::HashMap;
8
9pub struct NgramCache {
14 bigrams: HashMap<u32, Vec<(u32, u32)>>,
16 trigrams: HashMap<(u32, u32), Vec<(u32, u32)>>,
18 max_entries_per_key: usize,
20}
21
22impl NgramCache {
23 pub fn new() -> Self {
25 Self {
26 bigrams: HashMap::new(),
27 trigrams: HashMap::new(),
28 max_entries_per_key: 8,
29 }
30 }
31
32 pub fn record(&mut self, tokens: &[u32]) {
36 for window in tokens.windows(2) {
38 self.record_bigram(window[0], window[1]);
39 }
40 for window in tokens.windows(3) {
42 self.record_trigram(window[0], window[1], window[2]);
43 }
44 }
45
46 fn record_bigram(&mut self, a: u32, next: u32) {
48 let entries = self.bigrams.entry(a).or_default();
49 if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
50 entry.1 += 1;
51 } else if entries.len() < self.max_entries_per_key {
52 entries.push((next, 1));
53 }
54 entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
56 }
57
58 fn record_trigram(&mut self, a: u32, b: u32, next: u32) {
60 let entries = self.trigrams.entry((a, b)).or_default();
61 if let Some(entry) = entries.iter_mut().find(|(tok, _)| *tok == next) {
62 entry.1 += 1;
63 } else if entries.len() < self.max_entries_per_key {
64 entries.push((next, 1));
65 }
66 entries.sort_unstable_by_key(|e| std::cmp::Reverse(e.1));
67 }
68
69 pub fn predict_one(&self, context: &[u32]) -> Option<u32> {
74 if context.len() >= 2 {
76 let a = context[context.len() - 2];
77 let b = context[context.len() - 1];
78 if let Some(entries) = self.trigrams.get(&(a, b)) {
79 if let Some(&(next, _count)) = entries.first() {
80 return Some(next);
81 }
82 }
83 }
84
85 if let Some(&last) = context.last() {
87 if let Some(entries) = self.bigrams.get(&last) {
88 if let Some(&(next, _count)) = entries.first() {
89 return Some(next);
90 }
91 }
92 }
93
94 None
95 }
96
97 pub fn draft(&self, context: &[u32], lookahead: usize) -> Vec<u32> {
102 let mut draft = Vec::with_capacity(lookahead);
103 let mut ctx: Vec<u32> = context.to_vec();
104
105 for _ in 0..lookahead {
106 match self.predict_one(&ctx) {
107 Some(token) => {
108 draft.push(token);
109 ctx.push(token);
110 }
111 None => break,
112 }
113 }
114
115 draft
116 }
117
118 pub fn trigram_count(&self) -> usize {
120 self.trigrams.len()
121 }
122
123 pub fn bigram_count(&self) -> usize {
125 self.bigrams.len()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.bigrams.is_empty() && self.trigrams.is_empty()
131 }
132}
133
134impl Default for NgramCache {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn empty_cache_returns_no_prediction() {
146 let cache = NgramCache::new();
147 assert_eq!(cache.predict_one(&[1, 2, 3]), None);
148 assert!(cache.is_empty());
149 }
150
151 #[test]
152 fn bigram_prediction() {
153 let mut cache = NgramCache::new();
154 cache.record(&[10, 20, 30]);
155 assert_eq!(cache.predict_one(&[10]), Some(20));
157 assert_eq!(cache.predict_one(&[20]), Some(30));
158 }
159
160 #[test]
161 fn trigram_preferred_over_bigram() {
162 let mut cache = NgramCache::new();
163 cache.record(&[10, 20, 30]);
164 cache.record(&[10, 20, 40]); cache.record(&[10, 20, 40]); assert_eq!(cache.predict_one(&[10, 20]), Some(40));
168 }
169
170 #[test]
171 fn draft_chains_predictions() {
172 let mut cache = NgramCache::new();
173 cache.record(&[1, 2, 3, 1, 2, 3, 1, 2, 3]);
175
176 let draft = cache.draft(&[1, 2], 4);
177 assert_eq!(draft, vec![3, 1, 2, 3]);
179 }
180
181 #[test]
182 fn draft_stops_on_no_prediction() {
183 let mut cache = NgramCache::new();
184 cache.record(&[1, 2, 3]);
185
186 let draft = cache.draft(&[99], 4);
188 assert!(draft.is_empty());
189 }
190
191 #[test]
192 fn frequency_tracking() {
193 let mut cache = NgramCache::new();
194 cache.record(&[1, 2, 3]);
195 cache.record(&[1, 2, 3]);
196 cache.record(&[1, 2, 3]);
197 cache.record(&[1, 2, 5]);
198
199 assert_eq!(cache.predict_one(&[1, 2]), Some(3));
201 }
202}