1use crate::errors::InternalGroupcacheError::Anyhow;
4use crate::errors::{DedupedGroupcacheError, GroupcacheError, InternalGroupcacheError};
5use crate::groupcache::{GroupcachePeer, GroupcachePeerClient, ValueBounds, ValueLoader};
6use crate::metrics::{
7 METRIC_GET_TOTAL, METRIC_LOCAL_CACHE_HIT_TOTAL, METRIC_LOCAL_LOAD_ERROR_TOTAL,
8 METRIC_LOCAL_LOAD_TOTAL, METRIC_REMOTE_LOAD_ERROR, METRIC_REMOTE_LOAD_TOTAL,
9};
10use crate::options::Options;
11use crate::routing::{GroupcachePeerWithClient, RoutingState};
12use anyhow::{Context, Result};
13use groupcache_pb::GroupcacheClient;
14use groupcache_pb::{GetRequest, RemoveRequest};
15use metrics::counter;
16use moka::future::Cache;
17use singleflight_async::SingleFlight;
18use std::collections::HashSet;
19use std::net::SocketAddr;
20use std::sync::{Arc, RwLock};
21use tokio::task::JoinSet;
22use tonic::transport::Endpoint;
23use tonic::IntoRequest;
24
25pub struct GroupcacheInner<Value: ValueBounds> {
27 routing_state: Arc<RwLock<RoutingState>>,
28 single_flight_group: SingleFlight<String, Result<Value, DedupedGroupcacheError>>,
29 main_cache: Cache<String, Value>,
30 hot_cache: Cache<String, Value>,
31 loader: Box<dyn ValueLoader<Value = Value>>,
32 config: Config,
33 me: GroupcachePeer,
34}
35
36struct Config {
37 https: bool,
38 grpc_endpoint_builder: Arc<Box<dyn Fn(Endpoint) -> Endpoint + Send + Sync + 'static>>,
39}
40
41impl<Value: ValueBounds> GroupcacheInner<Value> {
42 pub(crate) fn new(
43 me: GroupcachePeer,
44 loader: Box<dyn ValueLoader<Value = Value>>,
45 options: Options<Value>,
46 ) -> Self {
47 let routing_state = Arc::new(RwLock::new(RoutingState::with_local_peer(me)));
48
49 let main_cache = options.main_cache;
50 let hot_cache = options.hot_cache;
51
52 let single_flight_group = SingleFlight::default();
53
54 let config = Config {
55 https: options.https,
56 grpc_endpoint_builder: Arc::new(options.grpc_endpoint_builder),
57 };
58
59 Self {
60 routing_state,
61 single_flight_group,
62 main_cache,
63 hot_cache,
64 loader,
65 me,
66 config,
67 }
68 }
69
70 pub(crate) async fn get(&self, key: &str) -> core::result::Result<Value, GroupcacheError> {
71 Ok(self.get_internal(key).await?)
72 }
73
74 pub(crate) async fn remove(&self, key: &str) -> core::result::Result<(), GroupcacheError> {
75 Ok(self.remove_internal(key).await?)
76 }
77
78 async fn get_internal(&self, key: &str) -> Result<Value, InternalGroupcacheError> {
79 counter!(METRIC_GET_TOTAL).increment(1);
80 if let Some(value) = self.main_cache.get(key).await {
81 counter!(METRIC_LOCAL_CACHE_HIT_TOTAL).increment(1);
82 return Ok(value);
83 }
84
85 if let Some(value) = self.hot_cache.get(key).await {
86 counter!(METRIC_LOCAL_CACHE_HIT_TOTAL).increment(1);
87 return Ok(value);
88 }
89
90 let peer = {
91 let lock = self.routing_state.read().unwrap();
92 lock.lookup_peer(key)
93 }?;
94
95 let value = self.get_deduped_instrumented(key, peer).await?;
96 Ok(value)
97 }
98
99 async fn get_deduped_instrumented(
100 &self,
101 key: &str,
102 peer: GroupcachePeerWithClient,
103 ) -> Result<Value, InternalGroupcacheError> {
104 self.single_flight_group
105 .work(key.to_owned(), || async {
106 self.get_deduped(key, peer)
107 .await
108 .map_err(|e| DedupedGroupcacheError(Arc::new(e)))
109 })
110 .await
111 .map_err(InternalGroupcacheError::Deduped)
112 }
113
114 async fn get_deduped(
115 &self,
116 key: &str,
117 peer: GroupcachePeerWithClient,
118 ) -> Result<Value, InternalGroupcacheError> {
119 if peer.peer == self.me {
120 let value = self.load_locally_instrumented(key).await?;
121 self.main_cache.insert(key.to_string(), value.clone()).await;
122 return Ok(value);
123 }
124
125 let mut client = peer
126 .client
127 .context("unreachable: cannot be empty since it's a remote peer")?;
128 let res = self.load_remotely_instrumented(key, &mut client).await;
129 match res {
130 Ok(value) => {
131 self.hot_cache.insert(key.to_string(), value.clone()).await;
132 Ok(value)
133 }
134 Err(_) => {
135 let value = self.load_locally_instrumented(key).await?;
136 Ok(value)
137 }
138 }
139 }
140
141 async fn load_locally_instrumented(&self, key: &str) -> Result<Value, InternalGroupcacheError> {
142 counter!(METRIC_LOCAL_LOAD_TOTAL).increment(1);
143 self.loader
144 .load(key)
145 .await
146 .inspect_err(|_| {
147 counter!(METRIC_LOCAL_LOAD_ERROR_TOTAL).increment(1);
148 })
149 .map_err(InternalGroupcacheError::LocalLoader)
150 }
151
152 async fn load_remotely_instrumented(
153 &self,
154 key: &str,
155 client: &mut GroupcachePeerClient,
156 ) -> Result<Value, InternalGroupcacheError> {
157 counter!(METRIC_REMOTE_LOAD_TOTAL).increment(1);
158 self.load_remotely(key, client)
159 .await
160 .inspect_err(|_| counter!(METRIC_REMOTE_LOAD_ERROR).increment(1))
161 }
162
163 async fn load_remotely(
164 &self,
165 key: &str,
166 client: &mut GroupcachePeerClient,
167 ) -> Result<Value, InternalGroupcacheError> {
168 let response = client
169 .get(
170 GetRequest {
171 key: key.to_string(),
172 }
173 .into_request(),
174 )
175 .await?;
176
177 let get_response = response.into_inner();
178 let bytes = get_response.value.unwrap();
179 let value = rmp_serde::from_read(bytes.as_slice())?;
180
181 Ok(value)
182 }
183
184 async fn remove_internal(
185 &self,
186 key: &str,
187 ) -> core::result::Result<(), InternalGroupcacheError> {
188 self.hot_cache.remove(key).await;
189
190 let peer = {
191 let lock = self.routing_state.read().unwrap();
192 lock.lookup_peer(key)
193 }?;
194
195 if peer.peer == self.me {
196 self.main_cache.remove(key).await;
197 } else {
198 let mut client = peer
199 .client
200 .context("unreachable: cannot be empty since it's a remote peer")?;
201 self.remove_remotely(key, &mut client).await?;
202 }
203
204 Ok(())
205 }
206
207 async fn remove_remotely(
208 &self,
209 key: &str,
210 client: &mut GroupcachePeerClient,
211 ) -> core::result::Result<(), InternalGroupcacheError> {
212 let _ = client
213 .remove(
214 RemoveRequest {
215 key: key.to_string(),
216 }
217 .into_request(),
218 )
219 .await?;
220
221 Ok(())
222 }
223
224 pub(crate) async fn add_peer(&self, peer: GroupcachePeer) -> Result<(), GroupcacheError> {
225 let contains_peer = {
226 let read_lock = self.routing_state.read().unwrap();
227 read_lock.contains_peer(&peer)
228 };
229
230 if contains_peer {
231 return Ok(());
232 }
233
234 let (_, client) = self.connect(peer).await?;
235 let mut write_lock = self.routing_state.write().unwrap();
236 write_lock.add_peer(peer, client);
237
238 Ok(())
239 }
240
241 pub(crate) async fn set_peers(
242 &self,
243 updated_peers: HashSet<GroupcachePeer>,
244 ) -> Result<(), GroupcacheError> {
245 let current_peers: HashSet<GroupcachePeer> = {
246 let read_lock = self.routing_state.read().unwrap();
247 read_lock.peers()
248 };
249
250 let new_connections_results = self
251 .connect_to_new_peers(&updated_peers, ¤t_peers)
252 .await?;
253 let peers_to_remove = current_peers.difference(&updated_peers).collect::<Vec<_>>();
254
255 let no_updates = peers_to_remove.is_empty() && new_connections_results.is_empty();
256 if no_updates {
257 return Ok(());
258 }
259
260 let conn_errors = self.update_routing_table(new_connections_results, peers_to_remove);
261 conn_errors.is_empty().then(|| Ok(())).unwrap_or_else(|| {
262 Err(GroupcacheError::from(
263 InternalGroupcacheError::ConnectionErrors(conn_errors),
264 ))
265 })
266 }
267
268 fn update_routing_table(
270 &self,
271 connection_results: Vec<
272 Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>,
273 >,
274 peers_to_remove: Vec<&GroupcachePeer>,
275 ) -> Vec<InternalGroupcacheError> {
276 let mut write_lock = self.routing_state.write().unwrap();
277
278 let mut connection_errors = Vec::new();
279 for result in connection_results {
280 match result {
281 Ok((peer, client)) => {
282 write_lock.add_peer(peer, client);
283 }
284 Err(e) => {
285 connection_errors.push(e);
286 }
287 }
288 }
289
290 for removed_peer in peers_to_remove {
291 write_lock.remove_peer(*removed_peer);
292 }
293
294 connection_errors
295 }
296
297 async fn connect_to_new_peers(
299 &self,
300 updated_peers: &HashSet<GroupcachePeer>,
301 current_peers: &HashSet<GroupcachePeer>,
302 ) -> Result<
303 Vec<Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>>,
304 GroupcacheError,
305 > {
306 let peers_to_connect = updated_peers.difference(current_peers);
307 let mut connection_task = JoinSet::<
308 Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError>,
309 >::new();
310 for new_peer in peers_to_connect {
311 let moved_peer = *new_peer;
312 let https = self.config.https;
313 let grpc_endpoint_builder = self.config.grpc_endpoint_builder.clone();
314 connection_task.spawn(async move {
315 GroupcacheInner::<Value>::connect_static(moved_peer, https, grpc_endpoint_builder)
316 .await
317 });
318 }
319
320 let mut results = Vec::with_capacity(connection_task.len());
321 while let Some(res) = connection_task.join_next().await {
322 let conn_result = res
323 .context("unexpected JoinError when awaiting peer connection")
324 .map_err(Anyhow)?;
325
326 results.push(conn_result);
327 }
328
329 Ok(results)
330 }
331
332 async fn connect(
333 &self,
334 peer: GroupcachePeer,
335 ) -> Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError> {
336 GroupcacheInner::<Value>::connect_static(
337 peer,
338 self.config.https,
339 self.config.grpc_endpoint_builder.clone(),
340 )
341 .await
342 }
343
344 async fn connect_static(
345 peer: GroupcachePeer,
346 https: bool,
347 grpc_endpoint_builder: Arc<Box<dyn Fn(Endpoint) -> Endpoint + Send + Sync + 'static>>,
348 ) -> Result<(GroupcachePeer, GroupcachePeerClient), InternalGroupcacheError> {
349 let socket = peer.socket;
350 let peer_addr = if https {
351 format!("https://{}", socket)
352 } else {
353 format!("http://{}", socket)
354 };
355
356 let endpoint: Endpoint = peer_addr.try_into()?;
357 let endpoint = grpc_endpoint_builder.as_ref()(endpoint);
358 let client = GroupcacheClient::connect(endpoint).await?;
359 Ok((peer, client))
360 }
361
362 pub(crate) async fn remove_peer(&self, peer: GroupcachePeer) -> Result<(), GroupcacheError> {
363 let contains_peer = {
364 let read_lock = self.routing_state.read().unwrap();
365 read_lock.contains_peer(&peer)
366 };
367
368 if !contains_peer {
369 return Ok(());
370 }
371
372 let mut write_lock = self.routing_state.write().unwrap();
373 write_lock.remove_peer(peer);
374
375 Ok(())
376 }
377
378 pub(crate) fn addr(&self) -> SocketAddr {
379 self.me.socket
380 }
381}