1use std::collections::HashMap;
2use std::hash::Hash;
3use std::fmt::Debug;
4use std::sync::Arc;
5use std::sync::mpsc::{self, TryRecvError};
6use std::time::Duration;
7
8use futures::{self, Async, Fuse, Future, IntoFuture, Poll};
9use futures::sync::oneshot::{self, Sender, Receiver};
10
11use linked_hash_map::LinkedHashMap;
12
13use mio::timer::{Builder as TimerBuilder, Timer};
14
15use tokio_core::reactor::{Handle, PollEvented};
16
17const LEADER: bool = true;
19const WAITER: bool = !LEADER;
20
21type Checker<V> = Sender<Result<Arc<V>, bool>>;
22type Loader<V, E> = Receiver<Result<Arc<V>, E>>;
23type Waiter<V, E> = Sender<Result<Option<Arc<V>>, E>>;
24
25enum Message<K, V, E> {
26 Stats(Sender<CacheStats>),
27 Get(K, bool, Waiter<V, E>),
28 Load(K, Checker<V>, Loader<V, E>, Waiter<V, E>),
29 Evict(K, Sender<()>),
30}
31
32pub trait Weighted {
33 fn weight(&self) -> usize;
34}
35
36pub struct ReactorCache<K, V, E> {
37 tx: mpsc::Sender<Message<K, V, E>>,
38}
39
40#[derive(Debug)]
41pub struct CacheStats {
42 pub entries: usize,
43 pub remaining: usize,
44 pub capacity: usize,
45}
46
47pub struct GetHandle<V, E> {
48 rx: Receiver<Result<Option<Arc<V>>, E>>,
49}
50
51enum LoadState<F: Future, V, E> {
52 Empty, Checking(Receiver<Result<Arc<V>, bool>>, Fuse<F>, Sender<Result<Arc<V>, E>>, GetHandle<V, E>),
54 Loading(Fuse<F>, Sender<Result<Arc<V>, E>>, GetHandle<V, E>),
55 Waiting(GetHandle<V, E>),
56}
57
58pub struct LoadHandle<F: Future, V, E> {
59 state: LoadState<F, V, E>,
60}
61
62struct CacheEntry<V> {
63 inner: Arc<V>,
64 weight: usize,
65 marked: bool,
66}
67
68struct Inner<K, V, E> {
69 rx: mpsc::Receiver<Message<K, V, E>>,
70 timer: PollEvented<Timer<()>>,
71 fetch_map: HashMap<K, (Loader<V, E>, Vec<Waiter<V, E>>)>,
72 cache_map: LinkedHashMap<K, CacheEntry<V>>,
73 usage: (usize, usize), }
75
76impl<V, E> Future for GetHandle<V, E> {
77 type Item = Option<Arc<V>>;
78 type Error = E;
79 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
80 match self.rx.poll().expect("get canceled") {
81 Async::Ready(Ok(res)) => Ok(res.into()),
82 Async::Ready(Err(e)) => Err(e),
83 Async::NotReady => Ok(Async::NotReady),
84 }
85 }
86}
87
88impl<F, V, E: Clone> Future for LoadHandle<F, V, E>
89 where F: Future<Item = V, Error = E>
90{
91 type Item = Arc<V>;
92 type Error = E;
93 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
94 trace!("loadhandle - start");
95
96 let mut state = ::std::mem::replace(&mut self.state, LoadState::Empty);
97
98 if let LoadState::Checking(mut checker, loader, resolver, waiter) = state {
99 trace!("loadhandle - checking");
100 match checker.poll().expect("check canceled") {
101 Async::Ready(Ok(res)) => {
102 trace!("loadhandle - hit");
103 return Ok(res.into());
104 }
105 Async::Ready(Err(LEADER)) => {
106 trace!("loadhandle - miss:leader");
107 state = LoadState::Loading(loader, resolver, waiter);
108 }
109 Async::Ready(Err(WAITER)) => {
110 trace!("loadhandle - miss:waiter");
111 state = LoadState::Waiting(waiter);
112 }
113 Async::NotReady => {
114 self.state = LoadState::Checking(checker, loader, resolver, waiter);
115 return Ok(Async::NotReady);
116 }
117 }
118 }
119
120 if let LoadState::Loading(mut loader, resolver, waiter) = state {
121 trace!("loadhandle - loading");
122 match loader.poll() {
123 Ok(Async::Ready(res)) => {
124 trace!("loadhandle - success");
125 resolver.complete(Ok(Arc::new(res)));
126 state = LoadState::Waiting(waiter);
127 }
128 Err(e) => {
129 trace!("loadhandle - failure");
130 resolver.complete(Err(e));
131 state = LoadState::Waiting(waiter);
132 }
133 Ok(Async::NotReady) => {
134 self.state = LoadState::Loading(loader, resolver, waiter);
135 return Ok(Async::NotReady);
136 }
137 }
138 }
139
140 if let LoadState::Waiting(mut waiter) = state {
141 trace!("loadhandle - waiting");
142 return match waiter.poll() {
143 Ok(Async::Ready(Some(res))) => {
144 trace!("loadhandle - ok");
145 Ok(res.into())
146 }
147 Ok(Async::Ready(None)) => unreachable!(),
148 Ok(Async::NotReady) => {
149 self.state = LoadState::Waiting(waiter);
150 Ok(Async::NotReady)
151 }
152 Err(e) => {
153 trace!("loadhandle - err");
154 Err(e)
155 }
156 };
157 }
158
159 unreachable!("invalid state transition")
160 }
161}
162
163impl<K: Clone + Eq + Hash, V: Weighted, E: Clone + Debug> ReactorCache<K, V, E> {
165 pub fn new(capacity: usize, handle: Handle) -> Self
167 where K: 'static,
168 V: 'static,
169 E: 'static
170 {
171 let (tx, rx) = mpsc::channel();
172 let mut timer = TimerBuilder::default().tick_duration(Duration::from_millis(10)).build();
173 timer.set_timeout(Duration::from_millis(10), ()).unwrap();
174 let poll = PollEvented::new(timer, &handle).unwrap();
175 handle.spawn_fn(move || Inner::new(capacity, rx, poll).map_err(|e| panic!("{:?}", e)));
176 ReactorCache { tx: tx }
177 }
178
179 pub fn stats(&self) -> Receiver<CacheStats> {
204 let (tx, rx) = oneshot::channel();
205 self.tx.send(Message::Stats(tx)).unwrap();
206 rx
207 }
208
209 pub fn get(&self, k: K) -> GetHandle<V, E> {
235 let (tx, rx) = oneshot::channel();
236 self.tx.send(Message::Get(k, true, tx)).unwrap();
237 GetHandle { rx: rx }
238 }
239
240 pub fn get_if_resident(&self, k: K) -> GetHandle<V, E> {
241 let (tx, rx) = oneshot::channel();
242 self.tx.send(Message::Get(k, false, tx)).unwrap();
243 GetHandle { rx: rx }
244 }
245
246 pub fn load_fn<F, T>(&self, k: K, f: F) -> LoadHandle<futures::Lazy<F, T>, V, E>
247 where F: 'static + Send + FnOnce() -> T,
248 T: 'static + IntoFuture<Item = V, Error = E>,
249 T::Future: 'static + Send
250 {
251 self.load(k, futures::lazy(f))
252 }
253
254 pub fn load<F>(&self, k: K, f: F) -> LoadHandle<F, V, E>
255 where F: Future<Item = V, Error = E>
256 {
257 let (check_tx, check_rx) = oneshot::channel();
258 let (load_tx, load_rx) = oneshot::channel();
259 let (get_tx, get_rx) = oneshot::channel();
260 self.tx.send(Message::Load(k, check_tx, load_rx, get_tx)).unwrap();
261
262 let state = LoadState::Checking(check_rx, f.fuse(), load_tx, GetHandle { rx: get_rx });
263 LoadHandle { state: state }
264 }
265
266 pub fn evict(&self, k: K) -> Receiver<()> {
267 let (tx, rx) = oneshot::channel();
268 self.tx.send(Message::Evict(k, tx)).unwrap();
269 rx
270 }
271}
272
273impl<V: Weighted> CacheEntry<V> {
274 fn new(v: Arc<V>) -> Self {
275 CacheEntry {
276 weight: v.weight(),
277 inner: v,
278 marked: false,
279 }
280 }
281}
282
283impl<K: Clone + Eq + Hash, V: Weighted, E: Clone> Inner<K, V, E> {
285 fn new(capacity: usize,
286 rx: mpsc::Receiver<Message<K, V, E>>,
287 timer: PollEvented<Timer<()>>)
288 -> Self {
289 Inner {
290 rx: rx,
291 timer: timer,
292 fetch_map: HashMap::new(),
293 cache_map: LinkedHashMap::new(),
294 usage: (capacity, capacity),
295 }
296 }
297
298 fn upgrade_fetches(&mut self) -> Result<(), ()> {
299 trace!("upgrade -- start");
300 if self.fetch_map.is_empty() {
301 trace!("upgrade -- empty");
302 return Ok(());
303 }
304
305 let mut to_upgrade = vec![];
306 for (k, &mut (ref mut f, _)) in self.fetch_map.iter_mut() {
307 match f.poll() {
308 Ok(Async::Ready(r)) => to_upgrade.push((k.clone(), Some(r))),
309 Ok(Async::NotReady) => continue,
310 Err(_) => to_upgrade.push((k.clone(), None)),
311 };
312 }
313
314 for (k, r_opt) in to_upgrade.into_iter() {
315 let (_, waiters) = self.fetch_map.remove(&k).unwrap();
316 if let Some(r) = r_opt {
317 for waiter in waiters.into_iter() {
318 trace!("upgrade -- waiter");
319 waiter.complete(r.clone().map(Some));
320 }
321
322 if let Ok(v) = r {
323 self.try_cache(k, v);
324 }
325 }
326 }
327
328 trace!("upgrade -- end");
329 Ok(())
330 }
331
332 fn try_cache(&mut self, k: K, v: Arc<V>) {
333 trace!("trycache -- start");
334 let (ref mut remaining, capacity) = self.usage;
335
336 let entry = CacheEntry::new(v);
337 if entry.weight >= capacity {
338 trace!("trycache -- toobig");
339 return;
340 }
341
342 loop {
343 if self.cache_map.is_empty() || *remaining >= entry.weight {
344 *remaining -= entry.weight;
345 self.cache_map.insert(k, entry);
346 break;
347 }
348
349 let (k2, mut v2) = self.cache_map.pop_front().expect("cache should be non-empty");
350 if v2.marked {
351 *remaining += v2.weight;
352 } else {
353 v2.marked = true;
354 self.cache_map.insert(k2, v2);
355 }
356 }
357 trace!("trycache -- end");
358 }
359
360 fn handle(&mut self, msg: Message<K, V, E>) -> Result<(), ()> {
361 trace!("handle -- start");
362 match msg {
363 Message::Stats(tx) => self.stats(tx),
364 Message::Get(k, w, tx) => self.get(k, w, tx),
365 Message::Load(k, ck, rx, tx) => self.load(k, ck, rx, tx),
366 Message::Evict(k, tx) => self.evict(k, tx),
367 };
368 trace!("handle -- end");
369 Ok(())
370 }
371
372 fn stats(&mut self, tx: Sender<CacheStats>) {
373 trace!("stats -- start");
374 let (remaining, capacity) = self.usage;
375 tx.complete(CacheStats {
376 entries: self.cache_map.len(),
377 remaining: remaining,
378 capacity: capacity,
379 });
380 trace!("stats -- end");
381 }
382
383 fn get(&mut self, k: K, wait: bool, tx: Waiter<V, E>) {
384 trace!("get -- start");
385 if let Some(mut entry) = self.cache_map.get_refresh(&k) {
386 entry.marked = false;
387 trace!("get -- hit");
388 return tx.complete(Ok(Some(entry.inner.clone())));
389 }
390
391 if wait {
392 if let Some(&mut (_, ref mut waiters)) = self.fetch_map.get_mut(&k) {
393 trace!("get -- wait");
394 return waiters.push(tx);
395 }
396 }
397
398 trace!("get -- miss");
399 tx.complete(Ok(None));
400 }
401
402 fn load(&mut self, k: K, checker: Checker<V>, f: Loader<V, E>, tx: Waiter<V, E>) {
403 trace!("load -- start");
404
405 if let Some(mut entry) = self.cache_map.get_refresh(&k) {
406 trace!("load -- hit");
407 entry.marked = false;
408 return checker.complete(Ok(entry.inner.clone()));
409 }
410 trace!("load -- miss");
411
412 let &mut (_, ref mut waiters) = self.fetch_map.entry(k).or_insert((f, vec![]));
413 checker.complete(Err(waiters.is_empty())); waiters.push(tx);
415
416 trace!("load -- end");
417 }
418
419 fn evict(&mut self, k: K, tx: Sender<()>) {
420 trace!("evict -- start");
421 self.fetch_map.remove(&k);
422 if let Some(entry) = self.cache_map.remove(&k) {
423 self.usage.0 += entry.weight;
424 }
425 tx.complete(());
426 trace!("evict -- end");
427 }
428}
429
430impl<K: Clone + Eq + Hash, V: Weighted, E: Clone> Future for Inner<K, V, E> {
431 type Item = ();
432 type Error = (); fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
435 trace!("poll -- start");
436 if let Async::NotReady = self.timer.poll_read() {
437 trace!("poll -- not ready");
438 return Ok(Async::NotReady);
439 }
440
441 while let Some(_) = self.timer.get_mut().poll() {}
443 self.timer.need_read();
444
445 self.upgrade_fetches()?;
446
447 loop {
448 match self.rx.try_recv() {
449 Ok(msg) => self.handle(msg)?,
450 Err(TryRecvError::Empty) => {
451 trace!("poll -- end");
452 self.timer.get_mut().set_timeout(Duration::from_millis(10), ()).unwrap();
453 return Ok(Async::NotReady);
454 }
455 Err(TryRecvError::Disconnected) => {
456 trace!("poll -- terminate");
457 return Ok(().into());
458 }
459 }
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 use std::sync::Arc;
469 use std::sync::atomic::{AtomicUsize, Ordering};
470
471 use tokio_core::reactor::Core;
472
473 impl Weighted for i64 {
474 fn weight(&self) -> usize {
475 8
476 }
477 }
478
479 #[derive(Debug, PartialEq, Eq, Hash)]
480 struct Scale(usize);
481
482 impl Weighted for Scale {
483 fn weight(&self) -> usize {
484 self.0
485 }
486 }
487
488 #[test]
489 fn basic_cache_ops() {
490 let mut core = Core::new().unwrap();
491 let cache = ReactorCache::<i64, i64, ()>::new(16, core.handle());
492
493 assert_eq!(10, *core.run(cache.load_fn(1, || Ok(10))).unwrap());
495 assert_eq!(20, *core.run(cache.load_fn(2, || Ok(20))).unwrap());
496
497 let stats = core.run(cache.stats()).unwrap();
498 assert_eq!(2, stats.entries);
499 assert_eq!(0, stats.remaining);
500
501 assert_eq!(30, *core.run(cache.load_fn(3, || Ok(30))).unwrap());
503 assert_eq!(None, core.run(cache.get(1)).unwrap());
504 assert_eq!(20, *core.run(cache.get(2)).unwrap().unwrap());
505
506 let stats = core.run(cache.stats()).unwrap();
507 assert_eq!(2, stats.entries);
508 assert_eq!(0, stats.remaining);
509
510 assert_eq!((), core.run(cache.evict(3)).unwrap());
512 assert_eq!(None, core.run(cache.get(1)).unwrap());
513
514 let stats = core.run(cache.stats()).unwrap();
515 assert_eq!(1, stats.entries);
516 assert_eq!(8, stats.remaining);
517 }
518
519 #[test]
520 fn waiters() {
521 let mut core = Core::new().unwrap();
522 let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
523
524 let counter = Arc::new(AtomicUsize::new(10));
525 let c1 = counter.clone();
526 let c2 = counter.clone();
527
528 let l1 = cache.load_fn(1, move || Ok(c1.fetch_add(1, Ordering::SeqCst) as i64));
529 let g1 = cache.get(1);
530 let l2 = cache.load_fn(1, move || Ok(c2.fetch_add(1, Ordering::SeqCst) as i64));
531
532 assert_eq!(10, *core.run(l1).unwrap());
533 assert_eq!(10, *core.run(g1).unwrap().unwrap());
534 assert_eq!(10, *core.run(l2).unwrap());
535 assert_eq!(11, counter.load(Ordering::SeqCst));
536 }
537
538 #[test]
539 fn get_if_resident() {
540 let mut core = Core::new().unwrap();
541 let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
542
543 let l1 = cache.load_fn(1, || Ok(10));
544 let g1 = cache.get_if_resident(1);
545 let g2 = cache.get_if_resident(1);
546
547 assert_eq!(None, core.run(g1).unwrap());
549 assert_eq!(10, *core.run(l1).unwrap());
550 assert_eq!(None, core.run(g2).unwrap());
551
552 let g3 = cache.get_if_resident(1);
554 assert_eq!(10, *core.run(g3).unwrap().unwrap());
555 }
556
557 #[test]
558 fn errors() {
559 let mut core = Core::new().unwrap();
560 let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
561
562 assert!(core.run(cache.load_fn(1, || Err(10))).is_err());
564 assert_eq!(None, core.run(cache.get(1)).unwrap());
565 assert!(core.run(cache.load_fn(1, || Ok(10))).is_ok());
566 assert_eq!(10, *core.run(cache.get(1)).unwrap().unwrap());
567 }
568
569 #[test]
570 #[should_panic]
571 #[allow(unreachable_code)]
572 fn panic() {
573 let mut core = Core::new().unwrap();
574 let cache = ReactorCache::<i64, i64, i64>::new(16, core.handle());
575 assert!(core.run(cache.load_fn(1, || Ok(panic!()))).is_err());
576 }
577
578 #[test]
579 fn lru_and_marking() {
580 let mut core = Core::new().unwrap();
581 let cache: ReactorCache<i64, Scale, ()> = ReactorCache::new(16, core.handle());
582
583 assert_eq!(Scale(8),
585 *core.run(cache.load_fn(1, || Ok(Scale(8)))).unwrap());
586 assert_eq!(Scale(8),
587 *core.run(cache.load_fn(2, || Ok(Scale(8)))).unwrap());
588 assert_eq!(Scale(8), *core.run(cache.get(1)).unwrap().unwrap());
589
590 assert_eq!(Scale(8),
592 *core.run(cache.load_fn(3, || Ok(Scale(8)))).unwrap());
593 assert_eq!(None, core.run(cache.get(2)).unwrap());
594 }
595}