generic_btree/generic_impl/
ord.rs

1use core::fmt::Debug;
2use std::cmp::Ordering;
3use std::ops::Range;
4
5use crate::rle::{CanRemove, HasLength, Mergeable, Sliceable, TryInsert};
6use crate::{BTree, BTreeTrait, FindResult, Query, SplitInfo};
7
8#[derive(Debug)]
9#[repr(transparent)]
10struct OrdTrait<Key, Value> {
11    _phantom: core::marker::PhantomData<(Key, Value)>,
12}
13
14#[derive(Debug)]
15pub struct OrdTreeMap<Key: Clone + Ord + Debug + 'static, Value: Clone + Debug> {
16    tree: BTree<OrdTrait<Key, Value>>,
17    len: usize,
18}
19
20#[derive(Debug)]
21pub struct OrdTreeSet<Key: Clone + Ord + Debug + 'static>(OrdTreeMap<Key, ()>);
22
23impl<Key: Clone + Ord + Debug + 'static, Value: Clone + Debug + 'static> OrdTreeMap<Key, Value> {
24    #[inline(always)]
25    pub fn new() -> Self {
26        Self {
27            tree: BTree::new(),
28            len: 0,
29        }
30    }
31
32    #[inline(always)]
33    pub fn insert(&mut self, key: Key, value: Value) {
34        let Some(result) = self.tree.query::<OrdTrait<Key, Value>>(&key) else {
35            self.len += 1;
36            self.tree.push(Unmergeable((key, value)));
37            return;
38        };
39
40        if !result.found {
41            self.len += 1;
42            let tree = &mut self.tree;
43            let data = Unmergeable((key, value));
44            let index = result.leaf();
45            let leaf = tree.leaf_nodes.get_mut(index.0).unwrap();
46            let parent = leaf.parent();
47
48            let mut is_full = false;
49            // Try to merge
50            if result.cursor.offset == 0 && data.can_merge(&leaf.elem) {
51                leaf.elem.merge_left(&data);
52            } else if result.cursor.offset == leaf.elem.rle_len() && leaf.elem.can_merge(&data) {
53                leaf.elem.merge_right(&data);
54            } else {
55                // Insert new leaf node
56                let child = tree.alloc_leaf_child(data, parent.unwrap_internal());
57                let SplitInfo {
58                    parent_idx: parent_index,
59                    insert_slot: insert_index,
60                    ..
61                } = tree.split_leaf_if_needed(result.cursor);
62                let parent = tree.in_nodes.get_mut(parent_index).unwrap();
63                parent.children.insert(insert_index, child).unwrap();
64                is_full = parent.is_full();
65            }
66
67            tree.recursive_update_cache(parent, false, None);
68            if is_full {
69                tree.split(parent);
70            }
71        } else {
72            let leaf = self.tree.get_elem_mut(result.leaf()).unwrap();
73            leaf.0 .1 = value;
74        }
75    }
76
77    #[inline(always)]
78    pub fn delete(&mut self, key: &Key) -> Option<(Key, Value)> {
79        let q = self.tree.query::<OrdTrait<Key, Value>>(key)?;
80        match self.tree.remove_leaf(q.cursor) {
81            Some(v) => {
82                self.len -= 1;
83                Some(v.0)
84            }
85            None => None,
86        }
87    }
88
89    #[inline(always)]
90    pub fn iter(&self) -> impl Iterator<Item = &(Key, Value)> {
91        self.tree.iter().map(|x| &x.0)
92    }
93
94    #[inline(always)]
95    pub fn iter_key(&self) -> impl Iterator<Item = &Key> {
96        self.tree.iter().map(|x| &x.0 .0)
97    }
98
99    #[inline(always)]
100    pub fn len(&self) -> usize {
101        self.len
102    }
103
104    #[inline(always)]
105    pub fn is_empty(&self) -> bool {
106        self.len == 0
107    }
108
109    #[allow(unused)]
110    pub(crate) fn check(&self) {
111        self.tree.check()
112    }
113}
114
115impl<Key: Clone + Ord + Debug + 'static> OrdTreeSet<Key> {
116    #[inline(always)]
117    pub fn new() -> Self {
118        Self(OrdTreeMap::new())
119    }
120
121    #[inline(always)]
122    pub fn insert(&mut self, key: Key) {
123        self.0.insert(key, ());
124    }
125
126    #[inline(always)]
127    pub fn delete(&mut self, key: &Key) -> bool {
128        self.0.delete(key).is_some()
129    }
130
131    #[inline(always)]
132    pub fn iter(&self) -> impl Iterator<Item = &Key> {
133        self.0.iter_key()
134    }
135
136    pub fn len(&self) -> usize {
137        self.0.len
138    }
139
140    pub fn is_empty(&self) -> bool {
141        self.0.len == 0
142    }
143
144    #[allow(unused)]
145    fn check(&self) {
146        self.0.check()
147    }
148}
149
150impl<Key: Clone + Ord + Debug + 'static> Default for OrdTreeSet<Key> {
151    #[inline(always)]
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl<Key: Clone + Ord + Debug + 'static, Value: Clone + Debug + 'static> Default
158    for OrdTreeMap<Key, Value>
159{
160    #[inline(always)]
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166impl<Key, Value> Default for OrdTrait<Key, Value> {
167    #[inline(always)]
168    fn default() -> Self {
169        Self {
170            _phantom: Default::default(),
171        }
172    }
173}
174
175#[repr(transparent)]
176#[derive(Debug, Clone)]
177pub struct Unmergeable<T>(T);
178
179impl<T> HasLength for Unmergeable<T> {
180    fn rle_len(&self) -> usize {
181        1
182    }
183}
184
185impl<T: Clone> Sliceable for Unmergeable<T> {
186    fn _slice(&self, range: Range<usize>) -> Self {
187        if range.end - range.start != 1 {
188            panic!("Invalid range");
189        }
190
191        self.clone()
192    }
193}
194
195impl<T> Mergeable for Unmergeable<T> {
196    fn can_merge(&self, _rhs: &Self) -> bool {
197        false
198    }
199
200    fn merge_right(&mut self, _rhs: &Self) {
201        unreachable!()
202    }
203
204    fn merge_left(&mut self, _left: &Self) {
205        unreachable!()
206    }
207}
208
209impl<T> TryInsert for Unmergeable<T> {
210    fn try_insert(&mut self, _pos: usize, elem: Self) -> Result<(), Self> {
211        Err(elem)
212    }
213}
214
215impl<T> CanRemove for Unmergeable<T> {
216    fn can_remove(&self) -> bool {
217        false
218    }
219}
220
221impl<Key: Clone + Ord + Debug + 'static, Value: Clone + Debug> BTreeTrait for OrdTrait<Key, Value> {
222    type Elem = Unmergeable<(Key, Value)>;
223    type Cache = Option<(Key, Key)>;
224    type CacheDiff = ();
225    const USE_DIFF: bool = false;
226
227    #[inline(always)]
228    fn calc_cache_internal(cache: &mut Self::Cache, caches: &[crate::Child<Self>]) {
229        if caches.is_empty() {
230            return;
231        }
232
233        *cache = Some((
234            caches[0].cache.as_ref().unwrap().0.clone(),
235            caches[caches.len() - 1].cache.as_ref().unwrap().1.clone(),
236        ));
237    }
238
239    #[inline(always)]
240    fn apply_cache_diff(_: &mut Self::Cache, _: &Self::CacheDiff) {
241        unreachable!()
242    }
243
244    #[inline(always)]
245    fn merge_cache_diff(_: &mut Self::CacheDiff, _: &Self::CacheDiff) {}
246
247    #[inline(always)]
248    fn get_elem_cache(elem: &Self::Elem) -> Self::Cache {
249        Some((elem.0 .0.clone(), elem.0 .0.clone()))
250    }
251
252    #[inline(always)]
253    fn new_cache_to_diff(_: &Self::Cache) -> Self::CacheDiff {}
254
255    fn sub_cache(_: &Self::Cache, _: &Self::Cache) -> Self::CacheDiff {}
256}
257
258impl<Key: Ord + Clone + Debug + 'static, Value: Clone + Debug + 'static> Query<OrdTrait<Key, Value>>
259    for OrdTrait<Key, Value>
260{
261    type QueryArg = Key;
262
263    #[inline(always)]
264    fn init(_target: &Self::QueryArg) -> Self {
265        Self::default()
266    }
267
268    #[inline]
269    fn find_node(
270        &mut self,
271        target: &Self::QueryArg,
272        child_caches: &[crate::Child<OrdTrait<Key, Value>>],
273    ) -> crate::FindResult {
274        let result = child_caches.binary_search_by(|x| {
275            let (min, max) = x.cache.as_ref().unwrap();
276            if target < min {
277                core::cmp::Ordering::Greater
278            } else if target > max {
279                core::cmp::Ordering::Less
280            } else {
281                core::cmp::Ordering::Equal
282            }
283        });
284        match result {
285            Ok(i) => FindResult::new_found(i, 0),
286            Err(i) => FindResult::new_missing(
287                i.min(child_caches.len() - 1),
288                if i == child_caches.len() { 1 } else { 0 },
289            ),
290        }
291    }
292
293    #[inline(always)]
294    fn confirm_elem(
295        &mut self,
296        q: &Self::QueryArg,
297        elem: &<OrdTrait<Key, Value> as BTreeTrait>::Elem,
298    ) -> (usize, bool) {
299        match q.cmp(&elem.0 .0) {
300            Ordering::Less => (0, false),
301            Ordering::Equal => (0, true),
302            Ordering::Greater => (1, false),
303        }
304    }
305}
306
307#[cfg(test)]
308mod test {
309    use std::cmp::Ordering;
310
311    use rand::{Rng, SeedableRng};
312
313    use crate::HeapVec;
314
315    use super::*;
316
317    #[test]
318    fn test() {
319        let mut tree: OrdTreeSet<u64> = OrdTreeSet::new();
320        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
321        let mut data: HeapVec<u64> = (0..1000).map(|_| rng.gen()).collect();
322        for &value in data.iter() {
323            tree.insert(value);
324        }
325        data.sort_unstable();
326        assert_eq!(tree.iter().copied().collect::<HeapVec<_>>(), data);
327        tree.check();
328    }
329
330    #[test]
331    fn test_delete() {
332        let mut tree: OrdTreeSet<u64> = OrdTreeSet::new();
333        tree.insert(12);
334        tree.delete(&12);
335        assert_eq!(tree.len(), 0);
336    }
337
338    #[test]
339    fn test_compare_pos() {
340        let mut tree: OrdTreeSet<u64> = OrdTreeSet::new();
341        for i in 0..100 {
342            tree.insert(i);
343        }
344
345        for i in 0..99 {
346            let a = tree.0.tree.query::<OrdTrait<u64, ()>>(&i).unwrap();
347            assert_eq!(
348                tree.0.tree.compare_pos(a.cursor(), a.cursor()),
349                Ordering::Equal
350            );
351            for j in i + 1..100 {
352                let b = tree.0.tree.query::<OrdTrait<u64, ()>>(&j).unwrap();
353                assert_eq!(
354                    tree.0.tree.compare_pos(a.cursor(), b.cursor()),
355                    Ordering::Less
356                );
357                assert_eq!(
358                    tree.0.tree.compare_pos(b.cursor(), a.cursor()),
359                    Ordering::Greater
360                );
361            }
362        }
363    }
364
365    mod move_event_test {
366
367        use super::*;
368
369        #[test]
370        fn test() {
371            let mut tree: OrdTreeMap<u64, usize> = OrdTreeMap::new();
372            let mut rng = rand::rngs::StdRng::seed_from_u64(123);
373            let mut data: HeapVec<u64> = (0..1000).map(|_| rng.gen()).collect();
374            for &value in data.iter() {
375                tree.insert(value, 0);
376            }
377            for value in data.drain(0..100) {
378                tree.delete(&value);
379            }
380            for value in data.drain(0..800) {
381                tree.delete(&value);
382            }
383            tree.tree.check();
384            for _ in (0..100).rev() {
385                tree.delete(&data.pop().unwrap());
386            }
387        }
388    }
389
390    #[test]
391    #[ignore]
392    fn depth_test() {
393        let mut tree: OrdTreeSet<u64> = OrdTreeSet::new();
394        for i in 0..2_100_000 {
395            tree.insert(i as u64);
396            let m = (!i) + 1;
397            if (i & m) == i {
398                eprintln!(
399                    "i={}, Depth={}, Avg Children={}",
400                    i,
401                    tree.0.tree.depth(),
402                    tree.0.tree.internal_avg_children_num()
403                );
404            }
405        }
406        tree.check();
407    }
408}