1pub mod error;
8
9use std::{marker::PhantomData, ops::RangeBounds, path::Path, sync::Arc};
10
11use error::TxnError;
12use seize::Collector;
13
14pub use crate::core::consts;
15use crate::core::{
16 mmap::{self, Guard, ImmutablePage, Mmap, Reader, ReaderPage, Store, Writer, WriterPage},
17 tree::Tree,
18};
19
20pub type Result<T> = std::result::Result<T, TxnError>;
22pub type RWTxn<'t, 'd> = Txn<'t, WriterPage<'t, 'd>, Writer<'d>>;
25pub type RTxn<'t, 'd> = Txn<'t, ReaderPage<'t>, Reader<'d>>;
28
29#[derive(Clone)]
35pub struct DB {
36 store: Arc<Store>,
37}
38
39impl DB {
40 pub fn open_or_create<P: AsRef<Path>>(path: P) -> Result<Self> {
51 Ok(DB {
52 store: Arc::new(Store::new(
53 Mmap::open_or_create(path, mmap::DEFAULT_MIN_FILE_GROWTH_SIZE)?,
54 Collector::new(),
55 )),
56 })
57 }
58
59 pub fn r_txn(&self) -> RTxn<'_, '_> {
67 let reader = self.store.reader();
68 let root_page = reader.root_page();
69 Txn {
70 _phantom: PhantomData,
71 guard: reader,
72 root_page,
73 }
74 }
75
76 pub fn rw_txn(&self) -> RWTxn<'_, '_> {
84 let writer = self.store.writer();
85 let root_page = writer.root_page();
86 Txn {
87 _phantom: PhantomData,
88 guard: writer,
89 root_page,
90 }
91 }
92}
93
94pub struct DBBuilder<P: AsRef<Path>> {
96 db_path: P,
97 pub free_batch_size: Option<usize>,
98 min_file_growth_size: usize,
99}
100
101impl<P: AsRef<Path>> DBBuilder<P> {
102 pub fn new(db_path: P) -> Self {
104 let free_batch_size = if cfg!(test) { Some(1) } else { None };
105 DBBuilder {
106 db_path,
107 free_batch_size,
108 min_file_growth_size: mmap::DEFAULT_MIN_FILE_GROWTH_SIZE,
109 }
110 }
111
112 pub fn free_batch_size(self, val: usize) -> Self {
115 DBBuilder {
116 db_path: self.db_path,
117 free_batch_size: Some(val),
118 min_file_growth_size: self.min_file_growth_size,
119 }
120 }
121
122 pub fn min_file_growth_size(self, val: usize) -> Self {
125 DBBuilder {
126 db_path: self.db_path,
127 free_batch_size: self.free_batch_size,
128 min_file_growth_size: val,
129 }
130 }
131
132 pub fn build(self) -> Result<DB> {
134 let collector = match &self.free_batch_size {
135 &Some(size) => Collector::new().batch_size(size),
136 None => Collector::new(),
137 };
138 Ok(DB {
139 store: Arc::new(Store::new(
140 Mmap::open_or_create(self.db_path, self.min_file_growth_size)?,
141 collector,
142 )),
143 })
144 }
145}
146
147pub struct Txn<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> {
153 _phantom: PhantomData<&'g P>,
154 guard: G,
155 root_page: usize,
156}
157
158impl<'g, P: ImmutablePage<'g>, G: Guard<'g, P>> Txn<'g, P, G> {
159 pub fn get(&'g self, key: &[u8]) -> Result<Option<&'g [u8]>> {
173 let tree = Tree::new(&self.guard, self.root_page);
174 let val = tree.get(key)?.map(|val| {
175 unsafe { &*std::ptr::slice_from_raw_parts(val.as_ptr(), val.len()) }
179 });
180 Ok(val)
181 }
182
183 pub fn in_order_iter(&'g self) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
187 Tree::new(&self.guard, self.root_page).in_order_iter()
188 }
189
190 pub fn in_order_range_iter<R: RangeBounds<[u8]>>(
198 &'g self,
199 range: &R,
200 ) -> impl Iterator<Item = (&'g [u8], &'g [u8])> {
201 Tree::new(&self.guard, self.root_page).in_order_range_iter(range)
202 }
203}
204
205impl<'t, 'd> Txn<'t, WriterPage<'t, 'd>, Writer<'d>> {
206 pub fn insert(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
216 self.root_page = Tree::new(&self.guard, self.root_page)
217 .insert(key, val)?
218 .page_num();
219 Ok(())
220 }
221
222 pub fn update(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
232 self.root_page = Tree::new(&self.guard, self.root_page)
233 .update(key, val)?
234 .page_num();
235 Ok(())
236 }
237
238 pub fn delete(&mut self, key: &[u8]) -> Result<()> {
247 self.root_page = Tree::new(&self.guard, self.root_page)
248 .delete(key)?
249 .page_num();
250 Ok(())
251 }
252
253 #[inline]
259 pub fn commit(self) {
260 self.guard.flush(self.root_page);
261 }
262
263 #[inline]
268 pub fn abort(self) {
269 self.guard.abort();
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use core::str;
276 use std::collections::HashSet;
277 use std::ops::{Bound, Range};
278
279 use anyhow::{Context, Result};
280 use rand::distr::{Alphabetic, SampleString as _};
281 use rand::prelude::*;
282 use rand_chacha::ChaCha8Rng;
283 use tempfile::NamedTempFile;
284
285 use super::{error::TreeError, *};
286
287 const DEFAULT_SEED: u64 = 1;
288 const DEFAULT_NUM_SEEDED_KEY_VALS: usize = 1000;
289
290 fn new_test_db() -> (DB, NamedTempFile) {
291 let temp_file = NamedTempFile::new().unwrap();
292 let path = temp_file.path();
293 let db = DBBuilder::new(path).free_batch_size(1).build().unwrap();
294 (db, temp_file)
295 }
296
297 struct Seeder {
298 n: usize,
299 rng: ChaCha8Rng,
300 }
301
302 impl Seeder {
303 fn new(n: usize, seed: u64) -> Self {
304 Seeder {
305 n,
306 rng: ChaCha8Rng::seed_from_u64(seed),
307 }
308 }
309
310 fn seed_db(self, db: &DB) -> Result<()> {
311 let mut t = db.rw_txn();
312 for (i, (k, v)) in self.enumerate() {
313 let result = t.insert(k.as_bytes(), v.as_bytes());
314 if matches!(result, Err(TxnError::Tree(TreeError::AlreadyExists))) {
315 continue;
317 }
318 result.with_context(|| format!("failed to insert {i}th ({k}, {v})"))?;
319 }
320 t.commit();
321 Ok(())
322 }
323 }
324
325 impl Iterator for Seeder {
326 type Item = (String, String);
327 fn next(&mut self) -> Option<Self::Item> {
328 if self.n == 0 {
329 return None;
330 }
331 self.n -= 1;
332 let key_len = self.rng.random_range(1..=consts::MAX_KEY_SIZE);
333 let val_len = self.rng.random_range(1..=consts::MAX_VALUE_SIZE);
334 let key: String = Alphabetic.sample_string(&mut self.rng, key_len);
335 let val: String = Alphabetic.sample_string(&mut self.rng, val_len);
336 Some((key, val))
337 }
338 }
339
340 fn u64_to_key(i: u64) -> [u8; consts::MAX_KEY_SIZE] {
341 let mut key = [0u8; consts::MAX_KEY_SIZE];
342 key[0..8].copy_from_slice(&i.to_be_bytes());
343 key
344 }
345
346 #[test]
347 fn test_insert() {
348 let (db, _temp_file) = new_test_db();
349 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
350 .seed_db(&db)
351 .unwrap();
352 let kvs = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
353 .collect::<HashSet<(String, String)>>();
354 let t = db.r_txn();
355 for (k, v) in kvs {
356 match t.get(k.as_bytes()) {
357 Err(err) => panic!("get({k}) unexpectedly got err {err}"),
358 Ok(None) => panic!("get({k}) unexpectedly got None"),
359 Ok(Some(got)) => {
360 let got = str::from_utf8(got).expect("get({k}) is a alphabetic string");
361 assert_eq!(got, v.as_str(), "get({k}) got = {got}, want = {v}");
362 }
363 }
364 }
365 Tree::new(&t.guard, t.root_page).check_height().unwrap();
367 }
368
369 #[test]
370 fn test_update() {
371 let (db, _temp_file) = new_test_db();
372 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
373 .seed_db(&db)
374 .unwrap();
375 let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
376 .map(|(k, _)| k)
377 .collect::<HashSet<_>>();
378 let updated_val = [1u8; consts::MAX_VALUE_SIZE];
379 {
380 let mut t = db.rw_txn();
381 for k in ks.iter() {
382 t.update(k.as_bytes(), &updated_val)
383 .unwrap_or_else(|_| panic!("update({k}, &updated_val) should succeed"));
384 }
385 t.commit();
386 }
387 {
388 let t = db.r_txn();
389 for k in ks.iter() {
390 match t.get(k.as_bytes()) {
391 Err(err) => panic!("get({k}) unexpectedly got err {err}"),
392 Ok(None) => panic!("get({k}) unexpectedly got None"),
393 Ok(Some(got)) => {
394 assert_eq!(
395 got, &updated_val,
396 "get({k}) got = {got:?}, want = {updated_val:?}"
397 );
398 }
399 }
400 }
401 Tree::new(&t.guard, t.root_page).check_height().unwrap();
403 }
404 }
405
406 #[test]
407 fn test_delete() {
408 let (db, _temp_file) = new_test_db();
409 Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
410 .seed_db(&db)
411 .unwrap();
412 let ks = Seeder::new(DEFAULT_NUM_SEEDED_KEY_VALS, DEFAULT_SEED)
413 .map(|(k, _)| k)
414 .collect::<HashSet<_>>();
415 let mut t = db.rw_txn();
416 for k in ks.iter() {
417 if let Err(err) = t.delete(k.as_bytes()) {
418 panic!("delete({k}) unexpectedly got err {err}");
419 }
420 match t.get(k.as_bytes()) {
421 Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
422 Ok(Some(v)) => {
423 panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
424 }
425 _ => {}
426 };
427 }
428 t.commit();
429 let t = db.r_txn();
430 for k in ks.iter() {
431 match t.get(k.as_bytes()) {
432 Err(err) => panic!("get({k}) after delete() unexpectedly got err {err}"),
433 Ok(Some(v)) => {
434 panic!("get({k}) after delete() unexpectedly got = Some({v:?}), want = None")
435 }
436 _ => {}
437 };
438 }
439 Tree::new(&t.guard, t.root_page).check_height().unwrap();
441 }
442
443 #[test]
444 fn test_in_order_range_iter() {
445 let (db, _temp_file) = new_test_db();
446 {
448 let mut t = db.rw_txn();
449 let mut inds = (1..=100).collect::<Vec<_>>();
450 inds.shuffle(&mut rand::rng());
451 for i in inds {
452 let x = u64_to_key(i);
453 t.insert(&x, &x).unwrap();
454 }
455 t.commit();
456 }
457
458 let t = db.r_txn();
459
460 struct TestCase {
462 name: &'static str,
463 range: (Bound<&'static [u8]>, Bound<&'static [u8]>),
464 want: Range<u64>,
465 }
466 impl Drop for TestCase {
467 fn drop(&mut self) {
468 for b in [self.range.0, self.range.1] {
469 match b {
470 Bound::Included(b) => {
471 drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
472 }
473 Bound::Excluded(b) => {
474 drop(unsafe { Box::from_raw(b.as_ptr() as *mut u8) });
475 }
476 _ => {}
477 }
478 }
479 }
480 }
481 let tests = [
482 TestCase {
483 name: "unbounded unbounded",
484 range: (Bound::Unbounded, Bound::Unbounded),
485 want: 1..101,
486 },
487 TestCase {
488 name: "included included",
489 range: (
490 Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
491 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
492 ),
493 want: 5..99,
494 },
495 TestCase {
496 name: "excluded included",
497 range: (
498 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
499 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
500 ),
501 want: 6..99,
502 },
503 TestCase {
504 name: "excluded excluded",
505 range: (
506 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
507 Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
508 ),
509 want: 6..98,
510 },
511 TestCase {
512 name: "unbounded included",
513 range: (
514 Bound::Unbounded,
515 Bound::Included(Box::leak(Box::new(u64_to_key(98)))),
516 ),
517 want: 1..99,
518 },
519 TestCase {
520 name: "unbounded excluded",
521 range: (
522 Bound::Unbounded,
523 Bound::Excluded(Box::leak(Box::new(u64_to_key(98)))),
524 ),
525 want: 1..98,
526 },
527 TestCase {
528 name: "included unbounded",
529 range: (
530 Bound::Included(Box::leak(Box::new(u64_to_key(5)))),
531 Bound::Unbounded,
532 ),
533 want: 5..101,
534 },
535 TestCase {
536 name: "excluded unbounded",
537 range: (
538 Bound::Excluded(Box::leak(Box::new(u64_to_key(5)))),
539 Bound::Unbounded,
540 ),
541 want: 6..101,
542 },
543 TestCase {
544 name: "no overlap",
545 range: (
546 Bound::Excluded(Box::leak(Box::new(u64_to_key(200)))),
547 Bound::Unbounded,
548 ),
549 want: 0..0,
550 },
551 ];
552 for test in tests {
553 let got = t
554 .in_order_range_iter(&test.range)
555 .map(|(k, _)| u64::from_be_bytes([k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]]))
556 .collect::<Vec<_>>();
557 let want = test.want.clone().collect::<Vec<_>>();
558 assert_eq!(got, want, "Test case \"{}\" failed", test.name);
559 }
560 }
561}