epoch_db/db/
mod.rs

1//! The `db` module contains the core logic for the TransientDB database.
2//! It includes the `DB` struct and its implementation, which provides the
3//! primary API for interacting with the database.
4
5pub mod errors;
6pub mod iter;
7pub mod transaction;
8
9use std::fs::File;
10use std::io::{
11    ErrorKind,
12    Read,
13    Write
14};
15use std::path::Path;
16use std::str::from_utf8;
17use std::sync::Arc;
18use std::sync::atomic::AtomicBool;
19use std::thread::{
20    self,
21    JoinHandle
22};
23use std::time::{
24    Duration,
25    SystemTime,
26    UNIX_EPOCH
27};
28
29use chrono::Local;
30use errors::TransientError;
31use sled::Config;
32use sled::transaction::{
33    ConflictableTransactionError,
34    TransactionError,
35    Transactional
36};
37use zip::write::SimpleFileOptions;
38use zip::{
39    ZipArchive,
40    ZipWriter
41};
42
43use crate::metrics::Metrics;
44use crate::{
45    DB,
46    Metadata
47};
48
49impl DB {
50    /// Creates a new `DB` instance or opens an existing one at the specified
51    /// path.
52    ///
53    /// This function initializes the underlying `sled` database, opens the
54    /// required data trees (`data_tree`, `meta_tree`, `ttl_tree`), and
55    /// spawns a background thread to handle TTL expirations.
56    ///
57    /// # Errors
58    ///
59    /// Returns a `sled::Error` if the database cannot be opened at the given
60    /// path.
61    pub fn new(path: &Path) -> Result<DB, TransientError> {
62        let db = Config::new()
63            .path(path)
64            .cache_capacity(512 * 1024 * 1024)
65            .open()
66            .map_err(|e| {
67                TransientError::SledError {
68                    error: e
69                }
70            })?;
71
72        let data_tree = Arc::new(db.open_tree("data_tree").map_err(|e| {
73            TransientError::SledError {
74                error: e
75            }
76        })?);
77        let meta_tree = Arc::new(db.open_tree("freq_tree").map_err(|e| {
78            TransientError::SledError {
79                error: e
80            }
81        })?);
82        let ttl_tree = Arc::new(db.open_tree("ttl_tree").map_err(|e| {
83            TransientError::SledError {
84                error: e
85            }
86        })?);
87
88        let ttl_tree_clone = Arc::clone(&ttl_tree);
89        let meta_tree_clone = Arc::clone(&meta_tree);
90        let data_tree_clone = Arc::clone(&data_tree);
91
92        let shutdown: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
93        let shutdown_clone_ttl_thread = Arc::clone(&shutdown);
94        let shutdown_clone_size_thread = Arc::clone(&shutdown);
95
96        // Convert to pathbuf to gain ownership
97        let path_buf = path.to_path_buf();
98
99        // TODO: Later have a clean up thread that checks if the following thread is
100        // fine and spawn it back and join the thread lol
101
102        let ttl_thread: JoinHandle<Result<(), TransientError>> = thread::spawn(move || {
103            loop {
104                thread::sleep(Duration::new(0, 100000000));
105
106                if shutdown_clone_ttl_thread.load(std::sync::atomic::Ordering::SeqCst) {
107                    break;
108                }
109
110                let keys = ttl_tree_clone.iter();
111
112                for i in keys {
113                    let full_key = i.map_err(|e| {
114                        TransientError::SledError {
115                            error: e
116                        }
117                    })?;
118
119                    // NOTE: The reason time is 14 u8s long is because it is being stored like
120                    // this ([time,key], key) not ((time,key), key)
121                    let key = full_key.0;
122                    let key_byte = full_key.1;
123
124                    if key.len() < 8 {
125                        Err(TransientError::ParsingToU64ByteFailed)?
126                    }
127
128                    let time_byte: [u8; 8] = (&key[..8])
129                        .try_into()
130                        .map_err(|_| TransientError::ParsingToByteError)?;
131
132                    let time = u64::from_be_bytes(time_byte);
133                    let curr_time = SystemTime::now()
134                        .duration_since(UNIX_EPOCH)
135                        .expect("Cant get SystemTime")
136                        .as_secs();
137
138                    if curr_time >= time {
139                        let l: Result<(), TransactionError<()>> =
140                            (&*data_tree_clone, &*meta_tree_clone, &*ttl_tree_clone).transaction(
141                                |(data, freq, ttl_tree_clone)| {
142                                    let byte = &key_byte;
143                                    data.remove(byte)?;
144                                    freq.remove(byte)?;
145
146                                    let _ = ttl_tree_clone.remove([&time_byte, &byte[..]].concat());
147
148                                    // Prometheus Metrics
149                                    Metrics::dec_keys_total("data");
150                                    Metrics::dec_keys_total("meta");
151                                    Metrics::dec_keys_total("ttl");
152                                    Metrics::increment_ttl_expired_keys();
153
154                                    Ok(())
155                                }
156                            );
157                        l.map_err(|_| TransientError::SledTransactionError)?;
158                    } else {
159                        break;
160                    }
161                }
162            }
163            Ok(())
164        });
165
166        let size_thread: JoinHandle<Result<(), TransientError>> = thread::spawn(move || {
167            loop {
168                thread::sleep(Duration::new(0, 100000000));
169
170                if shutdown_clone_size_thread.load(std::sync::atomic::Ordering::SeqCst) {
171                    break;
172                }
173
174                let metadata = path_buf
175                    .metadata()
176                    .map_err(|_| TransientError::DBMetadataNotFound)?;
177                Metrics::set_disk_size((metadata.len() as f64) / 1024.0 / 1024.0);
178            }
179            Ok(())
180        });
181
182        Ok(DB {
183            data_tree,
184            meta_tree,
185            ttl_tree,
186            ttl_thread: Some(ttl_thread),
187            size_thread: Some(size_thread),
188            shutdown,
189            path: path.to_path_buf()
190        })
191    }
192
193    /// Sets a key-value pair with an optional Time-To-Live (TTL).
194    ///
195    /// If the key already exists, its value and TTL will be updated.
196    /// If `ttl` is `None`, the key will be persistent.
197    ///
198    /// # Errors
199    ///
200    /// This function can return an error if there's an issue with the
201    /// underlying
202    pub fn set(&self, key: &str, val: &str, ttl: Option<Duration>) -> Result<(), TransientError> {
203        let data_tree = &self.data_tree;
204        let freq_tree = &self.meta_tree;
205        let ttl_tree = &self.ttl_tree;
206        let byte = key.as_bytes();
207        let ttl_sec = match ttl {
208            Some(t) => {
209                let systime = SystemTime::now()
210                    .duration_since(UNIX_EPOCH)
211                    .expect("Cant get SystemTime");
212                Some((t + systime).as_secs())
213            },
214            None => None
215        };
216
217        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
218            .transaction(|(data, freq, ttl_tree)| {
219                match freq.get(byte)? {
220                    Some(m) => {
221                        let mut meta = Metadata::from_u8(&m)
222                            .map_err(|_| ConflictableTransactionError::Abort(()))?;
223                        if let Some(t) = meta.ttl {
224                            let _ = ttl_tree.remove([&t.to_be_bytes()[..], byte].concat());
225                        }
226                        meta.ttl = ttl_sec;
227                        freq.insert(
228                            byte,
229                            meta.to_u8()
230                                .map_err(|_| ConflictableTransactionError::Abort(()))?
231                        )?;
232                    },
233                    None => {
234                        freq.insert(
235                            byte,
236                            Metadata::new(ttl_sec)
237                                .to_u8()
238                                .map_err(|_| ConflictableTransactionError::Abort(()))?
239                        )?;
240                    }
241                }
242
243                data.insert(byte, val.as_bytes())?;
244
245                if let Some(d) = ttl_sec {
246                    ttl_tree.insert([&d.to_be_bytes()[..], byte].concat(), byte)?;
247                    Metrics::inc_keys_total("ttl");
248                };
249
250                Ok(())
251            });
252        l.map_err(|_| TransientError::SledTransactionError)?;
253
254        // Prometheus metrics
255        Metrics::increment_operations("set");
256        Metrics::inc_keys_total("data");
257        Metrics::inc_keys_total("meta");
258
259        Ok(())
260    }
261
262    /// Retrieves the value for a given key.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the value cannot be retrieved from the database or
267    /// if the value is not valid UTF-8.
268    pub fn get(&self, key: &str) -> Result<Option<String>, TransientError> {
269        let data_tree = &self.data_tree;
270        let byte = key.as_bytes();
271        let val = data_tree.get(byte).map_err(|e| {
272            TransientError::SledError {
273                error: e
274            }
275        })?;
276
277        Metrics::increment_operations("get");
278
279        match val {
280            Some(val) => {
281                Ok(Some(
282                    from_utf8(&val)
283                        .map_err(|_| TransientError::ParsingToUTF8Error)?
284                        .to_string()
285                ))
286            },
287            None => Ok(None)
288        }
289    }
290
291    /// Atomically increments the frequency counter for a given key.
292    ///
293    /// # Errors
294    ///
295    /// This function can return an error if the key does not exist or if there
296    /// is an issue with the compare-and-swap operation.
297    pub fn increment_frequency(&self, key: &str) -> Result<(), TransientError> {
298        let freq_tree = &self.meta_tree;
299        let byte = &key.as_bytes();
300
301        loop {
302            let metadata = freq_tree
303                .get(byte)
304                .map_err(|e| {
305                    TransientError::SledError {
306                        error: e
307                    }
308                })?
309                .ok_or(TransientError::IncretmentError)?;
310            let meta =
311                Metadata::from_u8(&metadata).map_err(|_| TransientError::ParsingFromByteError)?;
312            let s = freq_tree.compare_and_swap(
313                byte,
314                Some(metadata),
315                Some(
316                    meta.freq_incretement()
317                        .to_u8()
318                        .map_err(|_| TransientError::ParsingToByteError)?
319                )
320            );
321            if let Ok(ss) = s
322                && ss.is_ok()
323            {
324                break;
325            }
326        }
327        Metrics::increment_operations("increment_frequency");
328
329        Ok(())
330    }
331
332    /// Removes a key-value pair and its associated metadata from the database.
333    ///
334    /// # Errors
335    ///
336    /// Can return an error if the transaction to remove the data fails.
337    pub fn remove(&self, key: &str) -> Result<(), TransientError> {
338        let data_tree = &self.data_tree;
339        let freq_tree = &self.meta_tree;
340        let ttl_tree = &self.ttl_tree;
341        let byte = &key.as_bytes();
342        let l: Result<(), TransactionError<()>> = (&**data_tree, &**freq_tree, &**ttl_tree)
343            .transaction(|(data, freq, ttl_tree)| {
344                data.remove(*byte)?;
345                let meta = freq
346                    .get(byte)?
347                    .ok_or(ConflictableTransactionError::Abort(()))?;
348                let time = Metadata::from_u8(&meta)
349                    .map_err(|_| ConflictableTransactionError::Abort(()))?
350                    .ttl;
351                freq.remove(*byte)?;
352
353                Metrics::dec_keys_total("data");
354                Metrics::dec_keys_total("meta");
355
356                if let Some(t) = time {
357                    Metrics::dec_keys_total("ttl");
358
359                    let _ = ttl_tree.remove([&t.to_be_bytes()[..], &byte[..]].concat());
360                }
361
362                Ok(())
363            });
364        l.map_err(|_| TransientError::SledTransactionError)?;
365
366        Metrics::increment_operations("rm");
367
368        Ok(())
369    }
370
371    /// Retrieves the metadata for a given key.
372    ///
373    /// # Errors
374    ///
375    /// Returns an error if the metadata cannot be retrieved or deserialized.
376    pub fn get_metadata(&self, key: &str) -> Result<Option<Metadata>, TransientError> {
377        let freq_tree = &self.meta_tree;
378        let byte = key.as_bytes();
379        let meta = freq_tree.get(byte).map_err(|e| {
380            TransientError::SledError {
381                error: e
382            }
383        })?;
384        match meta {
385            Some(val) => {
386                Ok(Some(
387                    Metadata::from_u8(&val).map_err(|_| TransientError::ParsingFromByteError)?
388                ))
389            },
390            None => Ok(None)
391        }
392    }
393
394    /// Flushes all the trees in the database.
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if sled fails to flush the trees.
399    pub fn flush(&self) -> Result<(), TransientError> {
400        self.data_tree.flush().map_err(|e| {
401            TransientError::SledError {
402                error: e
403            }
404        })?;
405        self.meta_tree.flush().map_err(|e| {
406            TransientError::SledError {
407                error: e
408            }
409        })?;
410        self.ttl_tree.flush().map_err(|e| {
411            TransientError::SledError {
412                error: e
413            }
414        })?;
415
416        Ok(())
417    }
418
419    /// Backup the database to the corresponding path.
420    ///
421    /// # Errors
422    ///
423    /// This function returns an Error if the following occurs:
424    /// - Any corresponding folder in the path is not found
425    /// - Zip or sled fails because of any reason
426    /// - IOError when the file is being access by the OS for something else
427    /// - Failing to parse any data to a [u8]
428    pub fn backup_to(&self, path: &Path) -> Result<(), TransientError> {
429        self.flush()?;
430
431        if !path.is_dir() {
432            Err(TransientError::FolderNotFound {
433                path: path.to_path_buf()
434            })?;
435        }
436
437        let options =
438            SimpleFileOptions::default().compression_method(zip::CompressionMethod::Bzip2);
439
440        let backup_name = format!("backup-{}.zip", Local::now().format("%Y-%m-%d_%H-%M-%S"));
441
442        let zip_file = File::create(path.join(&backup_name)).map_err(|_| {
443            TransientError::FolderNotFound {
444                path: path.to_path_buf()
445            }
446        })?;
447
448        let mut zipw = ZipWriter::new(zip_file);
449
450        zipw.start_file("data.epoch", options).map_err(|e| {
451            TransientError::ZipError {
452                error: e
453            }
454        })?;
455
456        for i in self.data_tree.iter() {
457            let iu = i.map_err(|e| {
458                TransientError::SledError {
459                    error: e
460                }
461            })?;
462
463            let key = &iu.0;
464            let value = &iu.1;
465            let meta = self
466                .meta_tree
467                .get(key)
468                .map_err(|e| {
469                    TransientError::SledError {
470                        error: e
471                    }
472                })?
473                .ok_or(TransientError::MetadataNotFound)?;
474
475            // NOTE: A usize is diffrent on diffrent machines
476            // and a usize will never exceed a u64 in length on paper lol
477            let kl: u64 = key
478                .len()
479                .try_into()
480                .map_err(|_| TransientError::ParsingToU64ByteFailed)?;
481            let vl: u64 = value
482                .len()
483                .try_into()
484                .map_err(|_| TransientError::ParsingToU64ByteFailed)?;
485            let ml: u64 = meta
486                .len()
487                .try_into()
488                .map_err(|_| TransientError::ParsingToU64ByteFailed)?;
489
490            zipw.write_all(&kl.to_be_bytes()).map_err(|e| {
491                TransientError::IOError {
492                    error: e
493                }
494            })?;
495            zipw.write_all(key).map_err(|e| {
496                TransientError::IOError {
497                    error: e
498                }
499            })?;
500            zipw.write_all(&vl.to_be_bytes()).map_err(|e| {
501                TransientError::IOError {
502                    error: e
503                }
504            })?;
505            zipw.write_all(value).map_err(|e| {
506                TransientError::IOError {
507                    error: e
508                }
509            })?;
510            zipw.write_all(&ml.to_be_bytes()).map_err(|e| {
511                TransientError::IOError {
512                    error: e
513                }
514            })?;
515            zipw.write_all(&meta).map_err(|e| {
516                TransientError::IOError {
517                    error: e
518                }
519            })?;
520        }
521
522        zipw.finish().map_err(|e| {
523            TransientError::ZipError {
524                error: e
525            }
526        })?;
527
528        let zip_file = File::open(path.join(backup_name)).map_err(|_| {
529            TransientError::FolderNotFound {
530                path: path.to_path_buf()
531            }
532        })?;
533        let size = zip_file
534            .metadata()
535            .map_err(|e| {
536                TransientError::IOError {
537                    error: e
538                }
539            })?
540            .len();
541        Metrics::set_backup_size((size as f64) / 1024.0 / 1024.0);
542
543        Ok(())
544    }
545
546    // WARN: Add a transactional batching algorithm to ensure safety incase of a
547    // power outage
548
549    /// This function loads the backup archive from the path given and loads the
550    /// database in the db_path
551    ///
552    /// # Errors
553    ///
554    /// This Function will fail if the following happens:
555    /// - Any corresponding folder in the path is not found
556    /// - Zip or sled fails because of any reason
557    /// - IOError when the file is being access by the OS for something else
558    /// - It fails to parse the .epoch file which may occur due to data
559    ///   corruption or wrong formatting.
560    pub fn load_from(path: &Path, db_path: &Path) -> Result<DB, TransientError> {
561        if !path.is_file() {
562            Err(TransientError::FolderNotFound {
563                path: path.to_path_buf()
564            })?;
565        }
566
567        let db = DB::new(db_path)?;
568
569        let file = File::open(path).map_err(|_| {
570            TransientError::FolderNotFound {
571                path: path.to_path_buf()
572            }
573        })?;
574
575        let mut archive = ZipArchive::new(file).map_err(|e| {
576            TransientError::ZipError {
577                error: e
578            }
579        })?;
580
581        // The error is not only is the archive is not found but also a few other
582        // errors, so it is prefered to not laced it with  a full on
583        // TransientError but a wrapper
584        let mut data = archive.by_name("data.epoch").map_err(|e| {
585            TransientError::ZipError {
586                error: e
587            }
588        })?;
589        loop {
590            let mut len: [u8; 8] = [0u8; 8];
591            if let Err(e) = data.read_exact(&mut len)
592                && let ErrorKind::UnexpectedEof = e.kind()
593            {
594                break;
595            }
596
597            let mut key = vec![
598                0;
599                u64::from_be_bytes(len)
600                    .try_into()
601                    .map_err(|_| TransientError::ParsingToU64ByteFailed)?
602            ];
603
604            // Since it contains both error, I figure that It would be better If I map it to
605            // a Transient Wrap of std::io::Error
606            data.read_exact(&mut key).map_err(|e| {
607                TransientError::IOError {
608                    error: e
609                }
610            })?;
611
612            data.read_exact(&mut len).map_err(|e| {
613                TransientError::IOError {
614                    error: e
615                }
616            })?;
617
618            let mut val = vec![
619                0;
620                u64::from_be_bytes(len)
621                    .try_into()
622                    .map_err(|_| TransientError::ParsingToU64ByteFailed)?
623            ];
624            data.read_exact(&mut val).map_err(|e| {
625                TransientError::IOError {
626                    error: e
627                }
628            })?;
629
630            data.read_exact(&mut len).map_err(|e| {
631                TransientError::IOError {
632                    error: e
633                }
634            })?;
635
636            let mut meta_byte = vec![
637                0;
638                u64::from_be_bytes(len)
639                    .try_into()
640                    .map_err(|_| TransientError::ParsingToU64ByteFailed)?
641            ];
642            data.read_exact(&mut meta_byte).map_err(|e| {
643                TransientError::IOError {
644                    error: e
645                }
646            })?;
647
648            let meta =
649                Metadata::from_u8(&meta_byte).map_err(|_| TransientError::ParsingFromByteError)?;
650
651            db.meta_tree
652                .insert(
653                    &key,
654                    meta.to_u8()
655                        .map_err(|_| TransientError::ParsingToByteError)?
656                )
657                .map_err(|e| {
658                    TransientError::SledError {
659                        error: e
660                    }
661                })?;
662
663            db.data_tree.insert(&key, val).map_err(|e| {
664                TransientError::SledError {
665                    error: e
666                }
667            })?;
668
669            if let Some(d) = meta.ttl {
670                db.ttl_tree
671                    .insert([&d.to_be_bytes()[..], &key].concat(), key)
672                    .map_err(|e| {
673                        TransientError::SledError {
674                            error: e
675                        }
676                    })?;
677            };
678        }
679
680        Ok(db)
681    }
682}
683
684impl Drop for DB {
685    /// Gracefully shuts down the TTL background thread when the `DB` instance
686    /// goes out of scope.
687    fn drop(&mut self) {
688        self.shutdown
689            .store(true, std::sync::atomic::Ordering::SeqCst);
690
691        let _ = self
692            .ttl_thread
693            .take()
694            .expect("Fail to take ownership of ttl_thread")
695            .join()
696            .expect("Joining failed");
697
698        let _ = self
699            .size_thread
700            .take()
701            .expect("Fail to take ownership of ttl_thread")
702            .join()
703            .expect("Joining failed");
704    }
705}