algorithm/map/
roaring_bitmap.rs

1use std::collections::HashMap;
2use std::fmt::{Debug, Display};
3use std::ops::{BitAnd, BitOr, BitXor};
4
5use crate::BitMap;
6
7
8const TAIL_BIT: usize = 16;
9const TAIL_NUM: usize = 0x10000;
10
11#[derive(Clone)]
12enum TailContainer {
13    Array(Vec<u16>),
14    Bitmap(BitMap),
15}
16
17impl TailContainer {
18    fn new() -> Self {
19        Self::Array(vec![])
20    }
21
22    fn try_move(&mut self) {
23        let val = match self {
24            TailContainer::Array(v) if v.len() >= 4096 => {
25                v.drain(..).collect::<Vec<_>>()
26            }
27            _ => {
28                return;
29            }
30
31        };
32        let mut map = BitMap::new(TAIL_NUM);
33        for v in val {
34            map.add(v as usize);
35        }
36        *self = TailContainer::Bitmap(map);
37    }
38
39    pub fn add(&mut self, val: u16) -> bool {
40        self.try_move();
41        match self {
42            TailContainer::Array(vec) => {
43                if let Err(s) = vec.binary_search(&val) {
44                    vec.insert(s, val);
45                    true
46                } else {
47                    false
48                }
49            },
50            TailContainer::Bitmap(hash) => hash.add(val as usize)
51        }
52    }
53
54    pub fn remove(&mut self, val: u16) -> bool {
55        match self {
56            TailContainer::Array(vec) => {
57                if let Ok(s) = vec.binary_search(&val) {
58                    vec.remove(s);
59                    true
60                } else {
61                    false
62                }
63            },
64            TailContainer::Bitmap(hash) => hash.remove(val as usize)
65        }
66    }
67
68    pub fn next(&self, val: u16) -> Option<u16> {
69        match self {
70            TailContainer::Array(vec) => {
71                match vec.binary_search(&val) {
72                    Ok(s) => { return Some(vec[s]) },
73                    Err(s) => {
74                        if s == vec.len() {
75                            return None;
76                        }
77                        return Some(vec[s]);
78                    }
79                }
80            },
81            TailContainer::Bitmap(hash) => {
82                for i in val..=65535u16 {
83                    if hash.contains(&(i as usize)) {
84                        return Some(i);
85                    }
86                }
87                return None;
88            }
89        }
90    }
91
92
93    pub fn next_back(&self, val: u16) -> Option<u16> {
94        match self {
95            TailContainer::Array(vec) => {
96                match vec.binary_search(&val) {
97                    Ok(s) => { return Some(vec[s]) },
98                    Err(s) => {
99                        if s == 0 {
100                            return None;
101                        }
102                        return Some(vec[s - 1]);
103                    }
104                }
105            },
106            TailContainer::Bitmap(hash) => {
107                for i in (0..=val).rev() {
108                    if hash.contains(&(i as usize)) {
109                        return Some(i);
110                    }
111                }
112                return None;
113            }
114        }
115    }
116
117    pub fn contains(&self, val: u16) -> bool {
118        match self {
119            TailContainer::Array(vec) => {
120                if let Ok(_) = vec.binary_search(&val) {
121                    true
122                } else {
123                    false
124                }
125            },
126            TailContainer::Bitmap(hash) => hash.contains(&(val as usize))
127        }
128    }
129}
130
131/// 位图类RoaringBitMap,根据访问的位看是否被占用
132/// 本质上是将大块的bitmap分成各个小块,其中每个小块在需要存储数据的时候才会存在
133/// 解决经典的是否被占用的问题,不会一次性分配大内存
134/// 头部以val / 65536做为索引键值, 尾部分为Array及HashSet结构
135/// 当元素个数小于4096时以有序array做为索引, 当>4096以HashSet做为存储
136///
137/// # Examples
138///
139/// ```
140/// use algorithm::RoaringBitMap;
141/// fn main() {
142///     let mut map = RoaringBitMap::new();
143///     map.add_many(&vec![1, 2, 3, 4, 10]);
144///     assert!(map.contains(&1));
145///     assert!(!map.contains(&5));
146///     assert!(map.contains(&10));
147///     map.add_range(7, 16);
148///     assert!(!map.contains(&6));
149///     assert!(map.contains(&7));
150///     assert!(map.contains(&16));
151///     assert!(!map.contains(&17));
152/// }
153/// ```
154pub struct RoaringBitMap {
155    map: HashMap<usize, TailContainer>,
156    len: usize,
157    max_key: usize,
158    min_key: usize,
159}
160
161impl RoaringBitMap {
162    pub fn new() -> Self {
163        Self {
164            map: HashMap::new(),
165            len: 0,
166            max_key: 0,
167            min_key: 0,
168        }
169    }
170
171    pub fn len(&self) -> usize {
172        self.len
173    }
174
175    pub fn is_empty(&self) -> bool { self.len == 0 }
176
177    pub fn clear(&mut self) {
178        self.map.clear();
179        self.len = 0;
180    }
181
182    /// 添加新的元素
183    /// # Examples
184    ///
185    /// ```
186    /// use algorithm::RoaringBitMap;
187    /// fn main() {
188    ///     let mut map = RoaringBitMap::new();
189    ///     map.add(1);
190    ///     assert!(map.contains(&1));
191    ///     assert!(map.len() == 1);
192    /// }
193    /// ```
194    pub fn add(&mut self, val: usize) {
195        let head = val >> 16;
196        let tail = (val % TAIL_NUM) as u16;
197        if self.map.entry(head).or_insert(TailContainer::new()).add(tail) {
198            self.len += 1;
199            self.min_key = self.min_key.min(val);
200            self.max_key = self.max_key.max(val);
201        }
202    }
203
204    /// 添加许多新的元素
205    /// # Examples
206    ///
207    /// ```
208    /// use algorithm::RoaringBitMap;
209    /// fn main() {
210    ///     let mut map = RoaringBitMap::new();
211    ///     map.add_many(&vec![1, 2, 3, 4, 10]);
212    ///     assert!(map.contains(&1));
213    ///     assert!(map.contains(&10));
214    ///     assert!(map.len() == 5);
215    /// }
216    /// ```
217    pub fn add_many(&mut self, val: &[usize]) {
218        for v in val {
219            self.add(*v);
220        }
221    }
222
223    /// 添加范围内的元素(包含头与结果),批量添加增加效率
224    /// # Examples
225    ///
226    /// ```
227    /// use algorithm::RoaringBitMap;
228    /// fn main() {
229    ///     let mut map = RoaringBitMap::new();
230    ///     map.add_range(7, 16);
231    ///     assert!(!map.contains(&6));
232    ///     assert!(map.contains(&7));
233    ///     assert!(map.contains(&16));
234    ///     assert!(!map.contains(&17));
235    ///     assert!(map.len() == 10);
236    /// }
237    /// ```
238    pub fn add_range(&mut self, start: usize, end: usize) {
239        for i in start..=end {
240            self.add(i);
241        }
242    }
243
244    /// 删除元素
245    /// # Examples
246    ///
247    /// ```
248    /// use algorithm::RoaringBitMap;
249    /// fn main() {
250    ///     let mut map = RoaringBitMap::new();
251    ///     map.add_range(7, 16);
252    ///     assert!(map.len() == 10);
253    ///     assert!(map.contains(&7));
254    ///     assert!(map.remove(7));
255    ///     assert!(!map.contains(&7));
256    ///     assert!(map.len() == 9);
257    /// }
258    /// ```
259    pub fn remove(&mut self, val: usize) -> bool {
260        let head = val >> 16;
261        let tail = (val % TAIL_NUM) as u16;
262        if let Some(map) = self.map.get_mut(&head) {
263            if map.remove(tail) {
264                self.len -= 1;
265                return true;
266            }
267        }
268        false
269    }
270
271    /// 删除列表中元素
272    /// # Examples
273    ///
274    /// ```
275    /// use algorithm::RoaringBitMap;
276    /// fn main() {
277    ///     let mut map = RoaringBitMap::new();
278    ///     map.add_range(7, 16);
279    ///     assert!(map.len() == 10);
280    ///     assert!(map.contains(&7));
281    ///     assert!(map.remove(7));
282    ///     assert!(!map.contains(&7));
283    ///     assert!(map.len() == 9);
284    /// }
285    /// ```
286    pub fn remove_many(&mut self, val: &[usize]) {
287        for v in val {
288            self.remove(*v);
289        }
290    }
291
292    /// 删除范围元素(包含头与尾)
293    /// # Examples
294    ///
295    /// ```
296    /// use algorithm::RoaringBitMap;
297    /// fn main() {
298    ///     let mut map = RoaringBitMap::new();
299    ///     map.add_range(7, 16);
300    ///     assert!(map.len() == 10);
301    ///     map.remove_range(7, 15);
302    ///     assert!(map.len() == 1);
303    ///     assert!(map.contains(&16));
304    /// }
305    /// ```
306    pub fn remove_range(&mut self, start: usize, end: usize) {
307        for i in start..=end {
308            self.remove(i);
309        }
310    }
311
312    /// 醒看是否包含
313    /// # Examples
314    ///
315    /// ```
316    /// use algorithm::RoaringBitMap;
317    /// fn main() {
318    ///     let mut map = RoaringBitMap::new();
319    ///     map.add(7);
320    ///     assert!(map.contains(&7));
321    /// }
322    /// ```
323    pub fn contains(&self, val: &usize) -> bool {
324        let head = val >> 16;
325        let tail = (val % TAIL_NUM) as u16;
326        if let Some(map) = self.map.get(&head) {
327            map.contains(tail)
328        } else {
329            false
330        }
331    }
332
333    /// 迭代器,通过遍历进行循环,如果位图的容量非常大,可能效率相当低
334    /// # Examples
335    ///
336    /// ```
337    /// use algorithm::RoaringBitMap;
338    /// fn main() {
339    ///     let mut map = RoaringBitMap::new();
340    ///     map.add(7);
341    ///     map.add_range(9, 12);
342    ///     map.add_many(&vec![20, 100, 300]);
343    ///     assert!(map.iter().collect::<Vec<_>>() == vec![7, 9, 10, 11, 12, 20, 100, 300]);
344    /// }
345    /// ```
346    pub fn iter(&self) -> Iter<'_> {
347        Iter {
348            base: self,
349            len: self.len,
350            min_val: self.min_key,
351            max_val: self.max_key,
352        }
353    }
354
355
356    /// 是否保留,通过遍历进行循环,如果位图的容量非常大,可能效率相当低
357    /// # Examples
358    ///
359    /// ```
360    /// use algorithm::RoaringBitMap;
361    /// fn main() {
362    ///     let mut map = RoaringBitMap::new();
363    ///     map.add_range(9, 16);
364    ///     map.retain(|v| v % 2 == 0);
365    ///     assert!(map.iter().collect::<Vec<_>>() == vec![10, 12, 14, 16]);
366    /// }
367    /// ```
368    pub fn retain<F>(&mut self, mut f: F)
369        where
370            F: FnMut(&usize) -> bool,
371    {
372        let mut oper = self.len;
373        for i in self.min_key..=self.max_key {
374            if oper == 0 {
375                break;
376            }
377            if self.contains(&i) {
378                oper -= 1;
379                if !f(&i) {
380                    self.remove(i);
381                }
382            }
383        }
384    }
385
386    /// 是否为子位图
387    /// # Examples
388    ///
389    /// ```
390    /// use algorithm::RoaringBitMap;
391    /// fn main() {
392    ///     let mut map = RoaringBitMap::new();
393    ///     map.add_range(9, 16);
394    ///     let mut sub_map = RoaringBitMap::new();
395    ///     sub_map.add_range(9, 12);
396    ///     assert!(map.contains_sub(&sub_map));
397    /// }
398    /// ```
399    pub fn contains_sub(&self, other: &RoaringBitMap) -> bool {
400        other.iter().all(|k| self.contains(&k))
401    }
402
403    /// 取两个位图间的交集
404    /// # Examples
405    ///
406    /// ```
407    /// use algorithm::RoaringBitMap;
408    /// fn main() {
409    ///     let mut map = RoaringBitMap::new();
410    ///     map.add_range(9, 16);
411    ///     let mut sub_map = RoaringBitMap::new();
412    ///     sub_map.add_range(7, 12);
413    ///     let map = map.intersect(&sub_map);
414    ///     assert!(map.iter().collect::<Vec<_>>() == vec![9, 10, 11, 12]);
415    /// }
416    /// ```
417    pub fn intersect(&self, other: &RoaringBitMap) -> RoaringBitMap {
418        let mut map = RoaringBitMap::new();
419        let mut from = self.min_key.max(other.min_key);
420        let end = self.max_key.min(other.max_key);
421        while from <= end {
422            let head_from = from >> TAIL_BIT;
423            let next_from = (head_from + 1) * TAIL_NUM;
424            if self.map.contains_key(&head_from) && other.map.contains_key(&head_from) {
425                for i in from..next_from.min(end+1) {
426                    if self.contains(&i) && other.contains(&i) {
427                        map.add(i);
428                    }
429                }
430            }
431            from = next_from;
432        }
433        map
434    }
435
436    /// 取两个位图间的并集
437    /// # Examples
438    ///
439    /// ```
440    /// use algorithm::RoaringBitMap;
441    /// fn main() {
442    ///     let mut map = RoaringBitMap::new();
443    ///     map.add_range(9, 16);
444    ///     let mut sub_map = RoaringBitMap::new();
445    ///     sub_map.add_range(7, 12);
446    ///     let map = map.union(&sub_map);
447    ///     assert!(map.iter().collect::<Vec<_>>() == vec![7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
448    /// }
449    /// ```
450    pub fn union(&self, other: &RoaringBitMap) -> RoaringBitMap {
451        let mut map = RoaringBitMap::new();
452        let mut from = self.min_key.min(other.min_key);
453        let end = self.max_key.max(other.max_key);
454        while from <= end {
455            let head_from = from >> TAIL_BIT;
456            let next_from = (head_from + 1) * TAIL_NUM;
457            if self.map.contains_key(&head_from) || other.map.contains_key(&head_from) {
458                for i in from..next_from.min(end+1) {
459                    if self.contains(&i) || other.contains(&i) {
460                        map.add(i);
461                    }
462                }
463            }
464            from = next_from;
465        }
466        map
467    }
468
469    /// 取两个位图间的异或并集
470    /// # Examples
471    ///
472    /// ```
473    /// use algorithm::RoaringBitMap;
474    /// fn main() {
475    ///     let mut map = RoaringBitMap::new();
476    ///     map.add_range(9, 16);
477    ///     let mut sub_map = RoaringBitMap::new();
478    ///     sub_map.add_range(7, 12);
479    ///     let map = map.union_xor(&sub_map);
480    ///     assert!(map.iter().collect::<Vec<_>>() == vec![7, 8, 13, 14, 15, 16]);
481    /// }
482    /// ```
483    pub fn union_xor(&self, other: &RoaringBitMap) -> RoaringBitMap {
484        let mut map = RoaringBitMap::new();
485        let mut from = self.min_key.min(other.min_key);
486        let end = self.max_key.max(other.max_key);
487        while from <= end {
488            let head_from = from >> TAIL_BIT;
489            let next_from = (head_from + 1) * TAIL_NUM;
490            if self.map.contains_key(&head_from) || other.map.contains_key(&head_from) {
491                for i in from..next_from.min(end+1) {
492                    if self.contains(&i) ^ other.contains(&i) {
493                        map.add(i);
494                    }
495                }
496            }
497            from = next_from;
498        }
499        map
500    }
501}
502
503impl BitAnd for &RoaringBitMap {
504    type Output=RoaringBitMap;
505    fn bitand(self, rhs: Self) -> Self::Output {
506        self.intersect(rhs)
507    }
508}
509
510impl BitOr for &RoaringBitMap {
511    type Output=RoaringBitMap;
512    fn bitor(self, rhs: Self) -> Self::Output {
513        self.union(rhs)
514    }
515}
516
517impl BitXor for &RoaringBitMap {
518    type Output=RoaringBitMap;
519
520    fn bitxor(self, rhs: Self) -> Self::Output {
521        self.union_xor(rhs)
522    }
523}
524
525impl Clone for RoaringBitMap {
526    fn clone(&self) -> Self {
527        Self {
528            map: self.map.clone(),
529            len: self.len,
530            max_key: self.max_key,
531            min_key: self.min_key,
532        }
533    }
534}
535
536
537impl FromIterator<usize> for RoaringBitMap {
538    fn from_iter<T: IntoIterator<Item=usize>>(iter: T) -> RoaringBitMap {
539        let vec = iter.into_iter().collect::<Vec<_>>();
540        let mut cap = 1024;
541        for v in &vec {
542            cap = cap.max(*v);
543        }
544        let mut map = RoaringBitMap::new();
545        map.extend(vec);
546        map
547    }
548}
549
550impl PartialEq for RoaringBitMap {
551    fn eq(&self, other: &Self) -> bool {
552        if self.len() != other.len() {
553            return false;
554        }
555        self.iter().all(|k| other.contains(&k))
556    }
557}
558
559impl Eq for RoaringBitMap {}
560
561impl Extend<usize> for RoaringBitMap {
562    fn extend<T: IntoIterator<Item=usize>>(&mut self, iter: T) {
563        let iter = iter.into_iter();
564        for v in iter {
565            self.add(v);
566        }
567    }
568}
569
570pub struct Iter<'a> {
571    base: &'a RoaringBitMap,
572    len: usize,
573    min_val: usize,
574    max_val: usize,
575}
576
577impl<'a> Iterator for Iter<'a> {
578    type Item = usize;
579
580    fn next(&mut self) -> Option<Self::Item> {
581        if self.len == 0 {
582            return None;
583        }
584
585        while self.min_val <= self.base.max_key {
586            let head = self.min_val >> 16;
587            if !self.base.map.contains_key(&head) {
588                self.min_val = (head + 1) * TAIL_NUM;
589                continue;
590            }
591            let tail = (self.min_val % TAIL_NUM) as u16;
592            let container = self.base.map.get(&head).expect("ok");
593            if let Some(i) = container.next(tail) {
594                self.min_val = head * TAIL_NUM + i as usize + 1;
595                self.len -= 1;
596                return Some(head * TAIL_NUM + i as usize);
597            } else {
598                self.min_val = (head + 1) * TAIL_NUM;
599                continue;
600            }
601        }
602        unreachable!()
603    }
604
605    fn size_hint(&self) -> (usize, Option<usize>) {
606        (self.len, Some(self.len))
607    }
608}
609
610impl<'a> DoubleEndedIterator for Iter<'a> {
611    fn next_back(&mut self) -> Option<Self::Item> {
612        if self.len == 0 {
613            return None;
614        }
615
616        loop {
617            let head = self.max_val >> 16;
618            if !self.base.map.contains_key(&head) {
619                self.max_val = (head * TAIL_NUM).saturating_sub(1);
620                continue;
621            }
622            let tail = (self.max_val % TAIL_NUM) as u16;
623            let container = self.base.map.get(&head).expect("ok");
624            if let Some(i) = container.next_back(tail) {
625                self.max_val = (head * TAIL_NUM + i as usize).saturating_sub(1);
626                self.len -= 1;
627                return Some(head * TAIL_NUM + i as usize);
628            } else {
629                self.max_val = (head * TAIL_NUM).saturating_sub(1);
630                continue;
631            }
632        }
633    }
634}
635
636impl Display for RoaringBitMap {
637    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638        f.write_fmt(format_args!("len:{}-val:{{", self.len))?;
639        let mut iter = self.iter();
640        if let Some(v) = iter.next() {
641            f.write_str(&v.to_string())?;
642        }
643        let mut sum = 1;
644        while let Some(v) = iter.next() {
645            f.write_fmt(format_args!(",{}", v))?;
646            sum += 1;
647            if sum > 0x100000 {
648                break;
649            }
650        }
651        f.write_str("}")
652    }
653}
654
655impl Debug for RoaringBitMap {
656    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
657        f.write_str(&format!("{}", self))
658    }
659}
660
661
662
663#[cfg(test)]
664mod tests {
665
666    use super::RoaringBitMap;
667
668    #[test]
669    fn test_display() {
670        let mut m = RoaringBitMap::new();
671        m.add_many(&vec![1, 3, 9, 10240000111]);
672        assert_eq!(format!("{}", m), "len:4-val:{1,3,9,10240000111}".to_string());
673    }
674
675    #[test]
676    fn test_nextback() {
677        let mut m = RoaringBitMap::new();
678        m.add_many(&vec![1, 3, 9, 10240000111]);
679        let vec = m.iter().rev().collect::<Vec<_>>();
680        assert_eq!(vec, vec![10240000111, 9, 3, 1]);
681    }
682}