kade_proto/pkg/
db.rs

1use tokio::sync::{broadcast, Notify};
2use tokio::time::{self, Duration, Instant};
3
4use bytes::Bytes;
5use serde::{Deserialize, Serialize};
6use std::collections::{BTreeSet, HashMap};
7use std::path::PathBuf;
8use std::sync::{Arc, Mutex};
9use tracing::debug;
10
11#[derive(Debug)]
12pub struct DbDropGuard {
13    db: Db,
14}
15
16#[derive(Debug, Clone)]
17pub struct Db {
18    shared: Arc<Shared>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct SerializableState {
23    entries: HashMap<String, SerializableEntry>,
24    expirations: Vec<(u64, String)>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28pub struct SerializableEntry {
29    data: Vec<u8>,
30    expires_at: Option<u64>,
31}
32
33#[derive(Debug)]
34struct Shared {
35    state: Mutex<State>,
36    background_task: Notify,
37}
38
39#[derive(Debug)]
40struct State {
41    entries: HashMap<String, Entry>,
42    pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
43    expirations: BTreeSet<(Instant, String)>,
44    shutdown: bool,
45}
46
47#[derive(Debug)]
48struct Entry {
49    data: Bytes,
50    expires_at: Option<Instant>,
51}
52
53impl DbDropGuard {
54    pub fn new() -> DbDropGuard { DbDropGuard { db: Db::new() } }
55    pub fn db(&self) -> Db { self.db.clone() }
56}
57
58impl Drop for DbDropGuard {
59    fn drop(&mut self) { self.db.shutdown_purge_task(); }
60}
61
62impl Db {
63    pub fn new() -> Db {
64        let shared = Arc::new(Shared {
65            state: Mutex::new(State {
66                entries: HashMap::new(),
67                pub_sub: HashMap::new(),
68                expirations: BTreeSet::new(),
69                shutdown: false,
70            }),
71            background_task: Notify::new(),
72        });
73
74        tokio::spawn(purge_expired_tasks(shared.clone()));
75
76        Db { shared }
77    }
78
79    pub fn get(&self, key: &str) -> Option<Bytes> {
80        let state = self.shared.state.lock().unwrap();
81        state.entries.get(key).map(|entry| entry.data.clone())
82    }
83
84    pub fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
85        let mut state = self.shared.state.lock().unwrap();
86        let mut notify = false;
87
88        let expires_at = expire.map(|duration| {
89            let when = Instant::now() + duration;
90            notify = state.next_expiration().map(|expiration| expiration > when).unwrap_or(true);
91
92            when
93        });
94
95        let prev = state.entries.insert(key.clone(), Entry { data: value, expires_at });
96
97        if let Some(prev) = prev {
98            if let Some(when) = prev.expires_at {
99                state.expirations.remove(&(when, key.clone()));
100            }
101        }
102
103        if let Some(when) = expires_at {
104            state.expirations.insert((when, key));
105        }
106
107        drop(state);
108
109        if notify {
110            self.shared.background_task.notify_one();
111        }
112    }
113
114    pub fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
115        use std::collections::hash_map::Entry;
116        let mut state = self.shared.state.lock().unwrap();
117
118        match state.pub_sub.entry(key) {
119            Entry::Occupied(e) => e.get().subscribe(),
120            Entry::Vacant(e) => {
121                let (tx, rx) = broadcast::channel(1024);
122                e.insert(tx);
123                rx
124            }
125        }
126    }
127
128    pub fn publish(&self, key: &str, value: Bytes) -> usize {
129        let state = self.shared.state.lock().unwrap();
130        state.pub_sub.get(key).map(|tx| tx.send(value).unwrap_or(0)).unwrap_or(0)
131    }
132
133    pub fn dump(&self) -> SerializableState {
134        let state = self.shared.state.lock().unwrap();
135        let now = Instant::now();
136
137        SerializableState {
138            entries: state
139                .entries
140                .iter()
141                .map(|(k, v)| {
142                    (
143                        k.clone(),
144                        SerializableEntry {
145                            data: v.data.to_vec(),
146                            expires_at: v.expires_at.map(|instant| instant.duration_since(now).as_secs()),
147                        },
148                    )
149                })
150                .collect(),
151            expirations: state.expirations.iter().map(|(instant, key)| (instant.duration_since(now).as_secs(), key.clone())).collect(),
152        }
153    }
154
155    pub fn load(&self, serializable_state: SerializableState) {
156        let mut state = self.shared.state.lock().unwrap();
157        let now = Instant::now();
158
159        state.entries = serializable_state
160            .entries
161            .into_iter()
162            .map(|(k, v)| {
163                (
164                    k.clone(),
165                    Entry {
166                        data: Bytes::from(v.data),
167                        expires_at: v.expires_at.map(|secs| now + Duration::from_secs(secs)),
168                    },
169                )
170            })
171            .collect();
172
173        state.expirations = serializable_state.expirations.into_iter().map(|(secs, key)| (now + Duration::from_secs(secs), key)).collect();
174    }
175
176    pub async fn dump_to(&self, path: &PathBuf) -> crate::Result<()> {
177        let serializable_state = self.dump();
178        let serialized = bincode::serialize(&serializable_state)?;
179        tokio::fs::write(path, serialized).await?;
180        Ok(())
181    }
182
183    pub async fn load_from(&self, path: &PathBuf) -> crate::Result<()> {
184        let serialized = tokio::fs::read(path).await?;
185        let serializable_state: SerializableState = bincode::deserialize(&serialized)?;
186        self.load(serializable_state);
187        Ok(())
188    }
189
190    fn shutdown_purge_task(&self) {
191        let mut state = self.shared.state.lock().unwrap();
192        state.shutdown = true;
193
194        drop(state);
195        self.shared.background_task.notify_one();
196    }
197}
198
199impl Shared {
200    fn purge_expired_keys(&self) -> Option<Instant> {
201        let mut state = self.state.lock().unwrap();
202
203        if state.shutdown {
204            return None;
205        }
206
207        let state = &mut *state;
208        let now = Instant::now();
209
210        while let Some(&(when, ref key)) = state.expirations.iter().next() {
211            if when > now {
212                return Some(when);
213            }
214
215            state.entries.remove(key);
216            state.expirations.remove(&(when, key.clone()));
217        }
218
219        None
220    }
221
222    fn is_shutdown(&self) -> bool { self.state.lock().unwrap().shutdown }
223}
224
225impl State {
226    fn next_expiration(&self) -> Option<Instant> { self.expirations.iter().next().map(|expiration| expiration.0) }
227}
228
229async fn purge_expired_tasks(shared: Arc<Shared>) {
230    while !shared.is_shutdown() {
231        if let Some(when) = shared.purge_expired_keys() {
232            tokio::select! {
233                _ = time::sleep_until(when) => {}
234                _ = shared.background_task.notified() => {}
235            }
236        } else {
237            shared.background_task.notified().await;
238        }
239    }
240
241    debug!("Purge background task shut down")
242}