byodb_rust/api/
mod.rs

1//! Provides the public API for interacting with the database.
2//!
3//! This module defines the main `DB` struct for database instantiation and
4//! transaction management, as well as the `Txn` struct for performing
5//! read (and possibly write) operations within a transaction.
6
7pub mod error;
8
9use std::{marker::PhantomData, ops::RangeBounds, path::Path, sync::Arc};
10
11use error::TxnError;
12use seize::Collector;
13
14pub use crate::core::consts;
15use crate::core::{
16    mmap::{self, Guard, ImmutablePage, Mmap, Reader, ReaderPage, Store, Writer, WriterPage},
17    tree::Tree,
18};
19
20/// A specialized `Result` type for database operations within this module.
21pub type Result<T> = std::result::Result<T, TxnError>;
22/// A read-write transaction. It must live as long as the [`DB`] that created
23/// it (via [`DB::rw_txn`]).
24pub type RWTxn<'t, 'd> = Txn<'t, WriterPage<'t, 'd>, Writer<'d>>;
25/// A read-only transaction. It must live as long as the [`DB`] that created
26/// it (via [`DB::r_txn`]).
27pub type RTxn<'t, 'd> = Txn<'t, ReaderPage<'t>, Reader<'d>>;
28
29/// Represents the main database instance.
30///
31/// `DB` is the entry point for all database interactions. It handles the
32/// underlying storage and provides methods to create read-only or read-write
33/// transactions.
34#[derive(Clone)]
35pub struct DB {
36    store: Arc<Store>,
37}
38
39impl DB {
40    /// Opens an existing database file or creates a new one if it doesn't exist.
41    /// Uses default settings. If you need customization, use
42    /// [`DBBuilder::new`] instead.
43    ///
44    /// # Parameters
45    /// - `path`: A path to the database file.
46    ///
47    /// # Errors
48    /// Returns [`TxnError`] if the database file cannot be opened or created,
49    /// or if there's an issue initializing the memory map.
50    pub fn open_or_create<P: AsRef<Path>>(path: P) -> Result<Self> {
51        Ok(DB {
52            store: Arc::new(Store::new(
53                Mmap::open_or_create(path, mmap::DEFAULT_MIN_FILE_GROWTH_SIZE)?,
54                Collector::new(),
55            )),
56        })
57    }
58
59    /// Begins a new read-only transaction.
60    ///
61    /// Read-only transactions allow for concurrent reads without blocking
62    /// other readers or writers.
63    ///
64    /// # Returns
65    /// A [`RTxn`] instance for performing read operations.
66    pub fn r_txn(&self) -> RTxn<'_, '_> {
67        let reader = self.store.reader();
68        let root_page = reader.root_page();
69        Txn {
70            _phantom: PhantomData,
71            guard: reader,
72            root_page,
73        }
74    }
75
76    /// Begins a new read-write transaction.
77    ///
78    /// Read-write transactions provide exclusive write access. Changes made
79    /// within this transaction are isolated until `commit` is called.
80    ///
81    /// # Returns
82    /// A [`RWTxn`] instance for performing read and write operations.
83    pub fn rw_txn(&self) -> RWTxn<'_, '_> {
84        let writer = self.store.writer();
85        let root_page = writer.root_page();
86        Txn {
87            _phantom: PhantomData,
88            guard: writer,
89            root_page,
90        }
91    }
92}
93
94/// Builder of a [`DB`].
95pub struct DBBuilder<P: AsRef<Path>> {
96    db_path: P,
97    pub free_batch_size: Option<usize>,
98    min_file_growth_size: usize,
99}
100
101impl<P: AsRef<Path>> DBBuilder<P> {
102    /// Creates a new [`DB`] from `db_path`.
103    pub fn new(db_path: P) -> Self {
104        let free_batch_size = if cfg!(test) { Some(1) } else { None };
105        DBBuilder {
106            db_path,
107            free_batch_size,
108            min_file_growth_size: mmap::DEFAULT_MIN_FILE_GROWTH_SIZE,
109        }
110    }
111
112    /// Sets the number of free pages before reclamation can be attempted.
113    /// The default is 32, though this can change in the future.
114    pub fn free_batch_size(self, val: usize) -> Self {
115        DBBuilder {
116            db_path: self.db_path,
117            free_batch_size: Some(val),
118            min_file_growth_size: self.min_file_growth_size,
119        }
120    }
121
122    /// Sets the minimum file growth size.
123    /// The default is 64MB in release, though this can change in the future.
124    pub fn min_file_growth_size(self, val: usize) -> Self {
125        DBBuilder {
126            db_path: self.db_path,
127            free_batch_size: self.free_batch_size,
128            min_file_growth_size: val,
129        }
130    }
131
132    /// Builds a DB.
133    pub fn build(self) -> Result<DB> {
134        let collector = match &self.free_batch_size {
135            &Some(size) => Collector::new().batch_size(size),
136            None => Collector::new(),
137        };
138        Ok(DB {
139            store: Arc::new(Store::new(
140                Mmap::open_or_create(self.db_path, self.min_file_growth_size)?,
141                collector,
142            )),
143        })
144    }
145}
146
147/// Represents a database transaction.
148///
149/// A transaction provides a consistent view of the database. It can be
150/// either read-only ([`RTxn`]) or read-write ([`RWTxn`]).
151/// All operations on the database are performed within a transaction.
152pub struct Txn<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> {
153    _phantom: PhantomData<&'g P>,
154    guard: G,
155    root_page: usize,
156}
157
158impl<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> Txn<'g, P, G> {
159    /// Retrieves the value associated with the given key.
160    ///
161    /// # Parameters
162    /// - `key`: The key to search for.
163    ///
164    /// # Returns
165    /// - `Ok(Some(value))` if the key is found.
166    /// - `Ok(None)` if the key is not found.
167    /// - `Err(TxnError)` if an error occurs during the tree traversal.
168    ///
169    /// # Safety
170    /// The returned slice `&[u8]` is valid as long as the transaction [`Txn`]
171    /// is alive, as it points directly into the memory-mapped region.
172    pub fn get(&'g self, key: &[u8]) -> Result<Option<&'g [u8]>> {
173        let tree = Tree::new(&self.guard, self.root_page);
174        let val = tree.get(key)?.map(|val| {
175            // Safety: The underlying data is in the mmap, which
176            // self.guard has access to. So long as self.guard exists,
177            // so too does the mmap.
178            unsafe { &*std::ptr::slice_from_raw_parts(val.as_ptr(), val.len()) }
179        });
180        Ok(val)
181    }
182
183    /// Returns an iterator over all key-value pairs in the database, in key order.
184    ///
185    /// The iterator yields tuples of `(&[u8], &[u8])` representing key-value pairs.
186    pub fn in_order_iter(&'g self) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
187        Tree::new(&self.guard, self.root_page).in_order_iter()
188    }
189
190    /// Returns an iterator over key-value pairs within the specified range, in key order.
191    ///
192    /// # Parameters
193    /// - `range`: A range bound (e.g., `start_key..end_key`, `..end_key`, `start_key..`)
194    ///   that defines the keys to include in the iteration.
195    ///
196    /// The iterator yields tuples of `(&[u8], &[u8])` representing key-value pairs.
197    pub fn in_order_range_iter<R: RangeBounds<[u8]>>(
198        &'g self,
199        range: &R,
200    ) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
201        Tree::new(&self.guard, self.root_page).in_order_range_iter(range)
202    }
203}
204
205impl<'t, 'd> Txn<'t, WriterPage<'t, 'd>, Writer<'d>> {
206    /// Inserts a new key-value pair into the database.
207    ///
208    /// # Parameters
209    /// - `key`: The key to insert.
210    /// - `val`: The value to associate with the key.
211    ///
212    /// # Errors
213    /// Returns [`TxnError`] if the key already exists or if an error occurs
214    /// during the insertion process.
215    pub fn insert(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
216        self.root_page = Tree::new(&self.guard, self.root_page)
217            .insert(key, val)?
218            .page_num();
219        Ok(())
220    }
221
222    /// Updates the value associated with an existing key.
223    ///
224    /// # Parameters
225    /// - `key`: The key whose value is to be updated.
226    /// - `val`: The new value to associate with the key.
227    ///
228    /// # Errors
229    /// Returns [`TxnError`] if the key does not exist or if an error occurs
230    /// during the update process.
231    pub fn update(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
232        self.root_page = Tree::new(&self.guard, self.root_page)
233            .update(key, val)?
234            .page_num();
235        Ok(())
236    }
237
238    /// Deletes a key-value pair from the database.
239    ///
240    /// # Parameters
241    /// - `key`: The key to delete.
242    ///
243    /// # Errors
244    /// Returns [`TxnError`] if the key does not exist or if an error occurs
245    /// during the deletion process.
246    pub fn delete(&mut self, key: &[u8]) -> Result<()> {
247        self.root_page = Tree::new(&self.guard, self.root_page)
248            .delete(key)?
249            .page_num();
250        Ok(())
251    }
252
253    /// Commits the transaction, making all changes permanent and visible
254    /// to subsequent transactions.
255    ///
256    /// If `commit` is not called, the transaction will be automatically
257    /// aborted when it goes out of scope.
258    #[inline]
259    pub fn commit(self) {
260        self.guard.flush(self.root_page);
261    }
262
263    /// Aborts the transaction, discarding all changes made within it.
264    ///
265    /// This is automatically called if the transaction goes out of scope
266    /// without [`Txn::commit`] being called.
267    #[inline]
268    pub fn abort(self) {
269        self.guard.abort();
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use core::str;
276    use std::collections::HashSet;
277    use std::ops::{Bound, Range};
278
279    use anyhow::{Context, Result};
280    use rand::distr::{Alphabetic, SampleString as _};
281    use rand::prelude::*;
282    use rand_chacha::ChaCha8Rng;
283    use tempfile::NamedTempFile;
284
285    use super::{
286        error::{NodeError, TreeError},
287        *,
288    };
289
290    const DEFAULT_SEED: u64 = 1;
291    const DEFAULT_NUM_SEEDED_KEY_VALS: usize = 1000;
292
293    fn new_test_db() -> (DB, NamedTempFile) {
294        let temp_file = NamedTempFile::new().unwrap();
295        let path = temp_file.path();
296        let db = DBBuilder::new(path).free_batch_size(1).build().unwrap();
297        (db, temp_file)
298    }
299
300    struct Seeder {
301        n: usize,
302        rng: ChaCha8Rng,
303    }
304
305    impl Seeder {
306        fn new(n: usize, seed: u64) -> Self {
307            Seeder {
308                n,
309                rng: ChaCha8Rng::seed_from_u64(seed),
310            }
311        }
312
313        fn seed_db(self, db: &DB) -> Result<()> {
314            let mut t = db.rw_txn();
315            for (i, (k, v)) in self.enumerate() {
316                let result = t.insert(k.as_bytes(), v.as_bytes());
317                if matches!(
318                    result,
319                    Err(TxnError::Tree(TreeError::AlreadyExists))
320                ) {
321                    // Skip
322                    continue;
323                }
324                result.with_context(|| format!("failed to insert {i}th ({k}, {v})"))?;
325            }
326            t.commit();
327            Ok(())
328        }
329    }
330
331    impl Iterator for Seeder {
332        type Item = (String, String);
333        fn next(&mut self) -> Option<Self::Item> {
334            if self.n == 0 {
335                return None;
336            }
337            self.n -= 1;
338            let key_len = self.rng.random_range(1..=consts::MAX_KEY_SIZE);
339            let val_len = self.rng.random_range(1..=consts::MAX_VALUE_SIZE);
340            let key: String = Alphabetic.sample_string(&mut self.rng, key_len);
341            let val: String = Alphabetic.sample_string(&mut self.rng, val_len);
342            Some((key, val))
343        }
344    }
345
346    fn u64_to_key(i: u64) -> [u8; consts::MAX_KEY_SIZE] {
347        let mut key = [0u8; consts::MAX_KEY_SIZE];
348        key[0..8].copy_from_slice(&i.to_be_bytes());
349        key
350    }
351
352    #[test]
353    fn test_insert() {
354        let (db, _temp_file) = new_test_db();
355        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
356            .seed_db(&db)
357            .unwrap();
358        let kvs = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
359            .collect::<HashSet<(String, String)>>();
360        let t = db.r_txn();
361        for (k, v) in kvs {
362            match t.get(k.as_bytes()) {
363                Err(err) => panic!("get({k}) unexpectedly got err {err}"),
364                Ok(None) => panic!("get({k}) unexpectedly got None"),
365                Ok(Some(got)) => {
366                    let got = str::from_utf8(got).expect("get({k}) is a alphabetic string");
367                    assert_eq!(got, v.as_str(), "get({k}) got = {got}, want = {v}");
368                }
369            }
370        }
371        // Verify equal height invariant.
372        Tree::new(&t.guard, t.root_page).check_height().unwrap();
373    }
374
375    #[test]
376    fn test_update() {
377        let (db, _temp_file) = new_test_db();
378        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
379            .seed_db(&db)
380            .unwrap();
381        let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
382            .map(|(k, _)| k)
383            .collect::<HashSet<_>>();
384        let updated_val = [1u8; consts::MAX_VALUE_SIZE];
385        {
386            let mut t = db.rw_txn();
387            for k in ks.iter() {
388                t.update(k.as_bytes(), &updated_val)
389                    .unwrap_or_else(|_| panic!("update({k}, &updated_val) should succeed"));
390            }
391            t.commit();
392        }
393        {
394            let t = db.r_txn();
395            for k in ks.iter() {
396                match t.get(k.as_bytes()) {
397                    Err(err) => panic!("get({k}) unexpectedly got err {err}"),
398                    Ok(None) => panic!("get({k}) unexpectedly got None"),
399                    Ok(Some(got)) => {
400                        assert_eq!(
401                            got, &updated_val,
402                            "get({k}) got = {got:?}, want = {updated_val:?}"
403                        );
404                    }
405                }
406            }
407            // Verify equal height invariant.
408            Tree::new(&t.guard, t.root_page).check_height().unwrap();
409        }
410    }
411
412    #[test]
413    fn test_delete() {
414        let (db, _temp_file) = new_test_db();
415        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
416            .seed_db(&db)
417            .unwrap();
418        let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
419            .map(|(k, _)| k)
420            .collect::<HashSet<_>>();
421        let mut t = db.rw_txn();
422        for k in ks.iter() {
423            if let Err(err) = t.delete(k.as_bytes()) {
424                panic!("delete({k}) unexpectedly got err {err}");
425            }
426            match t.get(k.as_bytes()) {
427                Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
428                Ok(Some(v)) => {
429                    panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
430                }
431                _ => {}
432            };
433        }
434        t.commit();
435        let t = db.r_txn();
436        for k in ks.iter() {
437            match t.get(k.as_bytes()) {
438                Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
439                Ok(Some(v)) => {
440                    panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
441                }
442                _ => {}
443            };
444        }
445        // Verify equal height invariant.
446        Tree::new(&t.guard, t.root_page).check_height().unwrap();
447    }
448
449    #[test]
450    fn test_in_order_range_iter() {
451        let (db, _temp_file) = new_test_db();
452        // Setup.
453        {
454            let mut t = db.rw_txn();
455            let mut inds = (1..=100).collect::<Vec<_>>();
456            inds.shuffle(&mut rand::rng());
457            for i in inds {
458                let x = u64_to_key(i);
459                t.insert(&x, &x).unwrap();
460            }
461            t.commit();
462        }
463
464        let t = db.r_txn();
465
466        // Golang style table-driven tests.
467        struct TestCase {
468            name: &'static str,
469            range: (Bound<&'static [u8]>, Bound<&'static [u8]>),
470            want: Range<u64>,
471        }
472        impl Drop for TestCase {
473            fn drop(&mut self) {
474                for b in [self.range.0, self.range.1] {
475                    match b {
476                        Bound::Included(b) => {
477                            drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
478                        }
479                        Bound::Excluded(b) => {
480                            drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
481                        }
482                        _ => {}
483                    }
484                }
485            }
486        }
487        let tests = [
488            TestCase {
489                name: "unbounded unbounded",
490                range: (Bound::Unbounded, Bound::Unbounded),
491                want: 1..101,
492            },
493            TestCase {
494                name: "included included",
495                range: (
496                    Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
497                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
498                ),
499                want: 5..99,
500            },
501            TestCase {
502                name: "excluded included",
503                range: (
504                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
505                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
506                ),
507                want: 6..99,
508            },
509            TestCase {
510                name: "excluded excluded",
511                range: (
512                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
513                    Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
514                ),
515                want: 6..98,
516            },
517            TestCase {
518                name: "unbounded included",
519                range: (
520                    Bound::Unbounded,
521                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
522                ),
523                want: 1..99,
524            },
525            TestCase {
526                name: "unbounded excluded",
527                range: (
528                    Bound::Unbounded,
529                    Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
530                ),
531                want: 1..98,
532            },
533            TestCase {
534                name: "included unbounded",
535                range: (
536                    Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
537                    Bound::Unbounded,
538                ),
539                want: 5..101,
540            },
541            TestCase {
542                name: "excluded unbounded",
543                range: (
544                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
545                    Bound::Unbounded,
546                ),
547                want: 6..101,
548            },
549            TestCase {
550                name: "no overlap",
551                range: (
552                    Bound::Excluded(Box::leak(Box::new(u64_to_key(200)))),
553                    Bound::Unbounded,
554                ),
555                want: 0..0,
556            },
557        ];
558        for test in tests {
559            let got = t
560                .in_order_range_iter(&test.range)
561                .map(|(k, _)| u64::from_be_bytes([k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]]))
562                .collect::<Vec<_>>();
563            let want = test.want.clone().collect::<Vec<_>>();
564            assert_eq!(got, want, "Test case \"{}\" failed", test.name);
565        }
566    }
567}