mini_redis/db.rs
1use tokio::sync::{broadcast, Notify};
2use tokio::time::{self, Duration, Instant};
3
4use bytes::Bytes;
5use std::collections::{BTreeMap, HashMap};
6use std::sync::{Arc, Mutex};
7
8/// Server state shared across all connections.
9///
10/// `Db` contains a `HashMap` storing the key/value data and all
11/// `broadcast::Sender` values for active pub/sub channels.
12///
13/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and
14/// only incurs an atomic ref count increment.
15///
16/// When a `Db` value is created, a background task is spawned. This task is
17/// used to expire values after the requested duration has elapsed. The task
18/// runs until all instances of `Db` are dropped, at which point the task
19/// terminates.
20#[derive(Debug, Clone)]
21pub(crate) struct Db {
22 /// Handle to shared state. The background task will also have an
23 /// `Arc<Shared>`.
24 shared: Arc<Shared>,
25}
26
27#[derive(Debug)]
28struct Shared {
29 /// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and
30 /// not a Tokio mutex. This is because there are no asynchronous operations
31 /// being performed while holding the mutex. Additionally, the critical
32 /// sections are very small.
33 ///
34 /// A Tokio mutex is mostly intended to be used when locks need to be held
35 /// across `.await` yield points. All other cases are **usually** best
36 /// served by a std mutex. If the critical section does not include any
37 /// async operations but is long (CPU intensive or performing blocking
38 /// operations), then the entire operation, including waiting for the mutex,
39 /// is considered a "blocking" operation and `tokio::task::spawn_blocking`
40 /// should be used.
41 state: Mutex<State>,
42
43 /// Notifies the background task handling entry expiration. The background
44 /// task waits on this to be notified, then checks for expired values or the
45 /// shutdown signal.
46 background_task: Notify,
47}
48
49#[derive(Debug)]
50struct State {
51 /// The key-value data. We are not trying to do anything fancy so a
52 /// `std::collections::HashMap` works fine.
53 entries: HashMap<String, Entry>,
54
55 /// The pub/sub key-space. Redis uses a **separate** key space for key-value
56 /// and pub/sub. `mini-redis` handles this by using a separate `HashMap`.
57 pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
58
59 /// Tracks key TTLs.
60 ///
61 /// A `BTreeMap` is used to maintain expirations sorted by when they expire.
62 /// This allows the background task to iterate this map to find the value
63 /// expiring next.
64 ///
65 /// While highly unlikely, it is possible for more than one expiration to be
66 /// created for the same instant. Because of this, the `Instant` is
67 /// insufficient for the key. A unique expiration identifier (`u64`) is used
68 /// to break these ties.
69 expirations: BTreeMap<(Instant, u64), String>,
70
71 /// Identifier to use for the next expiration. Each expiration is associated
72 /// with a unique identifier. See above for why.
73 next_id: u64,
74
75 /// True when the Db instance is shutting down. This happens when all `Db`
76 /// values drop. Setting this to `true` signals to the background task to
77 /// exit.
78 shutdown: bool,
79}
80
81/// Entry in the key-value store
82#[derive(Debug)]
83struct Entry {
84 /// Uniquely identifies this entry.
85 id: u64,
86
87 /// Stored data
88 data: Bytes,
89
90 /// Instant at which the entry expires and should be removed from the
91 /// database.
92 expires_at: Option<Instant>,
93}
94
95impl Db {
96 /// Create a new, empty, `Db` instance. Allocates shared state and spawns a
97 /// background task to manage key expiration.
98 pub(crate) fn new() -> Db {
99 let shared = Arc::new(Shared {
100 state: Mutex::new(State {
101 entries: HashMap::new(),
102 pub_sub: HashMap::new(),
103 expirations: BTreeMap::new(),
104 next_id: 0,
105 shutdown: false,
106 }),
107 background_task: Notify::new(),
108 });
109
110 // Start the background task.
111 tokio::spawn(purge_expired_tasks(shared.clone()));
112
113 Db { shared }
114 }
115
116 /// Get the value associated with a key.
117 ///
118 /// Returns `None` if there is no value associated with the key. This may be
119 /// due to never having assigned a value to the key or a previously assigned
120 /// value expired.
121 pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
122 // Acquire the lock, get the entry and clone the value.
123 //
124 // Because data is stored using `Bytes`, a clone here is a shallow
125 // clone. Data is not copied.
126 let state = self.shared.state.lock().unwrap();
127 state.entries.get(key).map(|entry| entry.data.clone())
128 }
129
130 /// Set the value associated with a key along with an optional expiration
131 /// Duration.
132 ///
133 /// If a value is already associated with the key, it is removed.
134 pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
135 let mut state = self.shared.state.lock().unwrap();
136
137 // Get and increment the next insertion ID. Guarded by the lock, this
138 // ensures a unique identifier is associated with each `set` operation.
139 let id = state.next_id;
140 state.next_id += 1;
141
142 // If this `set` becomes the key that expires **next**, the background
143 // task needs to be notified so it can update its state.
144 //
145 // Whether or not the task needs to be notified is computed during the
146 // `set` routine.
147 let mut notify = false;
148
149 let expires_at = expire.map(|duration| {
150 // `Instant` at which the key expires.
151 let when = Instant::now() + duration;
152
153 // Only notify the worker task if the newly inserted expiration is the
154 // **next** key to evict. In this case, the worker needs to be woken up
155 // to update its state.
156 notify = state
157 .next_expiration()
158 .map(|expiration| expiration > when)
159 .unwrap_or(true);
160
161 // Track the expiration.
162 state.expirations.insert((when, id), key.clone());
163 when
164 });
165
166 // Insert the entry into the `HashMap`.
167 let prev = state.entries.insert(
168 key,
169 Entry {
170 id,
171 data: value,
172 expires_at,
173 },
174 );
175
176 // If there was a value previously associated with the key **and** it
177 // had an expiration time. The associated entry in the `expirations` map
178 // must also be removed. This avoids leaking data.
179 if let Some(prev) = prev {
180 if let Some(when) = prev.expires_at {
181 // clear expiration
182 state.expirations.remove(&(when, prev.id));
183 }
184 }
185
186 // Release the mutex before notifying the background task. This helps
187 // reduce contention by avoiding the background task waking up only to
188 // be unable to acquire the mutex due to this function still holding it.
189 drop(state);
190
191 if notify {
192 // Finally, only notify the background task if it needs to update
193 // its state to reflect a new expiration.
194 self.shared.background_task.notify_one();
195 }
196 }
197
198 /// Returns a `Receiver` for the requested channel.
199 ///
200 /// The returned `Receiver` is used to receive values broadcast by `PUBLISH`
201 /// commands.
202 pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
203 use std::collections::hash_map::Entry;
204
205 // Acquire the mutex
206 let mut state = self.shared.state.lock().unwrap();
207
208 // If there is no entry for the requested channel, then create a new
209 // broadcast channel and associate it with the key. If one already
210 // exists, return an associated receiver.
211 match state.pub_sub.entry(key) {
212 Entry::Occupied(e) => e.get().subscribe(),
213 Entry::Vacant(e) => {
214 // No broadcast channel exists yet, so create one.
215 //
216 // The channel is created with a capacity of `1024` messages. A
217 // message is stored in the channel until **all** subscribers
218 // have seen it. This means that a slow subscriber could result
219 // in messages being held indefinitely.
220 //
221 // When the channel's capacity fills up, publishing will result
222 // in old messages being dropped. This prevents slow consumers
223 // from blocking the entire system.
224 let (tx, rx) = broadcast::channel(1024);
225 e.insert(tx);
226 rx
227 }
228 }
229 }
230
231 /// Publish a message to the channel. Returns the number of subscribers
232 /// listening on the channel.
233 pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize {
234 let state = self.shared.state.lock().unwrap();
235
236 state
237 .pub_sub
238 .get(key)
239 // On a successful message send on the broadcast channel, the number
240 // of subscribers is returned. An error indicates there are no
241 // receivers, in which case, `0` should be returned.
242 .map(|tx| tx.send(value).unwrap_or(0))
243 // If there is no entry for the channel key, then there are no
244 // subscribers. In this case, return `0`.
245 .unwrap_or(0)
246 }
247}
248
249impl Drop for Db {
250 fn drop(&mut self) {
251 // If this is the last active `Db` instance, the background task must be
252 // notified to shut down.
253 //
254 // First, determine if this is the last `Db` instance. This is done by
255 // checking `strong_count`. The count will be 2. One for this `Db`
256 // instance and one for the handle held by the background task.
257 if Arc::strong_count(&self.shared) == 2 {
258 // The background task must be signaled to shutdown. This is done by
259 // setting `State::shutdown` to `true` and signalling the task.
260 let mut state = self.shared.state.lock().unwrap();
261 state.shutdown = true;
262
263 // Drop the lock before signalling the background task. This helps
264 // reduce lock contention by ensuring the background task doesn't
265 // wake up only to be unable to acquire the mutex.
266 drop(state);
267 self.shared.background_task.notify_one();
268 }
269 }
270}
271
272impl Shared {
273 /// Purge all expired keys and return the `Instant` at which the **next**
274 /// key will expire. The background task will sleep until this instant.
275 fn purge_expired_keys(&self) -> Option<Instant> {
276 let mut state = self.state.lock().unwrap();
277
278 if state.shutdown {
279 // The database is shutting down. All handles to the shared state
280 // have dropped. The background task should exit.
281 return None;
282 }
283
284 // This is needed to make the borrow checker happy. In short, `lock()`
285 // returns a `MutexGuard` and not a `&mut State`. The borrow checker is
286 // not able to see "through" the mutex guard and determine that it is
287 // safe to access both `state.expirations` and `state.entries` mutably,
288 // so we get a "real" mutable reference to `State` outside of the loop.
289 let state = &mut *state;
290
291 // Find all keys scheduled to expire **before** now.
292 let now = Instant::now();
293
294 while let Some((&(when, id), key)) = state.expirations.iter().next() {
295 if when > now {
296 // Done purging, `when` is the instant at which the next key
297 // expires. The worker task will wait until this instant.
298 return Some(when);
299 }
300
301 // The key expired, remove it
302 state.entries.remove(key);
303 state.expirations.remove(&(when, id));
304 }
305
306 None
307 }
308
309 /// Returns `true` if the database is shutting down
310 ///
311 /// The `shutdown` flag is set when all `Db` values have dropped, indicating
312 /// that the shared state can no longer be accessed.
313 fn is_shutdown(&self) -> bool {
314 self.state.lock().unwrap().shutdown
315 }
316}
317
318impl State {
319 fn next_expiration(&self) -> Option<Instant> {
320 self.expirations
321 .keys()
322 .next()
323 .map(|expiration| expiration.0)
324 }
325}
326
327/// Routine executed by the background task.
328///
329/// Wait to be notified. On notification, purge any expired keys from the shared
330/// state handle. If `shutdown` is set, terminate the task.
331async fn purge_expired_tasks(shared: Arc<Shared>) {
332 // If the shutdown flag is set, then the task should exit.
333 while !shared.is_shutdown() {
334 // Purge all keys that are expired. The function returns the instant at
335 // which the **next** key will expire. The worker should wait until the
336 // instant has passed then purge again.
337 if let Some(when) = shared.purge_expired_keys() {
338 // Wait until the next key expires **or** until the background task
339 // is notified. If the task is notified, then it must reload its
340 // state as new keys have been set to expire early. This is done by
341 // looping.
342 tokio::select! {
343 _ = time::sleep_until(when) => {}
344 _ = shared.background_task.notified() => {}
345 }
346 } else {
347 // There are no keys expiring in the future. Wait until the task is
348 // notified.
349 shared.background_task.notified().await;
350 }
351 }
352}