1use tokio::sync::Notify;
2use tokio::time::{self, Duration, Instant};
3
4use std::collections::{BTreeMap, HashMap};
5use std::sync::{Arc, Mutex};
6use std::hash::Hash;
7
8
9#[derive(Debug)]
10pub struct DbDropGuard<K, Doc>
11where
12 Doc: Clone + Send + Sync + 'static,
13 K: Clone
14 + PartialOrd
15 + Ord
16 + PartialEq
17 + Eq
18 + Hash
19 + Send
20 + 'static
21{
22 storage: RedisStorage<K, Doc>,
23}
24
25#[derive(Debug, Clone)]
26pub struct RedisStorage<K, Doc>
27where
28 Doc: Clone + Send + Sync + 'static,
29 K: Clone
30 + PartialOrd
31 + Ord
32 + PartialEq
33 + Eq
34 + Hash
35 + Send
36 + 'static
37{
38 shared: Arc<Shared<K, Doc>>,
39}
40
41#[derive(Debug)]
42struct Shared<K, Doc>
43where
44 Doc: Clone + Send + Sync + 'static,
45 K: Clone
46 + PartialOrd
47 + Ord
48 + PartialEq
49 + Eq
50 + Hash
51 + Send
52 + 'static
53{
54 state: Mutex<State<K, Doc>>,
55 background_task: Notify,
56}
57
58#[derive(Debug)]
59struct State<K, Doc>
60where
61 Doc: Clone + Send + Sync + 'static,
62 K: Clone
63 + PartialOrd
64 + Ord
65 + PartialEq
66 + Eq
67 + Hash
68 + Send
69 + 'static
70{
71
72 entries: HashMap<K, Entry<Doc>>,
73 expirations: BTreeMap<(Instant, u64), K>,
74 next_id: u64,
75 shutdown: bool,
76}
77
78
79#[derive(Debug)]
80struct Entry<Doc> {
81 id: u64,
82 data: Arc<Doc>,
83 expires_at: Option<Instant>,
84}
85
86impl<K, Doc> DbDropGuard<K, Doc>
87where
88 Doc: Clone + Send + Sync + 'static,
89 K: Clone
90 + PartialOrd
91 + Ord
92 + PartialEq
93 + Eq
94 + Hash
95 + Send
96 + 'static
97{
98 pub(crate) fn new() -> DbDropGuard<K, Doc> {
99 DbDropGuard { storage: RedisStorage::new() }
100 }
101
102
103 pub(crate) fn storage(&self) -> RedisStorage<K, Doc> {
104 self.storage.clone()
105 }
106}
107
108impl<K, Doc> Drop for DbDropGuard<K, Doc>
109where
110 Doc: Clone + Send + Sync + 'static,
111 K: Clone
112 + PartialOrd
113 + Ord
114 + PartialEq
115 + Eq
116 + Hash
117 + Send
118 + 'static
119{
120 fn drop(&mut self) {
121 self.storage.shutdown_purge_task();
122 }
123}
124
125impl<K, Doc> RedisStorage<K, Doc>
126where
127 Doc: Clone + Send + Sync + 'static,
128 K: Clone
129 + PartialOrd
130 + Ord
131 + PartialEq
132 + Eq
133 + Hash
134 + Send
135 + 'static
136{
137
138 pub fn new() -> RedisStorage<K, Doc> {
139 let shared = Arc::new(Shared {
140 state: Mutex::new(State {
141 entries: HashMap::new(),
142 expirations: BTreeMap::new(),
143 next_id: 0,
144 shutdown: false,
145 }),
146 background_task: Notify::new(),
147 });
148
149 tokio::spawn(purge_expired_tasks(shared.clone()));
150
151 RedisStorage { shared }
152 }
153
154 pub fn get(&self, key: &K) -> Option<Arc<Doc>> {
155
156 let state = self.shared.state.lock().unwrap();
157 state.entries.get(key).map(|entry| entry.data.clone())
158 }
159
160 pub fn set(&self, key: K, value: Doc, expire: Option<Duration>) {
161 let mut state: std::sync::MutexGuard<'_, State<K, Doc>> = self.shared.state.lock().unwrap();
162
163
164 let id = state.next_id;
165 state.next_id += 1;
166
167
168 let mut notify = false;
169 let expires_at = expire.map(|duration| {
170
171 let when = Instant::now() + duration;
172 notify = state
173 .next_expiration()
174 .map(|expiration| expiration > when)
175 .unwrap_or(true);
176
177
178 state.expirations.insert((when, id), key.clone());
179 when
180 });
181
182 let prev = state.entries.insert(
183 key,
184 Entry {
185 id,
186 data: Arc::new(value),
187 expires_at,
188 },
189 );
190
191 if let Some(prev) = prev {
192 if let Some(when) = prev.expires_at {
193 state.expirations.remove(&(when, prev.id));
194 }
195 }
196
197 drop(state);
198
199 if notify {
200
201 self.shared.background_task.notify_one();
202 }
203 }
204
205 pub fn set_nx(&self, key: K, value: Doc, expire: Option<Duration>) -> bool {
206 let mut state = self.shared.state.lock().unwrap();
207
208 if state.entries.contains_key(&key) {
209 return false
210 }
211
212 let id = state.next_id;
213 state.next_id += 1;
214
215
216 let mut notify = false;
217
218 let expires_at = expire.map(|duration| {
219 let when = Instant::now() + duration;
220 notify = state
221 .next_expiration()
222 .map(|expiration| expiration > when)
223 .unwrap_or(true);
224
225 state.expirations.insert((when, id), key.clone());
226 when
227 });
228
229 let prev = state.entries.insert(
230 key,
231 Entry {
232 id,
233 data: Arc::new(value),
234 expires_at,
235 },
236 );
237
238 if let Some(prev) = prev {
239 if let Some(when) = prev.expires_at {
240
241 state.expirations.remove(&(when, prev.id));
242 }
243 }
244
245 drop(state);
246
247 if notify {
248 self.shared.background_task.notify_one();
249 }
250
251 return true;
252 }
253
254 pub fn del(&self, key: &K) {
255 let mut state = self.shared.state.lock().unwrap();
256 state.entries.remove(key);
257 }
258
259
260 pub fn len(&self) -> usize {
261 let mut state = self.shared.state.lock().unwrap();
262 return state.entries.len()
263 }
264
265 fn shutdown_purge_task(&self) {
266
267 let mut state = self.shared.state.lock().unwrap();
268 state.shutdown = true;
269
270 drop(state);
271 self.shared.background_task.notify_one();
272 }
273}
274
275impl<K, Doc> Shared<K, Doc>
276where
277 Doc: Clone + Send + Sync + 'static,
278 K: Clone
279 + PartialOrd
280 + Ord
281 + PartialEq
282 + Eq
283 + Hash
284 + Send
285 + 'static
286{
287
288 fn purge_expired_keys(&self) -> Option<Instant> {
289 let mut state = self.state.lock().unwrap();
290
291 if state.shutdown {
292 return None;
293 }
294
295 let state = &mut *state;
296 let now = Instant::now();
297
298 while let Some((&(when, id), key)) = state.expirations.iter().next() {
299 if when > now {
300 return Some(when);
301 }
302
303 state.entries.remove(key);
304 state.expirations.remove(&(when, id));
305 }
306
307 None
308 }
309
310 fn is_shutdown(&self) -> bool {
311 self.state.lock().unwrap().shutdown
312 }
313}
314
315impl<K, Doc> State<K, Doc>
316where
317 Doc: Clone + Send + Sync + 'static,
318 K: Clone
319 + PartialOrd
320 + Ord
321 + PartialEq
322 + Eq
323 + Hash
324 + Send
325 + 'static
326{
327 fn next_expiration(&self) -> Option<Instant> {
328 self.expirations
329 .keys()
330 .next()
331 .map(|expiration| expiration.0)
332 }
333}
334
335
336async fn purge_expired_tasks<K, Doc>(shared: Arc<Shared<K, Doc>>)
337where
338 Doc: Clone + Send + Sync + 'static,
339 K: Clone
340 + PartialOrd
341 + Ord
342 + PartialEq
343 + Eq
344 + Hash
345 + Send
346 + 'static
347{
348 while !shared.is_shutdown() {
349 if let Some(when) = shared.purge_expired_keys() {
350
351 tokio::select! {
352 _ = time::sleep_until(when) => {}
353 _ = shared.background_task.notified() => {}
354 }
355 } else {
356
357 shared.background_task.notified().await;
358 }
359 }
360}