file_mst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![feature(test)]
3
4extern crate test;
5
6#[cfg(test)]
7mod benches;
8#[cfg(test)]
9mod tests;
10
11use std::borrow::Cow;
12use std::collections::HashMap;
13use std::fs::{File, OpenOptions};
14use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
15use std::path::Path;
16use std::sync::{Arc, RwLock};
17
18use serde::{Deserialize, Serialize};
19
20const PAGE_SIZE: u64 = 4096;
21
22/// A trait for types that can serve as keys in a Merkle Search Tree.
23pub trait MerkleKey: Ord + Clone + std::fmt::Debug + Serialize + for<'a> Deserialize<'a> {
24    fn encode(&self) -> Cow<'_, [u8]>;
25}
26
27/// A trait for types that can serve as values. They must be cloneable and serializable.
28pub trait MerkleValue: Clone + std::fmt::Debug + Serialize + for<'a> Deserialize<'a> {}
29impl<T> MerkleValue for T where T: Clone + std::fmt::Debug + Serialize + for<'a> Deserialize<'a> {}
30
31impl MerkleKey for String {
32    #[inline]
33    fn encode(&self) -> Cow<'_, [u8]> {
34        self.as_bytes().into()
35    }
36}
37
38impl MerkleKey for Vec<u8> {
39    #[inline]
40    fn encode(&self) -> Cow<'_, [u8]> {
41        self.as_slice().into()
42    }
43}
44
45pub type Hash = [u8; 32];
46type NodeId = u64;
47
48pub struct MerkleSearchTree<K: MerkleKey, V: MerkleValue> {
49    root: Link<K, V>,
50    store: Arc<Store<K, V>>,
51}
52
53impl<K: MerkleKey, V: MerkleValue> MerkleSearchTree<K, V> {
54    /// Opens or creates a file-backed Merkle Search Tree at the given path.
55    pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
56        let store = Store::open(path)?;
57        Ok(Self {
58            root: Link::Loaded(Arc::new(Node::empty(0))),
59            store,
60        })
61    }
62
63    /// Creates a new MST backed by a temporary file.
64    pub fn new_temporary() -> io::Result<Self> {
65        let file = tempfile::tempfile()?;
66        let store = Store::new(file);
67
68        Ok(Self {
69            root: Link::Loaded(Arc::new(Node::empty(0))),
70            store,
71        })
72    }
73
74    /// Loads a tree from a known root offset and hash.
75    pub fn load_from_root<P: AsRef<Path>>(
76        path: P,
77        root_offset: u64,
78        root_hash: Hash,
79    ) -> io::Result<Self> {
80        let store = Store::open(path)?;
81        Ok(Self {
82            root: Link::Disk {
83                offset: root_offset,
84                hash: root_hash,
85            },
86            store,
87        })
88    }
89
90    /// Inserts a key-value pair into the tree, modifying it in-place.
91    pub fn insert(&mut self, key: K, value: V) -> io::Result<()> {
92        let key_arc = Arc::new(key);
93        let root_node = self.resolve_link(&self.root)?;
94
95        let target_level = Node::<K, V>::calc_level(key_arc.as_ref());
96        let new_root_node = root_node.put(key_arc, value, target_level, &self.store)?;
97
98        self.root = Link::Loaded(new_root_node);
99        Ok(())
100    }
101
102    /// Checks if a key exists in the tree.
103    pub fn contains(&self, key: &K) -> io::Result<bool> {
104        let root = self.resolve_link(&self.root)?;
105        root.contains(key, &self.store)
106    }
107
108    /// Retrieves a value by key. Returns None if the key does not exist.
109    pub fn get(&self, key: &K) -> io::Result<Option<V>> {
110        let root = self.resolve_link(&self.root)?;
111        root.get(key, &self.store)
112    }
113
114    /// Removes a key from the tree.
115    pub fn remove(&mut self, key: &K) -> io::Result<()> {
116        let root = self.resolve_link(&self.root)?;
117
118        let (new_root, deleted) = root.delete(key, &self.store)?;
119
120        if !deleted {
121            return Ok(());
122        }
123
124        if new_root.keys.is_empty() && !new_root.children.is_empty() {
125            self.root = new_root.children[0].clone();
126        } else {
127            self.root = Link::Loaded(new_root);
128        }
129
130        Ok(())
131    }
132
133    /// Persists any dirty nodes to disk and updates the root pointer.
134    pub fn flush(&mut self) -> io::Result<(u64, Hash)> {
135        let (offset, hash) = self.flush_recursive(&self.root)?;
136        self.store.flush()?;
137        self.root = Link::Disk { offset, hash };
138        Ok((offset, hash))
139    }
140
141    pub fn root_hash(&self) -> Hash {
142        self.root.hash()
143    }
144
145    fn resolve_link(&self, link: &Link<K, V>) -> io::Result<Arc<Node<K, V>>> {
146        match link {
147            Link::Loaded(node) => Ok(node.clone()),
148            Link::Disk { offset, .. } => self.store.load_node(*offset),
149        }
150    }
151
152    fn flush_recursive(&self, link: &Link<K, V>) -> io::Result<(NodeId, Hash)> {
153        match link {
154            Link::Disk { offset, hash } => Ok((*offset, *hash)),
155            Link::Loaded(node) => {
156                let mut dirty_children = false;
157                for child in &node.children {
158                    if let Link::Loaded(_) = child {
159                        dirty_children = true;
160                        break;
161                    }
162                }
163
164                if !dirty_children {
165                    let offset = self.store.write_node(node)?;
166                    return Ok((offset, node.hash));
167                }
168
169                let mut new_children = Vec::new();
170                for child in &node.children {
171                    let (child_offset, child_hash) = self.flush_recursive(child)?;
172                    new_children.push(Link::Disk {
173                        offset: child_offset,
174                        hash: child_hash,
175                    });
176                }
177
178                let mut new_node = (**node).clone();
179                new_node.children = new_children;
180                let offset = self.store.write_node(&new_node)?;
181                Ok((offset, new_node.hash))
182            }
183        }
184    }
185}
186
187#[derive(Debug, Clone)]
188enum Link<K: MerkleKey, V: MerkleValue> {
189    Disk { offset: NodeId, hash: Hash },
190    Loaded(Arc<Node<K, V>>),
191}
192
193impl<K: MerkleKey, V: MerkleValue> Link<K, V> {
194    fn hash(&self) -> Hash {
195        match self {
196            Link::Disk { hash, .. } => *hash,
197            Link::Loaded(node) => node.hash,
198        }
199    }
200}
201
202struct Store<K: MerkleKey, V: MerkleValue> {
203    file: RwLock<BufWriter<File>>,
204    cache: RwLock<HashMap<NodeId, Arc<Node<K, V>>>>,
205}
206
207impl<K: MerkleKey, V: MerkleValue> Store<K, V> {
208    fn new(file: File) -> Arc<Self> {
209        Arc::new(Self {
210            file: RwLock::new(BufWriter::with_capacity(64 * 1024, file)),
211            cache: RwLock::new(HashMap::new()),
212        })
213    }
214
215    fn open<P: AsRef<Path>>(path: P) -> io::Result<Arc<Self>> {
216        let file = OpenOptions::new()
217            .read(true)
218            .write(true)
219            .create(true)
220            .open(path)?;
221
222        Ok(Self::new(file))
223    }
224
225    fn flush(&self) -> io::Result<()> {
226        let mut writer = self.file.write().unwrap();
227        writer.flush()
228    }
229
230    fn load_node(&self, offset: NodeId) -> io::Result<Arc<Node<K, V>>> {
231        {
232            let cache = self.cache.read().unwrap();
233            if let Some(node) = cache.get(&offset) {
234                return Ok(node.clone());
235            }
236        }
237
238        let mut writer_guard = self.file.write().unwrap();
239        writer_guard.seek(SeekFrom::Start(offset))?;
240        let file = writer_guard.get_mut();
241
242        let mut len_buf = [0u8; 4];
243        file.read_exact(&mut len_buf)?;
244        let len = u32::from_le_bytes(len_buf) as usize;
245
246        let mut buf = vec![0u8; len];
247        file.read_exact(&mut buf)?;
248
249        let disk_node: DiskNode<K, V> = postcard::from_bytes(&buf)
250            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
251
252        let node = Arc::new(Node::from_disk(disk_node));
253        self.cache.write().unwrap().insert(offset, node.clone());
254        Ok(node)
255    }
256
257    fn write_node(&self, node: &Node<K, V>) -> io::Result<NodeId> {
258        let disk_node = node.to_disk();
259        let data = postcard::to_extend(&disk_node, Vec::with_capacity(4096))
260            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
261
262        let node_total_len = (data.len() + 4) as u64;
263        let mut writer = self.file.write().unwrap();
264        let mut current_pos = writer.seek(SeekFrom::End(0))?;
265
266        if node_total_len <= PAGE_SIZE {
267            let offset_in_page = current_pos % PAGE_SIZE;
268            let space_remaining = PAGE_SIZE - offset_in_page;
269
270            if node_total_len > space_remaining {
271                let padding_len = space_remaining as usize;
272                let padding = vec![0u8; padding_len];
273                writer.write_all(&padding)?;
274                current_pos += space_remaining;
275            }
276        }
277
278        let start_offset = current_pos;
279        writer.write_all(&(data.len() as u32).to_le_bytes())?;
280        writer.write_all(&data)?;
281
282        Ok(start_offset)
283    }
284}
285
286#[derive(Serialize, Deserialize)]
287struct DiskNode<K, V> {
288    level: u32,
289    keys: Vec<K>,
290    values: Vec<V>,
291    children: Vec<(NodeId, Hash)>,
292    hash: Hash,
293}
294
295#[derive(Debug, Clone)]
296struct Node<K: MerkleKey, V: MerkleValue> {
297    level: u32,
298    keys: Vec<Arc<K>>,
299    values: Vec<V>,
300    children: Vec<Link<K, V>>,
301    hash: Hash,
302}
303
304impl<K: MerkleKey, V: MerkleValue> Node<K, V> {
305    fn empty(level: u32) -> Self {
306        let mut node = Self {
307            level,
308            keys: Vec::new(),
309            values: Vec::new(),
310            children: Vec::new(),
311            hash: [0u8; 32],
312        };
313        node.rehash();
314        node
315    }
316
317    fn to_disk(&self) -> DiskNode<K, V> {
318        let children_meta = self
319            .children
320            .iter()
321            .map(|c| match c {
322                Link::Disk { offset, hash } => (*offset, *hash),
323                Link::Loaded(_) => {
324                    panic!("Cannot serialize a node with dirty children! Flush children first.")
325                }
326            })
327            .collect();
328
329        DiskNode {
330            level: self.level,
331            keys: self.keys.iter().map(|k| k.as_ref().clone()).collect(),
332            values: self.values.clone(),
333            children: children_meta,
334            hash: self.hash,
335        }
336    }
337
338    fn from_disk(disk: DiskNode<K, V>) -> Self {
339        let children = disk
340            .children
341            .into_iter()
342            .map(|(offset, hash)| Link::Disk { offset, hash })
343            .collect();
344
345        let keys = disk.keys.into_iter().map(Arc::new).collect();
346
347        Self {
348            level: disk.level,
349            keys,
350            values: disk.values,
351            children,
352            hash: disk.hash,
353        }
354    }
355
356    fn calc_level(key: &K) -> u32 {
357        let mut h = blake3::Hasher::new();
358        h.update(&key.encode());
359        let hash = h.finalize();
360        let bytes = hash.as_bytes();
361        let mut level = 0;
362        for byte in bytes {
363            if *byte == 0 {
364                level += 2;
365            } else {
366                if *byte & 0xF0 == 0 {
367                    level += 1;
368                }
369                break;
370            }
371        }
372        level
373    }
374
375    fn rehash(&mut self) {
376        if self.keys.is_empty() && self.children.is_empty() {
377            self.hash = [0u8; 32];
378            return;
379        }
380
381        let mut h = blake3::Hasher::new();
382        h.update(&self.level.to_le_bytes());
383        h.update(&(self.keys.len() as u64).to_le_bytes());
384
385        for (i, child) in self.children.iter().enumerate() {
386            h.update(&child.hash());
387            if i < self.keys.len() {
388                // Hash Key
389                let k_bytes = self.keys[i].encode();
390                h.update(&(k_bytes.len() as u64).to_le_bytes());
391                h.update(&k_bytes);
392
393                // Hash Value
394                let v_bytes = postcard::to_extend(&self.values[i], Vec::with_capacity(4096))
395                    .expect("Failed to serialize value for hashing");
396                h.update(&(v_bytes.len() as u64).to_le_bytes());
397                h.update(&v_bytes);
398            }
399        }
400        self.hash = *h.finalize().as_bytes();
401    }
402
403    fn contains(&self, key: &K, store: &Store<K, V>) -> io::Result<bool> {
404        match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
405            Ok(_) => Ok(true),
406            Err(idx) => {
407                if self.children.is_empty() {
408                    return Ok(false);
409                }
410                let child = match &self.children[idx] {
411                    Link::Loaded(n) => n.clone(),
412                    Link::Disk { offset, .. } => store.load_node(*offset)?,
413                };
414                child.contains(key, store)
415            }
416        }
417    }
418
419    fn get(&self, key: &K, store: &Store<K, V>) -> io::Result<Option<V>> {
420        match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
421            Ok(idx) => Ok(Some(self.values[idx].clone())),
422            Err(idx) => {
423                if self.children.is_empty() {
424                    return Ok(None);
425                }
426                let child = match &self.children[idx] {
427                    Link::Loaded(n) => n.clone(),
428                    Link::Disk { offset, .. } => store.load_node(*offset)?,
429                };
430                child.get(key, store)
431            }
432        }
433    }
434
435    fn put(
436        &self,
437        key: Arc<K>,
438        value: V,
439        key_level: u32,
440        store: &Arc<Store<K, V>>,
441    ) -> io::Result<Arc<Node<K, V>>> {
442        if key_level > self.level {
443            let (left_child, right_child) = self.split(&key, store)?;
444            let mut new_node = Node {
445                level: key_level,
446                keys: vec![key],
447                values: vec![value],
448                children: vec![Link::Loaded(left_child), Link::Loaded(right_child)],
449                hash: [0u8; 32],
450            };
451            new_node.rehash();
452            return Ok(Arc::new(new_node));
453        }
454
455        if key_level == self.level {
456            let mut new_node = self.clone();
457            match new_node
458                .keys
459                .binary_search_by(|probe| probe.as_ref().cmp(&key))
460            {
461                Ok(idx) => {
462                    // Update existing value
463                    new_node.values[idx] = value;
464                    new_node.rehash();
465                    return Ok(Arc::new(new_node));
466                }
467                Err(idx) => {
468                    let child_to_split = if !new_node.children.is_empty() {
469                        match &new_node.children[idx] {
470                            Link::Loaded(n) => n.clone(),
471                            Link::Disk { offset, .. } => store.load_node(*offset)?,
472                        }
473                    } else {
474                        Arc::new(Node::empty(self.level.saturating_sub(1)))
475                    };
476
477                    let (left_sub, right_sub) = child_to_split.split(&key, store)?;
478                    new_node.keys.insert(idx, key);
479                    new_node.values.insert(idx, value);
480
481                    if new_node.children.is_empty() {
482                        new_node.children.push(Link::Loaded(left_sub));
483                        new_node.children.push(Link::Loaded(right_sub));
484                    } else {
485                        new_node.children[idx] = Link::Loaded(left_sub);
486                        new_node.children.insert(idx + 1, Link::Loaded(right_sub));
487                    }
488                    new_node.rehash();
489                    return Ok(Arc::new(new_node));
490                }
491            }
492        }
493
494        if self.keys.is_empty() && self.children.is_empty() {
495            let mut new_node = Node {
496                level: key_level,
497                keys: vec![key],
498                values: vec![value],
499                children: vec![
500                    Link::Loaded(Arc::new(Node::empty(0))),
501                    Link::Loaded(Arc::new(Node::empty(0))),
502                ],
503                hash: [0u8; 32],
504            };
505            new_node.rehash();
506            return Ok(Arc::new(new_node));
507        }
508
509        let mut new_node = self.clone();
510        let idx = match new_node
511            .keys
512            .binary_search_by(|probe| probe.as_ref().cmp(&key))
513        {
514            Ok(i) => {
515                new_node.values[i] = value;
516                new_node.rehash();
517                return Ok(Arc::new(new_node));
518            }
519            Err(i) => i,
520        };
521
522        let child_node = match &new_node.children[idx] {
523            Link::Loaded(n) => n.clone(),
524            Link::Disk { offset, .. } => store.load_node(*offset)?,
525        };
526
527        let new_child = child_node.put(key, value, key_level, store)?;
528        new_node.children[idx] = Link::Loaded(new_child);
529        new_node.rehash();
530        Ok(Arc::new(new_node))
531    }
532
533    fn split(
534        &self,
535        split_key: &K,
536        store: &Arc<Store<K, V>>,
537    ) -> io::Result<(Arc<Node<K, V>>, Arc<Node<K, V>>)> {
538        if self.keys.is_empty() && self.children.is_empty() {
539            return Ok((
540                Arc::new(Node::empty(self.level)),
541                Arc::new(Node::empty(self.level)),
542            ));
543        }
544
545        let idx = match self
546            .keys
547            .binary_search_by(|probe| probe.as_ref().cmp(split_key))
548        {
549            Ok(i) => i,
550            Err(i) => i,
551        };
552
553        let left_keys = self.keys[..idx].to_vec();
554        let left_values = self.values[..idx].to_vec();
555
556        let right_start = if idx < self.keys.len() && self.keys[idx].as_ref() == split_key {
557            idx + 1
558        } else {
559            idx
560        };
561        let right_keys = self.keys[right_start..].to_vec();
562        let right_values = self.values[right_start..].to_vec();
563
564        let (mid_left, mid_right) = if idx < self.children.len() {
565            let child = match &self.children[idx] {
566                Link::Loaded(n) => n.clone(),
567                Link::Disk { offset, .. } => store.load_node(*offset)?,
568            };
569            child.split(split_key, store)?
570        } else {
571            (Arc::new(Node::empty(0)), Arc::new(Node::empty(0)))
572        };
573
574        let mut left_children = self.children[..idx].to_vec();
575        left_children.push(Link::Loaded(mid_left));
576        let mut left_node = Node {
577            level: self.level,
578            keys: left_keys,
579            values: left_values,
580            children: left_children,
581            hash: [0u8; 32],
582        };
583        left_node.rehash();
584
585        let mut right_children = vec![Link::Loaded(mid_right)];
586        if idx + 1 < self.children.len() {
587            right_children.extend_from_slice(&self.children[idx + 1..]);
588        }
589        let mut right_node = Node {
590            level: self.level,
591            keys: right_keys,
592            values: right_values,
593            children: right_children,
594            hash: [0u8; 32],
595        };
596        right_node.rehash();
597
598        Ok((Arc::new(left_node), Arc::new(right_node)))
599    }
600
601    fn delete(&self, key: &K, store: &Arc<Store<K, V>>) -> io::Result<(Arc<Node<K, V>>, bool)> {
602        match self.keys.binary_search_by(|probe| probe.as_ref().cmp(key)) {
603            Ok(idx) => {
604                let mut new_node = self.clone();
605                new_node.keys.remove(idx);
606                new_node.values.remove(idx);
607
608                let left_child = new_node.children.remove(idx);
609                let right_child = new_node.children.remove(idx);
610
611                let merged_child = Node::merge(left_child, right_child, store)?;
612
613                new_node.children.insert(idx, merged_child);
614
615                new_node.rehash();
616                Ok((Arc::new(new_node), true))
617            }
618            Err(idx) => {
619                if self.children.is_empty() {
620                    return Ok((Arc::new(self.clone()), false));
621                }
622
623                let child_link = &self.children[idx];
624                let child_node = match child_link {
625                    Link::Loaded(n) => n.clone(),
626                    Link::Disk { offset, .. } => store.load_node(*offset)?,
627                };
628
629                let (new_child, deleted) = child_node.delete(key, store)?;
630
631                if !deleted {
632                    return Ok((Arc::new(self.clone()), false));
633                }
634
635                let mut new_node = self.clone();
636                new_node.children[idx] = Link::Loaded(new_child);
637                new_node.rehash();
638                Ok((Arc::new(new_node), true))
639            }
640        }
641    }
642
643    fn merge(
644        left: Link<K, V>,
645        right: Link<K, V>,
646        store: &Arc<Store<K, V>>,
647    ) -> io::Result<Link<K, V>> {
648        let left_node = match &left {
649            Link::Loaded(n) => n.clone(),
650            Link::Disk { offset, .. } => store.load_node(*offset)?,
651        };
652
653        let right_node = match &right {
654            Link::Loaded(n) => n.clone(),
655            Link::Disk { offset, .. } => store.load_node(*offset)?,
656        };
657
658        if left_node.keys.is_empty() && left_node.children.is_empty() {
659            return Ok(Link::Loaded(right_node));
660        }
661        if right_node.keys.is_empty() && right_node.children.is_empty() {
662            return Ok(Link::Loaded(left_node));
663        }
664
665        if left_node.level > right_node.level {
666            let mut new_left = (*left_node).clone();
667            let last_idx = new_left.children.len() - 1;
668            let last_child = new_left.children.remove(last_idx);
669
670            let merged = Node::merge(last_child, right, store)?;
671            new_left.children.push(merged);
672            new_left.rehash();
673
674            return Ok(Link::Loaded(Arc::new(new_left)));
675        }
676
677        if right_node.level > left_node.level {
678            let mut new_right = (*right_node).clone();
679            let first_child = new_right.children.remove(0);
680
681            let merged = Node::merge(left, first_child, store)?;
682            new_right.children.insert(0, merged);
683            new_right.rehash();
684
685            return Ok(Link::Loaded(Arc::new(new_right)));
686        }
687
688        let mut new_node = (*left_node).clone();
689        let mut right_clone = (*right_node).clone();
690
691        let left_boundary_child = new_node.children.pop().expect("Node should have children");
692        let right_boundary_child = right_clone.children.remove(0);
693
694        let merged_boundary = Node::merge(left_boundary_child, right_boundary_child, store)?;
695
696        new_node.keys.extend(right_clone.keys.into_iter());
697        new_node.values.extend(right_clone.values.into_iter());
698        new_node.children.push(merged_boundary);
699        new_node.children.extend(right_clone.children.into_iter());
700        new_node.rehash();
701
702        Ok(Link::Loaded(Arc::new(new_node)))
703    }
704}