darkbird/darkbird/
storage_redis.rs

1use tokio::sync::Notify;
2use tokio::time::{self, Duration, Instant};
3
4use std::collections::{BTreeMap, HashMap};
5use std::sync::{Arc, Mutex};
6use std::hash::Hash;
7
8
9#[derive(Debug)]
10pub struct DbDropGuard<K, Doc> 
11where
12    Doc: Clone + Send + Sync + 'static,
13    K:  Clone
14        + PartialOrd
15        + Ord
16        + PartialEq
17        + Eq
18        + Hash
19        + Send
20        + 'static
21{
22    storage: RedisStorage<K, Doc>,
23}
24
25#[derive(Debug, Clone)]
26pub struct RedisStorage<K, Doc> 
27where
28    Doc: Clone + Send + Sync + 'static,
29    K:  Clone
30        + PartialOrd
31        + Ord
32        + PartialEq
33        + Eq
34        + Hash
35        + Send
36        + 'static
37{
38    shared: Arc<Shared<K, Doc>>,
39}
40
41#[derive(Debug)]
42struct Shared<K, Doc> 
43where
44    Doc: Clone + Send + Sync + 'static,
45    K:  Clone
46        + PartialOrd
47        + Ord
48        + PartialEq
49        + Eq
50        + Hash
51        + Send
52        + 'static
53{
54    state: Mutex<State<K, Doc>>,
55    background_task: Notify,
56}
57
58#[derive(Debug)]
59struct State<K, Doc> 
60where
61    Doc: Clone + Send + Sync + 'static,
62    K:  Clone
63        + PartialOrd
64        + Ord
65        + PartialEq
66        + Eq
67        + Hash
68        + Send
69        + 'static
70{
71    
72    entries: HashMap<K, Entry<Doc>>,
73    expirations: BTreeMap<(Instant, u64), K>,
74    next_id: u64,
75    shutdown: bool,
76}
77
78
79#[derive(Debug)]
80struct Entry<Doc> {
81    id: u64,
82    data: Arc<Doc>,
83    expires_at: Option<Instant>,
84}
85
86impl<K, Doc> DbDropGuard<K, Doc> 
87where
88    Doc: Clone + Send + Sync + 'static,
89    K:  Clone
90        + PartialOrd
91        + Ord
92        + PartialEq
93        + Eq
94        + Hash
95        + Send
96        + 'static
97{
98    pub(crate) fn new() -> DbDropGuard<K, Doc> {
99        DbDropGuard { storage: RedisStorage::new() }
100    }
101
102
103    pub(crate) fn storage(&self) -> RedisStorage<K, Doc> {
104        self.storage.clone()
105    }
106}
107
108impl<K, Doc> Drop for DbDropGuard<K, Doc> 
109where
110    Doc: Clone + Send + Sync + 'static,
111    K:  Clone
112        + PartialOrd
113        + Ord
114        + PartialEq
115        + Eq
116        + Hash
117        + Send
118        + 'static
119{
120    fn drop(&mut self) {
121        self.storage.shutdown_purge_task();
122    }
123}
124
125impl<K, Doc> RedisStorage<K, Doc> 
126where
127    Doc: Clone + Send + Sync + 'static,
128    K:  Clone
129        + PartialOrd
130        + Ord
131        + PartialEq
132        + Eq
133        + Hash
134        + Send
135        + 'static
136{
137
138    pub fn new() -> RedisStorage<K, Doc> {
139        let shared = Arc::new(Shared {
140            state: Mutex::new(State {
141                entries: HashMap::new(),
142                expirations: BTreeMap::new(),
143                next_id: 0,
144                shutdown: false,
145            }),
146            background_task: Notify::new(),
147        });
148
149        tokio::spawn(purge_expired_tasks(shared.clone()));
150
151        RedisStorage { shared }
152    }
153
154    pub fn get(&self, key: &K) -> Option<Arc<Doc>> {
155    
156        let state = self.shared.state.lock().unwrap();
157        state.entries.get(key).map(|entry| entry.data.clone())
158    }
159
160    pub fn set(&self, key: K, value: Doc, expire: Option<Duration>) {
161        let mut state: std::sync::MutexGuard<'_, State<K, Doc>> = self.shared.state.lock().unwrap();
162
163        
164        let id = state.next_id;
165        state.next_id += 1;
166
167    
168        let mut notify = false;
169        let expires_at = expire.map(|duration| {
170            
171            let when = Instant::now() + duration;
172            notify = state
173                .next_expiration()
174                .map(|expiration| expiration > when)
175                .unwrap_or(true);
176
177        
178            state.expirations.insert((when, id), key.clone());
179            when
180        });
181
182        let prev = state.entries.insert(
183            key,
184            Entry {
185                id,
186                data: Arc::new(value),
187                expires_at,
188            },
189        );
190
191        if let Some(prev) = prev {
192            if let Some(when) = prev.expires_at {
193                state.expirations.remove(&(when, prev.id));
194            }
195        }
196
197        drop(state);
198
199        if notify {
200            
201            self.shared.background_task.notify_one();
202        }
203    }
204
205    pub fn set_nx(&self, key: K, value: Doc, expire: Option<Duration>) -> bool {
206        let mut state = self.shared.state.lock().unwrap();
207
208        if state.entries.contains_key(&key) {
209            return false
210        }
211
212        let id = state.next_id;
213        state.next_id += 1;
214
215
216        let mut notify = false;
217
218        let expires_at = expire.map(|duration| {
219            let when = Instant::now() + duration;
220            notify = state
221                .next_expiration()
222                .map(|expiration| expiration > when)
223                .unwrap_or(true);
224
225            state.expirations.insert((when, id), key.clone());
226            when
227        });
228
229        let prev = state.entries.insert(
230            key,
231            Entry {
232                id,
233                data: Arc::new(value),
234                expires_at,
235            },
236        );
237
238        if let Some(prev) = prev {
239            if let Some(when) = prev.expires_at {
240                
241                state.expirations.remove(&(when, prev.id));
242            }
243        }
244
245        drop(state);
246
247        if notify {
248            self.shared.background_task.notify_one();
249        }
250
251        return true;
252    }
253
254    pub fn del(&self, key: &K) {
255        let mut state = self.shared.state.lock().unwrap();
256        state.entries.remove(key);
257    }
258
259 
260    pub fn len(&self) -> usize {
261        let mut state = self.shared.state.lock().unwrap();
262        return state.entries.len()
263    }
264
265    fn shutdown_purge_task(&self) {
266
267        let mut state = self.shared.state.lock().unwrap();
268        state.shutdown = true;
269
270        drop(state);
271        self.shared.background_task.notify_one();
272    }
273}
274
275impl<K, Doc> Shared<K, Doc> 
276where
277    Doc: Clone + Send + Sync + 'static,
278    K:  Clone
279        + PartialOrd
280        + Ord
281        + PartialEq
282        + Eq
283        + Hash
284        + Send
285        + 'static
286{
287    
288    fn purge_expired_keys(&self) -> Option<Instant> {
289        let mut state = self.state.lock().unwrap();
290
291        if state.shutdown {
292            return None;
293        }
294
295        let state = &mut *state;
296        let now = Instant::now();
297
298        while let Some((&(when, id), key)) = state.expirations.iter().next() {
299            if when > now {
300                return Some(when);
301            }
302
303            state.entries.remove(key);
304            state.expirations.remove(&(when, id));
305        }
306
307        None
308    }
309
310    fn is_shutdown(&self) -> bool {
311        self.state.lock().unwrap().shutdown
312    }
313}
314
315impl<K, Doc> State<K, Doc> 
316where
317    Doc: Clone + Send + Sync + 'static,
318    K:  Clone
319        + PartialOrd
320        + Ord
321        + PartialEq
322        + Eq
323        + Hash
324        + Send
325        + 'static
326{
327    fn next_expiration(&self) -> Option<Instant> {
328        self.expirations
329            .keys()
330            .next()
331            .map(|expiration| expiration.0)
332    }
333}
334
335
336async fn purge_expired_tasks<K, Doc>(shared: Arc<Shared<K, Doc>>) 
337where
338    Doc: Clone + Send + Sync + 'static,
339    K:  Clone
340        + PartialOrd
341        + Ord
342        + PartialEq
343        + Eq
344        + Hash
345        + Send
346        + 'static
347{
348    while !shared.is_shutdown() {
349        if let Some(when) = shared.purge_expired_keys() {
350            
351            tokio::select! {
352                _ = time::sleep_until(when) => {}
353                _ = shared.background_task.notified() => {}
354            }
355        } else {
356            
357            shared.background_task.notified().await;
358        }
359    }
360}