1use core::{fmt::Display, time::Duration};
6use std::{hash::Hash, time::Instant};
7
8use moka::future::Cache as LocalCache;
9use redis::{AsyncCommands, RedisError, ToRedisArgs};
10use serde::{de::DeserializeOwned, Serialize};
11use siphasher::sip128::{Hasher128, SipHasher24};
12use snafu::Snafu;
13
14type RedisConnection = redis::aio::ConnectionManager;
15
16pub struct Cache<K, V> {
18 local: LocalCache<K, LocalEntry<V>>,
19 redis: Option<RedisConfig>,
20}
21
22struct RedisConfig {
23 redis: RedisConnection,
24 prefix: String,
25 ttl: Duration,
26 hash_key: bool,
27}
28
29#[derive(Debug, Snafu)]
30pub enum CacheError {
31 #[snafu(display("Redis error: {}", source), context(false))]
32 Redis { source: RedisError },
33 #[snafu(display("Serde error: {}", source), context(false))]
34 Serde { source: bincode::Error },
35}
36
37impl<K, V> Cache<K, V>
38where
39 K: Display + Hash + Eq + Send + Sync + 'static,
40 V: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
41{
42 pub fn new(ttl: Duration) -> Self {
43 Self {
44 local: LocalCache::builder().time_to_live(ttl).build(),
45 redis: None,
46 }
47 }
48
49 pub fn with_redis(
50 self,
51 redis: RedisConnection,
52 prefix: impl Into<String>,
53 ttl: Duration,
54 hash_key: bool,
55 ) -> Self {
56 Self {
57 redis: Some(RedisConfig {
58 redis,
59 prefix: prefix.into(),
60 ttl,
61 hash_key,
62 }),
63 ..self
64 }
65 }
66
67 pub fn longest_ttl(&self) -> Duration {
69 let local_ttl = self
70 .local
71 .policy()
72 .time_to_live()
73 .expect("local always has a ttl");
74
75 if let Some(redis) = &self.redis {
76 redis.ttl.max(local_ttl)
77 } else {
78 local_ttl
79 }
80 }
81
82 pub async fn get(&self, key: &K) -> Result<Option<V>, CacheError> {
83 if let Some(entry) = self
84 .local
85 .get(key)
86 .await
87 .filter(|entry| entry.still_valid())
88 {
89 Ok(Some(entry.value))
90 } else if let Some(RedisConfig {
91 redis,
92 prefix,
93 hash_key,
94 ..
95 }) = &self.redis
96 {
97 let v: Option<Vec<u8>> = redis
98 .clone()
99 .get(RedisCacheKey {
100 prefix,
101 key,
102 hash_key: *hash_key,
103 })
104 .await?;
105
106 if let Some(v) = v {
107 let v = bincode::deserialize(&v)?;
108
109 Ok(Some(v))
110 } else {
111 Ok(None)
112 }
113 } else {
114 Ok(None)
115 }
116 }
117
118 pub async fn insert(&self, key: K, value: V) -> Result<(), CacheError> {
120 if let Some(RedisConfig {
121 redis,
122 prefix,
123 ttl,
124 hash_key,
125 }) = &self.redis
126 {
127 redis
128 .clone()
129 .set_ex::<_, _, ()>(
130 RedisCacheKey {
131 prefix,
132 key: &key,
133 hash_key: *hash_key,
134 },
135 bincode::serialize(&value)?,
136 ttl.as_secs(),
137 )
138 .await?;
139 }
140
141 self.local
142 .insert(
143 key,
144 LocalEntry {
145 value,
146 expires_at: None,
147 },
148 )
149 .await;
150
151 Ok(())
152 }
153
154 pub async fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> Result<(), CacheError> {
158 if ttl >= self.longest_ttl() {
159 return self.insert(key, value).await;
160 }
161
162 if let Some(RedisConfig {
163 redis,
164 prefix,
165 hash_key,
166 ..
167 }) = &self.redis
168 {
169 redis
170 .clone()
171 .set_ex::<_, _, ()>(
172 RedisCacheKey {
173 prefix,
174 key: &key,
175 hash_key: *hash_key,
176 },
177 bincode::serialize(&value)?,
178 ttl.as_secs(),
179 )
180 .await?;
181 }
182
183 self.local
184 .insert(
185 key,
186 LocalEntry {
187 value,
188 expires_at: Some(Instant::now() + ttl),
189 },
190 )
191 .await;
192
193 Ok(())
194 }
195}
196
197struct RedisCacheKey<'a, K> {
200 hash_key: bool,
201 prefix: &'a str,
202 key: &'a K,
203}
204
205impl<K: Display + Hash> Display for RedisCacheKey<'_, K> {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 if self.hash_key {
208 let mut h = SipHasher24::new_with_keys(!0x113, 0x311);
209 self.key.hash(&mut h);
210 let hash = h.finish128().as_u128();
211
212 write!(f, "opentalk-cache:{}:{:x}", self.prefix, hash)
213 } else {
214 write!(f, "opentalk-cache:{}:{}", self.prefix, self.key)
215 }
216 }
217}
218
219impl<D: Display + Hash> ToRedisArgs for RedisCacheKey<'_, D> {
220 fn write_redis_args<W>(&self, out: &mut W)
221 where
222 W: ?Sized + redis::RedisWrite,
223 {
224 out.write_arg_fmt(self)
225 }
226}
227
228#[derive(Debug, Clone, Copy)]
229struct LocalEntry<V> {
230 value: V,
231 expires_at: Option<Instant>,
233}
234
235impl<V> LocalEntry<V> {
236 fn still_valid(&self) -> bool {
238 if let Some(exp) = self.expires_at {
239 exp.saturating_duration_since(Instant::now()) > Duration::ZERO
240 } else {
241 true
242 }
243 }
244}