rkyv/collections/btree/map/
mod.rs

1//! [`Archive`](crate::Archive) implementation for B-tree maps.
2
3use core::{
4    borrow::Borrow,
5    cmp::Ordering,
6    fmt,
7    marker::PhantomData,
8    mem::{size_of, MaybeUninit},
9    ops::{ControlFlow, Index},
10    ptr::addr_of_mut,
11    slice,
12};
13
14use munge::munge;
15use rancor::{fail, Fallible, Source};
16
17use crate::{
18    collections::util::IteratorLengthMismatch,
19    primitive::{ArchivedUsize, FixedUsize},
20    seal::Seal,
21    ser::{Allocator, Writer, WriterExt as _},
22    traits::NoUndef,
23    util::{InlineVec, SerVec},
24    Place, Portable, RelPtr, Serialize,
25};
26
27// TODO(#515): Get Iterator APIs working without the `alloc` feature enabled
28#[cfg(feature = "alloc")]
29mod iter;
30
31#[cfg(feature = "alloc")]
32pub use self::iter::*;
33
34// B-trees are typically characterized as having a branching factor of B.
35// However, in this implementation our B-trees are characterized as having a
36// number of entries per node E where E = B - 1. This is done because it's
37// easier to add an additional node pointer to each inner node than it is to
38// store one less entry per inner node. Because generic const exprs are not
39// stable, we can't declare a field `entries: [Entry; { B - 1 }]`. But we can
40// declare `branches: [RelPtr; E]` and then add another `last: RelPtr`
41// field. When the branching factor B is needed, it will be calculated as E + 1.
42
43const fn nodes_in_level<const E: usize>(i: u32) -> usize {
44    // The root of the tree has one node, and each level down has B times as
45    // many nodes at the last. Therefore, the number of nodes in the I-th level
46    // is equal to B^I.
47
48    (E + 1).pow(i)
49}
50
51const fn entries_in_full_tree<const E: usize>(h: u32) -> usize {
52    // The number of nodes in each layer I of a B-tree is equal to B^I. At layer
53    // I = 0, the number of nodes is exactly one. At layer I = 1, the number of
54    // nodes is B, at layer I = 2 the number of nodes is B^2, and so on. The
55    // total number of nodes is equal to the sum from 0 to H - 1 of B^I. Since
56    // this is the sum of a geometric progression, we have the closed-form
57    // solution N = (B^H - 1) / (B - 1). Since the number of entries per node is
58    // equal to B - 1, we thus have the solution that the number of entries in a
59    // B-tree of height H is equal to B^H - 1.
60
61    // Note that this is one less than the number of nodes in the level after
62    // the final level of the B-tree.
63
64    nodes_in_level::<E>(h) - 1
65}
66
67const fn entries_to_height<const E: usize>(n: usize) -> u32 {
68    // Solving B^H - 1 = N for H yields H = log_B(N + 1). However, we'll be
69    // using an integer logarithm, and so the value of H will be rounded down
70    // which underestimates the height of the tree:
71    // => H = ilog_B(N + 1) = floor(log_B(N + 1)).
72    // To compensate for this, we'll calculate the height for a tree with a
73    // greater number of nodes and choose this greater number so that rounding
74    // down will always yield the correct result.
75
76    // The minimum value which yields a height of H is exactly B^H - 1, so we
77    // need to add a large enough correction to always be greater than or equal
78    // to that value. The maximum value which yields a height of H is one less
79    // than the number of nodes in the next-largest B-tree, which is equal to
80    // B^(H + 1) - 1. This gives the following relationships for N:
81    // => B^(H - 1) - 1 < N <= B^H - 1
82    // And the desired relationships for the corrected number of entries C(N):
83    // => B^H - 1 <= C(N) < B^(H + 1) - 1
84
85    // First, we can add 1 to the two ends of our first set of relationships
86    // to change whether equality is allowed. We can do this because all entries
87    // are integers. This makes the relationships match the desired
88    // relationships for C(N):
89    // => B^(H - 1) - 1 + 1 <= N < B^H - 1 + 1
90    // => B^(H - 1) <= N < B^H
91    // Let's choose a function to map the lower bound for N to the desired lower
92    // bound for C(N):
93    // => C(B^(H - 1)) = B^(H - 1)
94    // A straightforward choice would be C(N) = B * N - 1. Substituting yields:
95    // => C(B^(H - 1)) <= C(N) < C(B^H)
96    // => B * B^(H - 1) - 1 <= B * N - 1 < B * B^H - 1
97    // => B^H - 1 <= B * N - 1 < B^(H + 1) - 1
98    // These exactly match the desired bounds, so this is the function we want.
99
100    // Putting it all together:
101    // => H = ilog_B(C(N) + 1) = ilog_b(B * N - 1 + 1) = ilog_b(B * N)
102    // => H = 1 + ilog_b(N)
103    1 + n.ilog(E + 1)
104}
105
106const fn ll_entries<const E: usize>(height: u32, n: usize) -> usize {
107    // The number of entries not in the last level is equal to the number of
108    // entries in a full B-tree of height H - 1. The number of entries in
109    // the last level is thus the total number of entries minus the number
110    // of entries not in the last level.
111    n - entries_in_full_tree::<E>(height - 1)
112}
113
114#[derive(Clone, Copy, Portable)]
115#[cfg_attr(feature = "bytecheck", derive(bytecheck::CheckBytes))]
116#[rkyv(crate)]
117#[repr(u8)]
118enum NodeKind {
119    Leaf,
120    Inner,
121}
122
123// SAFETY: `NodeKind` is `repr(u8)` and so always consists of a single
124// well-defined byte.
125unsafe impl NoUndef for NodeKind {}
126
127#[derive(Portable)]
128#[rkyv(crate)]
129#[repr(C)]
130struct Node<K, V, const E: usize> {
131    kind: NodeKind,
132    keys: [MaybeUninit<K>; E],
133    values: [MaybeUninit<V>; E],
134}
135
136#[derive(Portable)]
137#[rkyv(crate)]
138#[repr(C)]
139struct LeafNode<K, V, const E: usize> {
140    node: Node<K, V, E>,
141    len: ArchivedUsize,
142}
143
144#[cfg_attr(feature = "bytecheck", derive(bytecheck::CheckBytes))]
145#[derive(Portable)]
146#[rkyv(crate)]
147#[repr(C)]
148struct InnerNode<K, V, const E: usize> {
149    node: Node<K, V, E>,
150    lesser_nodes: [RelPtr<Node<K, V, E>>; E],
151    greater_node: RelPtr<Node<K, V, E>>,
152}
153
154/// An archived [`BTreeMap`](crate::alloc::collections::BTreeMap).
155#[cfg_attr(
156    feature = "bytecheck",
157    derive(bytecheck::CheckBytes),
158    bytecheck(verify)
159)]
160#[derive(Portable)]
161#[rkyv(crate)]
162#[repr(C)]
163pub struct ArchivedBTreeMap<K, V, const E: usize = 5> {
164    // The type of the root node is determined at runtime because it may point
165    // to:
166    // - Nothing if the length is zero
167    // - A leaf node if there is only one node
168    // - Or an inner node if there are multiple nodes
169    root: RelPtr<Node<K, V, E>>,
170    len: ArchivedUsize,
171    _phantom: PhantomData<(K, V)>,
172}
173
174impl<K, V, const E: usize> ArchivedBTreeMap<K, V, E> {
175    /// Returns whether the B-tree map contains the given key.
176    pub fn contains_key<Q>(&self, key: &Q) -> bool
177    where
178        Q: Ord + ?Sized,
179        K: Borrow<Q> + Ord,
180    {
181        self.get_key_value(key).is_some()
182    }
183
184    /// Returns the value associated with the given key, or `None` if the key is
185    /// not present in the B-tree map.
186    pub fn get<Q>(&self, key: &Q) -> Option<&V>
187    where
188        Q: Ord + ?Sized,
189        K: Borrow<Q> + Ord,
190    {
191        Some(self.get_key_value(key)?.1)
192    }
193
194    /// Returns the mutable value associated with the given key, or `None` if
195    /// the key is not present in the B-tree map.
196    pub fn get_seal<'a, Q>(this: Seal<'a, Self>, key: &Q) -> Option<Seal<'a, V>>
197    where
198        Q: Ord + ?Sized,
199        K: Borrow<Q> + Ord,
200    {
201        Some(Self::get_key_value_seal(this, key)?.1)
202    }
203
204    /// Returns true if the B-tree map contains no entries.
205    pub fn is_empty(&self) -> bool {
206        self.len() == 0
207    }
208
209    /// Returns the number of entries in the B-tree map.
210    pub fn len(&self) -> usize {
211        self.len.to_native() as usize
212    }
213
214    /// Gets the key-value pair associated with the given key, or `None` if the
215    /// key is not present in the B-tree map.
216    pub fn get_key_value<Q>(&self, key: &Q) -> Option<(&K, &V)>
217    where
218        Q: Ord + ?Sized,
219        K: Borrow<Q> + Ord,
220    {
221        let this = (self as *const Self).cast_mut();
222        Self::get_key_value_raw(this, key)
223            .map(|(k, v)| (unsafe { &*k }, unsafe { &*v }))
224    }
225
226    /// Gets the mutable key-value pair associated with the given key, or `None`
227    /// if the key is not present in the B-tree map.
228    pub fn get_key_value_seal<'a, Q>(
229        this: Seal<'a, Self>,
230        key: &Q,
231    ) -> Option<(&'a K, Seal<'a, V>)>
232    where
233        Q: Ord + ?Sized,
234        K: Borrow<Q> + Ord,
235    {
236        let this = unsafe { Seal::unseal_unchecked(this) as *mut Self };
237        Self::get_key_value_raw(this, key)
238            .map(|(k, v)| (unsafe { &*k }, Seal::new(unsafe { &mut *v })))
239    }
240
241    fn get_key_value_raw<Q>(
242        this: *mut Self,
243        key: &Q,
244    ) -> Option<(*mut K, *mut V)>
245    where
246        Q: Ord + ?Sized,
247        K: Borrow<Q> + Ord,
248    {
249        let len = unsafe { (*this).len.to_native() };
250        if len == 0 {
251            return None;
252        }
253
254        let root_ptr = unsafe { addr_of_mut!((*this).root) };
255        let mut current = unsafe { RelPtr::as_ptr_raw(root_ptr) };
256        'outer: loop {
257            let kind = unsafe { (*current).kind };
258
259            match kind {
260                NodeKind::Leaf => {
261                    let leaf = current.cast::<LeafNode<K, V, E>>();
262                    let len = unsafe { (*leaf).len };
263
264                    for i in 0..len.to_native() as usize {
265                        let k = unsafe {
266                            addr_of_mut!((*current).keys[i]).cast::<K>()
267                        };
268                        let ordering = key.cmp(unsafe { (*k).borrow() });
269
270                        match ordering {
271                            Ordering::Equal => {
272                                let v = unsafe {
273                                    addr_of_mut!((*current).values[i])
274                                        .cast::<V>()
275                                };
276                                return Some((k, v));
277                            }
278                            Ordering::Less => return None,
279                            Ordering::Greater => (),
280                        }
281                    }
282
283                    return None;
284                }
285                NodeKind::Inner => {
286                    let inner = current.cast::<InnerNode<K, V, E>>();
287
288                    for i in 0..E {
289                        let k = unsafe {
290                            addr_of_mut!((*current).keys[i]).cast::<K>()
291                        };
292                        let ordering = key.cmp(unsafe { (*k).borrow() });
293
294                        match ordering {
295                            Ordering::Equal => {
296                                let v = unsafe {
297                                    addr_of_mut!((*current).values[i])
298                                        .cast::<V>()
299                                };
300                                return Some((k, v));
301                            }
302                            Ordering::Less => {
303                                let lesser = unsafe {
304                                    addr_of_mut!((*inner).lesser_nodes[i])
305                                };
306                                let lesser_is_invalid =
307                                    unsafe { RelPtr::is_invalid_raw(lesser) };
308                                if !lesser_is_invalid {
309                                    current =
310                                        unsafe { RelPtr::as_ptr_raw(lesser) };
311                                    continue 'outer;
312                                } else {
313                                    return None;
314                                }
315                            }
316                            Ordering::Greater => (),
317                        }
318                    }
319
320                    let inner = current.cast::<InnerNode<K, V, E>>();
321                    let greater =
322                        unsafe { addr_of_mut!((*inner).greater_node) };
323                    let greater_is_invalid =
324                        unsafe { RelPtr::is_invalid_raw(greater) };
325                    if !greater_is_invalid {
326                        current = unsafe { RelPtr::as_ptr_raw(greater) };
327                    } else {
328                        return None;
329                    }
330                }
331            }
332        }
333    }
334
335    /// Resolves an `ArchivedBTreeMap` from the given length, resolver, and
336    /// output place.
337    pub fn resolve_from_len(
338        len: usize,
339        resolver: BTreeMapResolver,
340        out: Place<Self>,
341    ) {
342        munge!(let ArchivedBTreeMap { root, len: out_len, _phantom: _ } = out);
343
344        if len == 0 {
345            RelPtr::emplace_invalid(root);
346        } else {
347            RelPtr::emplace(resolver.root_node_pos as usize, root);
348        }
349
350        out_len.write(ArchivedUsize::from_native(len as FixedUsize));
351    }
352
353    /// Serializes an `ArchivedBTreeMap` from the given iterator and serializer.
354    pub fn serialize_from_ordered_iter<I, BKU, BVU, KU, VU, S>(
355        mut iter: I,
356        serializer: &mut S,
357    ) -> Result<BTreeMapResolver, S::Error>
358    where
359        I: ExactSizeIterator<Item = (BKU, BVU)>,
360        BKU: Borrow<KU>,
361        BVU: Borrow<VU>,
362        KU: Serialize<S, Archived = K>,
363        VU: Serialize<S, Archived = V>,
364        S: Fallible + Allocator + Writer + ?Sized,
365        S::Error: Source,
366    {
367        let len = iter.len();
368
369        if len == 0 {
370            let actual = iter.count();
371            if actual != 0 {
372                fail!(IteratorLengthMismatch {
373                    expected: 0,
374                    actual,
375                });
376            }
377            return Ok(BTreeMapResolver { root_node_pos: 0 });
378        }
379
380        let height = entries_to_height::<E>(len);
381        let ll_entries = ll_entries::<E>(height, len);
382
383        SerVec::with_capacity(
384            serializer,
385            height as usize - 1,
386            |open_inners, serializer| {
387                for _ in 0..height - 1 {
388                    open_inners
389                        .push(InlineVec::<(BKU, BVU, Option<usize>), E>::new());
390                }
391
392                let mut open_leaf = InlineVec::<(BKU, BVU), E>::new();
393
394                let mut child_node_pos = None;
395                let mut leaf_entries = 0;
396                while let Some((key, value)) = iter.next() {
397                    open_leaf.push((key, value));
398                    leaf_entries += 1;
399
400                    if leaf_entries == ll_entries
401                        || open_leaf.len() == open_leaf.capacity()
402                    {
403                        // Close open leaf
404                        child_node_pos =
405                            Some(Self::close_leaf(&open_leaf, serializer)?);
406                        open_leaf.clear();
407
408                        // If on the transition node, fill and close open inner
409                        if leaf_entries == ll_entries {
410                            if let Some(mut inner) = open_inners.pop() {
411                                while inner.len() < inner.capacity() {
412                                    if let Some((k, v)) = iter.next() {
413                                        inner.push((k, v, child_node_pos));
414                                        child_node_pos = None;
415                                    } else {
416                                        break;
417                                    }
418                                }
419
420                                child_node_pos = Some(Self::close_inner(
421                                    &inner,
422                                    child_node_pos,
423                                    serializer,
424                                )?);
425                            }
426                        }
427
428                        // Add closed node to open inner
429                        let mut popped = 0;
430                        while let Some(last_inner) = open_inners.last_mut() {
431                            if last_inner.len() == last_inner.capacity() {
432                                // Close open inner
433                                child_node_pos = Some(Self::close_inner(
434                                    last_inner,
435                                    child_node_pos,
436                                    serializer,
437                                )?);
438                                open_inners.pop();
439                                popped += 1;
440                            } else {
441                                let (key, value) = iter.next().unwrap();
442                                last_inner.push((key, value, child_node_pos));
443                                child_node_pos = None;
444                                for _ in 0..popped {
445                                    open_inners.push(InlineVec::default());
446                                }
447                                break;
448                            }
449                        }
450                    }
451                }
452
453                if !open_leaf.is_empty() {
454                    // Close open leaf
455                    child_node_pos =
456                        Some(Self::close_leaf(&open_leaf, serializer)?);
457                    open_leaf.clear();
458                }
459
460                // Close open inners
461                while let Some(inner) = open_inners.pop() {
462                    child_node_pos = Some(Self::close_inner(
463                        &inner,
464                        child_node_pos,
465                        serializer,
466                    )?);
467                }
468
469                debug_assert!(open_inners.is_empty());
470                debug_assert!(open_leaf.is_empty());
471
472                let leftovers = iter.count();
473                if leftovers != 0 {
474                    fail!(IteratorLengthMismatch {
475                        expected: len,
476                        actual: len + leftovers,
477                    });
478                }
479
480                Ok(BTreeMapResolver {
481                    root_node_pos: child_node_pos.unwrap() as FixedUsize,
482                })
483            },
484        )?
485    }
486
487    fn close_leaf<BKU, BVU, KU, VU, S>(
488        items: &[(BKU, BVU)],
489        serializer: &mut S,
490    ) -> Result<usize, S::Error>
491    where
492        BKU: Borrow<KU>,
493        BVU: Borrow<VU>,
494        KU: Serialize<S, Archived = K>,
495        VU: Serialize<S, Archived = V>,
496        S: Writer + Fallible + ?Sized,
497    {
498        let mut resolvers = InlineVec::<(KU::Resolver, VU::Resolver), E>::new();
499        for (key, value) in items {
500            resolvers.push((
501                key.borrow().serialize(serializer)?,
502                value.borrow().serialize(serializer)?,
503            ));
504        }
505
506        let pos = serializer.align_for::<LeafNode<K, V, E>>()?;
507        let mut node = MaybeUninit::<LeafNode<K, V, E>>::uninit();
508        // SAFETY: `node` is properly aligned and valid for writes of
509        // `size_of::<LeafNode<K, V, E>>()` bytes.
510        unsafe {
511            node.as_mut_ptr().write_bytes(0, 1);
512        }
513
514        let node_place =
515            unsafe { Place::new_unchecked(pos, node.as_mut_ptr()) };
516
517        munge! {
518            let LeafNode {
519                node: Node {
520                    kind,
521                    keys,
522                    values,
523                },
524                len,
525            } = node_place;
526        }
527        kind.write(NodeKind::Leaf);
528        len.write(ArchivedUsize::from_native(items.len() as FixedUsize));
529        for (i, ((k, v), (kr, vr))) in
530            items.iter().zip(resolvers.drain()).enumerate()
531        {
532            let out_key = unsafe { keys.index(i).cast_unchecked() };
533            k.borrow().resolve(kr, out_key);
534            let out_value = unsafe { values.index(i).cast_unchecked() };
535            v.borrow().resolve(vr, out_value);
536        }
537
538        let bytes = unsafe {
539            slice::from_raw_parts(
540                node.as_ptr().cast::<u8>(),
541                size_of::<LeafNode<K, V, E>>(),
542            )
543        };
544        serializer.write(bytes)?;
545
546        Ok(pos)
547    }
548
549    fn close_inner<BKU, BVU, KU, VU, S>(
550        items: &[(BKU, BVU, Option<usize>)],
551        greater_node_pos: Option<usize>,
552        serializer: &mut S,
553    ) -> Result<usize, S::Error>
554    where
555        BKU: Borrow<KU>,
556        BVU: Borrow<VU>,
557        KU: Serialize<S, Archived = K>,
558        VU: Serialize<S, Archived = V>,
559        S: Writer + Fallible + ?Sized,
560    {
561        debug_assert_eq!(items.len(), E);
562
563        let mut resolvers = InlineVec::<(KU::Resolver, VU::Resolver), E>::new();
564        for (key, value, _) in items {
565            resolvers.push((
566                key.borrow().serialize(serializer)?,
567                value.borrow().serialize(serializer)?,
568            ));
569        }
570
571        let pos = serializer.align_for::<InnerNode<K, V, E>>()?;
572        let mut node = MaybeUninit::<InnerNode<K, V, E>>::uninit();
573        // SAFETY: `node` is properly aligned and valid for writes of
574        // `size_of::<InnerNode<K, V, E>>()` bytes.
575        unsafe {
576            node.as_mut_ptr().write_bytes(0, 1);
577        }
578
579        let node_place =
580            unsafe { Place::new_unchecked(pos, node.as_mut_ptr()) };
581
582        munge! {
583            let InnerNode {
584                node: Node {
585                    kind,
586                    keys,
587                    values,
588                },
589                lesser_nodes,
590                greater_node,
591            } = node_place;
592        }
593
594        kind.write(NodeKind::Inner);
595        for (i, ((k, v, l), (kr, vr))) in
596            items.iter().zip(resolvers.drain()).enumerate()
597        {
598            let out_key = unsafe { keys.index(i).cast_unchecked() };
599            k.borrow().resolve(kr, out_key);
600            let out_value = unsafe { values.index(i).cast_unchecked() };
601            v.borrow().resolve(vr, out_value);
602
603            let out_lesser_node = unsafe { lesser_nodes.index(i) };
604            if let Some(lesser_node) = l {
605                RelPtr::emplace(*lesser_node, out_lesser_node);
606            } else {
607                RelPtr::emplace_invalid(out_lesser_node);
608            }
609        }
610
611        if let Some(greater_node_pos) = greater_node_pos {
612            RelPtr::emplace(greater_node_pos, greater_node);
613        } else {
614            RelPtr::emplace_invalid(greater_node);
615        }
616
617        let bytes = unsafe {
618            slice::from_raw_parts(
619                node.as_ptr().cast::<u8>(),
620                size_of::<InnerNode<K, V, E>>(),
621            )
622        };
623        serializer.write(bytes)?;
624
625        Ok(pos)
626    }
627
628    /// Visits every key-value pair in the B-tree with a function.
629    ///
630    /// If `f` returns `ControlFlow::Break`, `visit` will return `Some` with the
631    /// broken value. If `f` returns `Continue` for every pair in the tree,
632    /// `visit` will return `None`.
633    pub fn visit<T>(
634        &self,
635        mut f: impl FnMut(&K, &V) -> ControlFlow<T>,
636    ) -> Option<T> {
637        if self.is_empty() {
638            None
639        } else {
640            let root = &self.root;
641            let root_ptr = unsafe { root.as_ptr().cast::<Node<K, V, E>>() };
642            let mut call_inner = |k: *mut K, v: *mut V| unsafe { f(&*k, &*v) };
643            match Self::visit_raw(root_ptr.cast_mut(), &mut call_inner) {
644                ControlFlow::Continue(()) => None,
645                ControlFlow::Break(x) => Some(x),
646            }
647        }
648    }
649
650    /// Visits every mutable key-value pair in the B-tree with a function.
651    ///
652    /// If `f` returns `ControlFlow::Break`, `visit` will return `Some` with the
653    /// broken value. If `f` returns `Continue` for every pair in the tree,
654    /// `visit` will return `None`.
655    pub fn visit_seal<T>(
656        this: Seal<'_, Self>,
657        mut f: impl FnMut(&K, Seal<'_, V>) -> ControlFlow<T>,
658    ) -> Option<T> {
659        if this.is_empty() {
660            None
661        } else {
662            munge!(let Self { root, .. } = this);
663            let root_ptr =
664                unsafe { RelPtr::as_mut_ptr(root).cast::<Node<K, V, E>>() };
665            let mut call_inner =
666                |k: *mut K, v: *mut V| unsafe { f(&*k, Seal::new(&mut *v)) };
667            match Self::visit_raw(root_ptr, &mut call_inner) {
668                ControlFlow::Continue(()) => None,
669                ControlFlow::Break(x) => Some(x),
670            }
671        }
672    }
673
674    fn visit_raw<T>(
675        current: *mut Node<K, V, E>,
676        f: &mut impl FnMut(*mut K, *mut V) -> ControlFlow<T>,
677    ) -> ControlFlow<T> {
678        let kind = unsafe { (*current).kind };
679
680        match kind {
681            NodeKind::Leaf => {
682                let leaf = current.cast::<LeafNode<K, V, E>>();
683                let len = unsafe { (*leaf).len };
684                for i in 0..len.to_native() as usize {
685                    Self::visit_key_value_raw(current, i, f)?;
686                }
687            }
688            NodeKind::Inner => {
689                let inner = current.cast::<InnerNode<K, V, E>>();
690
691                // Visit lesser nodes and key-value pairs
692                for i in 0..E {
693                    let lesser =
694                        unsafe { addr_of_mut!((*inner).lesser_nodes[i]) };
695                    let lesser_is_invalid =
696                        unsafe { RelPtr::is_invalid_raw(lesser) };
697                    if !lesser_is_invalid {
698                        let lesser_ptr = unsafe { RelPtr::as_ptr_raw(lesser) };
699                        Self::visit_raw(lesser_ptr, f)?;
700                    }
701                    Self::visit_key_value_raw(current, i, f)?;
702                }
703
704                // Visit greater node
705                let greater = unsafe { addr_of_mut!((*inner).greater_node) };
706                let greater_is_invalid =
707                    unsafe { RelPtr::is_invalid_raw(greater) };
708                if !greater_is_invalid {
709                    let greater_ptr = unsafe {
710                        RelPtr::as_ptr_raw(greater).cast::<Node<K, V, E>>()
711                    };
712                    Self::visit_raw(greater_ptr, f)?;
713                }
714            }
715        }
716
717        ControlFlow::Continue(())
718    }
719
720    fn visit_key_value_raw<T>(
721        current: *mut Node<K, V, E>,
722        i: usize,
723        f: &mut impl FnMut(*mut K, *mut V) -> ControlFlow<T>,
724    ) -> ControlFlow<T> {
725        let key_ptr = unsafe { addr_of_mut!((*current).keys[i]).cast::<K>() };
726        let value_ptr =
727            unsafe { addr_of_mut!((*current).values[i]).cast::<V>() };
728        f(key_ptr, value_ptr)
729    }
730}
731
732impl<K, V, const E: usize> fmt::Debug for ArchivedBTreeMap<K, V, E>
733where
734    K: fmt::Debug,
735    V: fmt::Debug,
736{
737    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
738        let mut map = f.debug_map();
739        self.visit(|k, v| {
740            map.entry(k, v);
741            ControlFlow::<()>::Continue(())
742        });
743        map.finish()
744    }
745}
746
747// TODO(#515): ungate this impl
748#[cfg(feature = "alloc")]
749impl<K, V, const E: usize> Eq for ArchivedBTreeMap<K, V, E>
750where
751    K: PartialEq,
752    V: PartialEq,
753{
754}
755
756impl<K, V, Q, const E: usize> Index<&Q> for ArchivedBTreeMap<K, V, E>
757where
758    Q: Ord + ?Sized,
759    K: Borrow<Q> + Ord,
760{
761    type Output = V;
762
763    fn index(&self, key: &Q) -> &Self::Output {
764        self.get(key).unwrap()
765    }
766}
767
768// TODO(#515): ungate this impl
769#[cfg(feature = "alloc")]
770impl<K, V, const E1: usize, const E2: usize>
771    PartialEq<ArchivedBTreeMap<K, V, E2>> for ArchivedBTreeMap<K, V, E1>
772where
773    K: PartialEq,
774    V: PartialEq,
775{
776    fn eq(&self, other: &ArchivedBTreeMap<K, V, E2>) -> bool {
777        if self.len() != other.len() {
778            return false;
779        }
780        let mut i = other.iter();
781        self.visit(|lk, lv| {
782            let (rk, rv) = i.next().unwrap();
783            if lk != rk || lv != rv {
784                ControlFlow::Break(())
785            } else {
786                ControlFlow::Continue(())
787            }
788        })
789        .is_none()
790    }
791}
792
793impl<K: core::hash::Hash, V: core::hash::Hash, const E: usize> core::hash::Hash
794    for ArchivedBTreeMap<K, V, E>
795{
796    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
797        self.visit(|k, v| {
798            (*k).hash(state);
799            (*v).hash(state);
800            ControlFlow::<()>::Continue(())
801        });
802    }
803}
804
805/// The resolver for [`ArchivedBTreeMap`].
806pub struct BTreeMapResolver {
807    root_node_pos: FixedUsize,
808}
809
810#[cfg(feature = "bytecheck")]
811mod verify {
812    use core::{alloc::Layout, error::Error, fmt, ptr::addr_of};
813
814    use bytecheck::{CheckBytes, Verify};
815    use rancor::{fail, Fallible, Source};
816
817    use super::{ArchivedBTreeMap, InnerNode, Node};
818    use crate::{
819        collections::btree_map::{LeafNode, NodeKind},
820        validation::{ArchiveContext, ArchiveContextExt as _},
821        RelPtr,
822    };
823
824    #[derive(Debug)]
825    struct InvalidLength {
826        len: usize,
827        maximum: usize,
828    }
829
830    impl fmt::Display for InvalidLength {
831        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
832            write!(
833                f,
834                "Invalid length in B-tree node: len {} was greater than \
835                 maximum {}",
836                self.len, self.maximum
837            )
838        }
839    }
840
841    impl Error for InvalidLength {}
842
843    unsafe impl<C, K, V, const E: usize> Verify<C> for ArchivedBTreeMap<K, V, E>
844    where
845        C: Fallible + ArchiveContext + ?Sized,
846        C::Error: Source,
847        K: CheckBytes<C>,
848        V: CheckBytes<C>,
849    {
850        fn verify(&self, context: &mut C) -> Result<(), C::Error> {
851            let len = self.len();
852
853            if len == 0 {
854                return Ok(());
855            }
856
857            check_node_rel_ptr::<C, K, V, E>(&self.root, context)
858        }
859    }
860
861    fn check_node_rel_ptr<C, K, V, const E: usize>(
862        node_rel_ptr: &RelPtr<Node<K, V, E>>,
863        context: &mut C,
864    ) -> Result<(), C::Error>
865    where
866        C: Fallible + ArchiveContext + ?Sized,
867        C::Error: Source,
868        K: CheckBytes<C>,
869        V: CheckBytes<C>,
870    {
871        let node_ptr = node_rel_ptr.as_ptr_wrapping().cast::<Node<K, V, E>>();
872        context.check_subtree_ptr(
873            node_ptr.cast::<u8>(),
874            &Layout::new::<Node<K, V, E>>(),
875        )?;
876
877        // SAFETY: We checked to make sure that `node_ptr` is properly aligned
878        // and dereferenceable by calling `check_subtree_ptr`.
879        let kind_ptr = unsafe { addr_of!((*node_ptr).kind) };
880        // SAFETY: `kind_ptr` is a pointer to a subfield of `node_ptr` and so is
881        // also properly aligned and dereferenceable.
882        unsafe {
883            CheckBytes::check_bytes(kind_ptr, context)?;
884        }
885        // SAFETY: `kind_ptr` was always properly aligned and dereferenceable,
886        // and we just checked to make sure it pointed to a valid `NodeKind`.
887        let kind = unsafe { kind_ptr.read() };
888
889        match kind {
890            NodeKind::Leaf => {
891                // SAFETY:
892                // We checked to make sure that `node_ptr` is properly aligned,
893                // dereferenceable, and contained entirely within `context`'s
894                // buffer by calling `check_subtree_ptr`.
895                unsafe {
896                    check_leaf_node::<C, K, V, E>(node_ptr.cast(), context)?
897                }
898            }
899            NodeKind::Inner => {
900                // SAFETY:
901                // We checked to make sure that `node_ptr` is properly aligned
902                // and dereferenceable.
903                unsafe {
904                    check_inner_node::<C, K, V, E>(node_ptr.cast(), context)?
905                }
906            }
907        }
908
909        Ok(())
910    }
911
912    /// # Safety
913    ///
914    /// `node_ptr` must be properly aligned, dereferenceable, and contained
915    /// within `context`'s buffer.
916    unsafe fn check_leaf_node<C, K, V, const E: usize>(
917        node_ptr: *const LeafNode<K, V, E>,
918        context: &mut C,
919    ) -> Result<(), C::Error>
920    where
921        C: Fallible + ArchiveContext + ?Sized,
922        C::Error: Source,
923        K: CheckBytes<C>,
924        V: CheckBytes<C>,
925    {
926        context.in_subtree(node_ptr, |context| {
927            // SAFETY: We checked to make sure that `node_ptr` is properly
928            // aligned and dereferenceable by calling
929            // `check_subtree_ptr`.
930            let len_ptr = unsafe { addr_of!((*node_ptr).len) };
931            // SAFETY: `len_ptr` is a pointer to a subfield of `node_ptr` and so
932            // is also properly aligned and dereferenceable.
933            unsafe {
934                CheckBytes::check_bytes(len_ptr, context)?;
935            }
936            // SAFETY: `len_ptr` was always properly aligned and
937            // dereferenceable, and we just checked to make sure it
938            // pointed to a valid `ArchivedUsize`.
939            let len = unsafe { &*len_ptr };
940            let len = len.to_native() as usize;
941            if len > E {
942                fail!(InvalidLength { len, maximum: E });
943            }
944
945            // SAFETY: We checked that `node_ptr` is properly-aligned and
946            // dereferenceable.
947            let node_ptr = unsafe { addr_of!((*node_ptr).node) };
948            // SAFETY:
949            // - We checked that `node_ptr` is properly aligned and
950            //   dereferenceable.
951            // - We checked that `len` is less than or equal to `E`.
952            unsafe {
953                check_node_entries(node_ptr, len, context)?;
954            }
955
956            Ok(())
957        })
958    }
959
960    /// # Safety
961    ///
962    /// - `node_ptr` must point to a valid `Node<K, V, E>`.
963    /// - `len` must be less than or equal to `E`.
964    unsafe fn check_node_entries<C, K, V, const E: usize>(
965        node_ptr: *const Node<K, V, E>,
966        len: usize,
967        context: &mut C,
968    ) -> Result<(), C::Error>
969    where
970        C: Fallible + ArchiveContext + ?Sized,
971        C::Error: Source,
972        K: CheckBytes<C>,
973        V: CheckBytes<C>,
974    {
975        for i in 0..len {
976            // SAFETY: The caller has guaranteed that `node_ptr` is properly
977            // aligned and dereferenceable.
978            let key_ptr = unsafe { addr_of!((*node_ptr).keys[i]).cast::<K>() };
979            // SAFETY: The caller has guaranteed that `node_ptr` is properly
980            // aligned and dereferenceable.
981            let value_ptr =
982                unsafe { addr_of!((*node_ptr).values[i]).cast::<V>() };
983            unsafe {
984                K::check_bytes(key_ptr, context)?;
985            }
986            // SAFETY: `value_ptr` is a subfield of a node, and so is guaranteed
987            // to be properly aligned and point to enough bytes for a `V`.
988            unsafe {
989                V::check_bytes(value_ptr, context)?;
990            }
991        }
992
993        Ok(())
994    }
995
996    /// # Safety
997    ///
998    /// - `node_ptr` must be properly aligned and dereferenceable.
999    /// - `len` must be less than or equal to `E`.
1000    unsafe fn check_inner_node<C, K, V, const E: usize>(
1001        node_ptr: *const InnerNode<K, V, E>,
1002        context: &mut C,
1003    ) -> Result<(), C::Error>
1004    where
1005        C: Fallible + ArchiveContext + ?Sized,
1006        C::Error: Source,
1007        K: CheckBytes<C>,
1008        V: CheckBytes<C>,
1009    {
1010        context.in_subtree(node_ptr, |context| {
1011            for i in 0..E {
1012                // SAFETY: `in_subtree` guarantees that `node_ptr` is properly
1013                // aligned and dereferenceable.
1014                let lesser_node_ptr =
1015                    unsafe { addr_of!((*node_ptr).lesser_nodes[i]) };
1016                // SAFETY: `lesser_node_ptr` is a subfield of an inner node, and
1017                // so is guaranteed to be properly aligned and point to enough
1018                // bytes for a `RelPtr`.
1019                unsafe {
1020                    RelPtr::check_bytes(lesser_node_ptr, context)?;
1021                }
1022                // SAFETY: We just checked the `lesser_node_ptr` and it
1023                // succeeded, so it's safe to dereference.
1024                let lesser_node = unsafe { &*lesser_node_ptr };
1025                if !lesser_node.is_invalid() {
1026                    check_node_rel_ptr::<C, K, V, E>(lesser_node, context)?;
1027                }
1028            }
1029            // SAFETY: We checked that `node_ptr` is properly aligned and
1030            // dereferenceable.
1031            let greater_node_ptr =
1032                unsafe { addr_of!((*node_ptr).greater_node) };
1033            // SAFETY: `greater_node_ptr` is a subfield of an inner node, and so
1034            // is guaranteed to be properly aligned and point to enough bytes
1035            // for a `RelPtr`.
1036            unsafe {
1037                RelPtr::check_bytes(greater_node_ptr, context)?;
1038            }
1039            // SAFETY: We just checked the `greater_node_ptr` and it succeeded,
1040            // so it's safe to dereference.
1041            let greater_node = unsafe { &*greater_node_ptr };
1042            if !greater_node.is_invalid() {
1043                check_node_rel_ptr::<C, K, V, E>(greater_node, context)?;
1044            }
1045
1046            // SAFETY: We checked that `node_ptr` is properly aligned and
1047            // dereferenceable.
1048            let node_ptr = unsafe { addr_of!((*node_ptr).node) };
1049            // SAFETY:
1050            // - The caller has guaranteed that `node_ptr` points to a valid
1051            //   `Node<K, V, E>`.
1052            // - All inner nodes have `E` items, and `E` is less than or equal
1053            //   to `E`.
1054            unsafe {
1055                check_node_entries::<C, K, V, E>(node_ptr, E, context)?;
1056            }
1057
1058            Ok(())
1059        })
1060    }
1061}
1062
1063#[cfg(all(test, feature = "alloc"))]
1064mod tests {
1065    use core::hash::{Hash, Hasher};
1066
1067    use ahash::AHasher;
1068
1069    use crate::{
1070        alloc::{collections::BTreeMap, string::ToString},
1071        api::test::to_archived,
1072    };
1073
1074    #[test]
1075    fn test_hash() {
1076        let mut map = BTreeMap::new();
1077        map.insert("a".to_string(), 1);
1078        map.insert("b".to_string(), 2);
1079
1080        to_archived(&map, |archived_map| {
1081            let mut hasher = AHasher::default();
1082            archived_map.hash(&mut hasher);
1083            let hash_value = hasher.finish();
1084
1085            let mut expected_hasher = AHasher::default();
1086            for (k, v) in &map {
1087                k.hash(&mut expected_hasher);
1088                v.hash(&mut expected_hasher);
1089            }
1090            let expected_hash_value = expected_hasher.finish();
1091
1092            assert_eq!(hash_value, expected_hash_value);
1093        });
1094    }
1095}