oxibonsai_runtime/grammar/cache.rs
1//! Memoization cache for Earley `allowed_tokens` results.
2//!
3//! Caches per-state token masks keyed by a 64-bit Earley chart hash.
4//! Capacity is bounded; eviction is LRU.
5//!
6//! # Design notes
7//!
8//! * The cache is keyed by `EarleyRecognizer::state_hash()` — a 64-bit hash of
9//! the current chart set at `input_pos`. Hash collisions are theoretically
10//! possible but extremely unlikely in practice; the worst case is a spurious
11//! cache hit returning a slightly wrong mask, which degrades output quality
12//! without crashing.
13//!
14//! * Eviction is strict LRU maintained by a `VecDeque<u64>` (key order list).
15//! On a cache hit the key is moved to the back. On eviction the front is
16//! removed. This is O(n) for the hit path due to `VecDeque::position`, but
17//! grammar parse states are typically O(|grammar|²) bounded and capacity is
18//! small (default 256), so the scan cost is negligible.
19//!
20//! * All allocation is on the Rust heap; no unsafe code.
21
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24
25// ─────────────────────────────────────────────────────────────────────────────
26// Internal entry
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// A single cached allowed-tokens mask for one Earley state.
30///
31/// The mask is stored as a reference-counted boxed slice so that callers can
32/// cheaply clone the `Arc` rather than copying the entire `Vec<bool>`.
33struct CachedMask {
34 /// Shared, immutable token mask (true = allowed).
35 mask: Arc<[bool]>,
36}
37
38// ─────────────────────────────────────────────────────────────────────────────
39// AllowedTokensCache
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// LRU cache mapping Earley state hashes → token masks.
43///
44/// Default capacity: 256 entries. At 150 k tokens (Qwen3 vocab) × 1 byte each,
45/// that is ≈38 MB worst-case; typical grammars cycle through far fewer states.
46///
47/// # Thread safety
48///
49/// `AllowedTokensCache` is **not** `Sync` on its own. In `GrammarConstraint`
50/// it is wrapped in a `std::sync::Mutex` to satisfy the `&self` signature of
51/// `TokenConstraint::allowed_tokens`.
52pub struct AllowedTokensCache {
53 capacity: usize,
54 inner: HashMap<u64, CachedMask>,
55 /// LRU order: front = least recently used, back = most recently used.
56 lru: VecDeque<u64>,
57 /// Total cache hits (for testing / metrics).
58 hits: u64,
59 /// Total cache misses (for testing / metrics).
60 misses: u64,
61}
62
63impl AllowedTokensCache {
64 /// Create a cache with the given capacity.
65 ///
66 /// The capacity is clamped to a minimum of 1 to keep invariants simple.
67 pub fn with_capacity(capacity: usize) -> Self {
68 let capacity = capacity.max(1);
69 Self {
70 capacity,
71 inner: HashMap::with_capacity(capacity),
72 lru: VecDeque::with_capacity(capacity),
73 hits: 0,
74 misses: 0,
75 }
76 }
77
78 /// Try to get a cached mask for the given state hash.
79 ///
80 /// On a hit the key is promoted to the back of the LRU queue (most recently
81 /// used) and the hit counter is incremented. On a miss the miss counter is
82 /// incremented.
83 pub fn get(&mut self, state_hash: u64) -> Option<Arc<[bool]>> {
84 if let Some(entry) = self.inner.get(&state_hash) {
85 // Promote to back of LRU (most recently used).
86 if let Some(pos) = self.lru.iter().position(|&k| k == state_hash) {
87 self.lru.remove(pos);
88 }
89 self.lru.push_back(state_hash);
90 self.hits += 1;
91 Some(Arc::clone(&entry.mask))
92 } else {
93 self.misses += 1;
94 None
95 }
96 }
97
98 /// Insert a mask for the given state hash, evicting LRU if at capacity.
99 ///
100 /// If the hash is already present the call is a no-op (the existing entry
101 /// is kept; this is safe because a single-threaded `Mutex` prevents races).
102 pub fn insert(&mut self, state_hash: u64, mask: Vec<bool>) {
103 if self.inner.contains_key(&state_hash) {
104 return;
105 }
106 if self.inner.len() >= self.capacity {
107 // Evict least-recently-used entry.
108 if let Some(oldest) = self.lru.pop_front() {
109 self.inner.remove(&oldest);
110 }
111 }
112 let mask: Arc<[bool]> = Arc::from(mask.into_boxed_slice());
113 self.inner.insert(state_hash, CachedMask { mask });
114 self.lru.push_back(state_hash);
115 }
116
117 /// Number of cache hits since creation (for testing / observability).
118 pub fn hits(&self) -> u64 {
119 self.hits
120 }
121
122 /// Number of cache misses since creation (for testing / observability).
123 pub fn misses(&self) -> u64 {
124 self.misses
125 }
126
127 /// Number of entries currently held in the cache.
128 pub fn len(&self) -> usize {
129 self.inner.len()
130 }
131
132 /// True if the cache holds no entries.
133 pub fn is_empty(&self) -> bool {
134 self.inner.is_empty()
135 }
136}
137
138// ─────────────────────────────────────────────────────────────────────────────
139// Unit tests
140// ─────────────────────────────────────────────────────────────────────────────
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn make_mask(v: &[bool]) -> Vec<bool> {
147 v.to_vec()
148 }
149
150 #[test]
151 fn cache_empty_initially() {
152 let cache = AllowedTokensCache::with_capacity(4);
153 assert!(cache.is_empty());
154 assert_eq!(cache.len(), 0);
155 }
156
157 #[test]
158 fn cache_miss_on_empty() {
159 let mut cache = AllowedTokensCache::with_capacity(4);
160 assert!(cache.get(42).is_none());
161 assert_eq!(cache.misses(), 1);
162 assert_eq!(cache.hits(), 0);
163 }
164
165 #[test]
166 fn cache_insert_and_hit() {
167 let mut cache = AllowedTokensCache::with_capacity(4);
168 cache.insert(1, make_mask(&[true, false, true]));
169 let result = cache.get(1).expect("should be present");
170 assert_eq!(&*result, &[true, false, true]);
171 assert_eq!(cache.hits(), 1);
172 assert_eq!(cache.misses(), 0);
173 }
174
175 #[test]
176 fn cache_duplicate_insert_is_noop() {
177 let mut cache = AllowedTokensCache::with_capacity(4);
178 cache.insert(7, make_mask(&[true]));
179 cache.insert(7, make_mask(&[false])); // should be ignored
180 let result = cache.get(7).expect("present");
181 // Original value should survive.
182 assert_eq!(&*result, &[true]);
183 }
184
185 #[test]
186 fn cache_evicts_lru_at_capacity() {
187 let mut cache = AllowedTokensCache::with_capacity(2);
188 cache.insert(10, make_mask(&[true]));
189 cache.insert(20, make_mask(&[true]));
190 // Access 20 to make 10 the LRU.
191 cache.get(20);
192 // Insert third entry — 10 should be evicted.
193 cache.insert(30, make_mask(&[true]));
194 assert_eq!(cache.len(), 2);
195 assert!(cache.get(10).is_none(), "10 should have been evicted");
196 assert!(cache.get(20).is_some(), "20 should still be present");
197 assert!(cache.get(30).is_some(), "30 should be present");
198 }
199
200 #[test]
201 fn cache_capacity_one_always_evicts() {
202 let mut cache = AllowedTokensCache::with_capacity(1);
203 cache.insert(1, make_mask(&[true]));
204 cache.insert(2, make_mask(&[false]));
205 assert_eq!(cache.len(), 1);
206 assert!(cache.get(1).is_none());
207 assert!(cache.get(2).is_some());
208 }
209
210 #[test]
211 fn cache_stats_track_correctly() {
212 let mut cache = AllowedTokensCache::with_capacity(8);
213 cache.get(99); // miss
214 cache.get(99); // miss again
215 cache.insert(99, make_mask(&[true, true]));
216 cache.get(99); // hit
217 cache.get(99); // hit
218 assert_eq!(cache.misses(), 2);
219 assert_eq!(cache.hits(), 2);
220 }
221
222 #[test]
223 fn cache_lru_promotes_on_hit() {
224 // Insert A, B; hit A; insert C → B should be evicted (not A).
225 let mut cache = AllowedTokensCache::with_capacity(2);
226 cache.insert(1, make_mask(&[true]));
227 cache.insert(2, make_mask(&[true]));
228 cache.get(1); // promote 1 to MRU
229 cache.insert(3, make_mask(&[true])); // should evict 2
230 assert!(cache.get(1).is_some(), "1 was promoted, should survive");
231 assert!(cache.get(2).is_none(), "2 was LRU, should be evicted");
232 assert!(cache.get(3).is_some(), "3 was just inserted");
233 }
234}