mini_redis/
db.rs

1use tokio::sync::{broadcast, Notify};
2use tokio::time::{self, Duration, Instant};
3
4use bytes::Bytes;
5use std::collections::{BTreeMap, HashMap};
6use std::sync::{Arc, Mutex};
7
8/// Server state shared across all connections.
9///
10/// `Db` contains a `HashMap` storing the key/value data and all
11/// `broadcast::Sender` values for active pub/sub channels.
12///
13/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and
14/// only incurs an atomic ref count increment.
15///
16/// When a `Db` value is created, a background task is spawned. This task is
17/// used to expire values after the requested duration has elapsed. The task
18/// runs until all instances of `Db` are dropped, at which point the task
19/// terminates.
20#[derive(Debug, Clone)]
21pub(crate) struct Db {
22    /// Handle to shared state. The background task will also have an
23    /// `Arc<Shared>`.
24    shared: Arc<Shared>,
25}
26
27#[derive(Debug)]
28struct Shared {
29    /// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and
30    /// not a Tokio mutex. This is because there are no asynchronous operations
31    /// being performed while holding the mutex. Additionally, the critical
32    /// sections are very small.
33    ///
34    /// A Tokio mutex is mostly intended to be used when locks need to be held
35    /// across `.await` yield points. All other cases are **usually** best
36    /// served by a std mutex. If the critical section does not include any
37    /// async operations but is long (CPU intensive or performing blocking
38    /// operations), then the entire operation, including waiting for the mutex,
39    /// is considered a "blocking" operation and `tokio::task::spawn_blocking`
40    /// should be used.
41    state: Mutex<State>,
42
43    /// Notifies the background task handling entry expiration. The background
44    /// task waits on this to be notified, then checks for expired values or the
45    /// shutdown signal.
46    background_task: Notify,
47}
48
49#[derive(Debug)]
50struct State {
51    /// The key-value data. We are not trying to do anything fancy so a
52    /// `std::collections::HashMap` works fine.
53    entries: HashMap<String, Entry>,
54
55    /// The pub/sub key-space. Redis uses a **separate** key space for key-value
56    /// and pub/sub. `mini-redis` handles this by using a separate `HashMap`.
57    pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
58
59    /// Tracks key TTLs.
60    ///
61    /// A `BTreeMap` is used to maintain expirations sorted by when they expire.
62    /// This allows the background task to iterate this map to find the value
63    /// expiring next.
64    ///
65    /// While highly unlikely, it is possible for more than one expiration to be
66    /// created for the same instant. Because of this, the `Instant` is
67    /// insufficient for the key. A unique expiration identifier (`u64`) is used
68    /// to break these ties.
69    expirations: BTreeMap<(Instant, u64), String>,
70
71    /// Identifier to use for the next expiration. Each expiration is associated
72    /// with a unique identifier. See above for why.
73    next_id: u64,
74
75    /// True when the Db instance is shutting down. This happens when all `Db`
76    /// values drop. Setting this to `true` signals to the background task to
77    /// exit.
78    shutdown: bool,
79}
80
81/// Entry in the key-value store
82#[derive(Debug)]
83struct Entry {
84    /// Uniquely identifies this entry.
85    id: u64,
86
87    /// Stored data
88    data: Bytes,
89
90    /// Instant at which the entry expires and should be removed from the
91    /// database.
92    expires_at: Option<Instant>,
93}
94
95impl Db {
96    /// Create a new, empty, `Db` instance. Allocates shared state and spawns a
97    /// background task to manage key expiration.
98    pub(crate) fn new() -> Db {
99        let shared = Arc::new(Shared {
100            state: Mutex::new(State {
101                entries: HashMap::new(),
102                pub_sub: HashMap::new(),
103                expirations: BTreeMap::new(),
104                next_id: 0,
105                shutdown: false,
106            }),
107            background_task: Notify::new(),
108        });
109
110        // Start the background task.
111        tokio::spawn(purge_expired_tasks(shared.clone()));
112
113        Db { shared }
114    }
115
116    /// Get the value associated with a key.
117    ///
118    /// Returns `None` if there is no value associated with the key. This may be
119    /// due to never having assigned a value to the key or a previously assigned
120    /// value expired.
121    pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
122        // Acquire the lock, get the entry and clone the value.
123        //
124        // Because data is stored using `Bytes`, a clone here is a shallow
125        // clone. Data is not copied.
126        let state = self.shared.state.lock().unwrap();
127        state.entries.get(key).map(|entry| entry.data.clone())
128    }
129
130    /// Set the value associated with a key along with an optional expiration
131    /// Duration.
132    ///
133    /// If a value is already associated with the key, it is removed.
134    pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
135        let mut state = self.shared.state.lock().unwrap();
136
137        // Get and increment the next insertion ID. Guarded by the lock, this
138        // ensures a unique identifier is associated with each `set` operation.
139        let id = state.next_id;
140        state.next_id += 1;
141
142        // If this `set` becomes the key that expires **next**, the background
143        // task needs to be notified so it can update its state.
144        //
145        // Whether or not the task needs to be notified is computed during the
146        // `set` routine.
147        let mut notify = false;
148
149        let expires_at = expire.map(|duration| {
150            // `Instant` at which the key expires.
151            let when = Instant::now() + duration;
152
153            // Only notify the worker task if the newly inserted expiration is the
154            // **next** key to evict. In this case, the worker needs to be woken up
155            // to update its state.
156            notify = state
157                .next_expiration()
158                .map(|expiration| expiration > when)
159                .unwrap_or(true);
160
161            // Track the expiration.
162            state.expirations.insert((when, id), key.clone());
163            when
164        });
165
166        // Insert the entry into the `HashMap`.
167        let prev = state.entries.insert(
168            key,
169            Entry {
170                id,
171                data: value,
172                expires_at,
173            },
174        );
175
176        // If there was a value previously associated with the key **and** it
177        // had an expiration time. The associated entry in the `expirations` map
178        // must also be removed. This avoids leaking data.
179        if let Some(prev) = prev {
180            if let Some(when) = prev.expires_at {
181                // clear expiration
182                state.expirations.remove(&(when, prev.id));
183            }
184        }
185
186        // Release the mutex before notifying the background task. This helps
187        // reduce contention by avoiding the background task waking up only to
188        // be unable to acquire the mutex due to this function still holding it.
189        drop(state);
190
191        if notify {
192            // Finally, only notify the background task if it needs to update
193            // its state to reflect a new expiration.
194            self.shared.background_task.notify_one();
195        }
196    }
197
198    /// Returns a `Receiver` for the requested channel.
199    ///
200    /// The returned `Receiver` is used to receive values broadcast by `PUBLISH`
201    /// commands.
202    pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
203        use std::collections::hash_map::Entry;
204
205        // Acquire the mutex
206        let mut state = self.shared.state.lock().unwrap();
207
208        // If there is no entry for the requested channel, then create a new
209        // broadcast channel and associate it with the key. If one already
210        // exists, return an associated receiver.
211        match state.pub_sub.entry(key) {
212            Entry::Occupied(e) => e.get().subscribe(),
213            Entry::Vacant(e) => {
214                // No broadcast channel exists yet, so create one.
215                //
216                // The channel is created with a capacity of `1024` messages. A
217                // message is stored in the channel until **all** subscribers
218                // have seen it. This means that a slow subscriber could result
219                // in messages being held indefinitely.
220                //
221                // When the channel's capacity fills up, publishing will result
222                // in old messages being dropped. This prevents slow consumers
223                // from blocking the entire system.
224                let (tx, rx) = broadcast::channel(1024);
225                e.insert(tx);
226                rx
227            }
228        }
229    }
230
231    /// Publish a message to the channel. Returns the number of subscribers
232    /// listening on the channel.
233    pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize {
234        let state = self.shared.state.lock().unwrap();
235
236        state
237            .pub_sub
238            .get(key)
239            // On a successful message send on the broadcast channel, the number
240            // of subscribers is returned. An error indicates there are no
241            // receivers, in which case, `0` should be returned.
242            .map(|tx| tx.send(value).unwrap_or(0))
243            // If there is no entry for the channel key, then there are no
244            // subscribers. In this case, return `0`.
245            .unwrap_or(0)
246    }
247}
248
249impl Drop for Db {
250    fn drop(&mut self) {
251        // If this is the last active `Db` instance, the background task must be
252        // notified to shut down.
253        //
254        // First, determine if this is the last `Db` instance. This is done by
255        // checking `strong_count`. The count will be 2. One for this `Db`
256        // instance and one for the handle held by the background task.
257        if Arc::strong_count(&self.shared) == 2 {
258            // The background task must be signaled to shutdown. This is done by
259            // setting `State::shutdown` to `true` and signalling the task.
260            let mut state = self.shared.state.lock().unwrap();
261            state.shutdown = true;
262
263            // Drop the lock before signalling the background task. This helps
264            // reduce lock contention by ensuring the background task doesn't
265            // wake up only to be unable to acquire the mutex.
266            drop(state);
267            self.shared.background_task.notify_one();
268        }
269    }
270}
271
272impl Shared {
273    /// Purge all expired keys and return the `Instant` at which the **next**
274    /// key will expire. The background task will sleep until this instant.
275    fn purge_expired_keys(&self) -> Option<Instant> {
276        let mut state = self.state.lock().unwrap();
277
278        if state.shutdown {
279            // The database is shutting down. All handles to the shared state
280            // have dropped. The background task should exit.
281            return None;
282        }
283
284        // This is needed to make the borrow checker happy. In short, `lock()`
285        // returns a `MutexGuard` and not a `&mut State`. The borrow checker is
286        // not able to see "through" the mutex guard and determine that it is
287        // safe to access both `state.expirations` and `state.entries` mutably,
288        // so we get a "real" mutable reference to `State` outside of the loop.
289        let state = &mut *state;
290
291        // Find all keys scheduled to expire **before** now.
292        let now = Instant::now();
293
294        while let Some((&(when, id), key)) = state.expirations.iter().next() {
295            if when > now {
296                // Done purging, `when` is the instant at which the next key
297                // expires. The worker task will wait until this instant.
298                return Some(when);
299            }
300
301            // The key expired, remove it
302            state.entries.remove(key);
303            state.expirations.remove(&(when, id));
304        }
305
306        None
307    }
308
309    /// Returns `true` if the database is shutting down
310    ///
311    /// The `shutdown` flag is set when all `Db` values have dropped, indicating
312    /// that the shared state can no longer be accessed.
313    fn is_shutdown(&self) -> bool {
314        self.state.lock().unwrap().shutdown
315    }
316}
317
318impl State {
319    fn next_expiration(&self) -> Option<Instant> {
320        self.expirations
321            .keys()
322            .next()
323            .map(|expiration| expiration.0)
324    }
325}
326
327/// Routine executed by the background task.
328///
329/// Wait to be notified. On notification, purge any expired keys from the shared
330/// state handle. If `shutdown` is set, terminate the task.
331async fn purge_expired_tasks(shared: Arc<Shared>) {
332    // If the shutdown flag is set, then the task should exit.
333    while !shared.is_shutdown() {
334        // Purge all keys that are expired. The function returns the instant at
335        // which the **next** key will expire. The worker should wait until the
336        // instant has passed then purge again.
337        if let Some(when) = shared.purge_expired_keys() {
338            // Wait until the next key expires **or** until the background task
339            // is notified. If the task is notified, then it must reload its
340            // state as new keys have been set to expire early. This is done by
341            // looping.
342            tokio::select! {
343                _ = time::sleep_until(when) => {}
344                _ = shared.background_task.notified() => {}
345            }
346        } else {
347            // There are no keys expiring in the future. Wait until the task is
348            // notified.
349            shared.background_task.notified().await;
350        }
351    }
352}