reactor_cache/
cache.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::fmt::Debug;
4use std::sync::Arc;
5use std::sync::mpsc::{self, TryRecvError};
6use std::time::Duration;
7
8use futures::{self, Async, Fuse, Future, IntoFuture, Poll};
9use futures::sync::oneshot::{self, Sender, Receiver};
10
11use linked_hash_map::LinkedHashMap;
12
13use mio::timer::{Builder as TimerBuilder, Timer};
14
15use tokio_core::reactor::{Handle, PollEvented};
16
17// for the Checker error case
18const LEADER: bool = true;
19const WAITER: bool = !LEADER;
20
21type Checker<V> = Sender<Result<Arc<V>, bool>>;
22type Loader<V, E> = Receiver<Result<Arc<V>, E>>;
23type Waiter<V, E> = Sender<Result<Option<Arc<V>>, E>>;
24
25enum Message<K, V, E> {
26    Stats(Sender<CacheStats>),
27    Get(K, bool, Waiter<V, E>),
28    Load(K, Checker<V>, Loader<V, E>, Waiter<V, E>),
29    Evict(K, Sender<()>),
30}
31
32pub trait Weighted {
33    fn weight(&self) -> usize;
34}
35
36pub struct ReactorCache<K, V, E> {
37    tx: mpsc::Sender<Message<K, V, E>>,
38}
39
40#[derive(Debug)]
41pub struct CacheStats {
42    pub entries: usize,
43    pub remaining: usize,
44    pub capacity: usize,
45}
46
47pub struct GetHandle<V, E> {
48    rx: Receiver<Result<Option<Arc<V>>, E>>,
49}
50
51enum LoadState<F: Future, V, E> {
52    Empty, // used for state transitions
53    Checking(Receiver<Result<Arc<V>, bool>>, Fuse<F>, Sender<Result<Arc<V>, E>>, GetHandle<V, E>),
54    Loading(Fuse<F>, Sender<Result<Arc<V>, E>>, GetHandle<V, E>),
55    Waiting(GetHandle<V, E>),
56}
57
58pub struct LoadHandle<F: Future, V, E> {
59    state: LoadState<F, V, E>,
60}
61
62struct CacheEntry<V> {
63    inner: Arc<V>,
64    weight: usize,
65    marked: bool,
66}
67
68struct Inner<K, V, E> {
69    rx: mpsc::Receiver<Message<K, V, E>>,
70    timer: PollEvented<Timer<()>>,
71    fetch_map: HashMap<K, (Loader<V, E>, Vec<Waiter<V, E>>)>,
72    cache_map: LinkedHashMap<K, CacheEntry<V>>,
73    usage: (usize, usize), // (remaining, capacity)
74}
75
76impl<V, E> Future for GetHandle<V, E> {
77    type Item = Option<Arc<V>>;
78    type Error = E;
79    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
80        match self.rx.poll().expect("get canceled") {
81            Async::Ready(Ok(res)) => Ok(res.into()),
82            Async::Ready(Err(e)) => Err(e),
83            Async::NotReady => Ok(Async::NotReady),
84        }
85    }
86}
87
88impl<F, V, E: Clone> Future for LoadHandle<F, V, E>
89    where F: Future<Item = V, Error = E>
90{
91    type Item = Arc<V>;
92    type Error = E;
93    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
94        trace!("loadhandle - start");
95
96        let mut state = ::std::mem::replace(&mut self.state, LoadState::Empty);
97
98        if let LoadState::Checking(mut checker, loader, resolver, waiter) = state {
99            trace!("loadhandle - checking");
100            match checker.poll().expect("check canceled") {
101                Async::Ready(Ok(res)) => {
102                    trace!("loadhandle - hit");
103                    return Ok(res.into());
104                }
105                Async::Ready(Err(LEADER)) => {
106                    trace!("loadhandle - miss:leader");
107                    state = LoadState::Loading(loader, resolver, waiter);
108                }
109                Async::Ready(Err(WAITER)) => {
110                    trace!("loadhandle - miss:waiter");
111                    state = LoadState::Waiting(waiter);
112                }
113                Async::NotReady => {
114                    self.state = LoadState::Checking(checker, loader, resolver, waiter);
115                    return Ok(Async::NotReady);
116                }
117            }
118        }
119
120        if let LoadState::Loading(mut loader, resolver, waiter) = state {
121            trace!("loadhandle - loading");
122            match loader.poll() {
123                Ok(Async::Ready(res)) => {
124                    trace!("loadhandle - success");
125                    resolver.complete(Ok(Arc::new(res)));
126                    state = LoadState::Waiting(waiter);
127                }
128                Err(e) => {
129                    trace!("loadhandle - failure");
130                    resolver.complete(Err(e));
131                    state = LoadState::Waiting(waiter);
132                }
133                Ok(Async::NotReady) => {
134                    self.state = LoadState::Loading(loader, resolver, waiter);
135                    return Ok(Async::NotReady);
136                }
137            }
138        }
139
140        if let LoadState::Waiting(mut waiter) = state {
141            trace!("loadhandle - waiting");
142            return match waiter.poll() {
143                Ok(Async::Ready(Some(res))) => {
144                    trace!("loadhandle - ok");
145                    Ok(res.into())
146                }
147                Ok(Async::Ready(None)) => unreachable!(),
148                Ok(Async::NotReady) => {
149                    self.state = LoadState::Waiting(waiter);
150                    Ok(Async::NotReady)
151                }
152                Err(e) => {
153                    trace!("loadhandle - err");
154                    Err(e)
155                }
156            };
157        }
158
159        unreachable!("invalid state transition")
160    }
161}
162
163/// Core methods for interacting with the cache
164impl<K: Clone + Eq + Hash, V: Weighted, E: Clone + Debug> ReactorCache<K, V, E> {
165    /// Creates a new Reactor Cache with the given capacity, which runs in `handle`'s event-loop
166    pub fn new(capacity: usize, handle: Handle) -> Self
167        where K: 'static,
168              V: 'static,
169              E: 'static
170    {
171        let (tx, rx) = mpsc::channel();
172        let mut timer = TimerBuilder::default().tick_duration(Duration::from_millis(10)).build();
173        timer.set_timeout(Duration::from_millis(10), ()).unwrap();
174        let poll = PollEvented::new(timer, &handle).unwrap();
175        handle.spawn_fn(move || Inner::new(capacity, rx, poll).map_err(|e| panic!("{:?}", e)));
176        ReactorCache { tx: tx }
177    }
178
179    /// Returns a future with a snapshot of the cache's stats. No guarnatees are made about
180    /// when the snapshot is taken.
181    ///
182    /// # Example
183    ///
184    /// Run a future that logs the cache stats:
185    ///
186    /// ```
187    /// # extern crate futures;
188    /// # extern crate reactor_cache;
189    /// # extern crate tokio_core;
190    /// #
191    /// # use futures::Future;
192    /// # use reactor_cache::*;
193    /// # use tokio_core::reactor::Core;
194    /// #
195    /// # #[derive(Clone, Eq, Hash, PartialEq)] struct Int(i64);
196    /// # impl Weighted for Int { fn weight(&self) -> usize { 8 } }
197    /// #
198    /// # fn main() {
199    ///     let mut core = Core::new().expect("meltdown");
200    ///     let cache = ReactorCache::<Int, Int, ()>::new(10, core.handle());
201    ///     core.run(cache.stats().map(|s| println!("{:?}",s))).unwrap();
202    /// # }
203    pub fn stats(&self) -> Receiver<CacheStats> {
204        let (tx, rx) = oneshot::channel();
205        self.tx.send(Message::Stats(tx)).unwrap();
206        rx
207    }
208
209    /// Returns a future with a snapshot of the cache's stats. No guarnatees are made about
210    /// when the snapshot is taken.
211    ///
212    /// # Example
213    ///
214    /// Run a future that logs the cache stats:
215    ///
216    /// ```
217    /// # extern crate futures;
218    /// # extern crate reactor_cache;
219    /// # extern crate tokio_core;
220    /// #
221    /// # use futures::Future;
222    /// # use reactor_cache::*;
223    /// # use tokio_core::reactor::Core;
224    /// #
225    /// # #[derive(Clone, Eq, Hash, PartialEq)] struct Int(i64);
226    /// # impl Weighted for Int { fn weight(&self) -> usize { 8 } }
227    /// #
228    /// # fn main() {
229    ///     let mut core = Core::new().expect("meltdown");
230    ///     let cache = ReactorCache::<Int, Int, ()>::new(10, core.handle());
231    ///     core.run(cache.load_fn(Int(1), ||
232    ///     core.run(cache.stats().map(|s| println!("{:?}",s))).unwrap();
233    /// # }
234    pub fn get(&self, k: K) -> GetHandle<V, E> {
235        let (tx, rx) = oneshot::channel();
236        self.tx.send(Message::Get(k, true, tx)).unwrap();
237        GetHandle { rx: rx }
238    }
239
240    pub fn get_if_resident(&self, k: K) -> GetHandle<V, E> {
241        let (tx, rx) = oneshot::channel();
242        self.tx.send(Message::Get(k, false, tx)).unwrap();
243        GetHandle { rx: rx }
244    }
245
246    pub fn load_fn<F, T>(&self, k: K, f: F) -> LoadHandle<futures::Lazy<F, T>, V, E>
247        where F: 'static + Send + FnOnce() -> T,
248              T: 'static + IntoFuture<Item = V, Error = E>,
249              T::Future: 'static + Send
250    {
251        self.load(k, futures::lazy(f))
252    }
253
254    pub fn load<F>(&self, k: K, f: F) -> LoadHandle<F, V, E>
255        where F: Future<Item = V, Error = E>
256    {
257        let (check_tx, check_rx) = oneshot::channel();
258        let (load_tx, load_rx) = oneshot::channel();
259        let (get_tx, get_rx) = oneshot::channel();
260        self.tx.send(Message::Load(k, check_tx, load_rx, get_tx)).unwrap();
261
262        let state = LoadState::Checking(check_rx, f.fuse(), load_tx, GetHandle { rx: get_rx });
263        LoadHandle { state: state }
264    }
265
266    pub fn evict(&self, k: K) -> Receiver<()> {
267        let (tx, rx) = oneshot::channel();
268        self.tx.send(Message::Evict(k, tx)).unwrap();
269        rx
270    }
271}
272
273impl<V: Weighted> CacheEntry<V> {
274    fn new(v: Arc<V>) -> Self {
275        CacheEntry {
276            weight: v.weight(),
277            inner: v,
278            marked: false,
279        }
280    }
281}
282
283// TODO - remove req of K: Clone
284impl<K: Clone + Eq + Hash, V: Weighted, E: Clone> Inner<K, V, E> {
285    fn new(capacity: usize,
286           rx: mpsc::Receiver<Message<K, V, E>>,
287           timer: PollEvented<Timer<()>>)
288           -> Self {
289        Inner {
290            rx: rx,
291            timer: timer,
292            fetch_map: HashMap::new(),
293            cache_map: LinkedHashMap::new(),
294            usage: (capacity, capacity),
295        }
296    }
297
298    fn upgrade_fetches(&mut self) -> Result<(), ()> {
299        trace!("upgrade -- start");
300        if self.fetch_map.is_empty() {
301            trace!("upgrade -- empty");
302            return Ok(());
303        }
304
305        let mut to_upgrade = vec![];
306        for (k, &mut (ref mut f, _)) in self.fetch_map.iter_mut() {
307            match f.poll() {
308                Ok(Async::Ready(r)) => to_upgrade.push((k.clone(), Some(r))),
309                Ok(Async::NotReady) => continue,
310                Err(_) => to_upgrade.push((k.clone(), None)),
311            };
312        }
313
314        for (k, r_opt) in to_upgrade.into_iter() {
315            let (_, waiters) = self.fetch_map.remove(&k).unwrap();
316            if let Some(r) = r_opt {
317                for waiter in waiters.into_iter() {
318                    trace!("upgrade -- waiter");
319                    waiter.complete(r.clone().map(Some));
320                }
321
322                if let Ok(v) = r {
323                    self.try_cache(k, v);
324                }
325            }
326        }
327
328        trace!("upgrade -- end");
329        Ok(())
330    }
331
332    fn try_cache(&mut self, k: K, v: Arc<V>) {
333        trace!("trycache -- start");
334        let (ref mut remaining, capacity) = self.usage;
335
336        let entry = CacheEntry::new(v);
337        if entry.weight >= capacity {
338            trace!("trycache -- toobig");
339            return;
340        }
341
342        loop {
343            if self.cache_map.is_empty() || *remaining >= entry.weight {
344                *remaining -= entry.weight;
345                self.cache_map.insert(k, entry);
346                break;
347            }
348
349            let (k2, mut v2) = self.cache_map.pop_front().expect("cache should be non-empty");
350            if v2.marked {
351                *remaining += v2.weight;
352            } else {
353                v2.marked = true;
354                self.cache_map.insert(k2, v2);
355            }
356        }
357        trace!("trycache -- end");
358    }
359
360    fn handle(&mut self, msg: Message<K, V, E>) -> Result<(), ()> {
361        trace!("handle -- start");
362        match msg {
363            Message::Stats(tx) => self.stats(tx),
364            Message::Get(k, w, tx) => self.get(k, w, tx),
365            Message::Load(k, ck, rx, tx) => self.load(k, ck, rx, tx),
366            Message::Evict(k, tx) => self.evict(k, tx),
367        };
368        trace!("handle -- end");
369        Ok(())
370    }
371
372    fn stats(&mut self, tx: Sender<CacheStats>) {
373        trace!("stats -- start");
374        let (remaining, capacity) = self.usage;
375        tx.complete(CacheStats {
376            entries: self.cache_map.len(),
377            remaining: remaining,
378            capacity: capacity,
379        });
380        trace!("stats -- end");
381    }
382
383    fn get(&mut self, k: K, wait: bool, tx: Waiter<V, E>) {
384        trace!("get -- start");
385        if let Some(mut entry) = self.cache_map.get_refresh(&k) {
386            entry.marked = false;
387            trace!("get -- hit");
388            return tx.complete(Ok(Some(entry.inner.clone())));
389        }
390
391        if wait {
392            if let Some(&mut (_, ref mut waiters)) = self.fetch_map.get_mut(&k) {
393                trace!("get -- wait");
394                return waiters.push(tx);
395            }
396        }
397
398        trace!("get -- miss");
399        tx.complete(Ok(None));
400    }
401
402    fn load(&mut self, k: K, checker: Checker<V>, f: Loader<V, E>, tx: Waiter<V, E>) {
403        trace!("load -- start");
404
405        if let Some(mut entry) = self.cache_map.get_refresh(&k) {
406            trace!("load -- hit");
407            entry.marked = false;
408            return checker.complete(Ok(entry.inner.clone()));
409        }
410        trace!("load -- miss");
411
412        let &mut (_, ref mut waiters) = self.fetch_map.entry(k).or_insert((f, vec![]));
413        checker.complete(Err(waiters.is_empty())); // if there are no waiters, we're the leader
414        waiters.push(tx);
415
416        trace!("load -- end");
417    }
418
419    fn evict(&mut self, k: K, tx: Sender<()>) {
420        trace!("evict -- start");
421        self.fetch_map.remove(&k);
422        if let Some(entry) = self.cache_map.remove(&k) {
423            self.usage.0 += entry.weight;
424        }
425        tx.complete(());
426        trace!("evict -- end");
427    }
428}
429
430impl<K: Clone + Eq + Hash, V: Weighted, E: Clone> Future for Inner<K, V, E> {
431    type Item = ();
432    type Error = (); // TODO - E doesn't work out of the box...
433
434    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
435        trace!("poll -- start");
436        if let Async::NotReady = self.timer.poll_read() {
437            trace!("poll -- not ready");
438            return Ok(Async::NotReady);
439        }
440
441        // remove scheduled bits and schedule Inner for the next loop
442        while let Some(_) = self.timer.get_mut().poll() {}
443        self.timer.need_read();
444
445        self.upgrade_fetches()?;
446
447        loop {
448            match self.rx.try_recv() {
449                Ok(msg) => self.handle(msg)?,
450                Err(TryRecvError::Empty) => {
451                    trace!("poll -- end");
452                    self.timer.get_mut().set_timeout(Duration::from_millis(10), ()).unwrap();
453                    return Ok(Async::NotReady);
454                }
455                Err(TryRecvError::Disconnected) => {
456                    trace!("poll -- terminate");
457                    return Ok(().into());
458                }
459            }
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    use std::sync::Arc;
469    use std::sync::atomic::{AtomicUsize, Ordering};
470
471    use tokio_core::reactor::Core;
472
473    impl Weighted for i64 {
474        fn weight(&self) -> usize {
475            8
476        }
477    }
478
479    #[derive(Debug, PartialEq, Eq, Hash)]
480    struct Scale(usize);
481
482    impl Weighted for Scale {
483        fn weight(&self) -> usize {
484            self.0
485        }
486    }
487
488    #[test]
489    fn basic_cache_ops() {
490        let mut core = Core::new().unwrap();
491        let cache = ReactorCache::<i64, i64, ()>::new(16, core.handle());
492
493        // insert two entries to fill the cache
494        assert_eq!(10, *core.run(cache.load_fn(1, || Ok(10))).unwrap());
495        assert_eq!(20, *core.run(cache.load_fn(2, || Ok(20))).unwrap());
496
497        let stats = core.run(cache.stats()).unwrap();
498        assert_eq!(2, stats.entries);
499        assert_eq!(0, stats.remaining);
500
501        // evict (1, 10) and insert (3, 30)
502        assert_eq!(30, *core.run(cache.load_fn(3, || Ok(30))).unwrap());
503        assert_eq!(None, core.run(cache.get(1)).unwrap());
504        assert_eq!(20, *core.run(cache.get(2)).unwrap().unwrap());
505
506        let stats = core.run(cache.stats()).unwrap();
507        assert_eq!(2, stats.entries);
508        assert_eq!(0, stats.remaining);
509
510        // evict (3, 30)
511        assert_eq!((), core.run(cache.evict(3)).unwrap());
512        assert_eq!(None, core.run(cache.get(1)).unwrap());
513
514        let stats = core.run(cache.stats()).unwrap();
515        assert_eq!(1, stats.entries);
516        assert_eq!(8, stats.remaining);
517    }
518
519    #[test]
520    fn waiters() {
521        let mut core = Core::new().unwrap();
522        let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
523
524        let counter = Arc::new(AtomicUsize::new(10));
525        let c1 = counter.clone();
526        let c2 = counter.clone();
527
528        let l1 = cache.load_fn(1, move || Ok(c1.fetch_add(1, Ordering::SeqCst) as i64));
529        let g1 = cache.get(1);
530        let l2 = cache.load_fn(1, move || Ok(c2.fetch_add(1, Ordering::SeqCst) as i64));
531
532        assert_eq!(10, *core.run(l1).unwrap());
533        assert_eq!(10, *core.run(g1).unwrap().unwrap());
534        assert_eq!(10, *core.run(l2).unwrap());
535        assert_eq!(11, counter.load(Ordering::SeqCst));
536    }
537
538    #[test]
539    fn get_if_resident() {
540        let mut core = Core::new().unwrap();
541        let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
542
543        let l1 = cache.load_fn(1, || Ok(10));
544        let g1 = cache.get_if_resident(1);
545        let g2 = cache.get_if_resident(1);
546
547        // because of how events are processed, both gets are misses
548        assert_eq!(None, core.run(g1).unwrap());
549        assert_eq!(10, *core.run(l1).unwrap());
550        assert_eq!(None, core.run(g2).unwrap());
551
552        // since the load has resolved, this will be a cache hit
553        let g3 = cache.get_if_resident(1);
554        assert_eq!(10, *core.run(g3).unwrap().unwrap());
555    }
556
557    #[test]
558    fn errors() {
559        let mut core = Core::new().unwrap();
560        let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
561
562        // errors should not be cached
563        assert!(core.run(cache.load_fn(1, || Err(10))).is_err());
564        assert_eq!(None, core.run(cache.get(1)).unwrap());
565        assert!(core.run(cache.load_fn(1, || Ok(10))).is_ok());
566        assert_eq!(10, *core.run(cache.get(1)).unwrap().unwrap());
567    }
568
569    #[test]
570    #[should_panic]
571    #[allow(unreachable_code)]
572    fn panic() {
573        let mut core = Core::new().unwrap();
574        let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
575        assert!(core.run(cache.load_fn(1, || Ok(panic!()))).is_err());
576    }
577
578    #[test]
579    fn lru_and_marking() {
580        let mut core = Core::new().unwrap();
581        let cache: ReactorCache<i64, Scale, ()> = ReactorCache::new(16, core.handle());
582
583        // load two entries and touch the first one
584        assert_eq!(Scale(8),
585                   *core.run(cache.load_fn(1, || Ok(Scale(8)))).unwrap());
586        assert_eq!(Scale(8),
587                   *core.run(cache.load_fn(2, || Ok(Scale(8)))).unwrap());
588        assert_eq!(Scale(8), *core.run(cache.get(1)).unwrap().unwrap());
589
590        // push out (2, 20)
591        assert_eq!(Scale(8),
592                   *core.run(cache.load_fn(3, || Ok(Scale(8)))).unwrap());
593        assert_eq!(None, core.run(cache.get(2)).unwrap());
594    }
595}