1#![doc = include_str!("../README.md")]
2#![feature(test)]
3
4extern crate test;
5
6#[cfg(test)]
7mod benches;
8#[cfg(test)]
9mod tests;
10
11use std::borrow::Cow;
12use std::collections::HashMap;
13use std::fs::{File, OpenOptions};
14use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
15use std::path::Path;
16use std::sync::{Arc, RwLock};
17
18use serde::{Deserialize, Serialize};
19
20const PAGE_SIZE: u64 = 4096;
21
22pub trait MerkleKey: Ord + Clone + std::fmt::Debug + Serialize + for<'a> Deserialize<'a> {
24 fn encode(&self) -> Cow<'_, [u8]>;
25}
26
27impl MerkleKey for String {
28 #[inline]
29 fn encode(&self) -> Cow<'_, [u8]> {
30 self.as_bytes().into()
31 }
32}
33
34impl MerkleKey for Vec<u8> {
35 fn encode(&self) -> Cow<'_, [u8]> {
36 self.as_slice().into()
37 }
38}
39
40pub type Hash = [u8; 32];
41type NodeId = u64;
42
43pub struct MerkleSearchTree<K: MerkleKey> {
44 root: Link<K>,
45 store: Arc<Store<K>>,
46}
47
48impl<K: MerkleKey> MerkleSearchTree<K> {
49 pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
51 let store = Store::open(path)?;
52 Ok(Self {
53 root: Link::Loaded(Arc::new(Node::empty(0))),
54 store,
55 })
56 }
57
58 pub fn new_temporary() -> io::Result<Self> {
61 let file = tempfile::tempfile()?;
63 let store = Store::new(file);
64
65 Ok(Self {
66 root: Link::Loaded(Arc::new(Node::empty(0))),
67 store,
68 })
69 }
70
71 pub fn load_from_root<P: AsRef<Path>>(
73 path: P,
74 root_offset: u64,
75 root_hash: Hash,
76 ) -> io::Result<Self> {
77 let store = Store::open(path)?;
78 Ok(Self {
79 root: Link::Disk {
80 offset: root_offset,
81 hash: root_hash,
82 },
83 store,
84 })
85 }
86
87 pub fn insert(&mut self, key: K) -> io::Result<()> {
89 let key_arc = Arc::new(key);
90 let root_node = self.resolve_link(&self.root)?;
92
93 let target_level = Node::calc_level(key_arc.as_ref());
94 let new_root_node = root_node.put(key_arc, target_level, &self.store)?;
95
96 self.root = Link::Loaded(new_root_node);
98 Ok(())
99 }
100
101 pub fn contains(&self, key: &K) -> io::Result<bool> {
103 let root = self.resolve_link(&self.root)?;
104 root.contains(key, &self.store)
105 }
106
107 pub fn remove(&mut self, key: &K) -> io::Result<()> {
109 let root = self.resolve_link(&self.root)?;
110
111 let (new_root, deleted) = root.delete(key, &self.store)?;
113
114 if !deleted {
115 return Ok(()); }
117
118 if new_root.keys.is_empty() && !new_root.children.is_empty() {
120 self.root = new_root.children[0].clone();
124 } else {
125 self.root = Link::Loaded(new_root);
126 }
127
128 Ok(())
129 }
130
131 pub fn flush(&mut self) -> io::Result<(u64, Hash)> {
133 let (offset, hash) = self.flush_recursive(&self.root)?;
134
135 self.store.flush()?;
138
139 self.root = Link::Disk { offset, hash };
140
141 Ok((offset, hash))
142 }
143
144 pub fn root_hash(&self) -> Hash {
145 self.root.hash()
146 }
147
148 fn resolve_link(&self, link: &Link<K>) -> io::Result<Arc<Node<K>>> {
149 match link {
150 Link::Loaded(node) => Ok(node.clone()),
151 Link::Disk { offset, .. } => self.store.load_node(*offset),
152 }
153 }
154
155 fn flush_recursive(&self, link: &Link<K>) -> io::Result<(NodeId, Hash)> {
156 match link {
157 Link::Disk { offset, hash } => Ok((*offset, *hash)),
158 Link::Loaded(node) => {
159 let mut dirty_children = false;
160 for child in &node.children {
161 if let Link::Loaded(_) = child {
162 dirty_children = true;
163 break;
164 }
165 }
166
167 if !dirty_children {
168 let offset = self.store.write_node(node)?;
169 return Ok((offset, node.hash));
170 }
171
172 let mut new_children = Vec::new();
173 for child in &node.children {
174 let (child_offset, child_hash) = self.flush_recursive(child)?;
175 new_children.push(Link::Disk {
176 offset: child_offset,
177 hash: child_hash,
178 });
179 }
180
181 let mut new_node = (**node).clone();
182 new_node.children = new_children;
183 let offset = self.store.write_node(&new_node)?;
184 Ok((offset, new_node.hash))
185 }
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
191enum Link<K: MerkleKey> {
192 Disk { offset: NodeId, hash: Hash },
193 Loaded(Arc<Node<K>>),
194}
195
196impl<K: MerkleKey> Link<K> {
197 fn hash(&self) -> Hash {
198 match self {
199 Link::Disk { hash, .. } => *hash,
200 Link::Loaded(node) => node.hash,
201 }
202 }
203}
204
205struct Store<K: MerkleKey> {
206 file: RwLock<BufWriter<File>>,
207 cache: RwLock<HashMap<NodeId, Arc<Node<K>>>>,
208}
209
210impl<K: MerkleKey> Store<K> {
211 fn new(file: File) -> Arc<Self> {
213 Arc::new(Self {
214 file: RwLock::new(BufWriter::with_capacity(64 * 1024, file)),
216 cache: RwLock::new(HashMap::new()),
217 })
218 }
219
220 fn open<P: AsRef<Path>>(path: P) -> io::Result<Arc<Self>> {
221 let file = OpenOptions::new()
222 .read(true)
223 .write(true)
224 .create(true)
225 .open(path)?;
226
227 Ok(Self::new(file))
228 }
229
230 fn flush(&self) -> io::Result<()> {
232 let mut writer = self.file.write().unwrap();
233 writer.flush()
234 }
235
236 fn load_node(&self, offset: NodeId) -> io::Result<Arc<Node<K>>> {
237 {
238 let cache = self.cache.read().unwrap();
239 if let Some(node) = cache.get(&offset) {
240 return Ok(node.clone());
241 }
242 }
243
244 let mut writer_guard = self.file.write().unwrap();
245
246 writer_guard.seek(SeekFrom::Start(offset))?;
248
249 let file = writer_guard.get_mut();
251
252 let mut len_buf = [0u8; 4];
253 file.read_exact(&mut len_buf)?;
254 let len = u32::from_le_bytes(len_buf) as usize;
255
256 let mut buf = vec![0u8; len];
257 file.read_exact(&mut buf)?;
258
259 let disk_node: DiskNode<K> = postcard::from_bytes(&buf)
260 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
261
262 let node = Arc::new(Node::from_disk(disk_node));
263 self.cache.write().unwrap().insert(offset, node.clone());
264 Ok(node)
265 }
266
267 fn write_node(&self, node: &Node<K>) -> io::Result<NodeId> {
268 let disk_node = node.to_disk();
269 let data = postcard::to_extend(&disk_node, Vec::with_capacity(4096))
270 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
271
272 let node_total_len = (data.len() + 4) as u64;
274
275 let mut writer = self.file.write().unwrap();
276
277 let mut current_pos = writer.seek(SeekFrom::End(0))?;
279
280 if node_total_len <= PAGE_SIZE {
282 let offset_in_page = current_pos % PAGE_SIZE;
283 let space_remaining = PAGE_SIZE - offset_in_page;
284
285 if node_total_len > space_remaining {
286 let padding_len = space_remaining as usize;
288 let padding = vec![0u8; padding_len];
289 writer.write_all(&padding)?;
290
291 current_pos += space_remaining;
293 }
294 }
295
296 let start_offset = current_pos;
297
298 writer.write_all(&(data.len() as u32).to_le_bytes())?;
300 writer.write_all(&data)?;
302
303 Ok(start_offset)
304 }
305}
306
307#[derive(Serialize, Deserialize)]
308struct DiskNode<K> {
309 level: u32,
310 keys: Vec<K>,
311 children: Vec<(NodeId, Hash)>,
312 hash: Hash,
313}
314
315#[derive(Debug, Clone)]
316struct Node<K: MerkleKey> {
317 level: u32,
318 keys: Vec<Arc<K>>,
319 children: Vec<Link<K>>,
320 hash: Hash,
321}
322
323impl<K: MerkleKey> Node<K> {
324 fn empty(level: u32) -> Self {
325 let mut node = Self {
326 level,
327 keys: Vec::new(),
328 children: Vec::new(),
329 hash: [0u8; 32],
330 };
331 node.rehash();
332 node
333 }
334
335 fn to_disk(&self) -> DiskNode<K> {
336 let children_meta = self
337 .children
338 .iter()
339 .map(|c| match c {
340 Link::Disk { offset, hash } => (*offset, *hash),
341 Link::Loaded(_) => {
342 panic!("Cannot serialize a node with dirty children! Flush children first.")
343 }
344 })
345 .collect();
346
347 DiskNode {
348 level: self.level,
349 keys: self.keys.iter().map(|k| k.as_ref().clone()).collect(),
350 children: children_meta,
351 hash: self.hash,
352 }
353 }
354
355 fn from_disk(disk: DiskNode<K>) -> Self {
356 let children = disk
357 .children
358 .into_iter()
359 .map(|(offset, hash)| Link::Disk { offset, hash })
360 .collect();
361
362 let keys = disk.keys.into_iter().map(Arc::new).collect();
363
364 Self {
365 level: disk.level,
366 keys,
367 children,
368 hash: disk.hash,
369 }
370 }
371
372 fn calc_level(key: &K) -> u32 {
373 let mut h = blake3::Hasher::new();
374 h.update(&key.encode());
375 let hash = h.finalize();
376 let bytes = hash.as_bytes();
377 let mut level = 0;
378 for byte in bytes {
379 if *byte == 0 {
380 level += 2;
381 } else {
382 if *byte & 0xF0 == 0 {
383 level += 1;
384 }
385 break;
386 }
387 }
388 level
389 }
390
391 fn rehash(&mut self) {
392 if self.keys.is_empty() && self.children.is_empty() {
393 self.hash = [0u8; 32];
394 return;
395 }
396
397 let mut h = blake3::Hasher::new();
398 h.update(&self.level.to_le_bytes());
399 h.update(&(self.keys.len() as u64).to_le_bytes());
400
401 for (i, child) in self.children.iter().enumerate() {
402 h.update(&child.hash());
403 if i < self.keys.len() {
404 let k_bytes = self.keys[i].encode();
405 h.update(&(k_bytes.len() as u64).to_le_bytes());
406 h.update(&k_bytes);
407 }
408 }
409 self.hash = *h.finalize().as_bytes();
410 }
411
412 fn contains(&self, key: &K, store: &Store<K>) -> io::Result<bool> {
413 match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
414 Ok(_) => Ok(true),
415 Err(idx) => {
416 if self.children.is_empty() {
417 return Ok(false);
418 }
419 let child = match &self.children[idx] {
420 Link::Loaded(n) => n.clone(),
421 Link::Disk { offset, .. } => store.load_node(*offset)?,
422 };
423 child.contains(key, store)
424 }
425 }
426 }
427
428 fn put(&self, key: Arc<K>, key_level: u32, store: &Arc<Store<K>>) -> io::Result<Arc<Node<K>>> {
429 if key_level > self.level {
430 let (left_child, right_child) = self.split(&key, store)?;
431 let mut new_node = Node {
432 level: key_level,
433 keys: vec![key],
434 children: vec![Link::Loaded(left_child), Link::Loaded(right_child)],
435 hash: [0u8; 32],
436 };
437 new_node.rehash();
438 return Ok(Arc::new(new_node));
439 }
440
441 if key_level == self.level {
442 let mut new_node = self.clone();
443 match new_node
444 .keys
445 .binary_search_by(|probe| probe.as_ref().cmp(&key))
446 {
447 Ok(_) => return Ok(Arc::new(new_node)),
448 Err(idx) => {
449 let child_to_split = if !new_node.children.is_empty() {
450 match &new_node.children[idx] {
451 Link::Loaded(n) => n.clone(),
452 Link::Disk { offset, .. } => store.load_node(*offset)?,
453 }
454 } else {
455 Arc::new(Node::empty(self.level.saturating_sub(1)))
456 };
457
458 let (left_sub, right_sub) = child_to_split.split(&key, store)?;
459 new_node.keys.insert(idx, key);
460
461 if new_node.children.is_empty() {
462 new_node.children.push(Link::Loaded(left_sub));
463 new_node.children.push(Link::Loaded(right_sub));
464 } else {
465 new_node.children[idx] = Link::Loaded(left_sub);
466 new_node.children.insert(idx + 1, Link::Loaded(right_sub));
467 }
468 new_node.rehash();
469 return Ok(Arc::new(new_node));
470 }
471 }
472 }
473
474 if self.keys.is_empty() && self.children.is_empty() {
475 let mut new_node = Node {
476 level: key_level,
477 keys: vec![key],
478 children: vec![
479 Link::Loaded(Arc::new(Node::empty(0))),
480 Link::Loaded(Arc::new(Node::empty(0))),
481 ],
482 hash: [0u8; 32],
483 };
484 new_node.rehash();
485 return Ok(Arc::new(new_node));
486 }
487
488 let mut new_node = self.clone();
489 let idx = match new_node
490 .keys
491 .binary_search_by(|probe| probe.as_ref().cmp(&key))
492 {
493 Ok(_) => return Ok(Arc::new(new_node)),
494 Err(i) => i,
495 };
496
497 let child_node = match &new_node.children[idx] {
498 Link::Loaded(n) => n.clone(),
499 Link::Disk { offset, .. } => store.load_node(*offset)?,
500 };
501
502 let new_child = child_node.put(key, key_level, store)?;
503 new_node.children[idx] = Link::Loaded(new_child);
504 new_node.rehash();
505 Ok(Arc::new(new_node))
506 }
507
508 fn split(
509 &self,
510 split_key: &K,
511 store: &Arc<Store<K>>,
512 ) -> io::Result<(Arc<Node<K>>, Arc<Node<K>>)> {
513 if self.keys.is_empty() && self.children.is_empty() {
514 return Ok((
515 Arc::new(Node::empty(self.level)),
516 Arc::new(Node::empty(self.level)),
517 ));
518 }
519
520 let idx = match self
521 .keys
522 .binary_search_by(|probe| probe.as_ref().cmp(split_key))
523 {
524 Ok(i) => i,
525 Err(i) => i,
526 };
527
528 let left_keys = self.keys[..idx].to_vec();
529 let right_start = if idx < self.keys.len() && self.keys[idx].as_ref() == split_key {
530 idx + 1
531 } else {
532 idx
533 };
534 let right_keys = self.keys[right_start..].to_vec();
535
536 let (mid_left, mid_right) = if idx < self.children.len() {
537 let child = match &self.children[idx] {
538 Link::Loaded(n) => n.clone(),
539 Link::Disk { offset, .. } => store.load_node(*offset)?,
540 };
541 child.split(split_key, store)?
542 } else {
543 (Arc::new(Node::empty(0)), Arc::new(Node::empty(0)))
544 };
545
546 let mut left_children = self.children[..idx].to_vec();
547 left_children.push(Link::Loaded(mid_left));
548 let mut left_node = Node {
549 level: self.level,
550 keys: left_keys,
551 children: left_children,
552 hash: [0u8; 32],
553 };
554 left_node.rehash();
555
556 let mut right_children = vec![Link::Loaded(mid_right)];
557 if idx + 1 < self.children.len() {
558 right_children.extend_from_slice(&self.children[idx + 1..]);
559 }
560 let mut right_node = Node {
561 level: self.level,
562 keys: right_keys,
563 children: right_children,
564 hash: [0u8; 32],
565 };
566 right_node.rehash();
567
568 Ok((Arc::new(left_node), Arc::new(right_node)))
569 }
570
571 fn delete(&self, key: &K, store: &Arc<Store<K>>) -> io::Result<(Arc<Node<K>>, bool)> {
572 match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
573 Ok(idx) => {
574 let mut new_node = self.clone();
576 new_node.keys.remove(idx);
577
578 let left_child = new_node.children.remove(idx);
581 let right_child = new_node.children.remove(idx);
583
584 let merged_child = Node::merge(left_child, right_child, store)?;
586
587 new_node.children.insert(idx, merged_child);
589
590 new_node.rehash();
591 Ok((Arc::new(new_node), true))
592 }
593 Err(idx) => {
594 if self.children.is_empty() {
596 return Ok((Arc::new(self.clone()), false));
598 }
599
600 let child_link = &self.children[idx];
601 let child_node = match child_link {
602 Link::Loaded(n) => n.clone(),
603 Link::Disk { offset, .. } => store.load_node(*offset)?,
604 };
605
606 let (new_child, deleted) = child_node.delete(key, store)?;
607
608 if !deleted {
609 return Ok((Arc::new(self.clone()), false));
610 }
611
612 let mut new_node = self.clone();
613 new_node.children[idx] = Link::Loaded(new_child);
614 new_node.rehash();
615 Ok((Arc::new(new_node), true))
616 }
617 }
618 }
619
620 fn merge(left: Link<K>, right: Link<K>, store: &Arc<Store<K>>) -> io::Result<Link<K>> {
622 let left_node = match &left {
624 Link::Loaded(n) => n.clone(),
625 Link::Disk { offset, .. } => store.load_node(*offset)?,
626 };
627
628 let right_node = match &right {
629 Link::Loaded(n) => n.clone(),
630 Link::Disk { offset, .. } => store.load_node(*offset)?,
631 };
632
633 if left_node.keys.is_empty() && left_node.children.is_empty() {
635 return Ok(Link::Loaded(right_node));
636 }
637 if right_node.keys.is_empty() && right_node.children.is_empty() {
638 return Ok(Link::Loaded(left_node));
639 }
640
641 if left_node.level > right_node.level {
643 let mut new_left = (*left_node).clone();
644
645 let last_idx = new_left.children.len() - 1;
647 let last_child = new_left.children.remove(last_idx);
648
649 let merged = Node::merge(last_child, right, store)?;
650 new_left.children.push(merged);
651 new_left.rehash();
652
653 return Ok(Link::Loaded(Arc::new(new_left)));
654 }
655
656 if right_node.level > left_node.level {
658 let mut new_right = (*right_node).clone();
659
660 let first_child = new_right.children.remove(0);
662
663 let merged = Node::merge(left, first_child, store)?;
664 new_right.children.insert(0, merged);
665 new_right.rehash();
666
667 return Ok(Link::Loaded(Arc::new(new_right)));
668 }
669
670 let mut new_node = (*left_node).clone();
677 let mut right_clone = (*right_node).clone();
678
679 let left_boundary_child = new_node.children.pop().expect("Node should have children");
681 let right_boundary_child = right_clone.children.remove(0);
682
683 let merged_boundary = Node::merge(left_boundary_child, right_boundary_child, store)?;
685
686 new_node.keys.extend(right_clone.keys.into_iter());
688 new_node.children.push(merged_boundary);
689 new_node.children.extend(right_clone.children.into_iter());
690 new_node.rehash();
691
692 Ok(Link::Loaded(Arc::new(new_node)))
693 }
694}