graft_client/
oracle.rs

1use circular_buffer::CircularBuffer;
2use graft_core::PageIdx;
3
4pub trait Oracle {
5    /// `observe_cache_hit` is called whenever Graft satisfies a page read from
6    /// it's local cache. This function is not called on a cache miss.
7    fn observe_cache_hit(&mut self, pageidx: PageIdx);
8
9    /// `predict_next` is called when Graft has a cache miss, and can be used to
10    /// hint that Graft should fetch additional pages along with the requested
11    /// page. The returned iterator should be empty if no additional pages
12    /// should be fetched, and it does not need to include the requested page.
13    fn predict_next(&mut self, pageidx: PageIdx) -> impl Iterator<Item = PageIdx>;
14}
15
16pub struct NoopOracle;
17
18impl Oracle for NoopOracle {
19    fn observe_cache_hit(&mut self, _pageidx: PageIdx) {
20        // do nothing
21    }
22
23    fn predict_next(&mut self, _pageidx: PageIdx) -> impl Iterator<Item = PageIdx> {
24        // predict nothing
25        std::iter::empty()
26    }
27}
28
29/// `LeapOracle` is an implementation of the algorithm described by the paper
30/// "Effectively Prefetching Remote Memory with Leap". It provides an Oracle
31/// that attempts to predict future page requests based on trends found in
32/// recent history.
33///
34/// Hasan Al Maruf and Mosharaf Chowdhury. (2020). [_Effectively Prefetching
35/// Remote Memory with Leap_][1]. In _Proceedings of the 2020 USENIX Conference on
36/// Usenix Annual Technical Conference (USENIX ATC'20)_, Article 58, 843–857.
37/// USENIX Association, USA.
38///
39/// [1]: https://www.usenix.org/system/files/atc20-maruf.pdf
40#[derive(Debug, Default, Clone)]
41pub struct LeapOracle {
42    /// the last observed read
43    last_read: PageIdx,
44    /// history of page index deltas ordered from most recent to least recent
45    history: CircularBuffer<32, isize>,
46    /// the last prediction
47    prediction: Vec<PageIdx>,
48    /// cache hits since the last prediction
49    prediction_hits: usize,
50}
51
52impl LeapOracle {
53    /// Tries to find a trend in the data by searching for strict majorities in
54    /// the access history. Returns None if no trend can be found.
55    fn find_trend(&self) -> Option<isize> {
56        const N_SPLIT: usize = 4;
57        let mut window_size = (self.history.len() / N_SPLIT).max(1);
58        while window_size <= self.history.len() {
59            let window = self.history.range(0..window_size);
60            if let Some(trend) = boyer_moore_strict_majority(window.copied()) {
61                return Some(trend);
62            }
63            window_size *= 2;
64        }
65        None
66    }
67
68    fn record_read(&mut self, pageidx: PageIdx) {
69        // update history buffer
70        let delta = pageidx.to_u32() as isize - self.last_read.to_u32() as isize;
71        self.history.push_front(delta);
72
73        // update last read and whether or not we are following the current trend
74        self.last_read = pageidx;
75    }
76}
77
78impl Oracle for LeapOracle {
79    fn observe_cache_hit(&mut self, pageidx: PageIdx) {
80        // ignore duplicate reads
81        if pageidx == self.last_read {
82            return;
83        }
84
85        // update hits counter
86        if self.prediction.contains(&pageidx) {
87            self.prediction_hits += 1;
88        }
89
90        self.record_read(pageidx);
91    }
92
93    fn predict_next(&mut self, pageidx: PageIdx) -> impl Iterator<Item = PageIdx> {
94        const MAX_LOOKAHEAD: usize = 8;
95
96        // calculate the trend
97        let trend = self.find_trend();
98
99        // calculate the number of predictions to make
100        let lookahead = if self.prediction_hits == 0 {
101            // the last prediction wasn't great
102            // check to see if reads are starting to follow a trend
103            if TrendIter::once(self.last_read, trend.unwrap_or(1)) == Some(pageidx) {
104                1
105            } else {
106                0
107            }
108        } else {
109            // the last prediction had hits
110            (self.prediction_hits + 1)
111                .checked_next_power_of_two()
112                .unwrap_or(MAX_LOOKAHEAD)
113        }
114        // ensure lookhead doesn't grow too large
115        .min(MAX_LOOKAHEAD)
116        // shrink lookahead smoothly
117        .max(self.prediction.len() / 2);
118
119        // clear previous prediction state
120        self.prediction_hits = 0;
121        self.prediction.clear();
122
123        // construct next prediction
124        if lookahead != 0 {
125            if let Some(trend) = trend {
126                // trend found, prefetch along the trend
127                self.prediction
128                    .extend(TrendIter::new(pageidx, trend).take(lookahead));
129            } else {
130                // no trend found, prefetch around the current page index
131                for i in 1..=(lookahead / 2) {
132                    self.prediction.push(pageidx.saturating_add(i as u32));
133                    self.prediction.push(pageidx.saturating_sub(i as u32));
134                }
135            }
136        } else {
137            // predictions are disabled until a new trend is established
138        }
139
140        self.record_read(pageidx);
141
142        self.prediction.iter().copied()
143    }
144}
145
146struct TrendIter {
147    cursor: isize,
148    trend: isize,
149}
150
151impl TrendIter {
152    fn new(pageidx: PageIdx, trend: isize) -> Self {
153        Self { cursor: pageidx.to_u32() as isize, trend }
154    }
155
156    fn once(pageidx: PageIdx, trend: isize) -> Option<PageIdx> {
157        Self::new(pageidx, trend).next()
158    }
159}
160
161impl Iterator for TrendIter {
162    type Item = PageIdx;
163
164    fn next(&mut self) -> Option<Self::Item> {
165        self.cursor += self.trend;
166        PageIdx::try_new(self.cursor as u32)
167    }
168}
169
170/// Computes the majority value contained by an iterator in two passes. If no
171/// strict majority (occurs > count/2 times) is found returns None.
172fn boyer_moore_strict_majority<I>(iter: I) -> Option<isize>
173where
174    I: Iterator<Item = isize> + Clone,
175{
176    let mut candidate = 0;
177    let mut count = 0;
178    let mut total_count = 0;
179
180    // First pass: Find candidate and count total elements
181    for num in iter.clone() {
182        total_count += 1;
183        if count == 0 {
184            candidate = num;
185            count = 1;
186        } else if num == candidate {
187            count += 1;
188        } else {
189            count -= 1;
190        }
191    }
192
193    // Second pass: Verify candidate
194    let mut occurrence = 0;
195    for num in iter {
196        if num == candidate {
197            occurrence += 1;
198        }
199    }
200
201    if occurrence > total_count / 2 {
202        Some(candidate)
203    } else {
204        None
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::collections::HashSet;
211
212    use super::*;
213
214    #[test]
215    fn test_boyer_moore_strict_majority() {
216        let test_cases = [
217            (vec![], None),
218            (vec![1], Some(1)),
219            (vec![1, 0], None),
220            (vec![0, 0, 0, 0], Some(0)),
221            (vec![0, 1, 0, 0], Some(0)),
222            (vec![0, 1, 1, 0], None),
223            (vec![0, 1, 1, 1], Some(1)),
224            (vec![0, 1, 1, 1, 0], Some(1)),
225            (vec![72, -3, -3, -3], Some(-3)),
226            (vec![-3, -58, 2, 2], None),
227            (vec![72, -3, -3, -3, -3, -58, 2, 2], None),
228            (vec![2, -58, 2, 2], Some(2)),
229            (vec![2, 2, 2, 4, -41, -39, 2, 2], Some(2)),
230        ];
231
232        for (input, expected) in test_cases {
233            assert_eq!(boyer_moore_strict_majority(input.into_iter()), expected);
234        }
235    }
236
237    #[test]
238    fn test_leap_oracle() {
239        #[derive(Default)]
240        struct State {
241            oracle: LeapOracle,
242            cache: HashSet<PageIdx>,
243        }
244        struct Case {
245            name: &'static str,
246            reads: Vec<u32>,
247            expected_misses: usize,
248        }
249
250        fn run_test(state: &mut State, case: Case) {
251            let mut misses = 0;
252            for pageidx in case.reads {
253                let pageidx = PageIdx::new(pageidx);
254                if state.cache.contains(&pageidx) {
255                    state.oracle.observe_cache_hit(pageidx);
256                } else {
257                    state.cache.insert(pageidx);
258                    state.cache.extend(state.oracle.predict_next(pageidx));
259                    misses += 1;
260                }
261            }
262            assert_eq!(
263                misses, case.expected_misses,
264                "{} failed: unexpected miss count",
265                case.name
266            );
267        }
268
269        let cases = [
270            Case {
271                name: "sequential",
272                reads: (1..=100).collect(),
273                expected_misses: 15,
274            },
275            Case {
276                name: "random",
277                reads: vec![
278                    1, 56, 12, 100, 124, 15550, 51, 10, 7, 4101, 23, 1, 154, 1856, 15,
279                ],
280                // every read is a miss
281                expected_misses: 14,
282            },
283            Case {
284                name: "interrupted-scan",
285                reads: (1..=100)
286                    .enumerate()
287                    // inject a huge random read every 15 pages to test algorithm resilience
288                    .map(
289                        |(i, p): (usize, u32)| {
290                            if i % 15 == 0 { p + 116589 } else { p }
291                        },
292                    )
293                    .collect(),
294                expected_misses: 25,
295            },
296            Case {
297                name: "stride-2",
298                reads: (1..=200).step_by(2).collect(),
299                expected_misses: 15,
300            },
301            Case {
302                name: "reverse",
303                reads: (1..=100).rev().collect(),
304                expected_misses: 15,
305            },
306            Case {
307                name: "reverse-stride-2",
308                reads: (1..=200).rev().step_by(2).collect(),
309                expected_misses: 15,
310            },
311            Case {
312                name: "multi-pattern",
313                reads: (1..=100)
314                    .chain((101..=300).step_by(2))
315                    .chain((301..=500).rev().step_by(2))
316                    .chain((501..=600).rev())
317                    .collect(),
318                expected_misses: 59,
319            },
320            Case {
321                name: "multi-pattern-random-middle",
322                reads: (1..=100)
323                    .chain((101..=300).step_by(2))
324                    .chain([
325                        1, 56, 12, 100, 124, 15550, 51, 10, 7, 4101, 23, 1, 154, 1856, 15,
326                    ])
327                    .chain((301..=700).rev().step_by(4))
328                    .chain((701..=800).rev())
329                    .collect(),
330                expected_misses: 68,
331            },
332        ];
333        for case in cases {
334            run_test(&mut State::default(), case);
335        }
336    }
337}