file_mst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![feature(test)]
3
4extern crate test;
5
6#[cfg(test)]
7mod benches;
8#[cfg(test)]
9mod tests;
10
11use std::borrow::Cow;
12use std::collections::HashMap;
13use std::fs::{File, OpenOptions};
14use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
15use std::path::Path;
16use std::sync::{Arc, RwLock};
17
18use serde::{Deserialize, Serialize};
19
20const PAGE_SIZE: u64 = 4096;
21
22/// A trait for types that can serve as keys in a Merkle Search Tree.
23pub trait MerkleKey: Ord + Clone + std::fmt::Debug + Serialize + for<'a> Deserialize<'a> {
24    fn encode(&self) -> Cow<'_, [u8]>;
25}
26
27impl MerkleKey for String {
28    #[inline]
29    fn encode(&self) -> Cow<'_, [u8]> {
30        self.as_bytes().into()
31    }
32}
33
34impl MerkleKey for Vec<u8> {
35    fn encode(&self) -> Cow<'_, [u8]> {
36        self.as_slice().into()
37    }
38}
39
40pub type Hash = [u8; 32];
41type NodeId = u64;
42
43pub struct MerkleSearchTree<K: MerkleKey> {
44    root: Link<K>,
45    store: Arc<Store<K>>,
46}
47
48impl<K: MerkleKey> MerkleSearchTree<K> {
49    /// Opens or creates a file-backed Merkle Search Tree at the given path.
50    pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
51        let store = Store::open(path)?;
52        Ok(Self {
53            root: Link::Loaded(Arc::new(Node::empty(0))),
54            store,
55        })
56    }
57
58    /// Creates a new MST backed by a temporary file.
59    /// The file is automatically deleted when the program exits or the tree is dropped.
60    pub fn new_temporary() -> io::Result<Self> {
61        // tempfile::tempfile() creates an anonymous file in the OS temp directory
62        let file = tempfile::tempfile()?;
63        let store = Store::new(file);
64
65        Ok(Self {
66            root: Link::Loaded(Arc::new(Node::empty(0))),
67            store,
68        })
69    }
70
71    /// Loads a tree from a known root offset and hash.
72    pub fn load_from_root<P: AsRef<Path>>(
73        path: P,
74        root_offset: u64,
75        root_hash: Hash,
76    ) -> io::Result<Self> {
77        let store = Store::open(path)?;
78        Ok(Self {
79            root: Link::Disk {
80                offset: root_offset,
81                hash: root_hash,
82            },
83            store,
84        })
85    }
86
87    /// Inserts a key into the tree, modifying it in-place.
88    pub fn insert(&mut self, key: K) -> io::Result<()> {
89        let key_arc = Arc::new(key);
90        // We only load the root if we really need to (lazy loading)
91        let root_node = self.resolve_link(&self.root)?;
92
93        let target_level = Node::calc_level(key_arc.as_ref());
94        let new_root_node = root_node.put(key_arc, target_level, &self.store)?;
95
96        // Update the root pointer
97        self.root = Link::Loaded(new_root_node);
98        Ok(())
99    }
100
101    /// Checks if a key exists in the tree.
102    pub fn contains(&self, key: &K) -> io::Result<bool> {
103        let root = self.resolve_link(&self.root)?;
104        root.contains(key, &self.store)
105    }
106
107    /// Removes a key from the tree.
108    pub fn remove(&mut self, key: &K) -> io::Result<()> {
109        let root = self.resolve_link(&self.root)?;
110
111        // Attempt recursive deletion
112        let (new_root, deleted) = root.delete(key, &self.store)?;
113
114        if !deleted {
115            return Ok(()); // Key not found, nothing changed
116        }
117
118        // Check if root needs collapsing (e.g., if we deleted the only key in the root)
119        if new_root.keys.is_empty() && !new_root.children.is_empty() {
120            // The root is empty but has children. In a valid MST/B-Tree,
121            // an empty node implies it has exactly one merged child.
122            // We promote that child to be the new root.
123            self.root = new_root.children[0].clone();
124        } else {
125            self.root = Link::Loaded(new_root);
126        }
127
128        Ok(())
129    }
130
131    /// Persists any dirty nodes to disk and updates the root to point to the disk location.
132    pub fn flush(&mut self) -> io::Result<(u64, Hash)> {
133        let (offset, hash) = self.flush_recursive(&self.root)?;
134
135        // Ensure all buffered writes are pushed to the underlying OS file
136        // before we return the offset.
137        self.store.flush()?;
138
139        self.root = Link::Disk { offset, hash };
140
141        Ok((offset, hash))
142    }
143
144    pub fn root_hash(&self) -> Hash {
145        self.root.hash()
146    }
147
148    fn resolve_link(&self, link: &Link<K>) -> io::Result<Arc<Node<K>>> {
149        match link {
150            Link::Loaded(node) => Ok(node.clone()),
151            Link::Disk { offset, .. } => self.store.load_node(*offset),
152        }
153    }
154
155    fn flush_recursive(&self, link: &Link<K>) -> io::Result<(NodeId, Hash)> {
156        match link {
157            Link::Disk { offset, hash } => Ok((*offset, *hash)),
158            Link::Loaded(node) => {
159                let mut dirty_children = false;
160                for child in &node.children {
161                    if let Link::Loaded(_) = child {
162                        dirty_children = true;
163                        break;
164                    }
165                }
166
167                if !dirty_children {
168                    let offset = self.store.write_node(node)?;
169                    return Ok((offset, node.hash));
170                }
171
172                let mut new_children = Vec::new();
173                for child in &node.children {
174                    let (child_offset, child_hash) = self.flush_recursive(child)?;
175                    new_children.push(Link::Disk {
176                        offset: child_offset,
177                        hash: child_hash,
178                    });
179                }
180
181                let mut new_node = (**node).clone();
182                new_node.children = new_children;
183                let offset = self.store.write_node(&new_node)?;
184                Ok((offset, new_node.hash))
185            }
186        }
187    }
188}
189
190#[derive(Debug, Clone)]
191enum Link<K: MerkleKey> {
192    Disk { offset: NodeId, hash: Hash },
193    Loaded(Arc<Node<K>>),
194}
195
196impl<K: MerkleKey> Link<K> {
197    fn hash(&self) -> Hash {
198        match self {
199            Link::Disk { hash, .. } => *hash,
200            Link::Loaded(node) => node.hash,
201        }
202    }
203}
204
205struct Store<K: MerkleKey> {
206    file: RwLock<BufWriter<File>>,
207    cache: RwLock<HashMap<NodeId, Arc<Node<K>>>>,
208}
209
210impl<K: MerkleKey> Store<K> {
211    /// Creates a store from an existing open file handle.
212    fn new(file: File) -> Arc<Self> {
213        Arc::new(Self {
214            // Use a generous buffer (64KB) to batch writes effectively
215            file: RwLock::new(BufWriter::with_capacity(64 * 1024, file)),
216            cache: RwLock::new(HashMap::new()),
217        })
218    }
219
220    fn open<P: AsRef<Path>>(path: P) -> io::Result<Arc<Self>> {
221        let file = OpenOptions::new()
222            .read(true)
223            .write(true)
224            .create(true)
225            .open(path)?;
226
227        Ok(Self::new(file))
228    }
229
230    /// Flushes the write buffer to the underlying file.
231    fn flush(&self) -> io::Result<()> {
232        let mut writer = self.file.write().unwrap();
233        writer.flush()
234    }
235
236    fn load_node(&self, offset: NodeId) -> io::Result<Arc<Node<K>>> {
237        {
238            let cache = self.cache.read().unwrap();
239            if let Some(node) = cache.get(&offset) {
240                return Ok(node.clone());
241            }
242        }
243
244        let mut writer_guard = self.file.write().unwrap();
245
246        // Seek to the location.
247        writer_guard.seek(SeekFrom::Start(offset))?;
248
249        // Access the underlying File to read.
250        let file = writer_guard.get_mut();
251
252        let mut len_buf = [0u8; 4];
253        file.read_exact(&mut len_buf)?;
254        let len = u32::from_le_bytes(len_buf) as usize;
255
256        let mut buf = vec![0u8; len];
257        file.read_exact(&mut buf)?;
258
259        let disk_node: DiskNode<K> = postcard::from_bytes(&buf)
260            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
261
262        let node = Arc::new(Node::from_disk(disk_node));
263        self.cache.write().unwrap().insert(offset, node.clone());
264        Ok(node)
265    }
266
267    fn write_node(&self, node: &Node<K>) -> io::Result<NodeId> {
268        let disk_node = node.to_disk();
269        let data = postcard::to_extend(&disk_node, Vec::with_capacity(4096))
270            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
271
272        // Total size = 4 bytes length header + data bytes
273        let node_total_len = (data.len() + 4) as u64;
274
275        let mut writer = self.file.write().unwrap();
276
277        // Get current end of file offset
278        let mut current_pos = writer.seek(SeekFrom::End(0))?;
279
280        // If the node fits in a page but straddles a boundary, insert padding.
281        if node_total_len <= PAGE_SIZE {
282            let offset_in_page = current_pos % PAGE_SIZE;
283            let space_remaining = PAGE_SIZE - offset_in_page;
284
285            if node_total_len > space_remaining {
286                // Pad with zeros to fill the rest of the current page
287                let padding_len = space_remaining as usize;
288                let padding = vec![0u8; padding_len];
289                writer.write_all(&padding)?;
290
291                // Update position to the start of the next page
292                current_pos += space_remaining;
293            }
294        }
295
296        let start_offset = current_pos;
297
298        // Write Length Header
299        writer.write_all(&(data.len() as u32).to_le_bytes())?;
300        // Write Data
301        writer.write_all(&data)?;
302
303        Ok(start_offset)
304    }
305}
306
307#[derive(Serialize, Deserialize)]
308struct DiskNode<K> {
309    level: u32,
310    keys: Vec<K>,
311    children: Vec<(NodeId, Hash)>,
312    hash: Hash,
313}
314
315#[derive(Debug, Clone)]
316struct Node<K: MerkleKey> {
317    level: u32,
318    keys: Vec<Arc<K>>,
319    children: Vec<Link<K>>,
320    hash: Hash,
321}
322
323impl<K: MerkleKey> Node<K> {
324    fn empty(level: u32) -> Self {
325        let mut node = Self {
326            level,
327            keys: Vec::new(),
328            children: Vec::new(),
329            hash: [0u8; 32],
330        };
331        node.rehash();
332        node
333    }
334
335    fn to_disk(&self) -> DiskNode<K> {
336        let children_meta = self
337            .children
338            .iter()
339            .map(|c| match c {
340                Link::Disk { offset, hash } => (*offset, *hash),
341                Link::Loaded(_) => {
342                    panic!("Cannot serialize a node with dirty children! Flush children first.")
343                }
344            })
345            .collect();
346
347        DiskNode {
348            level: self.level,
349            keys: self.keys.iter().map(|k| k.as_ref().clone()).collect(),
350            children: children_meta,
351            hash: self.hash,
352        }
353    }
354
355    fn from_disk(disk: DiskNode<K>) -> Self {
356        let children = disk
357            .children
358            .into_iter()
359            .map(|(offset, hash)| Link::Disk { offset, hash })
360            .collect();
361
362        let keys = disk.keys.into_iter().map(Arc::new).collect();
363
364        Self {
365            level: disk.level,
366            keys,
367            children,
368            hash: disk.hash,
369        }
370    }
371
372    fn calc_level(key: &K) -> u32 {
373        let mut h = blake3::Hasher::new();
374        h.update(&key.encode());
375        let hash = h.finalize();
376        let bytes = hash.as_bytes();
377        let mut level = 0;
378        for byte in bytes {
379            if *byte == 0 {
380                level += 2;
381            } else {
382                if *byte & 0xF0 == 0 {
383                    level += 1;
384                }
385                break;
386            }
387        }
388        level
389    }
390
391    fn rehash(&mut self) {
392        if self.keys.is_empty() && self.children.is_empty() {
393            self.hash = [0u8; 32];
394            return;
395        }
396
397        let mut h = blake3::Hasher::new();
398        h.update(&self.level.to_le_bytes());
399        h.update(&(self.keys.len() as u64).to_le_bytes());
400
401        for (i, child) in self.children.iter().enumerate() {
402            h.update(&child.hash());
403            if i < self.keys.len() {
404                let k_bytes = self.keys[i].encode();
405                h.update(&(k_bytes.len() as u64).to_le_bytes());
406                h.update(&k_bytes);
407            }
408        }
409        self.hash = *h.finalize().as_bytes();
410    }
411
412    fn contains(&self, key: &K, store: &Store<K>) -> io::Result<bool> {
413        match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
414            Ok(_) => Ok(true),
415            Err(idx) => {
416                if self.children.is_empty() {
417                    return Ok(false);
418                }
419                let child = match &self.children[idx] {
420                    Link::Loaded(n) => n.clone(),
421                    Link::Disk { offset, .. } => store.load_node(*offset)?,
422                };
423                child.contains(key, store)
424            }
425        }
426    }
427
428    fn put(&self, key: Arc<K>, key_level: u32, store: &Arc<Store<K>>) -> io::Result<Arc<Node<K>>> {
429        if key_level > self.level {
430            let (left_child, right_child) = self.split(&key, store)?;
431            let mut new_node = Node {
432                level: key_level,
433                keys: vec![key],
434                children: vec![Link::Loaded(left_child), Link::Loaded(right_child)],
435                hash: [0u8; 32],
436            };
437            new_node.rehash();
438            return Ok(Arc::new(new_node));
439        }
440
441        if key_level == self.level {
442            let mut new_node = self.clone();
443            match new_node
444                .keys
445                .binary_search_by(|probe| probe.as_ref().cmp(&key))
446            {
447                Ok(_) => return Ok(Arc::new(new_node)),
448                Err(idx) => {
449                    let child_to_split = if !new_node.children.is_empty() {
450                        match &new_node.children[idx] {
451                            Link::Loaded(n) => n.clone(),
452                            Link::Disk { offset, .. } => store.load_node(*offset)?,
453                        }
454                    } else {
455                        Arc::new(Node::empty(self.level.saturating_sub(1)))
456                    };
457
458                    let (left_sub, right_sub) = child_to_split.split(&key, store)?;
459                    new_node.keys.insert(idx, key);
460
461                    if new_node.children.is_empty() {
462                        new_node.children.push(Link::Loaded(left_sub));
463                        new_node.children.push(Link::Loaded(right_sub));
464                    } else {
465                        new_node.children[idx] = Link::Loaded(left_sub);
466                        new_node.children.insert(idx + 1, Link::Loaded(right_sub));
467                    }
468                    new_node.rehash();
469                    return Ok(Arc::new(new_node));
470                }
471            }
472        }
473
474        if self.keys.is_empty() && self.children.is_empty() {
475            let mut new_node = Node {
476                level: key_level,
477                keys: vec![key],
478                children: vec![
479                    Link::Loaded(Arc::new(Node::empty(0))),
480                    Link::Loaded(Arc::new(Node::empty(0))),
481                ],
482                hash: [0u8; 32],
483            };
484            new_node.rehash();
485            return Ok(Arc::new(new_node));
486        }
487
488        let mut new_node = self.clone();
489        let idx = match new_node
490            .keys
491            .binary_search_by(|probe| probe.as_ref().cmp(&key))
492        {
493            Ok(_) => return Ok(Arc::new(new_node)),
494            Err(i) => i,
495        };
496
497        let child_node = match &new_node.children[idx] {
498            Link::Loaded(n) => n.clone(),
499            Link::Disk { offset, .. } => store.load_node(*offset)?,
500        };
501
502        let new_child = child_node.put(key, key_level, store)?;
503        new_node.children[idx] = Link::Loaded(new_child);
504        new_node.rehash();
505        Ok(Arc::new(new_node))
506    }
507
508    fn split(
509        &self,
510        split_key: &K,
511        store: &Arc<Store<K>>,
512    ) -> io::Result<(Arc<Node<K>>, Arc<Node<K>>)> {
513        if self.keys.is_empty() && self.children.is_empty() {
514            return Ok((
515                Arc::new(Node::empty(self.level)),
516                Arc::new(Node::empty(self.level)),
517            ));
518        }
519
520        let idx = match self
521            .keys
522            .binary_search_by(|probe| probe.as_ref().cmp(split_key))
523        {
524            Ok(i) => i,
525            Err(i) => i,
526        };
527
528        let left_keys = self.keys[..idx].to_vec();
529        let right_start = if idx < self.keys.len() && self.keys[idx].as_ref() == split_key {
530            idx + 1
531        } else {
532            idx
533        };
534        let right_keys = self.keys[right_start..].to_vec();
535
536        let (mid_left, mid_right) = if idx < self.children.len() {
537            let child = match &self.children[idx] {
538                Link::Loaded(n) => n.clone(),
539                Link::Disk { offset, .. } => store.load_node(*offset)?,
540            };
541            child.split(split_key, store)?
542        } else {
543            (Arc::new(Node::empty(0)), Arc::new(Node::empty(0)))
544        };
545
546        let mut left_children = self.children[..idx].to_vec();
547        left_children.push(Link::Loaded(mid_left));
548        let mut left_node = Node {
549            level: self.level,
550            keys: left_keys,
551            children: left_children,
552            hash: [0u8; 32],
553        };
554        left_node.rehash();
555
556        let mut right_children = vec![Link::Loaded(mid_right)];
557        if idx + 1 < self.children.len() {
558            right_children.extend_from_slice(&self.children[idx + 1..]);
559        }
560        let mut right_node = Node {
561            level: self.level,
562            keys: right_keys,
563            children: right_children,
564            hash: [0u8; 32],
565        };
566        right_node.rehash();
567
568        Ok((Arc::new(left_node), Arc::new(right_node)))
569    }
570
571    fn delete(&self, key: &K, store: &Arc<Store<K>>) -> io::Result<(Arc<Node<K>>, bool)> {
572        match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
573            Ok(idx) => {
574                // Key found! Remove it.
575                let mut new_node = self.clone();
576                new_node.keys.remove(idx);
577
578                // We have removed a separator. We must merge the left and right children
579                // that this key previously separated.
580                let left_child = new_node.children.remove(idx);
581                // Note: After remove(idx), the element at idx is now the "right" child
582                let right_child = new_node.children.remove(idx);
583
584                // Merge the two disjoint subtrees
585                let merged_child = Node::merge(left_child, right_child, store)?;
586
587                // Insert the merged result back
588                new_node.children.insert(idx, merged_child);
589
590                new_node.rehash();
591                Ok((Arc::new(new_node), true))
592            }
593            Err(idx) => {
594                // Key not found in this node. Recurse into child.
595                if self.children.is_empty() {
596                    // Leaf node, key not found
597                    return Ok((Arc::new(self.clone()), false));
598                }
599
600                let child_link = &self.children[idx];
601                let child_node = match child_link {
602                    Link::Loaded(n) => n.clone(),
603                    Link::Disk { offset, .. } => store.load_node(*offset)?,
604                };
605
606                let (new_child, deleted) = child_node.delete(key, store)?;
607
608                if !deleted {
609                    return Ok((Arc::new(self.clone()), false));
610                }
611
612                let mut new_node = self.clone();
613                new_node.children[idx] = Link::Loaded(new_child);
614                new_node.rehash();
615                Ok((Arc::new(new_node), true))
616            }
617        }
618    }
619
620    /// Merges two disjoint subtrees (left keys < right keys) into a single link.
621    fn merge(left: Link<K>, right: Link<K>, store: &Arc<Store<K>>) -> io::Result<Link<K>> {
622        // Resolve both links to nodes
623        let left_node = match &left {
624            Link::Loaded(n) => n.clone(),
625            Link::Disk { offset, .. } => store.load_node(*offset)?,
626        };
627
628        let right_node = match &right {
629            Link::Loaded(n) => n.clone(),
630            Link::Disk { offset, .. } => store.load_node(*offset)?,
631        };
632
633        // Handle empty node cases to prevent panics on empty children access
634        if left_node.keys.is_empty() && left_node.children.is_empty() {
635            return Ok(Link::Loaded(right_node));
636        }
637        if right_node.keys.is_empty() && right_node.children.is_empty() {
638            return Ok(Link::Loaded(left_node));
639        }
640
641        // Case 1: Left is higher (Right belongs inside Left)
642        if left_node.level > right_node.level {
643            let mut new_left = (*left_node).clone();
644
645            // Should be the right-most child of left
646            let last_idx = new_left.children.len() - 1;
647            let last_child = new_left.children.remove(last_idx);
648
649            let merged = Node::merge(last_child, right, store)?;
650            new_left.children.push(merged);
651            new_left.rehash();
652
653            return Ok(Link::Loaded(Arc::new(new_left)));
654        }
655
656        // Case 2: Right is higher (Left belongs inside Right)
657        if right_node.level > left_node.level {
658            let mut new_right = (*right_node).clone();
659
660            // Should be the left-most child of right
661            let first_child = new_right.children.remove(0);
662
663            let merged = Node::merge(left, first_child, store)?;
664            new_right.children.insert(0, merged);
665            new_right.rehash();
666
667            return Ok(Link::Loaded(Arc::new(new_right)));
668        }
669
670        // Case 3: Levels are equal. Concatenate them.
671        // Since `left` keys are all strictly less than `right` keys, we connect them.
672        // However, we just removed the key separating `left` and `right`.
673        // This implies the right-most child of `left` and left-most child of `right`
674        // are now adjacent and must be merged to maintain "N keys -> N+1 children".
675
676        let mut new_node = (*left_node).clone();
677        let mut right_clone = (*right_node).clone();
678
679        // Pop boundary children
680        let left_boundary_child = new_node.children.pop().expect("Node should have children");
681        let right_boundary_child = right_clone.children.remove(0);
682
683        // Merge the boundary
684        let merged_boundary = Node::merge(left_boundary_child, right_boundary_child, store)?;
685
686        // Construct final node
687        new_node.keys.extend(right_clone.keys.into_iter());
688        new_node.children.push(merged_boundary);
689        new_node.children.extend(right_clone.children.into_iter());
690        new_node.rehash();
691
692        Ok(Link::Loaded(Arc::new(new_node)))
693    }
694}