hibit_tree/ops/
_multi_union.rs

1use std::marker::PhantomData;
2use std::slice;
3use arrayvec::ArrayVec;
4use crate::{BitBlock, LazyHibitTree, RegularHibitTree, MultiHibitTree, MultiHibitTreeTypes, HibitTree, HibitTreeData, HibitTreeCursor, HibitTreeCursorTypes, HibitTreeTypes, HierarchyIndex};
5use crate::const_utils::{ArrayOf, ConstBool, ConstFalse, ConstInteger, ConstTrue, IsConstTrue};
6use crate::utils::{Array, Borrowable, Ref};
7
8pub struct MultiUnion<Iter, D=ConstFalse> {
9    iter: Iter,
10    phantom: PhantomData<D>,
11}
12
13type IterItem<Iter> = <<Iter as Iterator>::Item as Ref>::Type;
14type IterItemCursor<'item, Iter> = <IterItem<Iter> as HibitTreeTypes<'item>>::Cursor;
15
16impl<'item, 'this, Iter, T, D> HibitTreeTypes<'this> for MultiUnion<Iter, D>
17where
18    Iter: Iterator<Item = &'item T> + Clone,
19    T: HibitTree + 'item,
20    D: ConstBool,
21{
22    type Data  = Data<'item, Iter>;
23    type DataUnchecked = DataUnchecked<Iter>;
24    type DataOrDefault = DataOrDefault<Iter>;
25    type Cursor = Cursor<'this, 'item, Iter, D>;
26}
27
28impl<'i, Iter, T, D> HibitTree for MultiUnion<Iter, D>
29where
30    Iter: Iterator<Item = &'i T> + Clone,
31    T: HibitTree + 'i,
32    D: ConstBool
33{
34    const EXACT_HIERARCHY: bool = T::EXACT_HIERARCHY;
35    type DefaultData = T::DefaultData;
36    
37    type LevelCount = T::LevelCount;
38    type LevelMask  = T::LevelMask;
39
40    #[inline]
41    fn data(&self, index: &HierarchyIndex<Self::LevelMask, Self::LevelCount>)
42        -> Option<<Self as HibitTreeTypes<'_>>::Data> 
43    {
44        // TODO: we can have custom iterator here, without gathering everything into array.
45        // Gather items - then return as iter.
46        let mut datas: ArrayVec<_, N> = Default::default();
47        for array in self.iter.clone(){
48            let data = array.borrow().data(index);
49            if let Some(data) = data {
50                datas.push(data);
51            }
52        }
53        if datas.is_empty(){
54            return None;
55        }
56        
57        Some(datas.into_iter())
58    }
59
60    #[inline]
61    unsafe fn data_unchecked(&self, index: &HierarchyIndex<Self::LevelMask, Self::LevelCount>)
62        -> <Self as HibitTreeTypes<'_>>::DataUnchecked 
63    {
64        DataUnchecked {
65            iter: self.iter.clone(),
66            hi_index: index.clone(),
67        }
68    }
69    
70    #[inline]
71    unsafe fn data_or_default(&self, index: &HierarchyIndex<Self::LevelMask, Self::LevelCount>)
72        -> <Self as HibitTreeTypes<'_>>::DataOrDefault 
73    {
74        DataOrDefault {
75            iter: self.iter.clone(),
76            hi_index: index.clone(),
77        }
78    }
79}
80
81pub type Data<'item, Iter> = arrayvec::IntoIter<<IterItem<Iter> as HibitTreeTypes<'item>>::Data, N>;
82
83pub struct DataUnchecked<Iter>
84where
85    Iter: Iterator<Item: Ref<Type: HibitTree>>,
86{
87    iter: Iter,
88    hi_index: HierarchyIndex<
89        <IterItem<Iter> as HibitTree>::LevelMask,
90        <IterItem<Iter> as HibitTree>::LevelCount,
91    >,
92}
93impl<'item, Iter, T> Iterator for DataUnchecked<Iter>
94where
95    Iter: Iterator<Item = &'item T> + Clone,
96    T: HibitTree + 'item,
97{
98    type Item = <T as HibitTreeTypes<'item>>::Data;
99
100    #[inline]
101    fn next(&mut self) -> Option<Self::Item> {
102        self.iter.find_map(|array| array.data(&self.hi_index))
103    }
104
105    #[inline]
106    fn fold<B, F>(self, mut init: B, mut f: F) -> B
107    where
108        Self: Sized,
109        F: FnMut(B, Self::Item) -> B,
110    {
111        for array in self.iter {
112            if let Some(item) = array.data(&self.hi_index) {
113                init = f(init, item)    
114            }
115        }
116        init
117    }
118    
119    #[inline]
120    fn size_hint(&self) -> (usize, Option<usize>) {
121        (0, self.iter.size_hint().1)
122    }
123}
124
125pub struct DataOrDefault<Iter>
126where
127    Iter: Iterator<Item: Ref<Type: HibitTree>>,
128{
129    iter: Iter,
130    hi_index: HierarchyIndex<
131        <IterItem<Iter> as HibitTree>::LevelMask,
132        <IterItem<Iter> as HibitTree>::LevelCount,
133    >,
134}
135impl<'item, Iter, T> Iterator for DataOrDefault<Iter>
136where
137    Iter: Iterator<Item = &'item T> + Clone,
138    T: HibitTree + 'item,
139{
140    type Item = <T as HibitTreeTypes<'item>>::DataOrDefault;
141
142    #[inline]
143    fn next(&mut self) -> Option<Self::Item> {
144        self.iter.next().map(|array| unsafe{
145            array.data_or_default(&self.hi_index)
146        })
147    }
148
149    #[inline]
150    fn fold<B, F>(self, mut init: B, mut f: F) -> B
151    where
152        Self: Sized,
153        F: FnMut(B, Self::Item) -> B,
154    {
155        for array in self.iter {
156            let item = unsafe{array.data_or_default(&self.hi_index)};
157            init = f(init, item);    
158        }
159        init
160    }
161    
162    #[inline]
163    fn size_hint(&self) -> (usize, Option<usize>) {
164        self.iter.size_hint()
165    }
166}
167impl<'item, Iter, T> ExactSizeIterator for DataOrDefault<Iter>
168where
169    Iter: Iterator<Item = &'item T> + Clone,
170    T: HibitTree + 'item
171{}
172
173// --- CURSOR ---
174
175const N: usize = 32;
176type CursorIndex = u8;
177type CursorsItem<'item, Iter> = (<Iter as Iterator>::Item, IterItemCursor<'item, Iter>);
178
179impl<'this, 'src, 'item, Iter, D> HibitTreeCursorTypes<'this> for Cursor<'src, 'item, Iter, D>
180where
181    Iter: Iterator<Item: Ref<Type: HibitTree>> + Clone,
182    D: ConstBool,
183{
184    type Data = CursorData<'this, 'item, Iter, ConstFalse>;
185    type DataUnchecked = CursorData<'this, 'item, Iter, D>;
186    type DataOrDefault = CursorData<'this, 'item, Iter, ConstTrue>;
187}
188
189pub struct Cursor<'src, 'item, Iter, D>
190where
191    Iter: Iterator<Item: Ref<Type: HibitTree>> + Clone,
192    D: ConstBool
193{
194    cursors: ArrayVec<CursorsItem<'item, Iter>, N>,
195    
196    /// [ArrayVec<usize, N>; Array::LevelCount - 1]
197    /// 
198    /// Root level skipped.
199    lvls_non_empty_states: ArrayOf<
200        ArrayVec<CursorIndex, N>,
201        <<IterItem<Iter> as HibitTree>::LevelCount as ConstInteger>::Dec,
202    >,
203    
204    phantom_data: PhantomData<&'src MultiUnion<Iter, D>>
205}
206
207impl<'src, 'item, Iter, T, D> Cursor<'src, 'item, Iter, D>
208where
209    Iter: Iterator<Item = &'item T> + Clone,
210    T: HibitTree + 'item,
211    D: ConstBool
212{
213    #[inline]
214    unsafe fn make_cursor_data<Def: ConstBool>(&self, level_index: usize) 
215        -> CursorData<'_, 'item, Iter, Def> 
216    {
217        if <<<Self as HibitTreeCursor>::Tree as HibitTree>::LevelCount as ConstInteger>::VALUE == 1 {
218            todo!("TODO: compile-time special case for 1-level SparseHierarchy");
219        }
220        
221        let lvl_non_empty_states = self.lvls_non_empty_states.as_ref()
222                                   .last().unwrap_unchecked();
223        
224        CursorData {
225            lvl_non_empty_states: lvl_non_empty_states.iter(),
226            cursors: &self.cursors,
227            level_index,
228            phantom_data: PhantomData,
229        }        
230    }
231}
232
233impl<'src, 'item, Iter, T, D> HibitTreeCursor<'src> for Cursor<'src, 'item, Iter, D>
234where
235    Iter: Iterator<Item = &'item T> + Clone,
236    T: HibitTree + 'item,
237    D: ConstBool
238{
239    type Tree = MultiUnion<Iter, D>;
240
241    #[inline]
242    fn new(src: &'src Self::Tree) -> Self {
243        let states = ArrayVec::from_iter(
244            src.iter.clone()
245                .map(|array|{
246                    let state = HibitTreeCursor::new(array.borrow()); 
247                    (array, state)
248                })
249        );
250        
251        Self {
252            cursors: states,
253            lvls_non_empty_states: Array::from_fn(|_|ArrayVec::new()),
254            phantom_data: PhantomData,
255        }
256    }
257
258    #[inline]
259    unsafe fn select_level_node<N: ConstInteger>(&mut self, _: &'src Self::Tree, level_n: N, level_index: usize) 
260        -> <Self::Tree as HibitTree>::LevelMask 
261    {
262        let mut acc_mask = BitBlock::zero();
263        
264        if N::VALUE == 0 {
265            for (array, array_cursor) in self.cursors.iter_mut() {
266                let mask = array_cursor.select_level_node(array, level_n, level_index);
267                acc_mask |= mask;
268            }            
269            return acc_mask;
270        }
271        
272        // Work with pointers for `get_many`-like access. 
273        let lvls_non_empty_states = self.lvls_non_empty_states.as_mut().as_mut_ptr();
274        let lvl_non_empty_states  = &mut*lvls_non_empty_states.add(level_n.value()-1);
275        lvl_non_empty_states.clear();
276        
277        let len = self.cursors.len() as u8;
278        
279        let mut foreach = |i: CursorIndex| {
280            let (array, array_cursor) = self.cursors.get_unchecked_mut(i as usize);
281            let mask = array_cursor.select_level_node(array, level_n, level_index);
282            if !mask.is_zero() {
283                lvl_non_empty_states.push_unchecked(i);
284            }
285            acc_mask |= mask;            
286        };
287        
288        if N::VALUE == 1 {
289            // Prev level is root. Since we don't store root - 
290            // just iterate all states.
291            for i in 0..len { foreach(i) }    
292        } else {
293            let prev_lvl_non_empty_states = &*lvls_non_empty_states.add(level_n.value()-2);
294            for i in prev_lvl_non_empty_states { foreach(*i) }
295        }
296        
297        acc_mask
298    }
299
300    #[inline]
301    unsafe fn select_level_node_unchecked<N: ConstInteger>(&mut self, src: &'src Self::Tree, level_n: N, level_index: usize) 
302        -> <Self::Tree as HibitTree>::LevelMask 
303    {
304        // There is actually no unchecked version for union.
305        self.select_level_node(src, level_n, level_index)
306    }
307
308    #[inline]
309    unsafe fn data<'a>(&'a self, _: &'src Self::Tree, level_index: usize) 
310        -> Option<<Self as HibitTreeCursorTypes<'a>>::Data> 
311    {
312        if <<Self::Tree as HibitTree>::LevelCount as ConstInteger>::VALUE == 1 {
313            todo!("TODO: compile-time special case for 1-level SparseHierarchy");
314        }
315        
316        let lvl_non_empty_states = self.lvls_non_empty_states.as_ref()
317                                   .last().unwrap_unchecked();
318        if lvl_non_empty_states.is_empty(){
319            return None;
320        }
321        
322        Some(CursorData {
323            lvl_non_empty_states: lvl_non_empty_states.iter(),
324            cursors: &self.cursors,
325            level_index,
326            phantom_data: PhantomData,
327        })
328    }
329
330    #[inline]
331    unsafe fn data_unchecked<'a>(&'a self, _: &'src Self::Tree, level_index: usize) 
332        -> <Self as HibitTreeCursorTypes<'a>>::DataUnchecked
333    {
334        self.make_cursor_data(level_index)
335    }
336    
337    #[inline]
338    unsafe fn data_or_default<'a>(&'a self, _: &'src Self::Tree, level_index: usize) 
339        -> <Self as HibitTreeCursorTypes<'a>>::DataOrDefault 
340    {
341        self.make_cursor_data(level_index)
342    }    
343}
344
345// `D=true` will use [data_or_default] to return values. 
346// Otherwise, [data] will be used.
347pub struct CursorData<'cursor, 'item, I, D>
348where
349    I: Iterator<Item: Ref<Type: HibitTree>>
350{
351    lvl_non_empty_states: slice::Iter<'cursor, CursorIndex>,
352    cursors: &'cursor [CursorsItem<'item, I>],
353    level_index: usize,
354    phantom_data: PhantomData<D>
355}
356
357impl<'cursor, 'item, I, T, D> Iterator for CursorData<'cursor, 'item, I, D>
358where
359    I: Iterator<Item = &'item T> + Clone,
360    T: RegularHibitTree + 'item,
361    D: ConstBool
362{
363    type Item = <IterItemCursor<'item, I> as HibitTreeCursorTypes<'cursor>>::Data;
364
365    #[inline]
366    fn next(&mut self) -> Option<Self::Item> {
367        if D::VALUE {
368            self.lvl_non_empty_states.next().map(|&i| unsafe{
369                let (array, array_cursor) = self.cursors.get_unchecked(i as usize);
370                array_cursor.data_or_default(array, self.level_index)
371            })
372        } else {
373            self.lvl_non_empty_states
374                .find_map(|&i| unsafe {
375                    let (array, array_cursor) = self.cursors.get_unchecked(i as usize);
376                    if let Some(data) = array_cursor.data(array, self.level_index) {
377                        Some(data)
378                    } else {
379                        None
380                    }
381                })
382        }
383    }
384
385    #[inline]
386    fn fold<B, F>(self, mut init: B, mut f: F) -> B
387    where
388        Self: Sized,
389        F: FnMut(B, Self::Item) -> B,
390    {
391        let level_index = self.level_index;
392        for &i in self.lvl_non_empty_states {
393            let (array, array_cursor) = unsafe{ self.cursors.get_unchecked(i as usize) };
394            if D::VALUE {
395                let data = unsafe{ array_cursor.data_or_default(array, self.level_index) };
396                init = f(init, data);
397            } else {
398                if let Some(data) = unsafe{ array_cursor.data(array, level_index) } {
399                    init = f(init, data);
400                }
401            }
402        }
403        init
404    }
405
406    #[inline]
407    fn size_hint(&self) -> (usize, Option<usize>) {
408        let len = self.lvl_non_empty_states.len();
409        if D::VALUE{
410            (len, Some(len))
411        } else {
412            (0, Some(len))
413        }
414    }
415}
416
417impl<'cursor, 'item, I, T, D> ExactSizeIterator for CursorData<'cursor, 'item, I, D>
418where
419    I: Iterator<Item = &'item T> + Clone,
420    T: RegularHibitTree + 'item,
421    D: IsConstTrue
422{}
423
424impl<'item, Iter, T, D> LazyHibitTree for MultiUnion<Iter, D>
425where
426    Iter: Iterator<Item = &'item T> + Clone,
427    T: RegularHibitTree + 'item,
428    D: ConstBool
429{}
430
431impl<'item, 'this, Iter, T, D> MultiHibitTreeTypes<'this> for MultiUnion<Iter, D>
432where
433    Iter: Iterator<Item = &'item T> + Clone,
434    T: RegularHibitTree + 'item,
435    D: ConstBool
436{ 
437    type IterItem = HibitTreeData<'item, T>; 
438}
439
440impl<'item, Iter, T, D> MultiHibitTree for MultiUnion<Iter, D>
441where
442    Iter: Iterator<Item = &'item T> + Clone,
443    T: RegularHibitTree + 'item,
444    D: ConstBool
445{}
446
447impl<Iter, D> Borrowable for MultiUnion<Iter, D>{ type Borrowed = Self; }
448
449/// Union between multiple &[RegularHibitTree]s.
450/// 
451/// `iter` will be cloned and iterated multiple times.
452/// Pass something like [slice::Iter].
453#[inline]
454pub fn multi_union<Iter>(iter: Iter) 
455    -> MultiUnion<Iter>
456where
457    Iter: Iterator<Item: Ref<Type: RegularHibitTree>> + Clone,
458{
459    MultiUnion{ iter, phantom: Default::default() }
460}
461
462/// Same as [multi_union] but iterator will use [data_or_default].
463/// 
464/// This can lead to a faster code, since iterator does not have to 
465/// skip values during iteration.
466/// 
467/// Default values MAY appear in iterator output.
468/// 
469/// [data_or_default]: crate::HibitTree::data_or_default
470#[inline]
471pub fn multi_union_w_default<Iter>(iter: Iter) 
472    -> MultiUnion<Iter, ConstTrue>
473where
474    Iter: Iterator<Item: Ref<Type: RegularHibitTree<DefaultData: IsConstTrue>>> + Clone,
475{
476    MultiUnion{ iter, phantom: Default::default() }
477}
478
479#[cfg(test)]
480mod tests{
481    use super::*;
482    use itertools::assert_equal;
483    use crate::hibit_tree::HibitTree;
484    use crate::ReqDefault;
485    use crate::config::_64bit;
486    use crate::utils::LendingIterator;
487    
488    type Array = crate::tree::Tree<usize, _64bit<3>, ReqDefault>;
489
490    #[test]
491    fn multi_union_test(){
492        let mut a1 = Array::default();
493        let mut a2 = Array::default();
494        let mut a3 = Array::default();
495        
496        a1.insert(10, 10);
497        a1.insert(15, 15);
498        a1.insert(200, 200);
499        
500        a2.insert(100, 100);
501        a2.insert(15, 15);
502        a2.insert(200, 200);
503        
504        a3.insert(300, 300);
505        a3.insert(15, 15);
506        
507        let arrays = [a1, a2, a3];
508
509        let union = multi_union( arrays.iter() ); 
510        
511        // iter test
512        let mut v = Vec::new();
513        let mut iter = union.iter();
514        while let Some((index, values)) = iter.next(){
515            let values: Vec<&usize> = values.collect();
516            println!("{:?}", values);
517            v.push(values);
518        }
519        assert_equal(v, vec![
520            vec![arrays[0].get(10).unwrap()],
521            vec![
522                arrays[0].get(15).unwrap(),
523                arrays[1].get(15).unwrap(),
524                arrays[2].get(15).unwrap(),
525            ],
526            vec![arrays[1].get(100).unwrap()],
527            vec![
528                arrays[0].get(200).unwrap(),
529                arrays[1].get(200).unwrap(),
530            ],
531            vec![arrays[2].get(300).unwrap()],
532        ]);
533
534        // get test
535        assert_equal( 
536            union.get(10).unwrap(),
537            vec![arrays[0].get(10).unwrap()]
538        );
539        assert_equal( 
540            union.get(15).unwrap(),
541            vec![arrays[0].get(15).unwrap(), arrays[1].get(15).unwrap(), arrays[2].get(15).unwrap()]
542        );
543        assert!(union.get(25).is_none());
544        
545        // get_unchecked test
546        assert_equal(unsafe{ union.get_unchecked(10) }, union.get(10).unwrap());
547        assert_equal(unsafe{ union.get_unchecked(15) }, union.get(15).unwrap());
548        
549        // get_or_default test
550        assert_equal( 
551            union.get_or_default(10),
552            vec![arrays[0].get(10).unwrap(), &0, &0]
553        );
554    }
555    
556    #[test]
557    fn multi_union_w_default_test(){
558        let mut a1 = Array::default();
559        let mut a2 = Array::default();
560        let mut a3 = Array::default();
561        
562        // All indices must be in the same terminal block for this test.
563        // Otherwise - they will not appear as default values in iterator output.
564        a1.insert(1, 1);
565        a1.insert(15, 15);
566        a1.insert(20, 20);
567        
568        a2.insert(10, 10);
569        a2.insert(15, 15);
570        a2.insert(20, 20);
571        
572        a3.insert(30, 30);
573        a3.insert(15, 15);
574        
575        let arrays = [a1, a2, a3];
576        
577        let union = multi_union_w_default( arrays.iter() );
578        
579        // iter test
580        let mut v = Vec::new();
581        let mut iter = union.iter();
582        while let Some((index, values)) = iter.next(){
583            let values: Vec<&usize> = values.collect();
584            println!("{:?}", values);
585            v.push(values);
586        }
587        assert_equal(v, vec![
588            vec![arrays[0].get(1).unwrap(), &0, &0],
589            vec![&0, arrays[1].get(10).unwrap(), &0],
590            vec![
591                arrays[0].get(15).unwrap(),
592                arrays[1].get(15).unwrap(),
593                arrays[2].get(15).unwrap(),
594            ],
595            vec![
596                arrays[0].get(20).unwrap(),
597                arrays[1].get(20).unwrap(),
598                &0
599            ],
600            vec![&0, &0, arrays[2].get(30).unwrap()],
601        ]);
602    }    
603
604}