byodb_rust/api/
mod.rs

1//! Provides the public API for interacting with the database.
2//!
3//! This module defines the main `DB` struct for database instantiation and
4//! transaction management, as well as the `Txn` struct for performing
5//! read (and possibly write) operations within a transaction.
6
7pub mod error;
8
9use std::{marker::PhantomData, ops::RangeBounds, path::Path, sync::Arc};
10
11use error::TxnError;
12use seize::Collector;
13
14pub use crate::core::consts;
15use crate::core::{
16    mmap::{self, Guard, ImmutablePage, Mmap, Reader, ReaderPage, Store, Writer, WriterPage},
17    tree::Tree,
18};
19
20/// A specialized `Result` type for database operations within this module.
21pub type Result<T> = std::result::Result<T, TxnError>;
22/// A read-write transaction. It must live as long as the [`DB`] that created
23/// it (via [`DB::rw_txn`]).
24pub type RWTxn<'t, 'd> = Txn<'t, WriterPage<'t, 'd>, Writer<'d>>;
25/// A read-only transaction. It must live as long as the [`DB`] that created
26/// it (via [`DB::r_txn`]).
27pub type RTxn<'t, 'd> = Txn<'t, ReaderPage<'t>, Reader<'d>>;
28
29/// Represents the main database instance.
30///
31/// `DB` is the entry point for all database interactions. It handles the
32/// underlying storage and provides methods to create read-only or read-write
33/// transactions.
34#[derive(Clone)]
35pub struct DB {
36    store: Arc<Store>,
37}
38
39impl DB {
40    /// Opens an existing database file or creates a new one if it doesn't exist.
41    /// Uses default settings. If you need customization, use
42    /// [`DBBuilder::new`] instead.
43    ///
44    /// # Parameters
45    /// - `path`: A path to the database file.
46    ///
47    /// # Errors
48    /// Returns [`TxnError`] if the database file cannot be opened or created,
49    /// or if there's an issue initializing the memory map.
50    pub fn open_or_create<P: AsRef<Path>>(path: P) -> Result<Self> {
51        Ok(DB {
52            store: Arc::new(Store::new(
53                Mmap::open_or_create(path, mmap::DEFAULT_MIN_FILE_GROWTH_SIZE)?,
54                Collector::new(),
55            )),
56        })
57    }
58
59    /// Begins a new read-only transaction.
60    ///
61    /// Read-only transactions allow for concurrent reads without blocking
62    /// other readers or writers.
63    ///
64    /// # Returns
65    /// A [`RTxn`] instance for performing read operations.
66    pub fn r_txn(&self) -> RTxn<'_, '_> {
67        let reader = self.store.reader();
68        let root_page = reader.root_page();
69        Txn {
70            _phantom: PhantomData,
71            guard: reader,
72            root_page,
73        }
74    }
75
76    /// Begins a new read-write transaction.
77    ///
78    /// Read-write transactions provide exclusive write access. Changes made
79    /// within this transaction are isolated until `commit` is called.
80    ///
81    /// # Returns
82    /// A [`RWTxn`] instance for performing read and write operations.
83    pub fn rw_txn(&self) -> RWTxn<'_, '_> {
84        let writer = self.store.writer();
85        let root_page = writer.root_page();
86        Txn {
87            _phantom: PhantomData,
88            guard: writer,
89            root_page,
90        }
91    }
92}
93
94/// Builder of a [`DB`].
95pub struct DBBuilder<P: AsRef<Path>> {
96    db_path: P,
97    pub free_batch_size: Option<usize>,
98    min_file_growth_size: usize,
99}
100
101impl<P: AsRef<Path>> DBBuilder<P> {
102    /// Creates a new [`DB`] from `db_path`.
103    pub fn new(db_path: P) -> Self {
104        let free_batch_size = if cfg!(test) { Some(1) } else { None };
105        DBBuilder {
106            db_path,
107            free_batch_size,
108            min_file_growth_size: mmap::DEFAULT_MIN_FILE_GROWTH_SIZE,
109        }
110    }
111
112    /// Sets the number of free pages before reclamation can be attempted.
113    /// The default is 32, though this can change in the future.
114    pub fn free_batch_size(self, val: usize) -> Self {
115        DBBuilder {
116            db_path: self.db_path,
117            free_batch_size: Some(val),
118            min_file_growth_size: self.min_file_growth_size,
119        }
120    }
121
122    /// Sets the minimum file growth size.
123    /// The default is 64MB in release, though this can change in the future.
124    pub fn min_file_growth_size(self, val: usize) -> Self {
125        DBBuilder {
126            db_path: self.db_path,
127            free_batch_size: self.free_batch_size,
128            min_file_growth_size: val,
129        }
130    }
131
132    /// Builds a DB.
133    pub fn build(self) -> Result<DB> {
134        let collector = match &self.free_batch_size {
135            &Some(size) => Collector::new().batch_size(size),
136            None => Collector::new(),
137        };
138        Ok(DB {
139            store: Arc::new(Store::new(
140                Mmap::open_or_create(self.db_path, self.min_file_growth_size)?,
141                collector,
142            )),
143        })
144    }
145}
146
147/// Represents a database transaction.
148///
149/// A transaction provides a consistent view of the database. It can be
150/// either read-only ([`RTxn`]) or read-write ([`RWTxn`]).
151/// All operations on the database are performed within a transaction.
152pub struct Txn<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> {
153    _phantom: PhantomData<&'g P>,
154    guard: G,
155    root_page: usize,
156}
157
158impl<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> Txn<'g, P, G> {
159    /// Retrieves the value associated with the given key.
160    ///
161    /// # Parameters
162    /// - `key`: The key to search for.
163    ///
164    /// # Returns
165    /// - `Ok(Some(value))` if the key is found.
166    /// - `Ok(None)` if the key is not found.
167    /// - `Err(TxnError)` if an error occurs during the tree traversal.
168    ///
169    /// # Safety
170    /// The returned slice `&[u8]` is valid as long as the transaction [`Txn`]
171    /// is alive, as it points directly into the memory-mapped region.
172    pub fn get(&'g self, key: &[u8]) -> Result<Option<&'g [u8]>> {
173        let tree = Tree::new(&self.guard, self.root_page);
174        let val = tree.get(key)?.map(|val| {
175            // Safety: The underlying data is in the mmap, which
176            // self.guard has access to. So long as self.guard exists,
177            // so too does the mmap.
178            unsafe { &*std::ptr::slice_from_raw_parts(val.as_ptr(), val.len()) }
179        });
180        Ok(val)
181    }
182
183    /// Returns an iterator over all key-value pairs in the database, in key order.
184    ///
185    /// The iterator yields tuples of `(&[u8], &[u8])` representing key-value pairs.
186    pub fn in_order_iter(&'g self) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
187        Tree::new(&self.guard, self.root_page).in_order_iter()
188    }
189
190    /// Returns an iterator over key-value pairs within the specified range, in key order.
191    ///
192    /// # Parameters
193    /// - `range`: A range bound (e.g., `start_key..end_key`, `..end_key`, `start_key..`)
194    ///   that defines the keys to include in the iteration.
195    ///
196    /// The iterator yields tuples of `(&[u8], &[u8])` representing key-value pairs.
197    pub fn in_order_range_iter<R: RangeBounds<[u8]>>(
198        &'g self,
199        range: &R,
200    ) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
201        Tree::new(&self.guard, self.root_page).in_order_range_iter(range)
202    }
203}
204
205impl<'t, 'd> Txn<'t, WriterPage<'t, 'd>, Writer<'d>> {
206    /// Inserts a new key-value pair into the database.
207    ///
208    /// # Parameters
209    /// - `key`: The key to insert.
210    /// - `val`: The value to associate with the key.
211    ///
212    /// # Errors
213    /// Returns [`TxnError`] if the key already exists or if an error occurs
214    /// during the insertion process.
215    pub fn insert(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
216        self.root_page = Tree::new(&self.guard, self.root_page)
217            .insert(key, val)?
218            .page_num();
219        Ok(())
220    }
221
222    /// Updates the value associated with an existing key.
223    ///
224    /// # Parameters
225    /// - `key`: The key whose value is to be updated.
226    /// - `val`: The new value to associate with the key.
227    ///
228    /// # Errors
229    /// Returns [`TxnError`] if the key does not exist or if an error occurs
230    /// during the update process.
231    pub fn update(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
232        self.root_page = Tree::new(&self.guard, self.root_page)
233            .update(key, val)?
234            .page_num();
235        Ok(())
236    }
237
238    /// Deletes a key-value pair from the database.
239    ///
240    /// # Parameters
241    /// - `key`: The key to delete.
242    ///
243    /// # Errors
244    /// Returns [`TxnError`] if the key does not exist or if an error occurs
245    /// during the deletion process.
246    pub fn delete(&mut self, key: &[u8]) -> Result<()> {
247        self.root_page = Tree::new(&self.guard, self.root_page)
248            .delete(key)?
249            .page_num();
250        Ok(())
251    }
252
253    /// Commits the transaction, making all changes permanent and visible
254    /// to subsequent transactions.
255    ///
256    /// If `commit` is not called, the transaction will be automatically
257    /// aborted when it goes out of scope.
258    #[inline]
259    pub fn commit(self) {
260        self.guard.flush(self.root_page);
261    }
262
263    /// Aborts the transaction, discarding all changes made within it.
264    ///
265    /// This is automatically called if the transaction goes out of scope
266    /// without [`Txn::commit`] being called.
267    #[inline]
268    pub fn abort(self) {
269        self.guard.abort();
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use core::str;
276    use std::collections::HashSet;
277    use std::ops::{Bound, Range};
278
279    use anyhow::{Context, Result};
280    use rand::distr::{Alphabetic, SampleString as _};
281    use rand::prelude::*;
282    use rand_chacha::ChaCha8Rng;
283    use tempfile::NamedTempFile;
284
285    use super::{error::TreeError, *};
286
287    const DEFAULT_SEED: u64 = 1;
288    const DEFAULT_NUM_SEEDED_KEY_VALS: usize = 1000;
289
290    fn new_test_db() -> (DB, NamedTempFile) {
291        let temp_file = NamedTempFile::new().unwrap();
292        let path = temp_file.path();
293        let db = DBBuilder::new(path).free_batch_size(1).build().unwrap();
294        (db, temp_file)
295    }
296
297    struct Seeder {
298        n: usize,
299        rng: ChaCha8Rng,
300    }
301
302    impl Seeder {
303        fn new(n: usize, seed: u64) -> Self {
304            Seeder {
305                n,
306                rng: ChaCha8Rng::seed_from_u64(seed),
307            }
308        }
309
310        fn seed_db(self, db: &DB) -> Result<()> {
311            let mut t = db.rw_txn();
312            for (i, (k, v)) in self.enumerate() {
313                let result = t.insert(k.as_bytes(), v.as_bytes());
314                if matches!(result, Err(TxnError::Tree(TreeError::AlreadyExists))) {
315                    // Skip
316                    continue;
317                }
318                result.with_context(|| format!("failed to insert {i}th ({k}, {v})"))?;
319            }
320            t.commit();
321            Ok(())
322        }
323    }
324
325    impl Iterator for Seeder {
326        type Item = (String, String);
327        fn next(&mut self) -> Option<Self::Item> {
328            if self.n == 0 {
329                return None;
330            }
331            self.n -= 1;
332            let key_len = self.rng.random_range(1..=consts::MAX_KEY_SIZE);
333            let val_len = self.rng.random_range(1..=consts::MAX_VALUE_SIZE);
334            let key: String = Alphabetic.sample_string(&mut self.rng, key_len);
335            let val: String = Alphabetic.sample_string(&mut self.rng, val_len);
336            Some((key, val))
337        }
338    }
339
340    fn u64_to_key(i: u64) -> [u8; consts::MAX_KEY_SIZE] {
341        let mut key = [0u8; consts::MAX_KEY_SIZE];
342        key[0..8].copy_from_slice(&i.to_be_bytes());
343        key
344    }
345
346    #[test]
347    fn test_insert() {
348        let (db, _temp_file) = new_test_db();
349        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
350            .seed_db(&db)
351            .unwrap();
352        let kvs = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
353            .collect::<HashSet<(String, String)>>();
354        let t = db.r_txn();
355        for (k, v) in kvs {
356            match t.get(k.as_bytes()) {
357                Err(err) => panic!("get({k}) unexpectedly got err {err}"),
358                Ok(None) => panic!("get({k}) unexpectedly got None"),
359                Ok(Some(got)) => {
360                    let got = str::from_utf8(got).expect("get({k}) is a alphabetic string");
361                    assert_eq!(got, v.as_str(), "get({k}) got = {got}, want = {v}");
362                }
363            }
364        }
365        // Verify equal height invariant.
366        Tree::new(&t.guard, t.root_page).check_height().unwrap();
367    }
368
369    #[test]
370    fn test_update() {
371        let (db, _temp_file) = new_test_db();
372        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
373            .seed_db(&db)
374            .unwrap();
375        let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
376            .map(|(k, _)| k)
377            .collect::<HashSet<_>>();
378        let updated_val = [1u8; consts::MAX_VALUE_SIZE];
379        {
380            let mut t = db.rw_txn();
381            for k in ks.iter() {
382                t.update(k.as_bytes(), &updated_val)
383                    .unwrap_or_else(|_| panic!("update({k}, &updated_val) should succeed"));
384            }
385            t.commit();
386        }
387        {
388            let t = db.r_txn();
389            for k in ks.iter() {
390                match t.get(k.as_bytes()) {
391                    Err(err) => panic!("get({k}) unexpectedly got err {err}"),
392                    Ok(None) => panic!("get({k}) unexpectedly got None"),
393                    Ok(Some(got)) => {
394                        assert_eq!(
395                            got, &updated_val,
396                            "get({k}) got = {got:?}, want = {updated_val:?}"
397                        );
398                    }
399                }
400            }
401            // Verify equal height invariant.
402            Tree::new(&t.guard, t.root_page).check_height().unwrap();
403        }
404    }
405
406    #[test]
407    fn test_delete() {
408        let (db, _temp_file) = new_test_db();
409        Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
410            .seed_db(&db)
411            .unwrap();
412        let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
413            .map(|(k, _)| k)
414            .collect::<HashSet<_>>();
415        let mut t = db.rw_txn();
416        for k in ks.iter() {
417            if let Err(err) = t.delete(k.as_bytes()) {
418                panic!("delete({k}) unexpectedly got err {err}");
419            }
420            match t.get(k.as_bytes()) {
421                Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
422                Ok(Some(v)) => {
423                    panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
424                }
425                _ => {}
426            };
427        }
428        t.commit();
429        let t = db.r_txn();
430        for k in ks.iter() {
431            match t.get(k.as_bytes()) {
432                Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
433                Ok(Some(v)) => {
434                    panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
435                }
436                _ => {}
437            };
438        }
439        // Verify equal height invariant.
440        Tree::new(&t.guard, t.root_page).check_height().unwrap();
441    }
442
443    #[test]
444    fn test_in_order_range_iter() {
445        let (db, _temp_file) = new_test_db();
446        // Setup.
447        {
448            let mut t = db.rw_txn();
449            let mut inds = (1..=100).collect::<Vec<_>>();
450            inds.shuffle(&mut rand::rng());
451            for i in inds {
452                let x = u64_to_key(i);
453                t.insert(&x, &x).unwrap();
454            }
455            t.commit();
456        }
457
458        let t = db.r_txn();
459
460        // Golang style table-driven tests.
461        struct TestCase {
462            name: &'static str,
463            range: (Bound<&'static [u8]>, Bound<&'static [u8]>),
464            want: Range<u64>,
465        }
466        impl Drop for TestCase {
467            fn drop(&mut self) {
468                for b in [self.range.0, self.range.1] {
469                    match b {
470                        Bound::Included(b) => {
471                            drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
472                        }
473                        Bound::Excluded(b) => {
474                            drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
475                        }
476                        _ => {}
477                    }
478                }
479            }
480        }
481        let tests = [
482            TestCase {
483                name: "unbounded unbounded",
484                range: (Bound::Unbounded, Bound::Unbounded),
485                want: 1..101,
486            },
487            TestCase {
488                name: "included included",
489                range: (
490                    Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
491                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
492                ),
493                want: 5..99,
494            },
495            TestCase {
496                name: "excluded included",
497                range: (
498                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
499                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
500                ),
501                want: 6..99,
502            },
503            TestCase {
504                name: "excluded excluded",
505                range: (
506                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
507                    Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
508                ),
509                want: 6..98,
510            },
511            TestCase {
512                name: "unbounded included",
513                range: (
514                    Bound::Unbounded,
515                    Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
516                ),
517                want: 1..99,
518            },
519            TestCase {
520                name: "unbounded excluded",
521                range: (
522                    Bound::Unbounded,
523                    Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
524                ),
525                want: 1..98,
526            },
527            TestCase {
528                name: "included unbounded",
529                range: (
530                    Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
531                    Bound::Unbounded,
532                ),
533                want: 5..101,
534            },
535            TestCase {
536                name: "excluded unbounded",
537                range: (
538                    Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
539                    Bound::Unbounded,
540                ),
541                want: 6..101,
542            },
543            TestCase {
544                name: "no overlap",
545                range: (
546                    Bound::Excluded(Box::leak(Box::new(u64_to_key(200)))),
547                    Bound::Unbounded,
548                ),
549                want: 0..0,
550            },
551        ];
552        for test in tests {
553            let got = t
554                .in_order_range_iter(&test.range)
555                .map(|(k, _)| u64::from_be_bytes([k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]]))
556                .collect::<Vec<_>>();
557            let want = test.want.clone().collect::<Vec<_>>();
558            assert_eq!(got, want, "Test case \"{}\" failed", test.name);
559        }
560    }
561}