1use circular_buffer::CircularBuffer;
2use graft_core::PageIdx;
3
4pub trait Oracle {
5 fn observe_cache_hit(&mut self, pageidx: PageIdx);
8
9 fn predict_next(&mut self, pageidx: PageIdx) -> impl Iterator<Item = PageIdx>;
14}
15
16pub struct NoopOracle;
17
18impl Oracle for NoopOracle {
19 fn observe_cache_hit(&mut self, _pageidx: PageIdx) {
20 }
22
23 fn predict_next(&mut self, _pageidx: PageIdx) -> impl Iterator<Item = PageIdx> {
24 std::iter::empty()
26 }
27}
28
29#[derive(Debug, Default, Clone)]
41pub struct LeapOracle {
42 last_read: PageIdx,
44 history: CircularBuffer<32, isize>,
46 prediction: Vec<PageIdx>,
48 prediction_hits: usize,
50}
51
52impl LeapOracle {
53 fn find_trend(&self) -> Option<isize> {
56 const N_SPLIT: usize = 4;
57 let mut window_size = (self.history.len() / N_SPLIT).max(1);
58 while window_size <= self.history.len() {
59 let window = self.history.range(0..window_size);
60 if let Some(trend) = boyer_moore_strict_majority(window.copied()) {
61 return Some(trend);
62 }
63 window_size *= 2;
64 }
65 None
66 }
67
68 fn record_read(&mut self, pageidx: PageIdx) {
69 let delta = pageidx.to_u32() as isize - self.last_read.to_u32() as isize;
71 self.history.push_front(delta);
72
73 self.last_read = pageidx;
75 }
76}
77
78impl Oracle for LeapOracle {
79 fn observe_cache_hit(&mut self, pageidx: PageIdx) {
80 if pageidx == self.last_read {
82 return;
83 }
84
85 if self.prediction.contains(&pageidx) {
87 self.prediction_hits += 1;
88 }
89
90 self.record_read(pageidx);
91 }
92
93 fn predict_next(&mut self, pageidx: PageIdx) -> impl Iterator<Item = PageIdx> {
94 const MAX_LOOKAHEAD: usize = 8;
95
96 let trend = self.find_trend();
98
99 let lookahead = if self.prediction_hits == 0 {
101 if TrendIter::once(self.last_read, trend.unwrap_or(1)) == Some(pageidx) {
104 1
105 } else {
106 0
107 }
108 } else {
109 (self.prediction_hits + 1)
111 .checked_next_power_of_two()
112 .unwrap_or(MAX_LOOKAHEAD)
113 }
114 .min(MAX_LOOKAHEAD)
116 .max(self.prediction.len() / 2);
118
119 self.prediction_hits = 0;
121 self.prediction.clear();
122
123 if lookahead != 0 {
125 if let Some(trend) = trend {
126 self.prediction
128 .extend(TrendIter::new(pageidx, trend).take(lookahead));
129 } else {
130 for i in 1..=(lookahead / 2) {
132 self.prediction.push(pageidx.saturating_add(i as u32));
133 self.prediction.push(pageidx.saturating_sub(i as u32));
134 }
135 }
136 } else {
137 }
139
140 self.record_read(pageidx);
141
142 self.prediction.iter().copied()
143 }
144}
145
146struct TrendIter {
147 cursor: isize,
148 trend: isize,
149}
150
151impl TrendIter {
152 fn new(pageidx: PageIdx, trend: isize) -> Self {
153 Self { cursor: pageidx.to_u32() as isize, trend }
154 }
155
156 fn once(pageidx: PageIdx, trend: isize) -> Option<PageIdx> {
157 Self::new(pageidx, trend).next()
158 }
159}
160
161impl Iterator for TrendIter {
162 type Item = PageIdx;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 self.cursor += self.trend;
166 PageIdx::try_new(self.cursor as u32)
167 }
168}
169
170fn boyer_moore_strict_majority<I>(iter: I) -> Option<isize>
173where
174 I: Iterator<Item = isize> + Clone,
175{
176 let mut candidate = 0;
177 let mut count = 0;
178 let mut total_count = 0;
179
180 for num in iter.clone() {
182 total_count += 1;
183 if count == 0 {
184 candidate = num;
185 count = 1;
186 } else if num == candidate {
187 count += 1;
188 } else {
189 count -= 1;
190 }
191 }
192
193 let mut occurrence = 0;
195 for num in iter {
196 if num == candidate {
197 occurrence += 1;
198 }
199 }
200
201 if occurrence > total_count / 2 {
202 Some(candidate)
203 } else {
204 None
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::collections::HashSet;
211
212 use super::*;
213
214 #[test]
215 fn test_boyer_moore_strict_majority() {
216 let test_cases = [
217 (vec![], None),
218 (vec![1], Some(1)),
219 (vec![1, 0], None),
220 (vec![0, 0, 0, 0], Some(0)),
221 (vec![0, 1, 0, 0], Some(0)),
222 (vec![0, 1, 1, 0], None),
223 (vec![0, 1, 1, 1], Some(1)),
224 (vec![0, 1, 1, 1, 0], Some(1)),
225 (vec![72, -3, -3, -3], Some(-3)),
226 (vec![-3, -58, 2, 2], None),
227 (vec![72, -3, -3, -3, -3, -58, 2, 2], None),
228 (vec![2, -58, 2, 2], Some(2)),
229 (vec![2, 2, 2, 4, -41, -39, 2, 2], Some(2)),
230 ];
231
232 for (input, expected) in test_cases {
233 assert_eq!(boyer_moore_strict_majority(input.into_iter()), expected);
234 }
235 }
236
237 #[test]
238 fn test_leap_oracle() {
239 #[derive(Default)]
240 struct State {
241 oracle: LeapOracle,
242 cache: HashSet<PageIdx>,
243 }
244 struct Case {
245 name: &'static str,
246 reads: Vec<u32>,
247 expected_misses: usize,
248 }
249
250 fn run_test(state: &mut State, case: Case) {
251 let mut misses = 0;
252 for pageidx in case.reads {
253 let pageidx = PageIdx::new(pageidx);
254 if state.cache.contains(&pageidx) {
255 state.oracle.observe_cache_hit(pageidx);
256 } else {
257 state.cache.insert(pageidx);
258 state.cache.extend(state.oracle.predict_next(pageidx));
259 misses += 1;
260 }
261 }
262 assert_eq!(
263 misses, case.expected_misses,
264 "{} failed: unexpected miss count",
265 case.name
266 );
267 }
268
269 let cases = [
270 Case {
271 name: "sequential",
272 reads: (1..=100).collect(),
273 expected_misses: 15,
274 },
275 Case {
276 name: "random",
277 reads: vec![
278 1, 56, 12, 100, 124, 15550, 51, 10, 7, 4101, 23, 1, 154, 1856, 15,
279 ],
280 expected_misses: 14,
282 },
283 Case {
284 name: "interrupted-scan",
285 reads: (1..=100)
286 .enumerate()
287 .map(
289 |(i, p): (usize, u32)| {
290 if i % 15 == 0 { p + 116589 } else { p }
291 },
292 )
293 .collect(),
294 expected_misses: 25,
295 },
296 Case {
297 name: "stride-2",
298 reads: (1..=200).step_by(2).collect(),
299 expected_misses: 15,
300 },
301 Case {
302 name: "reverse",
303 reads: (1..=100).rev().collect(),
304 expected_misses: 15,
305 },
306 Case {
307 name: "reverse-stride-2",
308 reads: (1..=200).rev().step_by(2).collect(),
309 expected_misses: 15,
310 },
311 Case {
312 name: "multi-pattern",
313 reads: (1..=100)
314 .chain((101..=300).step_by(2))
315 .chain((301..=500).rev().step_by(2))
316 .chain((501..=600).rev())
317 .collect(),
318 expected_misses: 59,
319 },
320 Case {
321 name: "multi-pattern-random-middle",
322 reads: (1..=100)
323 .chain((101..=300).step_by(2))
324 .chain([
325 1, 56, 12, 100, 124, 15550, 51, 10, 7, 4101, 23, 1, 154, 1856, 15,
326 ])
327 .chain((301..=700).rev().step_by(4))
328 .chain((701..=800).rev())
329 .collect(),
330 expected_misses: 68,
331 },
332 ];
333 for case in cases {
334 run_test(&mut State::default(), case);
335 }
336 }
337}