epoch_db/db/
mod.rs

1//! The `db` module contains the core logic for the TransientDB database.
2//! It includes the `DB` struct and its implementation, which provides the
3//! primary API for interacting with the database.
4
5pub mod errors;
6
7use errors::TransientError;
8use sled::{
9    Config,
10    transaction::{ConflictableTransactionError, TransactionError, Transactional},
11};
12use std::{
13    error::Error,
14    path::Path,
15    str::from_utf8,
16    sync::{Arc, atomic::AtomicBool},
17    thread::{self, JoinHandle},
18    time::{Duration, SystemTime, UNIX_EPOCH},
19};
20
21use crate::{DB, Metadata};
22
23impl DB {
24    /// Creates a new `DB` instance or opens an existing one at the specified path.
25    ///
26    /// This function initializes the underlying `sled` database, opens the required
27    /// data trees (`data_tree`, `meta_tree`, `ttl_tree`), and spawns a background
28    /// thread to handle TTL expirations.
29    ///
30    /// # Errors
31    ///
32    /// Returns a `sled::Error` if the database cannot be opened at the given path.
33    pub fn new(path: &Path) -> Result<DB, sled::Error> {
34        let db = Config::new()
35            .path(path)
36            .cache_capacity(512 * 1024 * 1024)
37            .open()?;
38
39        let data_tree = Arc::new(db.open_tree("data_tree")?);
40        let meta_tree = Arc::new(db.open_tree("freq_tree")?);
41        let ttl_tree = Arc::new(db.open_tree("ttl_tree")?);
42
43        let ttl_tree_clone = Arc::clone(&ttl_tree);
44        let meta_tree_clone = Arc::clone(&meta_tree);
45        let data_tree_clone = Arc::clone(&data_tree);
46
47        let shutdown: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
48        let shutdown_clone = Arc::clone(&shutdown);
49
50        // TODO: Later have a clean up thread that checks if the following thread is fine and spawn
51        // it back and join the thread lol
52
53        let thread: JoinHandle<Result<(), TransientError>> = thread::spawn(move || {
54            loop {
55                thread::sleep(Duration::new(0, 100000000));
56
57                if shutdown_clone.load(std::sync::atomic::Ordering::SeqCst) {
58                    break;
59                }
60
61                let keys = ttl_tree_clone.iter();
62
63                for i in keys {
64                    let full_key = i.map_err(|e| TransientError::SledError { error: e })?;
65
66                    // NOTE: The reason time is 14 u8s long is because it is being stored like
67                    // this ([time,key], key) not ((time,key), key)
68                    let key = full_key.0;
69                    let key_byte = full_key.1;
70
71                    if key.len() < 8 {
72                        Err(TransientError::ParsingToU64ByteFailed)?
73                    }
74
75                    let time_byte: [u8; 8] = (&key[..8])
76                        .try_into()
77                        .map_err(|_| TransientError::ParsingToByteError)?;
78
79                    let time = u64::from_be_bytes(time_byte);
80                    let curr_time = SystemTime::now()
81                        .duration_since(UNIX_EPOCH)
82                        .expect("Cant get SystemTime")
83                        .as_secs();
84
85                    if curr_time >= time {
86                        let l: Result<(), TransactionError<()>> =
87                            (&*data_tree_clone, &*meta_tree_clone, &*ttl_tree_clone).transaction(
88                                |(data, freq, ttl_tree_clone)| {
89                                    let byte = &key_byte;
90                                    data.remove(byte)?;
91                                    freq.remove(byte)?;
92
93                                    let _ = ttl_tree_clone.remove([&time_byte, &byte[..]].concat());
94
95                                    Ok(())
96                                },
97                            );
98                        l.map_err(|_| TransientError::SledTransactionError)?;
99                    } else {
100                        continue;
101                    }
102                }
103            }
104            Ok(())
105        });
106        Ok(DB {
107            data_tree,
108            meta_tree,
109            ttl_tree,
110            ttl_thread: Some(thread),
111            shutdown,
112        })
113    }
114
115    /// Sets a key-value pair with an optional Time-To-Live (TTL).
116    ///
117    /// If the key already exists, its value and TTL will be updated.
118    /// If `ttl` is `None`, the key will be persistent.
119    ///
120    /// # Errors
121    ///
122    /// This function can return an error if there's an issue with the underlying
123    pub fn set(&self, key: &str, val: &str, ttl: Option<Duration>) -> Result<(), Box<dyn Error>> {
124        let data_tree = &self.data_tree;
125        let freq_tree = &self.meta_tree;
126        let ttl_tree = &self.ttl_tree;
127        let byte = key.as_bytes();
128        let ttl_sec = match ttl {
129            Some(t) => {
130                let systime = SystemTime::now()
131                    .duration_since(UNIX_EPOCH)
132                    .expect("Cant get SystemTime");
133                Some((t + systime).as_secs())
134            }
135            None => None,
136        };
137
138        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
139            .transaction(|(data, freq, ttl_tree)| {
140                match freq.get(byte)? {
141                    Some(m) => {
142                        let mut meta = Metadata::from_u8(&m.to_vec())
143                            .map_err(|_| ConflictableTransactionError::Abort(()))?;
144                        if let Some(t) = meta.ttl {
145                            let _ = ttl_tree.remove([&t.to_be_bytes()[..], &byte[..]].concat());
146                        }
147                        meta.ttl = ttl_sec;
148                        freq.insert(
149                            byte,
150                            meta.to_u8()
151                                .map_err(|_| ConflictableTransactionError::Abort(()))?,
152                        )?;
153                    }
154                    None => {
155                        freq.insert(
156                            byte,
157                            Metadata::new(ttl_sec)
158                                .to_u8()
159                                .map_err(|_| ConflictableTransactionError::Abort(()))?,
160                        )?;
161                    }
162                }
163
164                data.insert(byte, val.as_bytes())?;
165
166                match ttl_sec {
167                    Some(d) => {
168                        ttl_tree.insert([&d.to_be_bytes()[..], &byte[..]].concat(), byte)?;
169                    }
170                    None => (),
171                };
172
173                Ok(())
174            });
175        let _ = l.map_err(|_| TransientError::SledTransactionError)?;
176
177        Ok(())
178    }
179
180    /// Retrieves the value for a given key.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if the value cannot be retrieved from the database or if
185    /// the value is not valid UTF-8.
186    pub fn get(&self, key: &str) -> Result<Option<String>, Box<dyn Error>> {
187        let data_tree = &self.data_tree;
188        let byte = key.as_bytes();
189        let val = data_tree.get(byte)?;
190        match val {
191            Some(val) => Ok(Some(from_utf8(&val.to_vec())?.to_string())),
192            None => Ok(None),
193        }
194    }
195
196    /// Atomically increments the frequency counter for a given key.
197    ///
198    /// # Errors
199    ///
200    /// This function can return an error if the key does not exist or if there
201    /// is an issue with the compare-and-swap operation.
202    pub fn increment_frequency(&self, key: &str) -> Result<(), Box<dyn Error>> {
203        let freq_tree = &self.meta_tree;
204        let byte = &key.as_bytes();
205
206        loop {
207            let metadata = freq_tree
208                .get(byte)?
209                .ok_or(TransientError::IncretmentError)?;
210            let meta = Metadata::from_u8(&metadata.to_vec())?;
211            let s = freq_tree.compare_and_swap(
212                byte,
213                Some(metadata),
214                Some(meta.freq_incretement().to_u8()?),
215            );
216            match s {
217                Ok(ss) => match ss {
218                    Ok(_) => break,
219                    Err(_) => (),
220                },
221                Err(_) => (),
222            }
223        }
224
225        Ok(())
226    }
227
228    /// Removes a key-value pair and its associated metadata from the database.
229    ///
230    /// # Errors
231    ///
232    /// Can return an error if the transaction to remove the data fails.
233    pub fn remove(&self, key: &str) -> Result<(), Box<dyn Error>> {
234        let data_tree = &self.data_tree;
235        let freq_tree = &self.meta_tree;
236        let ttl_tree = &self.ttl_tree;
237        let byte = &key.as_bytes();
238        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
239            .transaction(|(data, freq, ttl_tree)| {
240                data.remove(*byte)?;
241                let meta = freq
242                    .get(byte)?
243                    .ok_or(ConflictableTransactionError::Abort(()))?;
244                let time = Metadata::from_u8(&meta.to_vec())
245                    .map_err(|_| ConflictableTransactionError::Abort(()))?
246                    .ttl;
247                freq.remove(*byte)?;
248
249                match time {
250                    Some(t) => {
251                        let _ = ttl_tree.remove([&t.to_be_bytes()[..], &byte[..]].concat());
252                    }
253                    None => (),
254                }
255
256                Ok(())
257            });
258        l.map_err(|_| TransientError::SledTransactionError)?;
259        Ok(())
260    }
261
262    /// Retrieves the metadata for a given key.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the metadata cannot be retrieved or deserialized.
267    pub fn get_metadata(&self, key: &str) -> Result<Option<Metadata>, Box<dyn Error>> {
268        let freq_tree = &self.meta_tree;
269        let byte = key.as_bytes();
270        let meta = freq_tree.get(byte)?;
271        match meta {
272            Some(val) => Ok(Some(Metadata::from_u8(&val.to_vec())?)),
273            None => Ok(None),
274        }
275    }
276}
277
278impl Drop for DB {
279    /// Gracefully shuts down the TTL background thread when the `DB` instance
280    /// goes out of scope.
281    fn drop(&mut self) {
282        self.shutdown
283            .store(true, std::sync::atomic::Ordering::SeqCst);
284
285        let _ = self
286            .ttl_thread
287            .take()
288            .expect("Fail to take ownership of ttl_thread")
289            .join()
290            .expect("Joining failed");
291    }
292}