stretto/policy/
async.rs

1use crate::axync::{select, stop_channel, unbounded, Receiver, RecvError, Sender};
2use crate::policy::PolicyInner;
3use crate::{CacheError, MetricType, Metrics};
4use futures::future::{BoxFuture, FutureExt};
5use parking_lot::Mutex;
6use std::collections::hash_map::RandomState;
7use std::hash::BuildHasher;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11pub(crate) struct AsyncLFUPolicy<S = RandomState> {
12    pub(crate) inner: Arc<Mutex<PolicyInner<S>>>,
13    pub(crate) items_tx: Sender<Vec<u64>>,
14    pub(crate) stop_tx: Sender<()>,
15    pub(crate) is_closed: AtomicBool,
16    pub(crate) metrics: Arc<Metrics>,
17}
18
19impl AsyncLFUPolicy {
20    #[inline]
21    pub(crate) fn new<SP, R>(ctrs: usize, max_cost: i64, spawner: SP) -> Result<Self, CacheError>
22    where
23        SP: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
24    {
25        Self::with_hasher(ctrs, max_cost, RandomState::new(), spawner)
26    }
27}
28
29impl<S: BuildHasher + Clone + 'static + Send> AsyncLFUPolicy<S> {
30    #[inline]
31    pub fn with_hasher<SP, R>(
32        ctrs: usize,
33        max_cost: i64,
34        hasher: S,
35        spawner: SP,
36    ) -> Result<Self, CacheError>
37    where
38        SP: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
39    {
40        let inner = PolicyInner::with_hasher(ctrs, max_cost, hasher)?;
41
42        let (items_tx, items_rx) = unbounded();
43        let (stop_tx, stop_rx) = stop_channel();
44
45        PolicyProcessor::new(inner.clone(), items_rx, stop_rx).spawn(Box::new(move |fut| {
46            spawner(fut);
47        }));
48
49        let this = Self {
50            inner,
51            items_tx,
52            stop_tx,
53            is_closed: AtomicBool::new(false),
54            metrics: Arc::new(Metrics::new()),
55        };
56
57        Ok(this)
58    }
59
60    pub async fn push(&self, keys: Vec<u64>) -> Result<bool, CacheError> {
61        if self.is_closed.load(Ordering::SeqCst) {
62            return Ok(false);
63        }
64        let num_of_keys = keys.len() as u64;
65        if num_of_keys == 0 {
66            return Ok(true);
67        }
68        let first = keys[0];
69
70        select! {
71            rst =  self.items_tx.send(keys).fuse() => rst.map(|_| {
72                self.metrics.add(MetricType::KeepGets, first, num_of_keys);
73                true
74            })
75            .map_err(|e| {
76                self.metrics.add(MetricType::DropGets, first, num_of_keys);
77                CacheError::SendError(format!("sending on a disconnected channel, msg: {:?}", e))
78            }),
79            default => {
80                self.metrics.add(MetricType::DropGets, first, num_of_keys);
81                Ok(false)
82            }
83        }
84    }
85
86    #[inline]
87    pub async fn close(&self) -> Result<(), CacheError> {
88        if self.is_closed.load(Ordering::SeqCst) {
89            return Ok(());
90        }
91
92        // block until the Processor thread returns.
93        self.stop_tx
94            .send(())
95            .await
96            .map_err(|e| CacheError::SendError(format!("{}", e)))?;
97        self.is_closed.store(true, Ordering::SeqCst);
98        Ok(())
99    }
100}
101
102pub(crate) struct PolicyProcessor<S> {
103    inner: Arc<Mutex<PolicyInner<S>>>,
104    items_rx: Receiver<Vec<u64>>,
105    stop_rx: Receiver<()>,
106}
107
108impl<S: BuildHasher + Clone + 'static + Send> PolicyProcessor<S> {
109    #[inline]
110    fn new(
111        inner: Arc<Mutex<PolicyInner<S>>>,
112        items_rx: Receiver<Vec<u64>>,
113        stop_rx: Receiver<()>,
114    ) -> Self {
115        Self {
116            inner,
117            items_rx,
118            stop_rx,
119        }
120    }
121
122    #[inline]
123    fn spawn(self, spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>) {
124        (spawner)(Box::pin(async move {
125            loop {
126                select! {
127                    items = self.items_rx.recv().fuse() => self.handle_items(items),
128                    _ = self.stop_rx.recv().fuse() => {
129                        drop(self);
130                        return;
131                    },
132                }
133            }
134        }))
135    }
136
137    // TODO: None handle
138    #[inline]
139    fn handle_items(&self, items: Result<Vec<u64>, RecvError>) {
140        match items {
141            Ok(items) => {
142                let mut inner = self.inner.lock();
143                inner.admit.increments(items);
144            }
145            Err(_) => {
146                // error!("policy processor error")
147            }
148        }
149    }
150}
151
152unsafe impl<S: BuildHasher + Clone + 'static + Send> Send for PolicyProcessor<S> {}
153unsafe impl<S: BuildHasher + Clone + 'static + Send + Sync> Sync for PolicyProcessor<S> {}
154
155impl_policy!(AsyncLFUPolicy);