opentalk_cache/
lib.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use core::{fmt::Display, time::Duration};
6use std::{hash::Hash, time::Instant};
7
8use moka::future::Cache as LocalCache;
9use redis::{AsyncCommands, RedisError, ToRedisArgs};
10use serde::{de::DeserializeOwned, Serialize};
11use siphasher::sip128::{Hasher128, SipHasher24};
12use snafu::Snafu;
13
14type RedisConnection = redis::aio::ConnectionManager;
15
16/// Application level cache which can store entries both in a locally and distributed using redis
17pub struct Cache<K, V> {
18    local: LocalCache<K, LocalEntry<V>>,
19    redis: Option<RedisConfig>,
20}
21
22struct RedisConfig {
23    redis: RedisConnection,
24    prefix: String,
25    ttl: Duration,
26    hash_key: bool,
27}
28
29#[derive(Debug, Snafu)]
30pub enum CacheError {
31    #[snafu(display("Redis error: {}", source), context(false))]
32    Redis { source: RedisError },
33    #[snafu(display("Serde error: {}", source), context(false))]
34    Serde { source: bincode::Error },
35}
36
37impl<K, V> Cache<K, V>
38where
39    K: Display + Hash + Eq + Send + Sync + 'static,
40    V: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
41{
42    pub fn new(ttl: Duration) -> Self {
43        Self {
44            local: LocalCache::builder().time_to_live(ttl).build(),
45            redis: None,
46        }
47    }
48
49    pub fn with_redis(
50        self,
51        redis: RedisConnection,
52        prefix: impl Into<String>,
53        ttl: Duration,
54        hash_key: bool,
55    ) -> Self {
56        Self {
57            redis: Some(RedisConfig {
58                redis,
59                prefix: prefix.into(),
60                ttl,
61                hash_key,
62            }),
63            ..self
64        }
65    }
66
67    /// Return the longest duration an entry might live for
68    pub fn longest_ttl(&self) -> Duration {
69        let local_ttl = self
70            .local
71            .policy()
72            .time_to_live()
73            .expect("local always has a ttl");
74
75        if let Some(redis) = &self.redis {
76            redis.ttl.max(local_ttl)
77        } else {
78            local_ttl
79        }
80    }
81
82    pub async fn get(&self, key: &K) -> Result<Option<V>, CacheError> {
83        if let Some(entry) = self
84            .local
85            .get(key)
86            .await
87            .filter(|entry| entry.still_valid())
88        {
89            Ok(Some(entry.value))
90        } else if let Some(RedisConfig {
91            redis,
92            prefix,
93            hash_key,
94            ..
95        }) = &self.redis
96        {
97            let v: Option<Vec<u8>> = redis
98                .clone()
99                .get(RedisCacheKey {
100                    prefix,
101                    key,
102                    hash_key: *hash_key,
103                })
104                .await?;
105
106            if let Some(v) = v {
107                let v = bincode::deserialize(&v)?;
108
109                Ok(Some(v))
110            } else {
111                Ok(None)
112            }
113        } else {
114            Ok(None)
115        }
116    }
117
118    /// Insert a key-value pair with the cache's default TTL
119    pub async fn insert(&self, key: K, value: V) -> Result<(), CacheError> {
120        if let Some(RedisConfig {
121            redis,
122            prefix,
123            ttl,
124            hash_key,
125        }) = &self.redis
126        {
127            redis
128                .clone()
129                .set_ex::<_, _, ()>(
130                    RedisCacheKey {
131                        prefix,
132                        key: &key,
133                        hash_key: *hash_key,
134                    },
135                    bincode::serialize(&value)?,
136                    ttl.as_secs(),
137                )
138                .await?;
139        }
140
141        self.local
142            .insert(
143                key,
144                LocalEntry {
145                    value,
146                    expires_at: None,
147                },
148            )
149            .await;
150
151        Ok(())
152    }
153
154    /// Insert an entry with a custom TTL
155    ///
156    /// Note that TTLs larger than the configured one will be ignored
157    pub async fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> Result<(), CacheError> {
158        if ttl >= self.longest_ttl() {
159            return self.insert(key, value).await;
160        }
161
162        if let Some(RedisConfig {
163            redis,
164            prefix,
165            hash_key,
166            ..
167        }) = &self.redis
168        {
169            redis
170                .clone()
171                .set_ex::<_, _, ()>(
172                    RedisCacheKey {
173                        prefix,
174                        key: &key,
175                        hash_key: *hash_key,
176                    },
177                    bincode::serialize(&value)?,
178                    ttl.as_secs(),
179                )
180                .await?;
181        }
182
183        self.local
184            .insert(
185                key,
186                LocalEntry {
187                    value,
188                    expires_at: Some(Instant::now() + ttl),
189                },
190            )
191            .await;
192
193        Ok(())
194    }
195}
196
197/// [`ToRedisArgs`] implementation for the cache-key
198/// Takes the prefix and cache-key to turn them into a redis-key
199struct RedisCacheKey<'a, K> {
200    hash_key: bool,
201    prefix: &'a str,
202    key: &'a K,
203}
204
205impl<K: Display + Hash> Display for RedisCacheKey<'_, K> {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        if self.hash_key {
208            let mut h = SipHasher24::new_with_keys(!0x113, 0x311);
209            self.key.hash(&mut h);
210            let hash = h.finish128().as_u128();
211
212            write!(f, "opentalk-cache:{}:{:x}", self.prefix, hash)
213        } else {
214            write!(f, "opentalk-cache:{}:{}", self.prefix, self.key)
215        }
216    }
217}
218
219impl<D: Display + Hash> ToRedisArgs for RedisCacheKey<'_, D> {
220    fn write_redis_args<W>(&self, out: &mut W)
221    where
222        W: ?Sized + redis::RedisWrite,
223    {
224        out.write_arg_fmt(self)
225    }
226}
227
228#[derive(Debug, Clone, Copy)]
229struct LocalEntry<V> {
230    value: V,
231    /// Custom expiration value to work around moka's limitation to set a custom ttl for an entry
232    expires_at: Option<Instant>,
233}
234
235impl<V> LocalEntry<V> {
236    // Check if the custom ttl has expired
237    fn still_valid(&self) -> bool {
238        if let Some(exp) = self.expires_at {
239            exp.saturating_duration_since(Instant::now()) > Duration::ZERO
240        } else {
241            true
242        }
243    }
244}