fusioncache_rs/
distributed_cache.rs

1use std::{env, marker::PhantomData, sync::Arc, time::Duration};
2
3use chrono::Utc;
4use futures::StreamExt;
5use redis::{AsyncCommands, Client, RedisError, SetExpiry, SetOptions, aio::MultiplexedConnection};
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::fmt::Debug;
8use tokio::{
9    sync::{Mutex, mpsc},
10    task::JoinHandle,
11};
12use tokio_retry::{
13    Retry,
14    strategy::{FibonacciBackoff, jitter},
15};
16use tracing::{debug, error, info};
17
18use crate::{CacheValue, FusionCacheError, LOG_TARGET};
19
20impl From<RedisError> for FusionCacheError {
21    fn from(error: RedisError) -> Self {
22        FusionCacheError::RedisError(error.to_string())
23    }
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct DistributedCacheValue<TValue> {
28    value: TValue,
29    entry_ttl: Option<i64>,
30    entry_tti: Option<i64>,
31    last_write: i64,
32}
33
34impl<TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static> From<CacheValue<TValue>>
35    for DistributedCacheValue<TValue>
36{
37    fn from(value: CacheValue<TValue>) -> Self {
38        DistributedCacheValue {
39            value: value.value,
40            entry_ttl: value.time_to_live.map(|d| d.as_secs() as i64),
41            entry_tti: value.time_to_idle.map(|d| d.as_secs() as i64),
42            last_write: Utc::now().timestamp_millis(),
43        }
44    }
45}
46
47impl<TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static>
48    From<DistributedCacheValue<TValue>> for CacheValue<TValue>
49{
50    fn from(value: DistributedCacheValue<TValue>) -> Self {
51        CacheValue {
52            value: value.value,
53            time_to_live: value.entry_ttl.map(|d| Duration::from_secs(d as u64)),
54            time_to_idle: value.entry_tti.map(|d| Duration::from_secs(d as u64)),
55        }
56    }
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct CacheSynchronizationPayload {
61    node_id: String,
62    key: String,
63}
64
65#[derive(Clone, Debug)]
66pub struct RedisConnection {
67    redis_connection: MultiplexedConnection,
68    should_fail: bool,
69}
70
71impl RedisConnection {
72    pub fn new(redis_connection: MultiplexedConnection) -> Self {
73        Self {
74            redis_connection,
75            // I know, someone might say that I should use dependency injection here, but that would necessarily require boxing the connection,
76            // and I don't think it's worth it considering that all I want to do is return an error.
77            should_fail: false,
78        }
79    }
80}
81
82impl RedisConnection {
83    async fn get(
84        &mut self,
85        key: &str,
86        application_name: &str,
87    ) -> Result<Option<String>, FusionCacheError> {
88        if self.should_fail {
89            return Err(FusionCacheError::RedisError(
90                "Failed to get value".to_string(),
91            ));
92        }
93        self.redis_connection
94            .get(&format!("{}:{}", application_name, key))
95            .await
96            .map_err(FusionCacheError::from)
97    }
98    async fn set(
99        &mut self,
100        key: &str,
101        value: &str,
102        application_name: &str,
103        entry_ttl: Option<Duration>,
104    ) -> Result<(), FusionCacheError> {
105        if self.should_fail {
106            return Err(FusionCacheError::RedisError(
107                "Failed to set value".to_string(),
108            ));
109        }
110        let namespaced_key = format!("{}:{}", application_name, key);
111        let mut set_options = SetOptions::default();
112        if let Some(entry_ttl) = entry_ttl {
113            set_options = set_options.with_expiration(SetExpiry::EX(entry_ttl.as_secs()));
114        }
115        self.redis_connection
116            .set_options(&namespaced_key, value, set_options)
117            .await
118            .map_err(FusionCacheError::from)
119    }
120    async fn del(&mut self, key: &str, application_name: &str) -> Result<bool, FusionCacheError> {
121        if self.should_fail {
122            return Err(FusionCacheError::RedisError(
123                "Failed to delete value".to_string(),
124            ));
125        }
126        self.redis_connection
127            .del(&format!("{}:{}", application_name, key))
128            .await
129            .map_err(FusionCacheError::from)
130    }
131    async fn publish(&mut self, channel: &str, message: &str) -> Result<(), FusionCacheError> {
132        if self.should_fail {
133            return Err(FusionCacheError::RedisError(
134                "Failed to publish message".to_string(),
135            ));
136        }
137        self.redis_connection
138            .publish(channel, message)
139            .await
140            .map_err(FusionCacheError::from)
141    }
142}
143
144#[derive(Clone, Debug)]
145pub struct DistributedCache<
146    TKey: Eq + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
147    TValue: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
148> {
149    redis_client: Client,
150    eviction_event_sender: mpsc::Sender<TKey>,
151    auto_recovery_event_sender: mpsc::Sender<(TKey, i64)>,
152    synchronization_task: Option<Arc<Mutex<JoinHandle<()>>>>,
153    _auto_recovery_task: Arc<Mutex<JoinHandle<()>>>,
154    application_name: String,
155    _tkey: PhantomData<TKey>,
156    _tvalue: PhantomData<TValue>,
157    redis_connection: RedisConnection,
158    entry_ttl: Option<Duration>,
159    node_id: String,
160}
161
162impl<
163    TKey: Eq + Send + Sync + Clone + Serialize + DeserializeOwned + Debug + 'static,
164    TValue: Clone + Send + Sync + Serialize + DeserializeOwned + Debug + 'static,
165> DistributedCache<TKey, TValue>
166{
167    pub fn new(
168        redis_connection: RedisConnection,
169        redis_client: Client,
170        eviction_event_sender: mpsc::Sender<TKey>,
171        application_name: String,
172        entry_ttl: Option<Duration>,
173    ) -> Self {
174        let node_id = if env::var("KUBERNETES_SERVICE_HOST").is_ok() {
175            env::var("HOSTNAME").unwrap()
176        } else {
177            uuid::Uuid::new_v4().to_string()
178        };
179
180        let (auto_recovery_event_sender, mut auto_recovery_event_receiver) =
181            mpsc::channel::<(TKey, i64)>(1000);
182        let _event_sender = eviction_event_sender.clone();
183        let _redis_connection = redis_connection.clone();
184        let _node_id = node_id.clone();
185        let _application_name = application_name.clone();
186        let auto_recovery_task = tokio::spawn(async move {
187            let retry_strategy = FibonacciBackoff::from_millis(1000).map(jitter);
188            while let Some((key, failure_timestamp)) = auto_recovery_event_receiver.recv().await {
189                let _ = Retry::spawn(retry_strategy.clone(), async || {
190                    let serialized_key = serde_json::to_string(&key).unwrap();
191                    let mut connection = _redis_connection.clone();
192                    let value = connection.get(&serialized_key, &_application_name).await?;
193                    if let Some(value) = value {
194                        let distributed_cache_value: DistributedCacheValue<TValue> =
195                            serde_json::from_str(&value).unwrap();
196                        if distributed_cache_value.last_write > failure_timestamp {
197                            return Ok(());
198                        } else {
199                            connection.del(&serialized_key, &_application_name).await?;
200                            let json_payload =
201                                serde_json::to_string(&CacheSynchronizationPayload {
202                                    node_id: _node_id.clone(),
203                                    key: serialized_key,
204                                })
205                                .unwrap();
206                            return connection
207                                .publish(_application_name.as_str(), &json_payload)
208                                .await;
209                        }
210                    }
211                    Ok(())
212                })
213                .await;
214            }
215        });
216
217        Self {
218            redis_connection,
219            redis_client,
220            eviction_event_sender,
221            application_name,
222            synchronization_task: None,
223            _auto_recovery_task: Arc::new(Mutex::new(auto_recovery_task)),
224            auto_recovery_event_sender,
225            node_id,
226            _tkey: PhantomData,
227            _tvalue: PhantomData,
228            entry_ttl,
229        }
230    }
231
232    pub async fn start_synchronization(&mut self) -> Result<(), RedisError> {
233        let mut pubsub = self.redis_client.get_async_pubsub().await?;
234        pubsub.subscribe(self.application_name.as_str()).await?;
235        let eviction_event_sender = self.eviction_event_sender.clone();
236        let node_id = self.node_id.clone();
237        self.synchronization_task = Some(Arc::new(Mutex::new(tokio::spawn(async move {
238            while let Some(message) = pubsub.on_message().next().await {
239                let json_message_payload = message.get_payload::<String>().unwrap();
240                let payload: CacheSynchronizationPayload =
241                    serde_json::from_str(&json_message_payload).unwrap();
242                let deserialized_key: TKey = serde_json::from_str(&payload.key).unwrap();
243                if payload.node_id != node_id {
244                    eviction_event_sender.send(deserialized_key).await.unwrap();
245                }
246            }
247        }))));
248        Ok(())
249    }
250
251    #[tracing::instrument(name = "DistributedCache::get", skip(self))]
252    pub async fn get(
253        &mut self,
254        key: &TKey,
255    ) -> Result<Option<CacheValue<TValue>>, FusionCacheError> {
256        let key_str = serde_json::to_string(key).unwrap();
257        let value: Option<String> = self
258            .redis_connection
259            .get(&key_str, &self.application_name)
260            .await?;
261        let distributed_cache_value: Option<DistributedCacheValue<TValue>> =
262            if let Some(value) = value {
263                Some(serde_json::from_str(&value).unwrap())
264            } else {
265                None
266            };
267        Ok(distributed_cache_value.map(|v| v.into()))
268    }
269
270    #[tracing::instrument(name = "DistributedCache::set", skip(self))]
271    pub async fn set(
272        &mut self,
273        key: &TKey,
274        value: &CacheValue<TValue>,
275    ) -> Result<(), FusionCacheError> {
276        let key_str = serde_json::to_string(key).unwrap();
277        let value_str = serde_json::to_string(&DistributedCacheValue::from(value.clone())).unwrap();
278
279        let cache_synchronization_payload = CacheSynchronizationPayload {
280            node_id: self.node_id.clone(),
281            key: key_str.clone(),
282        };
283        let json_cache_synchronization_payload =
284            serde_json::to_string(&cache_synchronization_payload).unwrap();
285
286        match self
287            .redis_connection
288            .set(&key_str, &value_str, &self.application_name, self.entry_ttl)
289            .await
290        {
291            Ok(_) => {
292                debug!(target: LOG_TARGET, "Successfully set value in distributed cache for key: {:?}. Publishing synchronization payload: {:?}", key, json_cache_synchronization_payload);
293                let publish_result = self
294                    .redis_connection
295                    .publish(
296                        self.application_name.as_str(),
297                        &json_cache_synchronization_payload,
298                    )
299                    .await;
300                if let Err(e) = publish_result {
301                    error!(target: LOG_TARGET, "Failed to publish synchronization payload: {:?}. Kicking off auto-recovery for key: {:?}", e, key);
302                    let failure_timestamp = Utc::now().timestamp_millis();
303                    let key: TKey = serde_json::from_str(&key_str).unwrap();
304                    self.auto_recovery_event_sender
305                        .send((key.clone(), failure_timestamp))
306                        .await
307                        .unwrap();
308                    self.eviction_event_sender.send(key).await.unwrap();
309                    Err(e)
310                } else {
311                    Ok(())
312                }
313            }
314            Err(e) => {
315                self.eviction_event_sender.send(key.clone()).await.unwrap();
316                let failure_timestamp = Utc::now().timestamp_millis();
317                let key: TKey = serde_json::from_str(&key_str).unwrap();
318                self.auto_recovery_event_sender
319                    .send((key, failure_timestamp))
320                    .await
321                    .unwrap();
322                Err(e)
323            }
324        }
325    }
326
327    pub async fn evict(&mut self, key: &TKey) {
328        let key_str = serde_json::to_string(key).unwrap();
329        let cache_synchronization_payload = CacheSynchronizationPayload {
330            node_id: self.node_id.clone(),
331            key: key_str.clone(),
332        };
333        let json_cache_synchronization_payload =
334            serde_json::to_string(&cache_synchronization_payload).unwrap();
335        let publish_result = self
336            .redis_connection
337            .publish(
338                self.application_name.as_str(),
339                &json_cache_synchronization_payload,
340            )
341            .await
342            .map_err(FusionCacheError::from);
343        if let Err(e) = publish_result {
344            error!(target: LOG_TARGET, "Failed to publish synchronization payload: {:?}. Kicking off auto-recovery for key: {:?}", e, key);
345            let failure_timestamp = Utc::now().timestamp_millis();
346            let key: TKey = serde_json::from_str(&key_str).unwrap();
347            self.auto_recovery_event_sender
348                .send((key, failure_timestamp))
349                .await
350                .unwrap();
351        }
352    }
353
354    pub(crate) fn break_connection(&mut self) {
355        self.redis_connection.should_fail = true;
356    }
357
358    pub(crate) fn restore_connection(&mut self) {
359        self.redis_connection.should_fail = false;
360    }
361}
362mod tests {
363
364    use std::time::Duration;
365
366    use super::*;
367    #[tokio::test]
368    async fn test_basic_set_get() {
369        let redis_client = Client::open("redis://127.0.0.1/").unwrap();
370        let inner_redis_connection = redis_client
371            .get_multiplexed_async_connection()
372            .await
373            .unwrap();
374        let redis_connection = RedisConnection::new(inner_redis_connection);
375        let (eviction_sender, _) = mpsc::channel(100);
376
377        let mut cache = DistributedCache::<String, String>::new(
378            redis_connection,
379            redis_client,
380            eviction_sender,
381            "test_app".to_string(),
382            None,
383        );
384
385        // Test setting and getting a value
386        let key = "test_key".to_string();
387        let value = "test_value".to_string();
388        let cache_value = CacheValue {
389            value: value.clone(),
390            time_to_live: None,
391            time_to_idle: None,
392        };
393
394        cache.set(&key, &cache_value).await.unwrap();
395        let retrieved_value = cache.get(&key).await.unwrap();
396
397        assert_eq!(retrieved_value, Some(cache_value));
398    }
399
400    #[tokio::test]
401    async fn test_synchronization() {
402        let redis_client1 = Client::open("redis://127.0.0.1/").unwrap();
403        let inner_redis_connection1 = redis_client1
404            .get_multiplexed_async_connection()
405            .await
406            .unwrap();
407        let (eviction_sender1, eviction_receiver1) = mpsc::channel(100);
408        let _eviction_receiver1 = eviction_receiver1;
409        let redis_connection1 = RedisConnection::new(inner_redis_connection1);
410
411        let mut cache1 = DistributedCache::<String, String>::new(
412            redis_connection1,
413            redis_client1,
414            eviction_sender1,
415            "test_synchronization".to_string(),
416            None,
417        );
418
419        let redis_client2 = Client::open("redis://127.0.0.1/").unwrap();
420        let inner_redis_connection2 = redis_client2
421            .get_multiplexed_async_connection()
422            .await
423            .unwrap();
424        let redis_connection2 = RedisConnection::new(inner_redis_connection2);
425        let (eviction_sender2, mut eviction_receiver2) = mpsc::channel(100);
426
427        let mut cache2 = DistributedCache::<String, String>::new(
428            redis_connection2,
429            redis_client2,
430            eviction_sender2,
431            "test_synchronization".to_string(),
432            None,
433        );
434
435        // Start synchronization for both caches
436        cache1.start_synchronization().await.unwrap();
437        cache2.start_synchronization().await.unwrap();
438
439        // Set a value in cache1
440        let key = "sync_test_key".to_string();
441        let value = "sync_test_value".to_string();
442        let cache_value = CacheValue {
443            value: value.clone(),
444            time_to_live: None,
445            time_to_idle: None,
446        };
447        cache1.set(&key, &cache_value).await.unwrap();
448
449        // Wait for synchronization
450        tokio::time::sleep(Duration::from_millis(100)).await;
451
452        // Verify cache2 received the eviction event
453        let evicted_key = eviction_receiver2.recv().await.unwrap();
454        assert_eq!(evicted_key, key);
455    }
456
457    #[tokio::test]
458    async fn test_concurrent_writes() {
459        let redis_client = Client::open("redis://127.0.0.1/").unwrap();
460        let inner_redis_connection = redis_client
461            .get_multiplexed_async_connection()
462            .await
463            .unwrap();
464        let (eviction_sender, _) = mpsc::channel(100);
465        let redis_connection = RedisConnection::new(inner_redis_connection);
466
467        let mut cache = DistributedCache::<String, String>::new(
468            redis_connection,
469            redis_client,
470            eviction_sender,
471            "test_concurrent_writes".to_string(),
472            None,
473        );
474
475        let key = "concurrent_key".to_string();
476        let value1 = "value1".to_string();
477        let cache_value1 = CacheValue {
478            value: value1.clone(),
479            time_to_live: None,
480            time_to_idle: None,
481        };
482        let value2 = "value2".to_string();
483        let cache_value2 = CacheValue {
484            value: value2.clone(),
485            time_to_live: None,
486            time_to_idle: None,
487        };
488        // First write should succeed
489        assert!(cache.set(&key, &cache_value1).await.is_ok());
490
491        assert!(cache.set(&key, &cache_value2).await.is_ok());
492
493        // Verify the value was updated
494        let retrieved_value = cache.get(&key).await.unwrap();
495        assert_eq!(retrieved_value, Some(cache_value2));
496    }
497
498    #[tokio::test]
499    async fn test_auto_recovery() {
500        let redis_client = Client::open("redis://127.0.0.1/").unwrap();
501        let internal_redis_connection = redis_client
502            .get_multiplexed_async_connection()
503            .await
504            .unwrap();
505        let working_redis_connection = RedisConnection::new(internal_redis_connection);
506
507        let (eviction_sender, eviction_receiver) = mpsc::channel(100);
508        let _eviction_receiver = eviction_receiver;
509
510        let mut cache = DistributedCache::<String, String>::new(
511            working_redis_connection,
512            redis_client,
513            eviction_sender,
514            "test_auto_recovery".to_string(),
515            None,
516        );
517
518        let key = "key".to_string();
519        let value = "value".to_string();
520        let cache_value = CacheValue {
521            value: value.clone(),
522            time_to_live: None,
523            time_to_idle: None,
524        };
525
526        // Set initial value
527        cache.set(&key, &cache_value).await.unwrap();
528
529        cache.break_connection();
530
531        let value2 = "value2".to_string();
532        let cache_value2 = CacheValue {
533            value: value2.clone(),
534            time_to_live: None,
535            time_to_idle: None,
536        };
537        let set_result = cache.set(&key, &cache_value2).await;
538        assert!(set_result.is_err());
539
540        tokio::time::sleep(Duration::from_secs(2)).await;
541
542        cache.restore_connection();
543
544        // Wait for auto-recovery to kick in
545        tokio::time::sleep(Duration::from_secs(2)).await;
546
547        //Verify that the value has been evicted
548        let retrieved_value = cache.get(&key).await.unwrap();
549        assert_eq!(retrieved_value, None);
550    }
551}