epoch_db/db/
mod.rs

1//! The `db` module contains the core logic for the TransientDB database.
2//! It includes the `DB` struct and its implementation, which provides the
3//! primary API for interacting with the database.
4
5pub mod errors;
6
7use crate::{DB, Metadata, metrics::Metrics};
8use chrono::Local;
9use errors::TransientError;
10use sled::{
11    Config,
12    transaction::{ConflictableTransactionError, TransactionError, Transactional},
13};
14use std::{
15    error::Error,
16    fs::File,
17    io::{ErrorKind, Read, Write},
18    path::Path,
19    str::from_utf8,
20    sync::{Arc, atomic::AtomicBool},
21    thread::{self, JoinHandle},
22    time::{Duration, SystemTime, UNIX_EPOCH},
23};
24use zip::{ZipArchive, ZipWriter, write::SimpleFileOptions};
25
26impl DB {
27    /// Creates a new `DB` instance or opens an existing one at the specified path.
28    ///
29    /// This function initializes the underlying `sled` database, opens the required
30    /// data trees (`data_tree`, `meta_tree`, `ttl_tree`), and spawns a background
31    /// thread to handle TTL expirations.
32    ///
33    /// # Errors
34    ///
35    /// Returns a `sled::Error` if the database cannot be opened at the given path.
36    pub fn new(path: &Path) -> Result<DB, Box<dyn Error>> {
37        let db = Config::new()
38            .path(path)
39            .cache_capacity(512 * 1024 * 1024)
40            .open()?;
41
42        let data_tree = Arc::new(db.open_tree("data_tree")?);
43        let meta_tree = Arc::new(db.open_tree("freq_tree")?);
44        let ttl_tree = Arc::new(db.open_tree("ttl_tree")?);
45
46        let ttl_tree_clone = Arc::clone(&ttl_tree);
47        let meta_tree_clone = Arc::clone(&meta_tree);
48        let data_tree_clone = Arc::clone(&data_tree);
49
50        let shutdown: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
51        let shutdown_clone_ttl_thread = Arc::clone(&shutdown);
52        let shutdown_clone_size_thread = Arc::clone(&shutdown);
53
54        // Convert to pathbuf to gain ownership
55        let path_buf = path.to_path_buf();
56
57        // TODO: Later have a clean up thread that checks if the following thread is fine and spawn
58        // it back and join the thread lol
59
60        let ttl_thread: JoinHandle<Result<(), TransientError>> = thread::spawn(move || {
61            loop {
62                thread::sleep(Duration::new(0, 100000000));
63
64                if shutdown_clone_ttl_thread.load(std::sync::atomic::Ordering::SeqCst) {
65                    break;
66                }
67
68                let keys = ttl_tree_clone.iter();
69
70                for i in keys {
71                    let full_key = i.map_err(|e| TransientError::SledError { error: e })?;
72
73                    // NOTE: The reason time is 14 u8s long is because it is being stored like
74                    // this ([time,key], key) not ((time,key), key)
75                    let key = full_key.0;
76                    let key_byte = full_key.1;
77
78                    if key.len() < 8 {
79                        Err(TransientError::ParsingToU64ByteFailed)?
80                    }
81
82                    let time_byte: [u8; 8] = (&key[..8])
83                        .try_into()
84                        .map_err(|_| TransientError::ParsingToByteError)?;
85
86                    let time = u64::from_be_bytes(time_byte);
87                    let curr_time = SystemTime::now()
88                        .duration_since(UNIX_EPOCH)
89                        .expect("Cant get SystemTime")
90                        .as_secs();
91
92                    if curr_time >= time {
93                        let l: Result<(), TransactionError<()>> =
94                            (&*data_tree_clone, &*meta_tree_clone, &*ttl_tree_clone).transaction(
95                                |(data, freq, ttl_tree_clone)| {
96                                    let byte = &key_byte;
97                                    data.remove(byte)?;
98                                    freq.remove(byte)?;
99
100                                    let _ = ttl_tree_clone.remove([&time_byte, &byte[..]].concat());
101
102                                    // Prometheus Metrics
103                                    Metrics::dec_keys_total("data");
104                                    Metrics::dec_keys_total("meta");
105                                    Metrics::dec_keys_total("ttl");
106                                    Metrics::increment_ttl_expired_keys();
107
108                                    Ok(())
109                                },
110                            );
111                        l.map_err(|_| TransientError::SledTransactionError)?;
112                    } else {
113                        break;
114                    }
115                }
116            }
117            Ok(())
118        });
119
120        let size_thread: JoinHandle<Result<(), TransientError>> = thread::spawn(move || {
121            loop {
122                thread::sleep(Duration::new(0, 100000000));
123
124                if shutdown_clone_size_thread.load(std::sync::atomic::Ordering::SeqCst) {
125                    break;
126                }
127
128                let metadata = path_buf
129                    .metadata()
130                    .map_err(|_| TransientError::DBMetadataNotFound)?;
131                Metrics::set_disk_size((metadata.len() as f64) / 1024.0 / 1024.0);
132            }
133            Ok(())
134        });
135
136        Ok(DB {
137            data_tree,
138            meta_tree,
139            ttl_tree,
140            ttl_thread: Some(ttl_thread),
141            size_thread: Some(size_thread),
142            shutdown,
143            path: path.to_path_buf(),
144        })
145    }
146
147    /// Sets a key-value pair with an optional Time-To-Live (TTL).
148    ///
149    /// If the key already exists, its value and TTL will be updated.
150    /// If `ttl` is `None`, the key will be persistent.
151    ///
152    /// # Errors
153    ///
154    /// This function can return an error if there's an issue with the underlying
155    pub fn set(&self, key: &str, val: &str, ttl: Option<Duration>) -> Result<(), Box<dyn Error>> {
156        let data_tree = &self.data_tree;
157        let freq_tree = &self.meta_tree;
158        let ttl_tree = &self.ttl_tree;
159        let byte = key.as_bytes();
160        let ttl_sec = match ttl {
161            Some(t) => {
162                let systime = SystemTime::now()
163                    .duration_since(UNIX_EPOCH)
164                    .expect("Cant get SystemTime");
165                Some((t + systime).as_secs())
166            }
167            None => None,
168        };
169
170        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
171            .transaction(|(data, freq, ttl_tree)| {
172                match freq.get(byte)? {
173                    Some(m) => {
174                        let mut meta = Metadata::from_u8(&m)
175                            .map_err(|_| ConflictableTransactionError::Abort(()))?;
176                        if let Some(t) = meta.ttl {
177                            let _ = ttl_tree.remove([&t.to_be_bytes()[..], byte].concat());
178                        }
179                        meta.ttl = ttl_sec;
180                        freq.insert(
181                            byte,
182                            meta.to_u8()
183                                .map_err(|_| ConflictableTransactionError::Abort(()))?,
184                        )?;
185                    }
186                    None => {
187                        freq.insert(
188                            byte,
189                            Metadata::new(ttl_sec)
190                                .to_u8()
191                                .map_err(|_| ConflictableTransactionError::Abort(()))?,
192                        )?;
193                    }
194                }
195
196                data.insert(byte, val.as_bytes())?;
197
198                if let Some(d) = ttl_sec {
199                    ttl_tree.insert([&d.to_be_bytes()[..], byte].concat(), byte)?;
200                    Metrics::inc_keys_total("ttl");
201                };
202
203                Ok(())
204            });
205        l.map_err(|_| TransientError::SledTransactionError)?;
206
207        // Prometheus metrics
208        Metrics::increment_operations("set");
209        Metrics::inc_keys_total("data");
210        Metrics::inc_keys_total("meta");
211
212        Ok(())
213    }
214
215    /// Retrieves the value for a given key.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if the value cannot be retrieved from the database or if
220    /// the value is not valid UTF-8.
221    pub fn get(&self, key: &str) -> Result<Option<String>, Box<dyn Error>> {
222        let data_tree = &self.data_tree;
223        let byte = key.as_bytes();
224        let val = data_tree.get(byte)?;
225
226        Metrics::increment_operations("get");
227
228        match val {
229            Some(val) => Ok(Some(from_utf8(&val)?.to_string())),
230            None => Ok(None),
231        }
232    }
233
234    /// Atomically increments the frequency counter for a given key.
235    ///
236    /// # Errors
237    ///
238    /// This function can return an error if the key does not exist or if there
239    /// is an issue with the compare-and-swap operation.
240    pub fn increment_frequency(&self, key: &str) -> Result<(), Box<dyn Error>> {
241        let freq_tree = &self.meta_tree;
242        let byte = &key.as_bytes();
243
244        loop {
245            let metadata = freq_tree
246                .get(byte)?
247                .ok_or(TransientError::IncretmentError)?;
248            let meta = Metadata::from_u8(&metadata)?;
249            let s = freq_tree.compare_and_swap(
250                byte,
251                Some(metadata),
252                Some(meta.freq_incretement().to_u8()?),
253            );
254            if let Ok(ss) = s {
255                if ss.is_ok() {
256                    break;
257                }
258            }
259        }
260        Metrics::increment_operations("increment_frequency");
261
262        Ok(())
263    }
264
265    /// Removes a key-value pair and its associated metadata from the database.
266    ///
267    /// # Errors
268    ///
269    /// Can return an error if the transaction to remove the data fails.
270    pub fn remove(&self, key: &str) -> Result<(), Box<dyn Error>> {
271        let data_tree = &self.data_tree;
272        let freq_tree = &self.meta_tree;
273        let ttl_tree = &self.ttl_tree;
274        let byte = &key.as_bytes();
275        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
276            .transaction(|(data, freq, ttl_tree)| {
277                data.remove(*byte)?;
278                let meta = freq
279                    .get(byte)?
280                    .ok_or(ConflictableTransactionError::Abort(()))?;
281                let time = Metadata::from_u8(&meta)
282                    .map_err(|_| ConflictableTransactionError::Abort(()))?
283                    .ttl;
284                freq.remove(*byte)?;
285
286                Metrics::dec_keys_total("data");
287                Metrics::dec_keys_total("meta");
288
289                if let Some(t) = time {
290                    Metrics::dec_keys_total("ttl");
291
292                    let _ = ttl_tree.remove([&t.to_be_bytes()[..], &byte[..]].concat());
293                }
294
295                Ok(())
296            });
297        l.map_err(|_| TransientError::SledTransactionError)?;
298
299        Metrics::increment_operations("rm");
300
301        Ok(())
302    }
303
304    /// Retrieves the metadata for a given key.
305    ///
306    /// # Errors
307    ///
308    /// Returns an error if the metadata cannot be retrieved or deserialized.
309    pub fn get_metadata(&self, key: &str) -> Result<Option<Metadata>, Box<dyn Error>> {
310        let freq_tree = &self.meta_tree;
311        let byte = key.as_bytes();
312        let meta = freq_tree.get(byte)?;
313        match meta {
314            Some(val) => Ok(Some(Metadata::from_u8(&val)?)),
315            None => Ok(None),
316        }
317    }
318
319    pub fn flush(&self) -> Result<(), Box<dyn Error>> {
320        self.data_tree.flush()?;
321        self.meta_tree.flush()?;
322        self.ttl_tree.flush()?;
323
324        Ok(())
325    }
326
327    pub fn backup_to(&self, path: &Path) -> Result<(), Box<dyn Error>> {
328        self.flush()?;
329
330        if !path.is_dir() {
331            Err(TransientError::FolderNotFound {
332                path: path.to_path_buf(),
333            })?;
334        }
335
336        let options =
337            SimpleFileOptions::default().compression_method(zip::CompressionMethod::Bzip2);
338
339        let backup_name = format!("backup-{}.zip", Local::now().format("%Y-%m-%d_%H-%M-%S"));
340
341        let zip_file = File::create(path.join(&backup_name))?;
342
343        let mut zipw = ZipWriter::new(zip_file);
344
345        zipw.start_file("data.epoch", options)?;
346        for i in self.data_tree.iter() {
347            let iu = i?;
348
349            let key = &iu.0;
350            let value = &iu.1;
351            let meta = self
352                .meta_tree
353                .get(key)?
354                .ok_or(TransientError::MetadataNotFound)?;
355
356            // NOTE: A usize is diffrent on diffrent machines
357            // and a usize will never exceed a u64 in lenght lol
358            let kl: u64 = key.len().try_into()?;
359            let vl: u64 = value.len().try_into()?;
360            let ml: u64 = meta.len().try_into()?;
361
362            zipw.write_all(&kl.to_be_bytes())?;
363            zipw.write_all(key)?;
364            zipw.write_all(&vl.to_be_bytes())?;
365            zipw.write_all(value)?;
366            zipw.write_all(&ml.to_be_bytes())?;
367            zipw.write_all(&meta)?;
368        }
369
370        zipw.finish()?;
371
372        let zip_file = File::open(path.join(backup_name))?;
373        let size = zip_file.metadata()?.len();
374        Metrics::set_backup_size((size as f64) / 1024.0 / 1024.0);
375
376        Ok(())
377    }
378
379    // WARN: Add a transactional batching algorithm to ensure safety incase of a power outage
380    pub fn load_from(path: &Path, db_path: &Path) -> Result<DB, Box<dyn Error>> {
381        if !path.is_file() {
382            Err(TransientError::FolderNotFound {
383                path: path.to_path_buf(),
384            })?;
385        }
386
387        let db = DB::new(db_path)?;
388
389        let file = File::open(path)?;
390
391        let mut archive = ZipArchive::new(file)?;
392
393        let mut data = archive.by_name("data.epoch")?;
394
395        loop {
396            let mut len: [u8; 8] = [0u8; 8];
397            if let Err(e) = data.read_exact(&mut len) {
398                if let ErrorKind::UnexpectedEof = e.kind() {
399                    break;
400                }
401            }
402
403            let mut key = vec![0; u64::from_be_bytes(len).try_into()?];
404            data.read_exact(&mut key)?;
405
406            data.read_exact(&mut len)?;
407            let mut val = vec![0; u64::from_be_bytes(len).try_into()?];
408            data.read_exact(&mut val)?;
409
410            data.read_exact(&mut len)?;
411            let mut meta_byte = vec![0; u64::from_be_bytes(len).try_into()?];
412            data.read_exact(&mut meta_byte)?;
413
414            let meta = Metadata::from_u8(&meta_byte)?;
415
416            db.meta_tree.insert(&key, meta.to_u8()?)?;
417
418            db.data_tree.insert(&key, val)?;
419
420            if let Some(d) = meta.ttl {
421                db.ttl_tree
422                    .insert([&d.to_be_bytes()[..], &key].concat(), key)?;
423            };
424        }
425
426        Ok(db)
427    }
428
429    pub fn iter(&mut self) -> DataIter {
430        DataIter {
431            data: (self.data_tree.iter(), self.meta_tree.clone()),
432        }
433    }
434}
435
436impl Drop for DB {
437    /// Gracefully shuts down the TTL background thread when the `DB` instance
438    /// goes out of scope.
439    fn drop(&mut self) {
440        self.shutdown
441            .store(true, std::sync::atomic::Ordering::SeqCst);
442
443        let _ = self
444            .ttl_thread
445            .take()
446            .expect("Fail to take ownership of ttl_thread")
447            .join()
448            .expect("Joining failed");
449
450        let _ = self
451            .size_thread
452            .take()
453            .expect("Fail to take ownership of ttl_thread")
454            .join()
455            .expect("Joining failed");
456    }
457}
458
459pub struct DataIter {
460    pub data: (sled::Iter, Arc<sled::Tree>),
461}
462
463impl Iterator for DataIter {
464    type Item = Result<(String, String, Metadata), Box<dyn Error>>;
465
466    fn next(&mut self) -> Option<Self::Item> {
467        let data_iter = &mut self.data.0;
468
469        let data = match data_iter.next()? {
470            Ok(a) => a,
471            Err(e) => {
472                return Some(Err(Box::new(e)));
473            }
474        };
475
476        let (kb, vb) = data;
477
478        let meta_tree = &mut self.data.1;
479
480        let mb = match meta_tree.get(&kb) {
481            Ok(a) => a,
482            Err(e) => {
483                return Some(Err(Box::new(e)));
484            }
485        }?;
486
487        let key = match from_utf8(&kb) {
488            Ok(a) => a,
489            Err(e) => {
490                return Some(Err(Box::new(e)));
491            }
492        }
493        .to_string();
494
495        let value = match from_utf8(&vb) {
496            Ok(a) => a,
497            Err(e) => {
498                return Some(Err(Box::new(e)));
499            }
500        }
501        .to_string();
502
503        let meta = match Metadata::from_u8(&mb) {
504            Ok(a) => a,
505            Err(e) => {
506                return Some(Err(Box::new(e)));
507            }
508        };
509
510        Some(Ok((key, value, meta)))
511    }
512}