1pub mod error;
8
9use std::{marker::PhantomData, ops::RangeBounds, path::Path, sync::Arc};
10
11use error::TxnError;
12use seize::Collector;
13
14pub use crate::core::consts;
15use crate::core::{
16 mmap::{self, Guard, ImmutablePage, Mmap, Reader, ReaderPage, Store, Writer, WriterPage},
17 tree::Tree,
18};
19
20pub type Result<T> = std::result::Result<T, TxnError>;
22pub type RWTxn<'t, 'd> = Txn<'t, WriterPage<'t, 'd>, Writer<'d>>;
25pub type RTxn<'t, 'd> = Txn<'t, ReaderPage<'t>, Reader<'d>>;
28
29#[derive(Clone)]
35pub struct DB {
36 store: Arc<Store>,
37}
38
39impl DB {
40 pub fn open_or_create<P: AsRef<Path>>(path: P) -> Result<Self> {
51 Ok(DB {
52 store: Arc::new(Store::new(
53 Mmap::open_or_create(path, mmap::DEFAULT_MIN_FILE_GROWTH_SIZE)?,
54 Collector::new(),
55 )),
56 })
57 }
58
59 pub fn r_txn(&self) -> RTxn<'_, '_> {
67 let reader = self.store.reader();
68 let root_page = reader.root_page();
69 Txn {
70 _phantom: PhantomData,
71 guard: reader,
72 root_page,
73 }
74 }
75
76 pub fn rw_txn(&self) -> RWTxn<'_, '_> {
84 let writer = self.store.writer();
85 let root_page = writer.root_page();
86 Txn {
87 _phantom: PhantomData,
88 guard: writer,
89 root_page,
90 }
91 }
92}
93
94pub struct DBBuilder<P: AsRef<Path>> {
96 db_path: P,
97 pub free_batch_size: Option<usize>,
98 min_file_growth_size: usize,
99}
100
101impl<P: AsRef<Path>> DBBuilder<P> {
102 pub fn new(db_path: P) -> Self {
104 let free_batch_size = if cfg!(test) { Some(1) } else { None };
105 DBBuilder {
106 db_path,
107 free_batch_size,
108 min_file_growth_size: mmap::DEFAULT_MIN_FILE_GROWTH_SIZE,
109 }
110 }
111
112 pub fn free_batch_size(self, val: usize) -> Self {
115 DBBuilder {
116 db_path: self.db_path,
117 free_batch_size: Some(val),
118 min_file_growth_size: self.min_file_growth_size,
119 }
120 }
121
122 pub fn min_file_growth_size(self, val: usize) -> Self {
125 DBBuilder {
126 db_path: self.db_path,
127 free_batch_size: self.free_batch_size,
128 min_file_growth_size: val,
129 }
130 }
131
132 pub fn build(self) -> Result<DB> {
134 let collector = match &self.free_batch_size {
135 &Some(size) => Collector::new().batch_size(size),
136 None => Collector::new(),
137 };
138 Ok(DB {
139 store: Arc::new(Store::new(
140 Mmap::open_or_create(self.db_path, self.min_file_growth_size)?,
141 collector,
142 )),
143 })
144 }
145}
146
147pub struct Txn<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> {
153 _phantom: PhantomData<&'g P>,
154 guard: G,
155 root_page: usize,
156}
157
158impl<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> Txn<'g, P, G> {
159 pub fn get(&'g self, key: &[u8]) -> Result<Option<&'g [u8]>> {
173 let tree = Tree::new(&self.guard, self.root_page);
174 let val = tree.get(key)?.map(|val| {
175 unsafe { &*std::ptr::slice_from_raw_parts(val.as_ptr(), val.len()) }
179 });
180 Ok(val)
181 }
182
183 pub fn in_order_iter(&'g self) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
187 Tree::new(&self.guard, self.root_page).in_order_iter()
188 }
189
190 pub fn in_order_range_iter<R: RangeBounds<[u8]>>(
198 &'g self,
199 range: &R,
200 ) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
201 Tree::new(&self.guard, self.root_page).in_order_range_iter(range)
202 }
203}
204
205impl<'t, 'd> Txn<'t, WriterPage<'t, 'd>, Writer<'d>> {
206 pub fn insert(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
216 self.root_page = Tree::new(&self.guard, self.root_page)
217 .insert(key, val)?
218 .page_num();
219 Ok(())
220 }
221
222 pub fn update(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
232 self.root_page = Tree::new(&self.guard, self.root_page)
233 .update(key, val)?
234 .page_num();
235 Ok(())
236 }
237
238 pub fn delete(&mut self, key: &[u8]) -> Result<()> {
247 self.root_page = Tree::new(&self.guard, self.root_page)
248 .delete(key)?
249 .page_num();
250 Ok(())
251 }
252
253 #[inline]
259 pub fn commit(self) {
260 self.guard.flush(self.root_page);
261 }
262
263 #[inline]
268 pub fn abort(self) {
269 self.guard.abort();
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use core::str;
276 use std::collections::HashSet;
277 use std::ops::{Bound, Range};
278
279 use anyhow::{Context, Result};
280 use rand::distr::{Alphabetic, SampleString as _};
281 use rand::prelude::*;
282 use rand_chacha::ChaCha8Rng;
283 use tempfile::NamedTempFile;
284
285 use super::{
286 error::{NodeError, TreeError},
287 *,
288 };
289
290 const DEFAULT_SEED: u64 = 1;
291 const DEFAULT_NUM_SEEDED_KEY_VALS: usize = 1000;
292
293 fn new_test_db() -> (DB, NamedTempFile) {
294 let temp_file = NamedTempFile::new().unwrap();
295 let path = temp_file.path();
296 let db = DBBuilder::new(path).free_batch_size(1).build().unwrap();
297 (db, temp_file)
298 }
299
300 struct Seeder {
301 n: usize,
302 rng: ChaCha8Rng,
303 }
304
305 impl Seeder {
306 fn new(n: usize, seed: u64) -> Self {
307 Seeder {
308 n,
309 rng: ChaCha8Rng::seed_from_u64(seed),
310 }
311 }
312
313 fn seed_db(self, db: &DB) -> Result<()> {
314 let mut t = db.rw_txn();
315 for (i, (k, v)) in self.enumerate() {
316 let result = t.insert(k.as_bytes(), v.as_bytes());
317 if matches!(
318 result,
319 Err(TxnError::Tree(TreeError::Node(NodeError::AlreadyExists)))
320 ) {
321 continue;
323 }
324 result.with_context(|| format!("failed to insert {i}th ({k}, {v})"))?;
325 }
326 t.commit();
327 Ok(())
328 }
329 }
330
331 impl Iterator for Seeder {
332 type Item = (String, String);
333 fn next(&mut self) -> Option<Self::Item> {
334 if self.n == 0 {
335 return None;
336 }
337 self.n -= 1;
338 let key_len = self.rng.random_range(1..=consts::MAX_KEY_SIZE);
339 let val_len = self.rng.random_range(1..=consts::MAX_VALUE_SIZE);
340 let key: String = Alphabetic.sample_string(&mut self.rng, key_len);
341 let val: String = Alphabetic.sample_string(&mut self.rng, val_len);
342 Some((key, val))
343 }
344 }
345
346 fn u64_to_key(i: u64) -> [u8; consts::MAX_KEY_SIZE] {
347 let mut key = [0u8; consts::MAX_KEY_SIZE];
348 key[0..8].copy_from_slice(&i.to_be_bytes());
349 key
350 }
351
352 #[test]
353 fn test_insert() {
354 let (db, _temp_file) = new_test_db();
355 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
356 .seed_db(&db)
357 .unwrap();
358 let kvs = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
359 .collect::<HashSet<(String, String)>>();
360 let t = db.r_txn();
361 for (k, v) in kvs {
362 match t.get(k.as_bytes()) {
363 Err(err) => panic!("get({k}) unexpectedly got err {err}"),
364 Ok(None) => panic!("get({k}) unexpectedly got None"),
365 Ok(Some(got)) => {
366 let got = str::from_utf8(got).expect("get({k}) is a alphabetic string");
367 assert_eq!(got, v.as_str(), "get({k}) got = {got}, want = {v}");
368 }
369 }
370 }
371 Tree::new(&t.guard, t.root_page).check_height().unwrap();
373 }
374
375 #[test]
376 fn test_update() {
377 let (db, _temp_file) = new_test_db();
378 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
379 .seed_db(&db)
380 .unwrap();
381 let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
382 .map(|(k, _)| k)
383 .collect::<HashSet<_>>();
384 let updated_val = [1u8; consts::MAX_VALUE_SIZE];
385 {
386 let mut t = db.rw_txn();
387 for k in ks.iter() {
388 t.update(k.as_bytes(), &updated_val)
389 .unwrap_or_else(|_| panic!("update({k}, &updated_val) should succeed"));
390 }
391 t.commit();
392 }
393 {
394 let t = db.r_txn();
395 for k in ks.iter() {
396 match t.get(k.as_bytes()) {
397 Err(err) => panic!("get({k}) unexpectedly got err {err}"),
398 Ok(None) => panic!("get({k}) unexpectedly got None"),
399 Ok(Some(got)) => {
400 assert_eq!(
401 got, &updated_val,
402 "get({k}) got = {got:?}, want = {updated_val:?}"
403 );
404 }
405 }
406 }
407 Tree::new(&t.guard, t.root_page).check_height().unwrap();
409 }
410 }
411
412 #[test]
413 fn test_delete() {
414 let (db, _temp_file) = new_test_db();
415 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
416 .seed_db(&db)
417 .unwrap();
418 let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
419 .map(|(k, _)| k)
420 .collect::<HashSet<_>>();
421 let mut t = db.rw_txn();
422 for k in ks.iter() {
423 if let Err(err) = t.delete(k.as_bytes()) {
424 panic!("delete({k}) unexpectedly got err {err}");
425 }
426 match t.get(k.as_bytes()) {
427 Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
428 Ok(Some(v)) => {
429 panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
430 }
431 _ => {}
432 };
433 }
434 t.commit();
435 let t = db.r_txn();
436 for k in ks.iter() {
437 match t.get(k.as_bytes()) {
438 Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
439 Ok(Some(v)) => {
440 panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
441 }
442 _ => {}
443 };
444 }
445 Tree::new(&t.guard, t.root_page).check_height().unwrap();
447 }
448
449 #[test]
450 fn test_in_order_range_iter() {
451 let (db, _temp_file) = new_test_db();
452 {
454 let mut t = db.rw_txn();
455 let mut inds = (1..=100).collect::<Vec<_>>();
456 inds.shuffle(&mut rand::rng());
457 for i in inds {
458 let x = u64_to_key(i);
459 t.insert(&x, &x).unwrap();
460 }
461 t.commit();
462 }
463
464 let t = db.r_txn();
465
466 struct TestCase {
468 name: &'static str,
469 range: (Bound<&'static [u8]>, Bound<&'static [u8]>),
470 want: Range<u64>,
471 }
472 impl Drop for TestCase {
473 fn drop(&mut self) {
474 for b in [self.range.0, self.range.1] {
475 match b {
476 Bound::Included(b) => {
477 drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
478 }
479 Bound::Excluded(b) => {
480 drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
481 }
482 _ => {}
483 }
484 }
485 }
486 }
487 let tests = [
488 TestCase {
489 name: "unbounded unbounded",
490 range: (Bound::Unbounded, Bound::Unbounded),
491 want: 1..101,
492 },
493 TestCase {
494 name: "included included",
495 range: (
496 Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
497 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
498 ),
499 want: 5..99,
500 },
501 TestCase {
502 name: "excluded included",
503 range: (
504 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
505 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
506 ),
507 want: 6..99,
508 },
509 TestCase {
510 name: "excluded excluded",
511 range: (
512 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
513 Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
514 ),
515 want: 6..98,
516 },
517 TestCase {
518 name: "unbounded included",
519 range: (
520 Bound::Unbounded,
521 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
522 ),
523 want: 1..99,
524 },
525 TestCase {
526 name: "unbounded excluded",
527 range: (
528 Bound::Unbounded,
529 Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
530 ),
531 want: 1..98,
532 },
533 TestCase {
534 name: "included unbounded",
535 range: (
536 Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
537 Bound::Unbounded,
538 ),
539 want: 5..101,
540 },
541 TestCase {
542 name: "excluded unbounded",
543 range: (
544 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
545 Bound::Unbounded,
546 ),
547 want: 6..101,
548 },
549 TestCase {
550 name: "no overlap",
551 range: (
552 Bound::Excluded(Box::leak(Box::new(u64_to_key(200)))),
553 Bound::Unbounded,
554 ),
555 want: 0..0,
556 },
557 ];
558 for test in tests {
559 let got = t
560 .in_order_range_iter(&test.range)
561 .map(|(k, _)| u64::from_be_bytes([k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]]))
562 .collect::<Vec<_>>();
563 let want = test.want.clone().collect::<Vec<_>>();
564 assert_eq!(got, want, "Test case \"{}\" failed", test.name);
565 }
566 }
567}