groupcache/
groupcache_inner.rs

1//! groupcache module contains the core groupcache logic
2
3use crate::errors::InternalGroupcacheError::Anyhow;
4use crate::errors::{DedupedGroupcacheError, GroupcacheError, InternalGroupcacheError};
5use crate::groupcache::{GroupcachePeer, GroupcachePeerClient, ValueBounds, ValueLoader};
6use crate::metrics::{
7    METRIC_GET_TOTAL, METRIC_LOCAL_CACHE_HIT_TOTAL, METRIC_LOCAL_LOAD_ERROR_TOTAL,
8    METRIC_LOCAL_LOAD_TOTAL, METRIC_REMOTE_LOAD_ERROR, METRIC_REMOTE_LOAD_TOTAL,
9};
10use crate::options::Options;
11use crate::routing::{GroupcachePeerWithClient, RoutingState};
12use anyhow::{Context, Result};
13use groupcache_pb::GroupcacheClient;
14use groupcache_pb::{GetRequest, RemoveRequest};
15use metrics::counter;
16use moka::future::Cache;
17use singleflight_async::SingleFlight;
18use std::collections::HashSet;
19use std::net::SocketAddr;
20use std::sync::{Arc, RwLock};
21use tokio::task::JoinSet;
22use tonic::transport::Endpoint;
23use tonic::IntoRequest;
24
25/// Core implementation of groupcache API.
26pub struct GroupcacheInner<Value: ValueBounds> {
27    routing_state: Arc<RwLock<RoutingState>>,
28    single_flight_group: SingleFlight<String, Result<Value, DedupedGroupcacheError>>,
29    main_cache: Cache<String, Value>,
30    hot_cache: Cache<String, Value>,
31    loader: Box<dyn ValueLoader<Value = Value>>,
32    config: Config,
33    me: GroupcachePeer,
34}
35
36struct Config {
37    https: bool,
38    grpc_endpoint_builder: Arc<Box<dyn Fn(Endpoint) -> Endpoint + Send + Sync + 'static>>,
39}
40
41impl<Value: ValueBounds> GroupcacheInner<Value> {
42    pub(crate) fn new(
43        me: GroupcachePeer,
44        loader: Box<dyn ValueLoader<Value = Value>>,
45        options: Options<Value>,
46    ) -> Self {
47        let routing_state = Arc::new(RwLock::new(RoutingState::with_local_peer(me)));
48
49        let main_cache = options.main_cache;
50        let hot_cache = options.hot_cache;
51
52        let single_flight_group = SingleFlight::default();
53
54        let config = Config {
55            https: options.https,
56            grpc_endpoint_builder: Arc::new(options.grpc_endpoint_builder),
57        };
58
59        Self {
60            routing_state,
61            single_flight_group,
62            main_cache,
63            hot_cache,
64            loader,
65            me,
66            config,
67        }
68    }
69
70    pub(crate) async fn get(&self, key: &str) -> core::result::Result<Value, GroupcacheError> {
71        Ok(self.get_internal(key).await?)
72    }
73
74    pub(crate) async fn remove(&self, key: &str) -> core::result::Result<(), GroupcacheError> {
75        Ok(self.remove_internal(key).await?)
76    }
77
78    async fn get_internal(&self, key: &str) -> Result<Value, InternalGroupcacheError> {
79        counter!(METRIC_GET_TOTAL).increment(1);
80        if let Some(value) = self.main_cache.get(key).await {
81            counter!(METRIC_LOCAL_CACHE_HIT_TOTAL).increment(1);
82            return Ok(value);
83        }
84
85        if let Some(value) = self.hot_cache.get(key).await {
86            counter!(METRIC_LOCAL_CACHE_HIT_TOTAL).increment(1);
87            return Ok(value);
88        }
89
90        let peer = {
91            let lock = self.routing_state.read().unwrap();
92            lock.lookup_peer(key)
93        }?;
94
95        let value = self.get_deduped_instrumented(key, peer).await?;
96        Ok(value)
97    }
98
99    async fn get_deduped_instrumented(
100        &self,
101        key: &str,
102        peer: GroupcachePeerWithClient,
103    ) -> Result<Value, InternalGroupcacheError> {
104        self.single_flight_group
105            .work(key.to_owned(), || async {
106                self.get_deduped(key, peer)
107                    .await
108                    .map_err(|e| DedupedGroupcacheError(Arc::new(e)))
109            })
110            .await
111            .map_err(InternalGroupcacheError::Deduped)
112    }
113
114    async fn get_deduped(
115        &self,
116        key: &str,
117        peer: GroupcachePeerWithClient,
118    ) -> Result<Value, InternalGroupcacheError> {
119        if peer.peer == self.me {
120            let value = self.load_locally_instrumented(key).await?;
121            self.main_cache.insert(key.to_string(), value.clone()).await;
122            return Ok(value);
123        }
124
125        let mut client = peer
126            .client
127            .context("unreachable: cannot be empty since it's a remote peer")?;
128        let res = self.load_remotely_instrumented(key, &mut client).await;
129        match res {
130            Ok(value) => {
131                self.hot_cache.insert(key.to_string(), value.clone()).await;
132                Ok(value)
133            }
134            Err(_) => {
135                let value = self.load_locally_instrumented(key).await?;
136                Ok(value)
137            }
138        }
139    }
140
141    async fn load_locally_instrumented(&self, key: &str) -> Result<Value, InternalGroupcacheError> {
142        counter!(METRIC_LOCAL_LOAD_TOTAL).increment(1);
143        self.loader
144            .load(key)
145            .await
146            .inspect_err(|_| {
147                counter!(METRIC_LOCAL_LOAD_ERROR_TOTAL).increment(1);
148            })
149            .map_err(InternalGroupcacheError::LocalLoader)
150    }
151
152    async fn load_remotely_instrumented(
153        &self,
154        key: &str,
155        client: &mut GroupcachePeerClient,
156    ) -> Result<Value, InternalGroupcacheError> {
157        counter!(METRIC_REMOTE_LOAD_TOTAL).increment(1);
158        self.load_remotely(key, client)
159            .await
160            .inspect_err(|_| counter!(METRIC_REMOTE_LOAD_ERROR).increment(1))
161    }
162
163    async fn load_remotely(
164        &self,
165        key: &str,
166        client: &mut GroupcachePeerClient,
167    ) -> Result<Value, InternalGroupcacheError> {
168        let response = client
169            .get(
170                GetRequest {
171                    key: key.to_string(),
172                }
173                .into_request(),
174            )
175            .await?;
176
177        let get_response = response.into_inner();
178        let bytes = get_response.value.unwrap();
179        let value = rmp_serde::from_read(bytes.as_slice())?;
180
181        Ok(value)
182    }
183
184    async fn remove_internal(
185        &self,
186        key: &str,
187    ) -> core::result::Result<(), InternalGroupcacheError> {
188        self.hot_cache.remove(key).await;
189
190        let peer = {
191            let lock = self.routing_state.read().unwrap();
192            lock.lookup_peer(key)
193        }?;
194
195        if peer.peer == self.me {
196            self.main_cache.remove(key).await;
197        } else {
198            let mut client = peer
199                .client
200                .context("unreachable: cannot be empty since it's a remote peer")?;
201            self.remove_remotely(key, &mut client).await?;
202        }
203
204        Ok(())
205    }
206
207    async fn remove_remotely(
208        &self,
209        key: &str,
210        client: &mut GroupcachePeerClient,
211    ) -> core::result::Result<(), InternalGroupcacheError> {
212        let _ = client
213            .remove(
214                RemoveRequest {
215                    key: key.to_string(),
216                }
217                .into_request(),
218            )
219            .await?;
220
221        Ok(())
222    }
223
224    pub(crate) async fn add_peer(&self, peer: GroupcachePeer) -> Result<(), GroupcacheError> {
225        let contains_peer = {
226            let read_lock = self.routing_state.read().unwrap();
227            read_lock.contains_peer(&peer)
228        };
229
230        if contains_peer {
231            return Ok(());
232        }
233
234        let (_, client) = self.connect(peer).await?;
235        let mut write_lock = self.routing_state.write().unwrap();
236        write_lock.add_peer(peer, client);
237
238        Ok(())
239    }
240
241    pub(crate) async fn set_peers(
242        &self,
243        updated_peers: HashSet<GroupcachePeer>,
244    ) -> Result<(), GroupcacheError> {
245        let current_peers: HashSet<GroupcachePeer> = {
246            let read_lock = self.routing_state.read().unwrap();
247            read_lock.peers()
248        };
249
250        let new_connections_results = self
251            .connect_to_new_peers(&updated_peers, &current_peers)
252            .await?;
253        let peers_to_remove = current_peers.difference(&updated_peers).collect::<Vec<_>>();
254
255        let no_updates = peers_to_remove.is_empty() && new_connections_results.is_empty();
256        if no_updates {
257            return Ok(());
258        }
259
260        let conn_errors = self.update_routing_table(new_connections_results, peers_to_remove);
261        conn_errors.is_empty().then(|| Ok(())).unwrap_or_else(|| {
262            Err(GroupcacheError::from(
263                InternalGroupcacheError::ConnectionErrors(conn_errors),
264            ))
265        })
266    }
267
268    /// Updates routing table by adding new successful connections and removing old peers
269    fn update_routing_table(
270        &self,
271        connection_results: Vec<
272            Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>,
273        >,
274        peers_to_remove: Vec<&GroupcachePeer>,
275    ) -> Vec<InternalGroupcacheError> {
276        let mut write_lock = self.routing_state.write().unwrap();
277
278        let mut connection_errors = Vec::new();
279        for result in connection_results {
280            match result {
281                Ok((peer, client)) => {
282                    write_lock.add_peer(peer, client);
283                }
284                Err(e) => {
285                    connection_errors.push(e);
286                }
287            }
288        }
289
290        for removed_peer in peers_to_remove {
291            write_lock.remove_peer(*removed_peer);
292        }
293
294        connection_errors
295    }
296
297    /// Connects to new peers that were not previously connected in parallel.
298    async fn connect_to_new_peers(
299        &self,
300        updated_peers: &HashSet<GroupcachePeer>,
301        current_peers: &HashSet<GroupcachePeer>,
302    ) -> Result<
303        Vec<Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>>,
304        GroupcacheError,
305    > {
306        let peers_to_connect = updated_peers.difference(current_peers);
307        let mut connection_task = JoinSet::<
308            Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>,
309        >::new();
310        for new_peer in peers_to_connect {
311            let moved_peer = *new_peer;
312            let https = self.config.https;
313            let grpc_endpoint_builder = self.config.grpc_endpoint_builder.clone();
314            connection_task.spawn(async move {
315                GroupcacheInner::<Value>::connect_static(moved_peer, https, grpc_endpoint_builder)
316                    .await
317            });
318        }
319
320        let mut results = Vec::with_capacity(connection_task.len());
321        while let Some(res) = connection_task.join_next().await {
322            let conn_result = res
323                .context("unexpected JoinError when awaiting peer connection")
324                .map_err(Anyhow)?;
325
326            results.push(conn_result);
327        }
328
329        Ok(results)
330    }
331
332    async fn connect(
333        &self,
334        peer: GroupcachePeer,
335    ) -> Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError> {
336        GroupcacheInner::<Value>::connect_static(
337            peer,
338            self.config.https,
339            self.config.grpc_endpoint_builder.clone(),
340        )
341        .await
342    }
343
344    async fn connect_static(
345        peer: GroupcachePeer,
346        https: bool,
347        grpc_endpoint_builder: Arc<Box<dyn Fn(Endpoint) -> Endpoint + Send + Sync + 'static>>,
348    ) -> Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError> {
349        let socket = peer.socket;
350        let peer_addr = if https {
351            format!("https://{}", socket)
352        } else {
353            format!("http://{}", socket)
354        };
355
356        let endpoint: Endpoint = peer_addr.try_into()?;
357        let endpoint = grpc_endpoint_builder.as_ref()(endpoint);
358        let client = GroupcacheClient::connect(endpoint).await?;
359        Ok((peer, client))
360    }
361
362    pub(crate) async fn remove_peer(&self, peer: GroupcachePeer) -> Result<(), GroupcacheError> {
363        let contains_peer = {
364            let read_lock = self.routing_state.read().unwrap();
365            read_lock.contains_peer(&peer)
366        };
367
368        if !contains_peer {
369            return Ok(());
370        }
371
372        let mut write_lock = self.routing_state.write().unwrap();
373        write_lock.remove_peer(peer);
374
375        Ok(())
376    }
377
378    pub(crate) fn addr(&self) -> SocketAddr {
379        self.me.socket
380    }
381}