imbl/nodes/
btree.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5use std::borrow::Borrow;
6use std::collections::VecDeque;
7use std::iter::FromIterator;
8use std::mem;
9use std::num::NonZeroUsize;
10use std::ops::{Bound, RangeBounds};
11
12use archery::{SharedPointer, SharedPointerKind};
13use imbl_sized_chunks::Chunk;
14
15pub(crate) use crate::config::ORD_CHUNK_SIZE as NODE_SIZE;
16
17const MEDIAN: usize = NODE_SIZE / 2;
18const THIRD: usize = NODE_SIZE / 3;
19const NUM_CHILDREN: usize = NODE_SIZE + 1;
20
21/// A node in a `B+Tree`.
22///
23/// The main tree representation uses [`Branch`] and [`Leaf`]; this is only used
24/// in places that want to handle either a branch or a leaf.
25#[derive(Debug)]
26pub(crate) enum Node<K, V, P: SharedPointerKind> {
27    Branch(SharedPointer<Branch<K, V, P>, P>),
28    Leaf(SharedPointer<Leaf<K, V>, P>),
29}
30
31impl<K: Ord + std::fmt::Debug, V: std::fmt::Debug, P: SharedPointerKind> Branch<K, V, P> {
32    #[cfg(any(test, fuzzing))]
33    pub(crate) fn check_sane(&self, is_root: bool) -> usize {
34        assert!(self.keys.len() >= if is_root { 1 } else { MEDIAN - 1 });
35        assert_eq!(self.keys.len() + 1, self.children.len());
36        assert!(self.keys.windows(2).all(|w| w[0] < w[1]));
37        match &self.children {
38            Children::Leaves { leaves } => {
39                for i in 0..self.keys.len() {
40                    let left = &leaves[i];
41                    let right = &leaves[i + 1];
42                    assert!(left.keys.last().unwrap().0 < right.keys.first().unwrap().0);
43                }
44                leaves.iter().map(|child| child.check_sane(false)).sum()
45            }
46            Children::Branches { branches, level } => {
47                for i in 0..self.keys.len() {
48                    let left = &branches[i];
49                    let right = &branches[i + 1];
50                    assert!(left.level() == level.get() - 1);
51                    assert!(right.level() == level.get() - 1);
52                }
53                branches.iter().map(|child| child.check_sane(false)).sum()
54            }
55        }
56    }
57}
58impl<K: Ord + std::fmt::Debug, V: std::fmt::Debug> Leaf<K, V> {
59    #[cfg(any(test, fuzzing))]
60    pub(crate) fn check_sane(&self, is_root: bool) -> usize {
61        assert!(self.keys.windows(2).all(|w| w[0].0 < w[1].0));
62        assert!(self.keys.len() >= if is_root { 0 } else { THIRD });
63        self.keys.len()
64    }
65}
66impl<K: Ord + std::fmt::Debug, V: std::fmt::Debug, P: SharedPointerKind> Node<K, V, P> {
67    /// Check invariants
68    #[cfg(any(test, fuzzing))]
69    pub(crate) fn check_sane(&self, is_root: bool) -> usize {
70        match self {
71            Node::Branch(branch) => branch.check_sane(is_root),
72            Node::Leaf(leaf) => leaf.check_sane(is_root),
73        }
74    }
75}
76
77impl<K, V, P: SharedPointerKind> Node<K, V, P> {
78    pub(crate) fn unit(key: K, value: V) -> Self {
79        Node::Leaf(SharedPointer::new(Leaf {
80            keys: Chunk::unit((key, value)),
81        }))
82    }
83
84    fn level(&self) -> usize {
85        match self {
86            Node::Branch(branch) => branch.level(),
87            Node::Leaf(_) => 0,
88        }
89    }
90
91    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
92        match (self, other) {
93            (Node::Branch(a), Node::Branch(b)) => SharedPointer::ptr_eq(a, b),
94            (Node::Leaf(a), Node::Leaf(b)) => SharedPointer::ptr_eq(a, b),
95            _ => false,
96        }
97    }
98}
99
100/// A branch node in a `B+Tree`.
101/// Invariants:
102/// * keys are ordered and unique
103/// * keys.len() + 1 == children.len()
104/// * all children have level = level - 1 (or level is 1 and all children are leaves)
105/// * all keys in the subtree at children[i] are between keys[i - 1] (if i > 0) and keys[i] (if i < keys.len()).
106/// * root branch must have at least 1 key, whereas non-root branches must have at least MEDIAN - 1 keys
107#[derive(Debug)]
108pub(crate) struct Branch<K, V, P: SharedPointerKind> {
109    keys: Chunk<K, NODE_SIZE>,
110    children: Children<K, V, P>,
111}
112
113#[derive(Debug)]
114pub(crate) enum Children<K, V, P: SharedPointerKind> {
115    /// implicitly level 1
116    Leaves {
117        leaves: Chunk<SharedPointer<Leaf<K, V>, P>, NUM_CHILDREN>,
118    },
119    /// level >= 2
120    Branches {
121        branches: Chunk<SharedPointer<Branch<K, V, P>, P>, NUM_CHILDREN>,
122        /// The level of the tree node that contains these children.
123        ///
124        /// Leaves have level zero, so branches have level at least one. Since this is the
125        /// level of something containing branches, it is at least two.
126        level: NonZeroUsize,
127    },
128}
129
130impl<K, V, P: SharedPointerKind> Children<K, V, P> {
131    fn len(&self) -> usize {
132        match self {
133            Children::Leaves { leaves } => leaves.len(),
134            Children::Branches { branches, .. } => branches.len(),
135        }
136    }
137    fn drain_from_front(&mut self, other: &mut Self, count: usize) {
138        match (self, other) {
139            (
140                Children::Leaves { leaves },
141                Children::Leaves {
142                    leaves: other_leaves,
143                },
144            ) => leaves.drain_from_front(other_leaves, count),
145            (
146                Children::Branches { branches, .. },
147                Children::Branches {
148                    branches: other_branches,
149                    ..
150                },
151            ) => branches.drain_from_front(other_branches, count),
152            _ => panic!("mismatched drain_from_front"),
153        }
154    }
155    fn drain_from_back(&mut self, other: &mut Self, count: usize) {
156        match (self, other) {
157            (
158                Children::Leaves { leaves },
159                Children::Leaves {
160                    leaves: other_leaves,
161                },
162            ) => leaves.drain_from_back(other_leaves, count),
163            (
164                Children::Branches { branches, .. },
165                Children::Branches {
166                    branches: other_branches,
167                    ..
168                },
169            ) => branches.drain_from_back(other_branches, count),
170            _ => panic!("mismatched drain_from_back"),
171        }
172    }
173    fn extend(&mut self, other: &Self) {
174        match (self, other) {
175            (
176                Children::Leaves { leaves },
177                Children::Leaves {
178                    leaves: other_leaves,
179                },
180            ) => leaves.extend(other_leaves.iter().cloned()),
181            (
182                Children::Branches { branches, .. },
183                Children::Branches {
184                    branches: other_branches,
185                    ..
186                },
187            ) => branches.extend(other_branches.iter().cloned()),
188            _ => panic!("mismatched extend"),
189        }
190    }
191    fn insert_front(&mut self, other: &Self) {
192        match (self, other) {
193            (
194                Children::Leaves { leaves },
195                Children::Leaves {
196                    leaves: other_leaves,
197                },
198            ) => leaves.insert_from(0, other_leaves.iter().cloned()),
199            (
200                Children::Branches { branches, .. },
201                Children::Branches {
202                    branches: other_branches,
203                    ..
204                },
205            ) => branches.insert_from(0, other_branches.iter().cloned()),
206            _ => panic!("mismatched insert_front"),
207        }
208    }
209    fn insert(&mut self, index: usize, node: Node<K, V, P>) {
210        match (self, node) {
211            (Children::Leaves { leaves }, Node::Leaf(node)) => leaves.insert(index, node),
212            (Children::Branches { branches, .. }, Node::Branch(node)) => {
213                branches.insert(index, node)
214            }
215            _ => panic!("mismatched insert"),
216        }
217    }
218    fn split_off(&mut self, at: usize) -> Self {
219        match self {
220            Children::Leaves { leaves } => Children::Leaves {
221                leaves: leaves.split_off(at),
222            },
223            Children::Branches { branches, level } => Children::Branches {
224                branches: branches.split_off(at),
225                level: *level,
226            },
227        }
228    }
229}
230
231impl<K, V, P: SharedPointerKind> Branch<K, V, P> {
232    pub(crate) fn pop_single_child(&mut self) -> Option<Node<K, V, P>> {
233        if self.children.len() == 1 {
234            debug_assert_eq!(self.keys.len(), 0);
235            Some(match &mut self.children {
236                Children::Leaves { leaves } => Node::Leaf(leaves.pop_back()),
237                Children::Branches { branches, .. } => Node::Branch(branches.pop_back()),
238            })
239        } else {
240            None
241        }
242    }
243
244    fn level(&self) -> usize {
245        match &self.children {
246            Children::Leaves { .. } => 1,
247            Children::Branches { level, .. } => level.get(),
248        }
249    }
250}
251
252/// A leaf node in a `B+Tree`.
253///
254/// Invariants:
255/// * keys are ordered and unique
256/// * leaf is the lowest level in the tree (level 0)
257/// * non-root leaves must have at least THIRD keys
258#[derive(Debug)]
259pub(crate) struct Leaf<K, V> {
260    keys: Chunk<(K, V), NODE_SIZE>,
261}
262
263impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Node<K, V, P> {
264    /// Removes a key from the node or its children.
265    /// Returns `true` if the node is underflowed and should be rebalanced.
266    pub(crate) fn remove<BK>(&mut self, key: &BK, removed: &mut Option<(K, V)>) -> bool
267    where
268        BK: Ord + ?Sized,
269        K: Borrow<BK>,
270    {
271        match self {
272            Node::Branch(branch) => SharedPointer::make_mut(branch).remove(key, removed),
273            Node::Leaf(leaf) => SharedPointer::make_mut(leaf).remove(key, removed),
274        }
275    }
276}
277
278impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Branch<K, V, P> {
279    pub(crate) fn remove<BK>(&mut self, key: &BK, removed: &mut Option<(K, V)>) -> bool
280    where
281        BK: Ord + ?Sized,
282        K: Borrow<BK>,
283    {
284        let i = self
285            .keys
286            .binary_search_by(|k| k.borrow().cmp(key))
287            .map(|x| x + 1)
288            .unwrap_or_else(|x| x);
289        let rebalance = match &mut self.children {
290            Children::Leaves { leaves } => {
291                SharedPointer::make_mut(&mut leaves[i]).remove(key, removed)
292            }
293            Children::Branches { branches, .. } => {
294                SharedPointer::make_mut(&mut branches[i]).remove(key, removed)
295            }
296        };
297        if rebalance {
298            self.branch_rebalance_children(i);
299        }
300        // Underflow if the branch is < 1/2 full. Since the branches are relatively
301        // rarely rebalanced (given relaxed leaf underflow), we can afford to be
302        // a bit more conservative here.
303        self.keys.len() < MEDIAN
304    }
305}
306
307impl<K: Ord + Clone, V: Clone> Leaf<K, V> {
308    pub(crate) fn remove<BK>(&mut self, key: &BK, removed: &mut Option<(K, V)>) -> bool
309    where
310        BK: Ord + ?Sized,
311        K: Borrow<BK>,
312    {
313        if let Ok(i) = self.keys.binary_search_by(|(k, _)| k.borrow().cmp(key)) {
314            *removed = Some(self.keys.remove(i));
315        }
316        // Underflow if the leaf is < 1/3 full. This relaxed underflow (vs. 1/2 full) is
317        // useful to prevent degenerate cases where a random insert/remove workload will
318        // constantly merge/split a leaf.
319        self.keys.len() < THIRD
320    }
321}
322
323impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Branch<K, V, P> {
324    #[cold]
325    pub(crate) fn branch_rebalance_children(&mut self, underflow_idx: usize) {
326        let left_idx = underflow_idx.saturating_sub(1);
327        match &mut self.children {
328            Children::Leaves { leaves } => {
329                let (left, mid, right) = match &leaves[left_idx..] {
330                    [left, mid, right, ..] => (&**left, &**mid, Some(&**right)),
331                    [left, mid, ..] => (&**left, &**mid, None),
332                    _ => return,
333                };
334                // Prefer merging two sibling children if we can fit them into a single node.
335                // But also try to rebalance if the smallest child is small (< 1/3), to amortize the cost of rebalancing.
336                // Since we prefer merging, for rebalancing to apply the the largest child will be least 2/3 full,
337                // which results in two at least half full nodes after rebalancing.
338                match (left, mid, right) {
339                    (left, mid, _) if left.keys.len() + mid.keys.len() <= NODE_SIZE => {
340                        Self::merge_leaves(leaves, &mut self.keys, left_idx, false);
341                    }
342                    (_, mid, Some(right)) if mid.keys.len() + right.keys.len() <= NODE_SIZE => {
343                        Self::merge_leaves(leaves, &mut self.keys, left_idx + 1, true);
344                    }
345                    (left, mid, _) if mid.keys.len().min(left.keys.len()) < THIRD => {
346                        Self::rebalance_leaves(leaves, &mut self.keys, left_idx);
347                    }
348                    (_, mid, Some(right)) if mid.keys.len().min(right.keys.len()) < THIRD => {
349                        Self::rebalance_leaves(leaves, &mut self.keys, left_idx + 1);
350                    }
351                    _ => (),
352                }
353            }
354            Children::Branches { branches, .. } => {
355                let (left, mid, right) = match &branches[left_idx..] {
356                    [left, mid, right, ..] => (&**left, &**mid, Some(&**right)),
357                    [left, mid, ..] => (&**left, &**mid, None),
358                    _ => return,
359                };
360                match (left, mid, right) {
361                    (left, mid, _) if left.keys.len() + mid.keys.len() < NODE_SIZE => {
362                        Self::merge_branches(branches, &mut self.keys, left_idx, false);
363                    }
364                    (_, mid, Some(right)) if mid.keys.len() + right.keys.len() < NODE_SIZE => {
365                        Self::merge_branches(branches, &mut self.keys, left_idx + 1, true);
366                    }
367                    (left, mid, _) if mid.keys.len().min(left.keys.len()) < THIRD => {
368                        Self::rebalance_branches(branches, &mut self.keys, left_idx);
369                    }
370                    (_, mid, Some(right)) if mid.keys.len().min(right.keys.len()) < THIRD => {
371                        Self::rebalance_branches(branches, &mut self.keys, left_idx + 1);
372                    }
373                    _ => (),
374                }
375            }
376        }
377    }
378
379    /// Merges two children leaves of this branch.
380    ///
381    /// Assumes that the two children can fit in a single leaf, panicking if not.
382    fn merge_leaves(
383        children: &mut Chunk<SharedPointer<Leaf<K, V>, P>, NUM_CHILDREN>,
384        keys: &mut Chunk<K, NODE_SIZE>,
385        left_idx: usize,
386        keep_left: bool,
387    ) {
388        let [left, right, ..] = &mut children[left_idx..] else {
389            unreachable!()
390        };
391        if keep_left {
392            let left = SharedPointer::make_mut(left);
393            let (left, right) = (left, &**right);
394            left.keys.extend(right.keys.iter().cloned());
395        } else {
396            let right = SharedPointer::make_mut(right);
397            let (left, right) = (&**left, right);
398            right.keys.insert_from(0, left.keys.iter().cloned());
399        }
400        keys.remove(left_idx);
401        children.remove(left_idx + (keep_left as usize));
402        debug_assert_eq!(keys.len() + 1, children.len());
403    }
404
405    /// Rebalances two adjacent leaves so that they have the same
406    /// number of keys (or differ by at most 1).
407    fn rebalance_leaves(
408        children: &mut Chunk<SharedPointer<Leaf<K, V>, P>, NUM_CHILDREN>,
409        keys: &mut Chunk<K, NODE_SIZE>,
410        left_idx: usize,
411    ) {
412        let [left, right, ..] = &mut children[left_idx..] else {
413            unreachable!()
414        };
415        let (left, right) = (
416            SharedPointer::make_mut(left),
417            SharedPointer::make_mut(right),
418        );
419        let num_to_move = left.keys.len().abs_diff(right.keys.len()) / 2;
420        if num_to_move == 0 {
421            return;
422        }
423        if left.keys.len() > right.keys.len() {
424            right.keys.drain_from_back(&mut left.keys, num_to_move);
425        } else {
426            left.keys.drain_from_front(&mut right.keys, num_to_move);
427        }
428        keys[left_idx] = right.keys.first().unwrap().0.clone();
429        debug_assert_ne!(left.keys.len(), 0);
430        debug_assert_ne!(right.keys.len(), 0);
431    }
432
433    /// Rebalances two adjacent child branches so that they have the same number of keys
434    /// (or differ by at most 1). The separator key is rotated between the two branches.
435    /// to keep the invariants of the parent branch.
436    fn rebalance_branches(
437        children: &mut Chunk<SharedPointer<Branch<K, V, P>, P>, NUM_CHILDREN>,
438        keys: &mut Chunk<K, NODE_SIZE>,
439        left_idx: usize,
440    ) {
441        let [left, right, ..] = &mut children[left_idx..] else {
442            unreachable!()
443        };
444        let (left, right) = (
445            SharedPointer::make_mut(left),
446            SharedPointer::make_mut(right),
447        );
448        let num_to_move = left.keys.len().abs_diff(right.keys.len()) / 2;
449        if num_to_move == 0 {
450            return;
451        }
452        let separator = &mut keys[left_idx];
453        if left.keys.len() > right.keys.len() {
454            right.keys.push_front(separator.clone());
455            right.keys.drain_from_back(&mut left.keys, num_to_move - 1);
456            *separator = left.keys.pop_back();
457            right
458                .children
459                .drain_from_back(&mut left.children, num_to_move);
460        } else {
461            left.keys.push_back(separator.clone());
462            left.keys.drain_from_front(&mut right.keys, num_to_move - 1);
463            *separator = right.keys.pop_front();
464            left.children
465                .drain_from_front(&mut right.children, num_to_move);
466        }
467        debug_assert_ne!(left.keys.len(), 0);
468        debug_assert_eq!(left.children.len(), left.keys.len() + 1);
469        debug_assert_ne!(right.keys.len(), 0);
470        debug_assert_eq!(right.children.len(), right.keys.len() + 1);
471    }
472
473    /// Merges two children of this branch.
474    ///
475    /// Assumes that the two children can fit in a single branch, panicking if not.
476    fn merge_branches(
477        children: &mut Chunk<SharedPointer<Branch<K, V, P>, P>, NUM_CHILDREN>,
478        keys: &mut Chunk<K, NODE_SIZE>,
479        left_idx: usize,
480        keep_left: bool,
481    ) {
482        let [left, right, ..] = &mut children[left_idx..] else {
483            unreachable!()
484        };
485        let separator = keys.remove(left_idx);
486        if keep_left {
487            let left = SharedPointer::make_mut(left);
488            let (left, right) = (left, &**right);
489            left.keys.push_back(separator);
490            left.keys.extend(right.keys.iter().cloned());
491            left.children.extend(&right.children);
492        } else {
493            let right = SharedPointer::make_mut(right);
494            let (left, right) = (&**left, right);
495            right.keys.push_front(separator);
496            right.keys.insert_from(0, left.keys.iter().cloned());
497            right.children.insert_front(&left.children);
498        }
499        children.remove(left_idx + (keep_left as usize));
500        debug_assert_eq!(keys.len() + 1, children.len());
501    }
502}
503
504impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Branch<K, V, P> {
505    pub(crate) fn insert(&mut self, key: K, value: V) -> InsertAction<K, V, P> {
506        let i = self
507            .keys
508            .binary_search(&key)
509            .map(|x| x + 1)
510            .unwrap_or_else(|x| x);
511        let insert_action = match &mut self.children {
512            Children::Leaves { leaves } => {
513                SharedPointer::make_mut(&mut leaves[i]).insert(key, value)
514            }
515            Children::Branches { branches, .. } => {
516                SharedPointer::make_mut(&mut branches[i]).insert(key, value)
517            }
518        };
519        match insert_action {
520            InsertAction::Split(new_key, new_node) if self.keys.len() >= NODE_SIZE => {
521                self.split_branch_insert(i, new_key, new_node)
522            }
523            InsertAction::Split(separator, new_node) => {
524                self.keys.insert(i, separator);
525                self.children.insert(i + 1, new_node);
526                InsertAction::Inserted
527            }
528            action => action,
529        }
530    }
531}
532impl<K: Ord + Clone, V: Clone> Leaf<K, V> {
533    pub(crate) fn insert<P: SharedPointerKind>(
534        &mut self,
535        key: K,
536        value: V,
537    ) -> InsertAction<K, V, P> {
538        match self.keys.binary_search_by(|(k, _)| k.cmp(&key)) {
539            Ok(i) => {
540                let (k, v) = mem::replace(&mut self.keys[i], (key, value));
541                InsertAction::Replaced(k, v)
542            }
543            Err(i) if self.keys.len() >= NODE_SIZE => self.split_leaf_insert(i, key, value),
544            Err(i) => {
545                self.keys.insert(i, (key, value));
546                InsertAction::Inserted
547            }
548        }
549    }
550}
551impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Node<K, V, P> {
552    pub(crate) fn insert(&mut self, key: K, value: V) -> InsertAction<K, V, P> {
553        match self {
554            Node::Branch(branch) => SharedPointer::make_mut(branch).insert(key, value),
555            Node::Leaf(leaf) => SharedPointer::make_mut(leaf).insert(key, value),
556        }
557    }
558}
559impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Branch<K, V, P> {
560    #[cold]
561    fn split_branch_insert(
562        &mut self,
563        i: usize,
564        new_key: K,
565        new_node: Node<K, V, P>,
566    ) -> InsertAction<K, V, P> {
567        let split_idx = MEDIAN + (i > MEDIAN) as usize;
568        let mut right_keys = self.keys.split_off(split_idx);
569        let split_idx = MEDIAN + (i >= MEDIAN) as usize;
570        let mut right_children = self.children.split_off(split_idx);
571        let separator = if i == MEDIAN {
572            right_children.insert(0, new_node.clone());
573            new_key
574        } else {
575            if i < MEDIAN {
576                self.keys.insert(i, new_key);
577                self.children.insert(i + 1, new_node);
578            } else {
579                right_keys.insert(i - (MEDIAN + 1), new_key);
580                right_children.insert(i - (MEDIAN + 1) + 1, new_node);
581            }
582            self.keys.pop_back()
583        };
584        debug_assert_eq!(self.keys.len(), right_keys.len());
585        debug_assert_eq!(self.keys.len() + 1, self.children.len());
586        debug_assert_eq!(right_keys.len() + 1, right_children.len());
587        InsertAction::Split(
588            separator,
589            Node::Branch(SharedPointer::new(Branch {
590                keys: right_keys,
591                children: right_children,
592            })),
593        )
594    }
595}
596
597impl<K: Ord + Clone, V: Clone> Leaf<K, V> {
598    #[inline]
599    fn split_leaf_insert<P: SharedPointerKind>(
600        &mut self,
601        i: usize,
602        key: K,
603        value: V,
604    ) -> InsertAction<K, V, P> {
605        let mut right_keys = self.keys.split_off(MEDIAN);
606        if i < MEDIAN {
607            self.keys.insert(i, (key, value));
608        } else {
609            right_keys.insert(i - MEDIAN, (key, value));
610        }
611        InsertAction::Split(
612            right_keys.first().unwrap().0.clone(),
613            Node::Leaf(SharedPointer::new(Leaf { keys: right_keys })),
614        )
615    }
616}
617
618impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Branch<K, V, P> {
619    pub(crate) fn lookup_mut<BK>(&mut self, key: &BK) -> Option<(&K, &mut V)>
620    where
621        BK: Ord + ?Sized,
622        K: Borrow<BK>,
623    {
624        let i = self
625            .keys
626            .binary_search_by(|k| k.borrow().cmp(key))
627            .map(|x| x + 1)
628            .unwrap_or_else(|x| x);
629        match &mut self.children {
630            Children::Leaves { leaves } => SharedPointer::make_mut(&mut leaves[i]).lookup_mut(key),
631            Children::Branches { branches, .. } => {
632                SharedPointer::make_mut(&mut branches[i]).lookup_mut(key)
633            }
634        }
635    }
636}
637
638impl<K: Ord + Clone, V: Clone> Leaf<K, V> {
639    pub(crate) fn lookup_mut<BK>(&mut self, key: &BK) -> Option<(&K, &mut V)>
640    where
641        BK: Ord + ?Sized,
642        K: Borrow<BK>,
643    {
644        let keys = &mut self.keys;
645        let i = keys.binary_search_by(|(k, _)| k.borrow().cmp(key)).ok()?;
646        keys.get_mut(i).map(|(k, v)| (&*k, v))
647    }
648}
649
650impl<K: Ord + Clone, V: Clone, P: SharedPointerKind> Node<K, V, P> {
651    pub(crate) fn lookup_mut<BK>(&mut self, key: &BK) -> Option<(&K, &mut V)>
652    where
653        BK: Ord + ?Sized,
654        K: Borrow<BK>,
655    {
656        match self {
657            Node::Branch(branch) => SharedPointer::make_mut(branch).lookup_mut(key),
658            Node::Leaf(leaf) => SharedPointer::make_mut(leaf).lookup_mut(key),
659        }
660    }
661
662    pub(crate) fn new_from_split(left: Self, separator: K, right: Self) -> Self {
663        Node::Branch(SharedPointer::new(Branch {
664            keys: Chunk::unit(separator),
665            children: match (left, right) {
666                (Node::Branch(left), Node::Branch(right)) => Children::Branches {
667                    level: NonZeroUsize::new(left.level() + 1).unwrap(),
668                    branches: Chunk::from_iter([left, right]),
669                },
670                (Node::Leaf(left), Node::Leaf(right)) => Children::Leaves {
671                    leaves: Chunk::from_iter([left, right]),
672                },
673                _ => panic!("mismatched split"),
674            },
675        }))
676    }
677}
678
679impl<K: Ord, V, P: SharedPointerKind> Branch<K, V, P> {
680    fn min(&self) -> Option<&(K, V)> {
681        let mut node = self;
682        loop {
683            match &node.children {
684                Children::Leaves { leaves } => return leaves.first()?.min(),
685                Children::Branches { branches, .. } => node = branches.first()?,
686            }
687        }
688    }
689    fn max(&self) -> Option<&(K, V)> {
690        let mut node = self;
691        loop {
692            match &node.children {
693                Children::Leaves { leaves } => return leaves.last()?.max(),
694                Children::Branches { branches, .. } => node = branches.last()?,
695            }
696        }
697    }
698    pub(crate) fn lookup<BK>(&self, key: &BK) -> Option<&(K, V)>
699    where
700        BK: Ord + ?Sized,
701        K: Borrow<BK>,
702    {
703        let mut node = self;
704        loop {
705            let i = node
706                .keys
707                .binary_search_by(|k| k.borrow().cmp(key))
708                .map(|x| x + 1)
709                .unwrap_or_else(|x| x);
710            match &node.children {
711                Children::Leaves { leaves } => return leaves[i].lookup(key),
712                Children::Branches { branches, .. } => node = &branches[i],
713            }
714        }
715    }
716}
717
718impl<K: Ord, V> Leaf<K, V> {
719    fn min(&self) -> Option<&(K, V)> {
720        self.keys.first()
721    }
722    fn max(&self) -> Option<&(K, V)> {
723        self.keys.last()
724    }
725    fn lookup<BK>(&self, key: &BK) -> Option<&(K, V)>
726    where
727        BK: Ord + ?Sized,
728        K: Borrow<BK>,
729    {
730        let keys = &self.keys;
731        let i = keys.binary_search_by(|(k, _)| k.borrow().cmp(key)).ok()?;
732        keys.get(i)
733    }
734}
735
736impl<K: Ord, V, P: SharedPointerKind> Node<K, V, P> {
737    pub(crate) fn min(&self) -> Option<&(K, V)> {
738        match self {
739            Node::Branch(branch) => branch.min(),
740            Node::Leaf(leaf) => leaf.min(),
741        }
742    }
743
744    pub(crate) fn max(&self) -> Option<&(K, V)> {
745        match self {
746            Node::Branch(branch) => branch.max(),
747            Node::Leaf(leaf) => leaf.max(),
748        }
749    }
750
751    pub(crate) fn lookup<BK>(&self, key: &BK) -> Option<&(K, V)>
752    where
753        BK: Ord + ?Sized,
754        K: Borrow<BK>,
755    {
756        match self {
757            Node::Branch(branch) => branch.lookup(key),
758            Node::Leaf(leaf) => leaf.lookup(key),
759        }
760    }
761}
762
763impl<K: Clone, V: Clone> Clone for Leaf<K, V> {
764    fn clone(&self) -> Self {
765        Self {
766            keys: self.keys.clone(),
767        }
768    }
769}
770
771impl<K: Clone, V: Clone, P: SharedPointerKind> Clone for Branch<K, V, P> {
772    fn clone(&self) -> Self {
773        Self {
774            keys: self.keys.clone(),
775            children: self.children.clone(),
776        }
777    }
778}
779
780impl<K: Clone, V: Clone, P: SharedPointerKind> Clone for Children<K, V, P> {
781    fn clone(&self) -> Self {
782        match self {
783            Children::Leaves { leaves } => Children::Leaves {
784                leaves: leaves.clone(),
785            },
786            Children::Branches { branches, level } => Children::Branches {
787                branches: branches.clone(),
788                level: *level,
789            },
790        }
791    }
792}
793
794impl<K, V, P: SharedPointerKind> Clone for Node<K, V, P> {
795    fn clone(&self) -> Self {
796        match self {
797            Node::Branch(branch) => Node::Branch(branch.clone()),
798            Node::Leaf(leaf) => Node::Leaf(leaf.clone()),
799        }
800    }
801}
802
803pub(crate) enum InsertAction<K, V, P: SharedPointerKind> {
804    Inserted,
805    Replaced(K, V),
806    Split(K, Node<K, V, P>),
807}
808
809impl<K, V, P: SharedPointerKind> Default for Node<K, V, P> {
810    fn default() -> Self {
811        Node::Leaf(SharedPointer::new(Leaf { keys: Chunk::new() }))
812    }
813}
814
815#[derive(Debug)]
816pub(crate) struct ConsumingIter<K, V, P: SharedPointerKind> {
817    /// The leaves of the tree, in order, note that this will remain the shared ptr
818    /// as it will allows us to have a smaller VecDeque allocation and avoid eagerly
819    /// cloning the leaves, which defeats the purpose of this iterator.
820    /// Leaves present in the VecDeque are guaranteed to be non-empty.
821    leaves: VecDeque<SharedPointer<Leaf<K, V>, P>>,
822    remaining: usize,
823}
824
825impl<K, V, P: SharedPointerKind> ConsumingIter<K, V, P> {
826    pub(crate) fn new(node: Option<Node<K, V, P>>, size: usize) -> Self {
827        fn push<K, V, P: SharedPointerKind>(
828            out: &mut VecDeque<SharedPointer<Leaf<K, V>, P>>,
829            node: SharedPointer<Branch<K, V, P>, P>,
830        ) {
831            match &node.children {
832                Children::Leaves { leaves } => {
833                    out.extend(leaves.iter().filter(|leaf| !leaf.keys.is_empty()).cloned())
834                }
835                Children::Branches { branches, .. } => {
836                    for child in branches.iter() {
837                        push(out, child.clone());
838                    }
839                }
840            }
841        }
842        // preallocate the VecDeque assuming each leaf is half full
843        let mut leaves = VecDeque::with_capacity(size.div_ceil(NODE_SIZE / 2));
844        match node {
845            Some(Node::Branch(b)) => push(&mut leaves, b),
846            Some(Node::Leaf(l)) => {
847                if !l.keys.is_empty() {
848                    leaves.push_back(l)
849                }
850            }
851            None => (),
852        }
853        Self {
854            leaves,
855            remaining: size,
856        }
857    }
858}
859
860impl<K: Clone, V: Clone, P: SharedPointerKind> Iterator for ConsumingIter<K, V, P> {
861    type Item = (K, V);
862
863    fn next(&mut self) -> Option<Self::Item> {
864        let node = self.leaves.front_mut()?;
865        let leaf = SharedPointer::make_mut(node);
866        self.remaining -= 1;
867        let item = leaf.keys.pop_front();
868        if leaf.keys.is_empty() {
869            self.leaves.pop_front();
870        }
871        Some(item)
872    }
873
874    fn size_hint(&self) -> (usize, Option<usize>) {
875        (self.remaining, Some(self.remaining))
876    }
877}
878
879impl<K: Clone, V: Clone, P: SharedPointerKind> DoubleEndedIterator for ConsumingIter<K, V, P> {
880    fn next_back(&mut self) -> Option<Self::Item> {
881        let node = self.leaves.back_mut()?;
882        let leaf = SharedPointer::make_mut(node);
883        self.remaining -= 1;
884        let item = leaf.keys.pop_back();
885        if leaf.keys.is_empty() {
886            self.leaves.pop_back();
887        }
888        Some(item)
889    }
890}
891
892#[derive(Debug)]
893pub(crate) struct Iter<'a, K, V, P: SharedPointerKind> {
894    /// The forward and backward cursors
895    /// The cursors are lazily initialized if their corresponding bound is unbounded
896    fwd: Cursor<'a, K, V, P>,
897    bwd: Cursor<'a, K, V, P>,
898    fwd_yielded: bool,
899    bwd_yielded: bool,
900    exhausted: bool,
901    exact: bool,
902    remaining: usize,
903    root: Option<&'a Node<K, V, P>>,
904}
905
906impl<'a, K, V, P: SharedPointerKind> Iter<'a, K, V, P> {
907    pub(crate) fn new<R, BK>(root: Option<&'a Node<K, V, P>>, len: usize, range: R) -> Self
908    where
909        R: RangeBounds<BK>,
910        K: Borrow<BK>,
911        BK: Ord + ?Sized,
912    {
913        let mut fwd = Cursor::empty();
914        let mut bwd = Cursor::empty();
915        let mut exhausted = match range.start_bound() {
916            Bound::Included(key) | Bound::Excluded(key) => {
917                fwd.init(root);
918                if fwd.seek_to_key(key, false) && matches!(range.start_bound(), Bound::Excluded(_))
919                {
920                    fwd.next().is_none()
921                } else {
922                    fwd.is_empty()
923                }
924            }
925            Bound::Unbounded => false,
926        };
927
928        exhausted = match (exhausted, range.end_bound()) {
929            (false, Bound::Included(key) | Bound::Excluded(key)) => {
930                bwd.init(root);
931                if bwd.seek_to_key(key, true) && matches!(range.end_bound(), Bound::Excluded(_)) {
932                    bwd.prev().is_none()
933                } else {
934                    bwd.is_empty()
935                }
936            }
937            (exhausted, _) => exhausted,
938        };
939
940        // Check if forward is > backward cursor to determine if we are exhausted
941        // Due to the usage of zip this is correct even if the cursors are already or not initialized yet
942        fn cursors_exhausted<K, V, P: SharedPointerKind>(
943            fwd: &Cursor<'_, K, V, P>,
944            bwd: &Cursor<'_, K, V, P>,
945        ) -> bool {
946            for (&(fi, f), &(bi, b)) in fwd.stack.iter().zip(bwd.stack.iter()) {
947                if !std::ptr::eq(f, b) {
948                    return false;
949                }
950                if fi > bi {
951                    return true;
952                }
953            }
954            if let (Some((fi, f)), Some((bi, b))) = (fwd.leaf, bwd.leaf) {
955                if !std::ptr::eq(f, b) {
956                    return false;
957                }
958                if fi > bi {
959                    return true;
960                }
961            }
962            false
963        }
964        exhausted = exhausted || cursors_exhausted(&fwd, &bwd);
965
966        let exact = matches!(range.start_bound(), Bound::Unbounded)
967            && matches!(range.end_bound(), Bound::Unbounded);
968
969        Self {
970            fwd,
971            bwd,
972            remaining: len,
973            exact,
974            exhausted,
975            fwd_yielded: false,
976            bwd_yielded: false,
977            root,
978        }
979    }
980
981    /// Updates the exhausted state of the iterator.
982    /// Returns true if the iterator is immaterially exhausted, which implies ignoring the
983    /// current next candidate, if any.
984    #[inline]
985    fn update_exhausted(&mut self, has_next: bool, other_side_yielded: bool) -> bool {
986        debug_assert!(!self.exhausted);
987        if !has_next {
988            self.exhausted = true;
989            return true;
990        }
991        // Check if the cursors are exhausted by checking their leaves
992        // This is valid even if the cursors are empty due to not being initialized yet.
993        // If they were empty because exhaustion we would not be in this function.
994        if let (Some((fi, f)), Some((bi, b))) = (self.fwd.leaf, self.bwd.leaf) {
995            if std::ptr::eq(f, b) && fi >= bi {
996                self.exhausted = true;
997                return fi == bi && other_side_yielded;
998            }
999        }
1000        false
1001    }
1002
1003    #[cold]
1004    fn peek_initial(&mut self, fwd: bool) -> Option<&'a (K, V)> {
1005        debug_assert!(!self.exhausted);
1006        let cursor = if fwd {
1007            self.fwd_yielded = true;
1008            &mut self.fwd
1009        } else {
1010            self.bwd_yielded = true;
1011            &mut self.bwd
1012        };
1013        // If the cursor is empty we need to initialize it and seek to the first/last element.
1014        // If they were empty because exhaustion we would not be in this function.
1015        if cursor.is_empty() {
1016            cursor.init(self.root);
1017            if fwd {
1018                cursor.seek_to_first();
1019            } else {
1020                cursor.seek_to_last();
1021            }
1022        }
1023        cursor.peek()
1024    }
1025}
1026
1027impl<'a, K, V, P: SharedPointerKind> Iterator for Iter<'a, K, V, P> {
1028    type Item = (&'a K, &'a V);
1029
1030    fn next(&mut self) -> Option<Self::Item> {
1031        if self.exhausted {
1032            return None;
1033        }
1034        let next = if self.fwd_yielded {
1035            self.fwd.next()
1036        } else {
1037            self.peek_initial(true)
1038        }
1039        .map(|(k, v)| (k, v));
1040        if self.update_exhausted(next.is_some(), self.bwd_yielded) {
1041            return None;
1042        }
1043        self.remaining -= 1;
1044        next
1045    }
1046
1047    fn size_hint(&self) -> (usize, Option<usize>) {
1048        if self.exhausted {
1049            return (0, Some(0));
1050        }
1051        let lb = if self.exact { self.remaining } else { 0 };
1052        (lb, Some(self.remaining))
1053    }
1054}
1055
1056impl<'a, K, V, P: SharedPointerKind> DoubleEndedIterator for Iter<'a, K, V, P> {
1057    fn next_back(&mut self) -> Option<Self::Item> {
1058        if self.exhausted {
1059            return None;
1060        }
1061        let next = if self.bwd_yielded {
1062            self.bwd.prev()
1063        } else {
1064            self.peek_initial(false)
1065        }
1066        .map(|(k, v)| (k, v));
1067        if self.update_exhausted(next.is_some(), self.fwd_yielded) {
1068            return None;
1069        }
1070        self.remaining -= 1;
1071        next
1072    }
1073}
1074
1075impl<'a, K, V, P: SharedPointerKind> Clone for Iter<'a, K, V, P> {
1076    fn clone(&self) -> Self {
1077        Self {
1078            fwd: self.fwd.clone(),
1079            bwd: self.bwd.clone(),
1080            exact: self.exact,
1081            fwd_yielded: self.fwd_yielded,
1082            bwd_yielded: self.bwd_yielded,
1083            exhausted: self.exhausted,
1084            remaining: self.remaining,
1085            root: self.root,
1086        }
1087    }
1088}
1089
1090#[derive(Debug)]
1091pub(crate) struct Cursor<'a, K, V, P: SharedPointerKind> {
1092    // a sequence of nodes starting at the root
1093    stack: Vec<(usize, &'a Branch<K, V, P>)>,
1094    leaf: Option<(usize, &'a Leaf<K, V>)>,
1095}
1096
1097impl<'a, K, V, P: SharedPointerKind> Clone for Cursor<'a, K, V, P> {
1098    fn clone(&self) -> Self {
1099        Self {
1100            stack: self.stack.clone(),
1101            leaf: self.leaf.clone(),
1102        }
1103    }
1104}
1105
1106impl<'a, K, V, P: SharedPointerKind> Cursor<'a, K, V, P> {
1107    /// Creates a new empty cursor.
1108    /// The variety of methods is to allow for a more efficient initialization
1109    /// in all cases.
1110    pub(crate) fn empty() -> Self {
1111        Self {
1112            stack: Vec::new(),
1113            leaf: None,
1114        }
1115    }
1116
1117    fn is_empty(&self) -> bool {
1118        self.stack.is_empty() && self.leaf.is_none()
1119    }
1120
1121    pub(crate) fn init(&mut self, node: Option<&'a Node<K, V, P>>) {
1122        if let Some(node) = node {
1123            self.stack.reserve_exact(node.level());
1124            match node {
1125                Node::Branch(branch) => self.stack.push((0, &*branch)),
1126                Node::Leaf(leaf) => {
1127                    debug_assert!(self.leaf.is_none());
1128                    self.leaf = Some((0, &*leaf))
1129                }
1130            }
1131        }
1132    }
1133
1134    // pushes the `ix`th child of `branch` onto the stack, whether it's a leaf
1135    // or a branch
1136    fn push_child(&mut self, branch: &'a Branch<K, V, P>, ix: usize) {
1137        debug_assert!(
1138            self.leaf.is_none(),
1139            "it doesn't make sense to push when we're already at a leaf"
1140        );
1141        match &branch.children {
1142            Children::Leaves { leaves } => self.leaf = Some((0, &leaves[ix])),
1143            Children::Branches { branches, .. } => self.stack.push((0, &branches[ix])),
1144        }
1145    }
1146
1147    pub(crate) fn seek_to_first(&mut self) -> Option<&'a (K, V)> {
1148        loop {
1149            if let Some((i, leaf)) = &self.leaf {
1150                debug_assert_eq!(i, &0);
1151                return leaf.keys.first();
1152            }
1153            let Some((i, branch)) = self.stack.last() else {
1154                return None;
1155            };
1156            debug_assert_eq!(i, &0);
1157            self.push_child(branch, 0);
1158        }
1159    }
1160
1161    fn seek_to_last(&mut self) -> Option<&'a (K, V)> {
1162        loop {
1163            if let Some((i, leaf)) = &mut self.leaf {
1164                debug_assert_eq!(i, &0);
1165                *i = leaf.keys.len().saturating_sub(1);
1166                return leaf.keys.last();
1167            }
1168            let Some((i, branch)) = self.stack.last_mut() else {
1169                return None;
1170            };
1171            debug_assert_eq!(i, &0);
1172            *i = branch.children.len() - 1;
1173            let (i, branch) = (*i, *branch);
1174            self.push_child(branch, i);
1175        }
1176    }
1177
1178    fn seek_to_key<BK>(&mut self, key: &BK, for_prev: bool) -> bool
1179    where
1180        BK: Ord + ?Sized,
1181        K: Borrow<BK>,
1182    {
1183        loop {
1184            if let Some((i, leaf)) = &mut self.leaf {
1185                let search = leaf.keys.binary_search_by(|(k, _)| k.borrow().cmp(key));
1186                *i = search.unwrap_or_else(|x| x);
1187                if for_prev {
1188                    if search.is_err() {
1189                        self.prev();
1190                    }
1191                } else if search == Err(leaf.keys.len()) {
1192                    self.next();
1193                }
1194                return search.is_ok();
1195            }
1196            let Some((i, branch)) = self.stack.last_mut() else {
1197                return false;
1198            };
1199            *i = branch
1200                .keys
1201                .binary_search_by(|k| k.borrow().cmp(key))
1202                .map(|x| x + 1)
1203                .unwrap_or_else(|x| x);
1204            let (i, branch) = (*i, *branch);
1205            self.push_child(branch, i);
1206        }
1207    }
1208
1209    /// Advances this and another cursor to their next position.
1210    /// While doing so skip all shared nodes between them.
1211    pub(crate) fn advance_skipping_shared<'b>(&mut self, other: &mut Cursor<'b, K, V, P>) {
1212        // The current implementation is not optimal as it will still visit many nodes unnecessarily
1213        // before skipping them. But it requires very little additional code.
1214        // Nevertheless it will still improve performance when there are shared nodes.
1215        loop {
1216            let mut skipped_any = false;
1217            debug_assert!(self.leaf.is_some());
1218            debug_assert!(other.leaf.is_some());
1219            if let (Some(this), Some(that)) = (self.leaf, other.leaf) {
1220                if std::ptr::eq(this.1, that.1) {
1221                    self.leaf = None;
1222                    other.leaf = None;
1223                    skipped_any = true;
1224                    let shared_levels = self
1225                        .stack
1226                        .iter()
1227                        .rev()
1228                        .zip(other.stack.iter().rev())
1229                        .take_while(|(this, that)| std::ptr::eq(this.1, that.1))
1230                        .count();
1231                    if shared_levels != 0 {
1232                        self.stack.drain(self.stack.len() - shared_levels..);
1233                        other.stack.drain(other.stack.len() - shared_levels..);
1234                    }
1235                }
1236            }
1237            self.next();
1238            other.next();
1239            if !skipped_any || self.leaf.is_none() {
1240                break;
1241            }
1242        }
1243    }
1244
1245    pub(crate) fn next(&mut self) -> Option<&'a (K, V)> {
1246        loop {
1247            if let Some((i, leaf)) = &mut self.leaf {
1248                if *i + 1 < leaf.keys.len() {
1249                    *i += 1;
1250                    return leaf.keys.get(*i);
1251                }
1252                self.leaf = None;
1253            }
1254            let Some((i, branch)) = self.stack.last_mut() else {
1255                break;
1256            };
1257            if *i + 1 < branch.children.len() {
1258                *i += 1;
1259                let (i, branch) = (*i, *branch);
1260                self.push_child(branch, i);
1261                break;
1262            }
1263            self.stack.pop();
1264        }
1265        self.seek_to_first()
1266    }
1267
1268    fn prev(&mut self) -> Option<&'a (K, V)> {
1269        loop {
1270            if let Some((i, leaf)) = &mut self.leaf {
1271                if *i > 0 {
1272                    *i -= 1;
1273                    return leaf.keys.get(*i);
1274                }
1275                self.leaf = None;
1276            }
1277            let Some((i, branch)) = self.stack.last_mut() else {
1278                break;
1279            };
1280            if *i > 0 {
1281                *i -= 1;
1282                let (i, branch) = (*i, *branch);
1283                self.push_child(branch, i);
1284                break;
1285            }
1286            self.stack.pop();
1287        }
1288        self.seek_to_last()
1289    }
1290
1291    pub(crate) fn peek(&self) -> Option<&'a (K, V)> {
1292        if let Some((i, leaf)) = &self.leaf {
1293            leaf.keys.get(*i)
1294        } else {
1295            None
1296        }
1297    }
1298}