1use tokio::sync::{broadcast, Notify};
2use tokio::time::{self, Duration, Instant};
3
4use bytes::Bytes;
5use serde::{Deserialize, Serialize};
6use std::collections::{BTreeSet, HashMap};
7use std::path::PathBuf;
8use std::sync::{Arc, Mutex};
9use tracing::debug;
10
11#[derive(Debug)]
12pub struct DbDropGuard {
13 db: Db,
14}
15
16#[derive(Debug, Clone)]
17pub struct Db {
18 shared: Arc<Shared>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct SerializableState {
23 entries: HashMap<String, SerializableEntry>,
24 expirations: Vec<(u64, String)>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28pub struct SerializableEntry {
29 data: Vec<u8>,
30 expires_at: Option<u64>,
31}
32
33#[derive(Debug)]
34struct Shared {
35 state: Mutex<State>,
36 background_task: Notify,
37}
38
39#[derive(Debug)]
40struct State {
41 entries: HashMap<String, Entry>,
42 pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
43 expirations: BTreeSet<(Instant, String)>,
44 shutdown: bool,
45}
46
47#[derive(Debug)]
48struct Entry {
49 data: Bytes,
50 expires_at: Option<Instant>,
51}
52
53impl DbDropGuard {
54 pub fn new() -> DbDropGuard { DbDropGuard { db: Db::new() } }
55 pub fn db(&self) -> Db { self.db.clone() }
56}
57
58impl Drop for DbDropGuard {
59 fn drop(&mut self) { self.db.shutdown_purge_task(); }
60}
61
62impl Db {
63 pub fn new() -> Db {
64 let shared = Arc::new(Shared {
65 state: Mutex::new(State {
66 entries: HashMap::new(),
67 pub_sub: HashMap::new(),
68 expirations: BTreeSet::new(),
69 shutdown: false,
70 }),
71 background_task: Notify::new(),
72 });
73
74 tokio::spawn(purge_expired_tasks(shared.clone()));
75
76 Db { shared }
77 }
78
79 pub fn get(&self, key: &str) -> Option<Bytes> {
80 let state = self.shared.state.lock().unwrap();
81 state.entries.get(key).map(|entry| entry.data.clone())
82 }
83
84 pub fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
85 let mut state = self.shared.state.lock().unwrap();
86 let mut notify = false;
87
88 let expires_at = expire.map(|duration| {
89 let when = Instant::now() + duration;
90 notify = state.next_expiration().map(|expiration| expiration > when).unwrap_or(true);
91
92 when
93 });
94
95 let prev = state.entries.insert(key.clone(), Entry { data: value, expires_at });
96
97 if let Some(prev) = prev {
98 if let Some(when) = prev.expires_at {
99 state.expirations.remove(&(when, key.clone()));
100 }
101 }
102
103 if let Some(when) = expires_at {
104 state.expirations.insert((when, key));
105 }
106
107 drop(state);
108
109 if notify {
110 self.shared.background_task.notify_one();
111 }
112 }
113
114 pub fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
115 use std::collections::hash_map::Entry;
116 let mut state = self.shared.state.lock().unwrap();
117
118 match state.pub_sub.entry(key) {
119 Entry::Occupied(e) => e.get().subscribe(),
120 Entry::Vacant(e) => {
121 let (tx, rx) = broadcast::channel(1024);
122 e.insert(tx);
123 rx
124 }
125 }
126 }
127
128 pub fn publish(&self, key: &str, value: Bytes) -> usize {
129 let state = self.shared.state.lock().unwrap();
130 state.pub_sub.get(key).map(|tx| tx.send(value).unwrap_or(0)).unwrap_or(0)
131 }
132
133 pub fn dump(&self) -> SerializableState {
134 let state = self.shared.state.lock().unwrap();
135 let now = Instant::now();
136
137 SerializableState {
138 entries: state
139 .entries
140 .iter()
141 .map(|(k, v)| {
142 (
143 k.clone(),
144 SerializableEntry {
145 data: v.data.to_vec(),
146 expires_at: v.expires_at.map(|instant| instant.duration_since(now).as_secs()),
147 },
148 )
149 })
150 .collect(),
151 expirations: state.expirations.iter().map(|(instant, key)| (instant.duration_since(now).as_secs(), key.clone())).collect(),
152 }
153 }
154
155 pub fn load(&self, serializable_state: SerializableState) {
156 let mut state = self.shared.state.lock().unwrap();
157 let now = Instant::now();
158
159 state.entries = serializable_state
160 .entries
161 .into_iter()
162 .map(|(k, v)| {
163 (
164 k.clone(),
165 Entry {
166 data: Bytes::from(v.data),
167 expires_at: v.expires_at.map(|secs| now + Duration::from_secs(secs)),
168 },
169 )
170 })
171 .collect();
172
173 state.expirations = serializable_state.expirations.into_iter().map(|(secs, key)| (now + Duration::from_secs(secs), key)).collect();
174 }
175
176 pub async fn dump_to(&self, path: &PathBuf) -> crate::Result<()> {
177 let serializable_state = self.dump();
178 let serialized = bincode::serialize(&serializable_state)?;
179 tokio::fs::write(path, serialized).await?;
180 Ok(())
181 }
182
183 pub async fn load_from(&self, path: &PathBuf) -> crate::Result<()> {
184 let serialized = tokio::fs::read(path).await?;
185 let serializable_state: SerializableState = bincode::deserialize(&serialized)?;
186 self.load(serializable_state);
187 Ok(())
188 }
189
190 fn shutdown_purge_task(&self) {
191 let mut state = self.shared.state.lock().unwrap();
192 state.shutdown = true;
193
194 drop(state);
195 self.shared.background_task.notify_one();
196 }
197}
198
199impl Shared {
200 fn purge_expired_keys(&self) -> Option<Instant> {
201 let mut state = self.state.lock().unwrap();
202
203 if state.shutdown {
204 return None;
205 }
206
207 let state = &mut *state;
208 let now = Instant::now();
209
210 while let Some(&(when, ref key)) = state.expirations.iter().next() {
211 if when > now {
212 return Some(when);
213 }
214
215 state.entries.remove(key);
216 state.expirations.remove(&(when, key.clone()));
217 }
218
219 None
220 }
221
222 fn is_shutdown(&self) -> bool { self.state.lock().unwrap().shutdown }
223}
224
225impl State {
226 fn next_expiration(&self) -> Option<Instant> { self.expirations.iter().next().map(|expiration| expiration.0) }
227}
228
229async fn purge_expired_tasks(shared: Arc<Shared>) {
230 while !shared.is_shutdown() {
231 if let Some(when) = shared.purge_expired_keys() {
232 tokio::select! {
233 _ = time::sleep_until(when) => {}
234 _ = shared.background_task.notified() => {}
235 }
236 } else {
237 shared.background_task.notified().await;
238 }
239 }
240
241 debug!("Purge background task shut down")
242}