1use std::{env, marker::PhantomData, sync::Arc, time::Duration};
2
3use chrono::Utc;
4use futures::StreamExt;
5use redis::{AsyncCommands, Client, RedisError, SetExpiry, SetOptions, aio::MultiplexedConnection};
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::fmt::Debug;
8use tokio::{
9 sync::{Mutex, mpsc},
10 task::JoinHandle,
11};
12use tokio_retry::{
13 Retry,
14 strategy::{FibonacciBackoff, jitter},
15};
16use tracing::{debug, error, info};
17
18use crate::{CacheValue, FusionCacheError, LOG_TARGET};
19
20impl From<RedisError> for FusionCacheError {
21 fn from(error: RedisError) -> Self {
22 FusionCacheError::RedisError(error.to_string())
23 }
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct DistributedCacheValue<TValue> {
28 value: TValue,
29 entry_ttl: Option<i64>,
30 entry_tti: Option<i64>,
31 last_write: i64,
32}
33
34impl<TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static> From<CacheValue<TValue>>
35 for DistributedCacheValue<TValue>
36{
37 fn from(value: CacheValue<TValue>) -> Self {
38 DistributedCacheValue {
39 value: value.value,
40 entry_ttl: value.time_to_live.map(|d| d.as_secs() as i64),
41 entry_tti: value.time_to_idle.map(|d| d.as_secs() as i64),
42 last_write: Utc::now().timestamp_millis(),
43 }
44 }
45}
46
47impl<TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static>
48 From<DistributedCacheValue<TValue>> for CacheValue<TValue>
49{
50 fn from(value: DistributedCacheValue<TValue>) -> Self {
51 CacheValue {
52 value: value.value,
53 time_to_live: value.entry_ttl.map(|d| Duration::from_secs(d as u64)),
54 time_to_idle: value.entry_tti.map(|d| Duration::from_secs(d as u64)),
55 }
56 }
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct CacheSynchronizationPayload {
61 node_id: String,
62 key: String,
63}
64
65#[derive(Clone, Debug)]
66pub struct RedisConnection {
67 redis_connection: MultiplexedConnection,
68 should_fail: bool,
69}
70
71impl RedisConnection {
72 pub fn new(redis_connection: MultiplexedConnection) -> Self {
73 Self {
74 redis_connection,
75 should_fail: false,
78 }
79 }
80}
81
82impl RedisConnection {
83 async fn get(
84 &mut self,
85 key: &str,
86 application_name: &str,
87 ) -> Result<Option<String>, FusionCacheError> {
88 if self.should_fail {
89 return Err(FusionCacheError::RedisError(
90 "Failed to get value".to_string(),
91 ));
92 }
93 self.redis_connection
94 .get(&format!("{}:{}", application_name, key))
95 .await
96 .map_err(FusionCacheError::from)
97 }
98 async fn set(
99 &mut self,
100 key: &str,
101 value: &str,
102 application_name: &str,
103 entry_ttl: Option<Duration>,
104 ) -> Result<(), FusionCacheError> {
105 if self.should_fail {
106 return Err(FusionCacheError::RedisError(
107 "Failed to set value".to_string(),
108 ));
109 }
110 let namespaced_key = format!("{}:{}", application_name, key);
111 let mut set_options = SetOptions::default();
112 if let Some(entry_ttl) = entry_ttl {
113 set_options = set_options.with_expiration(SetExpiry::EX(entry_ttl.as_secs()));
114 }
115 self.redis_connection
116 .set_options(&namespaced_key, value, set_options)
117 .await
118 .map_err(FusionCacheError::from)
119 }
120 async fn del(&mut self, key: &str, application_name: &str) -> Result<bool, FusionCacheError> {
121 if self.should_fail {
122 return Err(FusionCacheError::RedisError(
123 "Failed to delete value".to_string(),
124 ));
125 }
126 self.redis_connection
127 .del(&format!("{}:{}", application_name, key))
128 .await
129 .map_err(FusionCacheError::from)
130 }
131 async fn publish(&mut self, channel: &str, message: &str) -> Result<(), FusionCacheError> {
132 if self.should_fail {
133 return Err(FusionCacheError::RedisError(
134 "Failed to publish message".to_string(),
135 ));
136 }
137 self.redis_connection
138 .publish(channel, message)
139 .await
140 .map_err(FusionCacheError::from)
141 }
142}
143
144#[derive(Clone, Debug)]
145pub struct DistributedCache<
146 TKey: Eq + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
147 TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
148> {
149 redis_client: Client,
150 eviction_event_sender: mpsc::Sender<TKey>,
151 auto_recovery_event_sender: mpsc::Sender<(TKey, i64)>,
152 synchronization_task: Option<Arc<Mutex<JoinHandle<()>>>>,
153 _auto_recovery_task: Arc<Mutex<JoinHandle<()>>>,
154 application_name: String,
155 _tkey: PhantomData<TKey>,
156 _tvalue: PhantomData<TValue>,
157 redis_connection: RedisConnection,
158 entry_ttl: Option<Duration>,
159 node_id: String,
160}
161
162impl<
163 TKey: Eq + Send + Sync + Clone + Serialize + DeserializeOwned + Debug + 'static,
164 TValue: Clone + Send + Sync + Serialize + DeserializeOwned + Debug + 'static,
165> DistributedCache<TKey, TValue>
166{
167 pub fn new(
168 redis_connection: RedisConnection,
169 redis_client: Client,
170 eviction_event_sender: mpsc::Sender<TKey>,
171 application_name: String,
172 entry_ttl: Option<Duration>,
173 ) -> Self {
174 let node_id = if env::var("KUBERNETES_SERVICE_HOST").is_ok() {
175 env::var("HOSTNAME").unwrap()
176 } else {
177 uuid::Uuid::new_v4().to_string()
178 };
179
180 let (auto_recovery_event_sender, mut auto_recovery_event_receiver) =
181 mpsc::channel::<(TKey, i64)>(1000);
182 let _event_sender = eviction_event_sender.clone();
183 let _redis_connection = redis_connection.clone();
184 let _node_id = node_id.clone();
185 let _application_name = application_name.clone();
186 let auto_recovery_task = tokio::spawn(async move {
187 let retry_strategy = FibonacciBackoff::from_millis(1000).map(jitter);
188 while let Some((key, failure_timestamp)) = auto_recovery_event_receiver.recv().await {
189 let _ = Retry::spawn(retry_strategy.clone(), async || {
190 let serialized_key = serde_json::to_string(&key).unwrap();
191 let mut connection = _redis_connection.clone();
192 let value = connection.get(&serialized_key, &_application_name).await?;
193 if let Some(value) = value {
194 let distributed_cache_value: DistributedCacheValue<TValue> =
195 serde_json::from_str(&value).unwrap();
196 if distributed_cache_value.last_write > failure_timestamp {
197 return Ok(());
198 } else {
199 connection.del(&serialized_key, &_application_name).await?;
200 let json_payload =
201 serde_json::to_string(&CacheSynchronizationPayload {
202 node_id: _node_id.clone(),
203 key: serialized_key,
204 })
205 .unwrap();
206 return connection
207 .publish(_application_name.as_str(), &json_payload)
208 .await;
209 }
210 }
211 Ok(())
212 })
213 .await;
214 }
215 });
216
217 Self {
218 redis_connection,
219 redis_client,
220 eviction_event_sender,
221 application_name,
222 synchronization_task: None,
223 _auto_recovery_task: Arc::new(Mutex::new(auto_recovery_task)),
224 auto_recovery_event_sender,
225 node_id,
226 _tkey: PhantomData,
227 _tvalue: PhantomData,
228 entry_ttl,
229 }
230 }
231
232 pub async fn start_synchronization(&mut self) -> Result<(), RedisError> {
233 let mut pubsub = self.redis_client.get_async_pubsub().await?;
234 pubsub.subscribe(self.application_name.as_str()).await?;
235 let eviction_event_sender = self.eviction_event_sender.clone();
236 let node_id = self.node_id.clone();
237 self.synchronization_task = Some(Arc::new(Mutex::new(tokio::spawn(async move {
238 while let Some(message) = pubsub.on_message().next().await {
239 let json_message_payload = message.get_payload::<String>().unwrap();
240 let payload: CacheSynchronizationPayload =
241 serde_json::from_str(&json_message_payload).unwrap();
242 let deserialized_key: TKey = serde_json::from_str(&payload.key).unwrap();
243 if payload.node_id != node_id {
244 eviction_event_sender.send(deserialized_key).await.unwrap();
245 }
246 }
247 }))));
248 Ok(())
249 }
250
251 #[tracing::instrument(name = "DistributedCache::get", skip(self))]
252 pub async fn get(
253 &mut self,
254 key: &TKey,
255 ) -> Result<Option<CacheValue<TValue>>, FusionCacheError> {
256 let key_str = serde_json::to_string(key).unwrap();
257 let value: Option<String> = self
258 .redis_connection
259 .get(&key_str, &self.application_name)
260 .await?;
261 let distributed_cache_value: Option<DistributedCacheValue<TValue>> =
262 if let Some(value) = value {
263 Some(serde_json::from_str(&value).unwrap())
264 } else {
265 None
266 };
267 Ok(distributed_cache_value.map(|v| v.into()))
268 }
269
270 #[tracing::instrument(name = "DistributedCache::set", skip(self))]
271 pub async fn set(
272 &mut self,
273 key: &TKey,
274 value: &CacheValue<TValue>,
275 ) -> Result<(), FusionCacheError> {
276 let key_str = serde_json::to_string(key).unwrap();
277 let value_str = serde_json::to_string(&DistributedCacheValue::from(value.clone())).unwrap();
278
279 let cache_synchronization_payload = CacheSynchronizationPayload {
280 node_id: self.node_id.clone(),
281 key: key_str.clone(),
282 };
283 let json_cache_synchronization_payload =
284 serde_json::to_string(&cache_synchronization_payload).unwrap();
285
286 match self
287 .redis_connection
288 .set(&key_str, &value_str, &self.application_name, self.entry_ttl)
289 .await
290 {
291 Ok(_) => {
292 debug!(target: LOG_TARGET, "Successfully set value in distributed cache for key: {:?}. Publishing synchronization payload: {:?}", key, json_cache_synchronization_payload);
293 let publish_result = self
294 .redis_connection
295 .publish(
296 self.application_name.as_str(),
297 &json_cache_synchronization_payload,
298 )
299 .await;
300 if let Err(e) = publish_result {
301 error!(target: LOG_TARGET, "Failed to publish synchronization payload: {:?}. Kicking off auto-recovery for key: {:?}", e, key);
302 let failure_timestamp = Utc::now().timestamp_millis();
303 let key: TKey = serde_json::from_str(&key_str).unwrap();
304 self.auto_recovery_event_sender
305 .send((key.clone(), failure_timestamp))
306 .await
307 .unwrap();
308 self.eviction_event_sender.send(key).await.unwrap();
309 Err(e)
310 } else {
311 Ok(())
312 }
313 }
314 Err(e) => {
315 self.eviction_event_sender.send(key.clone()).await.unwrap();
316 let failure_timestamp = Utc::now().timestamp_millis();
317 let key: TKey = serde_json::from_str(&key_str).unwrap();
318 self.auto_recovery_event_sender
319 .send((key, failure_timestamp))
320 .await
321 .unwrap();
322 Err(e)
323 }
324 }
325 }
326
327 pub async fn evict(&mut self, key: &TKey) {
328 let key_str = serde_json::to_string(key).unwrap();
329 let cache_synchronization_payload = CacheSynchronizationPayload {
330 node_id: self.node_id.clone(),
331 key: key_str.clone(),
332 };
333 let json_cache_synchronization_payload =
334 serde_json::to_string(&cache_synchronization_payload).unwrap();
335 let publish_result = self
336 .redis_connection
337 .publish(
338 self.application_name.as_str(),
339 &json_cache_synchronization_payload,
340 )
341 .await
342 .map_err(FusionCacheError::from);
343 if let Err(e) = publish_result {
344 error!(target: LOG_TARGET, "Failed to publish synchronization payload: {:?}. Kicking off auto-recovery for key: {:?}", e, key);
345 let failure_timestamp = Utc::now().timestamp_millis();
346 let key: TKey = serde_json::from_str(&key_str).unwrap();
347 self.auto_recovery_event_sender
348 .send((key, failure_timestamp))
349 .await
350 .unwrap();
351 }
352 }
353
354 pub(crate) fn break_connection(&mut self) {
355 self.redis_connection.should_fail = true;
356 }
357
358 pub(crate) fn restore_connection(&mut self) {
359 self.redis_connection.should_fail = false;
360 }
361}
362mod tests {
363
364 use std::time::Duration;
365
366 use super::*;
367 #[tokio::test]
368 async fn test_basic_set_get() {
369 let redis_client = Client::open("redis://127.0.0.1/").unwrap();
370 let inner_redis_connection = redis_client
371 .get_multiplexed_async_connection()
372 .await
373 .unwrap();
374 let redis_connection = RedisConnection::new(inner_redis_connection);
375 let (eviction_sender, _) = mpsc::channel(100);
376
377 let mut cache = DistributedCache::<String, String>::new(
378 redis_connection,
379 redis_client,
380 eviction_sender,
381 "test_app".to_string(),
382 None,
383 );
384
385 let key = "test_key".to_string();
387 let value = "test_value".to_string();
388 let cache_value = CacheValue {
389 value: value.clone(),
390 time_to_live: None,
391 time_to_idle: None,
392 };
393
394 cache.set(&key, &cache_value).await.unwrap();
395 let retrieved_value = cache.get(&key).await.unwrap();
396
397 assert_eq!(retrieved_value, Some(cache_value));
398 }
399
400 #[tokio::test]
401 async fn test_synchronization() {
402 let redis_client1 = Client::open("redis://127.0.0.1/").unwrap();
403 let inner_redis_connection1 = redis_client1
404 .get_multiplexed_async_connection()
405 .await
406 .unwrap();
407 let (eviction_sender1, eviction_receiver1) = mpsc::channel(100);
408 let _eviction_receiver1 = eviction_receiver1;
409 let redis_connection1 = RedisConnection::new(inner_redis_connection1);
410
411 let mut cache1 = DistributedCache::<String, String>::new(
412 redis_connection1,
413 redis_client1,
414 eviction_sender1,
415 "test_synchronization".to_string(),
416 None,
417 );
418
419 let redis_client2 = Client::open("redis://127.0.0.1/").unwrap();
420 let inner_redis_connection2 = redis_client2
421 .get_multiplexed_async_connection()
422 .await
423 .unwrap();
424 let redis_connection2 = RedisConnection::new(inner_redis_connection2);
425 let (eviction_sender2, mut eviction_receiver2) = mpsc::channel(100);
426
427 let mut cache2 = DistributedCache::<String, String>::new(
428 redis_connection2,
429 redis_client2,
430 eviction_sender2,
431 "test_synchronization".to_string(),
432 None,
433 );
434
435 cache1.start_synchronization().await.unwrap();
437 cache2.start_synchronization().await.unwrap();
438
439 let key = "sync_test_key".to_string();
441 let value = "sync_test_value".to_string();
442 let cache_value = CacheValue {
443 value: value.clone(),
444 time_to_live: None,
445 time_to_idle: None,
446 };
447 cache1.set(&key, &cache_value).await.unwrap();
448
449 tokio::time::sleep(Duration::from_millis(100)).await;
451
452 let evicted_key = eviction_receiver2.recv().await.unwrap();
454 assert_eq!(evicted_key, key);
455 }
456
457 #[tokio::test]
458 async fn test_concurrent_writes() {
459 let redis_client = Client::open("redis://127.0.0.1/").unwrap();
460 let inner_redis_connection = redis_client
461 .get_multiplexed_async_connection()
462 .await
463 .unwrap();
464 let (eviction_sender, _) = mpsc::channel(100);
465 let redis_connection = RedisConnection::new(inner_redis_connection);
466
467 let mut cache = DistributedCache::<String, String>::new(
468 redis_connection,
469 redis_client,
470 eviction_sender,
471 "test_concurrent_writes".to_string(),
472 None,
473 );
474
475 let key = "concurrent_key".to_string();
476 let value1 = "value1".to_string();
477 let cache_value1 = CacheValue {
478 value: value1.clone(),
479 time_to_live: None,
480 time_to_idle: None,
481 };
482 let value2 = "value2".to_string();
483 let cache_value2 = CacheValue {
484 value: value2.clone(),
485 time_to_live: None,
486 time_to_idle: None,
487 };
488 assert!(cache.set(&key, &cache_value1).await.is_ok());
490
491 assert!(cache.set(&key, &cache_value2).await.is_ok());
492
493 let retrieved_value = cache.get(&key).await.unwrap();
495 assert_eq!(retrieved_value, Some(cache_value2));
496 }
497
498 #[tokio::test]
499 async fn test_auto_recovery() {
500 let redis_client = Client::open("redis://127.0.0.1/").unwrap();
501 let internal_redis_connection = redis_client
502 .get_multiplexed_async_connection()
503 .await
504 .unwrap();
505 let working_redis_connection = RedisConnection::new(internal_redis_connection);
506
507 let (eviction_sender, eviction_receiver) = mpsc::channel(100);
508 let _eviction_receiver = eviction_receiver;
509
510 let mut cache = DistributedCache::<String, String>::new(
511 working_redis_connection,
512 redis_client,
513 eviction_sender,
514 "test_auto_recovery".to_string(),
515 None,
516 );
517
518 let key = "key".to_string();
519 let value = "value".to_string();
520 let cache_value = CacheValue {
521 value: value.clone(),
522 time_to_live: None,
523 time_to_idle: None,
524 };
525
526 cache.set(&key, &cache_value).await.unwrap();
528
529 cache.break_connection();
530
531 let value2 = "value2".to_string();
532 let cache_value2 = CacheValue {
533 value: value2.clone(),
534 time_to_live: None,
535 time_to_idle: None,
536 };
537 let set_result = cache.set(&key, &cache_value2).await;
538 assert!(set_result.is_err());
539
540 tokio::time::sleep(Duration::from_secs(2)).await;
541
542 cache.restore_connection();
543
544 tokio::time::sleep(Duration::from_secs(2)).await;
546
547 let retrieved_value = cache.get(&key).await.unwrap();
549 assert_eq!(retrieved_value, None);
550 }
551}