Skip to main content

oxibonsai_runtime/grammar/
cache.rs

1//! Memoization cache for Earley `allowed_tokens` results.
2//!
3//! Caches per-state token masks keyed by a 64-bit Earley chart hash.
4//! Capacity is bounded; eviction is LRU.
5//!
6//! # Design notes
7//!
8//! * The cache is keyed by `EarleyRecognizer::state_hash()` — a 64-bit hash of
9//!   the current chart set at `input_pos`.  Hash collisions are theoretically
10//!   possible but extremely unlikely in practice; the worst case is a spurious
11//!   cache hit returning a slightly wrong mask, which degrades output quality
12//!   without crashing.
13//!
14//! * Eviction is strict LRU maintained by a `VecDeque<u64>` (key order list).
15//!   On a cache hit the key is moved to the back.  On eviction the front is
16//!   removed.  This is O(n) for the hit path due to `VecDeque::position`, but
17//!   grammar parse states are typically O(|grammar|²) bounded and capacity is
18//!   small (default 256), so the scan cost is negligible.
19//!
20//! * All allocation is on the Rust heap; no unsafe code.
21
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24
25// ─────────────────────────────────────────────────────────────────────────────
26// Internal entry
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// A single cached allowed-tokens mask for one Earley state.
30///
31/// The mask is stored as a reference-counted boxed slice so that callers can
32/// cheaply clone the `Arc` rather than copying the entire `Vec<bool>`.
33struct CachedMask {
34    /// Shared, immutable token mask (true = allowed).
35    mask: Arc<[bool]>,
36}
37
38// ─────────────────────────────────────────────────────────────────────────────
39// AllowedTokensCache
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// LRU cache mapping Earley state hashes → token masks.
43///
44/// Default capacity: 256 entries.  At 150 k tokens (Qwen3 vocab) × 1 byte each,
45/// that is ≈38 MB worst-case; typical grammars cycle through far fewer states.
46///
47/// # Thread safety
48///
49/// `AllowedTokensCache` is **not** `Sync` on its own.  In `GrammarConstraint`
50/// it is wrapped in a `std::sync::Mutex` to satisfy the `&self` signature of
51/// `TokenConstraint::allowed_tokens`.
52pub struct AllowedTokensCache {
53    capacity: usize,
54    inner: HashMap<u64, CachedMask>,
55    /// LRU order: front = least recently used, back = most recently used.
56    lru: VecDeque<u64>,
57    /// Total cache hits (for testing / metrics).
58    hits: u64,
59    /// Total cache misses (for testing / metrics).
60    misses: u64,
61}
62
63impl AllowedTokensCache {
64    /// Create a cache with the given capacity.
65    ///
66    /// The capacity is clamped to a minimum of 1 to keep invariants simple.
67    pub fn with_capacity(capacity: usize) -> Self {
68        let capacity = capacity.max(1);
69        Self {
70            capacity,
71            inner: HashMap::with_capacity(capacity),
72            lru: VecDeque::with_capacity(capacity),
73            hits: 0,
74            misses: 0,
75        }
76    }
77
78    /// Try to get a cached mask for the given state hash.
79    ///
80    /// On a hit the key is promoted to the back of the LRU queue (most recently
81    /// used) and the hit counter is incremented.  On a miss the miss counter is
82    /// incremented.
83    pub fn get(&mut self, state_hash: u64) -> Option<Arc<[bool]>> {
84        if let Some(entry) = self.inner.get(&state_hash) {
85            // Promote to back of LRU (most recently used).
86            if let Some(pos) = self.lru.iter().position(|&k| k == state_hash) {
87                self.lru.remove(pos);
88            }
89            self.lru.push_back(state_hash);
90            self.hits += 1;
91            Some(Arc::clone(&entry.mask))
92        } else {
93            self.misses += 1;
94            None
95        }
96    }
97
98    /// Insert a mask for the given state hash, evicting LRU if at capacity.
99    ///
100    /// If the hash is already present the call is a no-op (the existing entry
101    /// is kept; this is safe because a single-threaded `Mutex` prevents races).
102    pub fn insert(&mut self, state_hash: u64, mask: Vec<bool>) {
103        if self.inner.contains_key(&state_hash) {
104            return;
105        }
106        if self.inner.len() >= self.capacity {
107            // Evict least-recently-used entry.
108            if let Some(oldest) = self.lru.pop_front() {
109                self.inner.remove(&oldest);
110            }
111        }
112        let mask: Arc<[bool]> = Arc::from(mask.into_boxed_slice());
113        self.inner.insert(state_hash, CachedMask { mask });
114        self.lru.push_back(state_hash);
115    }
116
117    /// Number of cache hits since creation (for testing / observability).
118    pub fn hits(&self) -> u64 {
119        self.hits
120    }
121
122    /// Number of cache misses since creation (for testing / observability).
123    pub fn misses(&self) -> u64 {
124        self.misses
125    }
126
127    /// Number of entries currently held in the cache.
128    pub fn len(&self) -> usize {
129        self.inner.len()
130    }
131
132    /// True if the cache holds no entries.
133    pub fn is_empty(&self) -> bool {
134        self.inner.is_empty()
135    }
136}
137
138// ─────────────────────────────────────────────────────────────────────────────
139// Unit tests
140// ─────────────────────────────────────────────────────────────────────────────
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn make_mask(v: &[bool]) -> Vec<bool> {
147        v.to_vec()
148    }
149
150    #[test]
151    fn cache_empty_initially() {
152        let cache = AllowedTokensCache::with_capacity(4);
153        assert!(cache.is_empty());
154        assert_eq!(cache.len(), 0);
155    }
156
157    #[test]
158    fn cache_miss_on_empty() {
159        let mut cache = AllowedTokensCache::with_capacity(4);
160        assert!(cache.get(42).is_none());
161        assert_eq!(cache.misses(), 1);
162        assert_eq!(cache.hits(), 0);
163    }
164
165    #[test]
166    fn cache_insert_and_hit() {
167        let mut cache = AllowedTokensCache::with_capacity(4);
168        cache.insert(1, make_mask(&[true, false, true]));
169        let result = cache.get(1).expect("should be present");
170        assert_eq!(&*result, &[true, false, true]);
171        assert_eq!(cache.hits(), 1);
172        assert_eq!(cache.misses(), 0);
173    }
174
175    #[test]
176    fn cache_duplicate_insert_is_noop() {
177        let mut cache = AllowedTokensCache::with_capacity(4);
178        cache.insert(7, make_mask(&[true]));
179        cache.insert(7, make_mask(&[false])); // should be ignored
180        let result = cache.get(7).expect("present");
181        // Original value should survive.
182        assert_eq!(&*result, &[true]);
183    }
184
185    #[test]
186    fn cache_evicts_lru_at_capacity() {
187        let mut cache = AllowedTokensCache::with_capacity(2);
188        cache.insert(10, make_mask(&[true]));
189        cache.insert(20, make_mask(&[true]));
190        // Access 20 to make 10 the LRU.
191        cache.get(20);
192        // Insert third entry — 10 should be evicted.
193        cache.insert(30, make_mask(&[true]));
194        assert_eq!(cache.len(), 2);
195        assert!(cache.get(10).is_none(), "10 should have been evicted");
196        assert!(cache.get(20).is_some(), "20 should still be present");
197        assert!(cache.get(30).is_some(), "30 should be present");
198    }
199
200    #[test]
201    fn cache_capacity_one_always_evicts() {
202        let mut cache = AllowedTokensCache::with_capacity(1);
203        cache.insert(1, make_mask(&[true]));
204        cache.insert(2, make_mask(&[false]));
205        assert_eq!(cache.len(), 1);
206        assert!(cache.get(1).is_none());
207        assert!(cache.get(2).is_some());
208    }
209
210    #[test]
211    fn cache_stats_track_correctly() {
212        let mut cache = AllowedTokensCache::with_capacity(8);
213        cache.get(99); // miss
214        cache.get(99); // miss again
215        cache.insert(99, make_mask(&[true, true]));
216        cache.get(99); // hit
217        cache.get(99); // hit
218        assert_eq!(cache.misses(), 2);
219        assert_eq!(cache.hits(), 2);
220    }
221
222    #[test]
223    fn cache_lru_promotes_on_hit() {
224        // Insert A, B; hit A; insert C → B should be evicted (not A).
225        let mut cache = AllowedTokensCache::with_capacity(2);
226        cache.insert(1, make_mask(&[true]));
227        cache.insert(2, make_mask(&[true]));
228        cache.get(1); // promote 1 to MRU
229        cache.insert(3, make_mask(&[true])); // should evict 2
230        assert!(cache.get(1).is_some(), "1 was promoted, should survive");
231        assert!(cache.get(2).is_none(), "2 was LRU, should be evicted");
232        assert!(cache.get(3).is_some(), "3 was just inserted");
233    }
234}