competitive_programming_rs/data_structure/
bitset.rs

1pub mod bitset {
2    const ONE_VALUE_LENGTH: usize = 63;
3    const MAXIMUM: u64 = (1u64 << ONE_VALUE_LENGTH as u64) - 1;
4
5    pub fn get_bit_position(index: usize) -> (usize, usize) {
6        let data_index = index / ONE_VALUE_LENGTH;
7        let bit_index = index % ONE_VALUE_LENGTH;
8        (data_index, bit_index)
9    }
10
11    #[derive(PartialEq, Clone, Debug)]
12    pub struct BitSet {
13        data: Vec<u64>,
14    }
15
16    impl std::ops::BitOrAssign for BitSet {
17        fn bitor_assign(&mut self, rhs: Self) {
18            if self.data.len() < rhs.data.len() {
19                self.data.resize(rhs.data.len(), 0);
20            }
21            let n = if self.data.len() > rhs.data.len() {
22                rhs.data.len()
23            } else {
24                self.data.len()
25            };
26            for i in 0..n {
27                assert!(self.data[i] <= MAXIMUM);
28                assert!(rhs.data[i] <= MAXIMUM);
29                self.data[i] |= rhs.data[i];
30            }
31        }
32    }
33
34    impl std::ops::Shl<usize> for BitSet {
35        type Output = Self;
36        fn shl(self, rhs: usize) -> Self {
37            self.shift_left(rhs)
38        }
39    }
40
41    impl BitSet {
42        pub fn new(n: usize) -> Self {
43            let size = (n + ONE_VALUE_LENGTH - 1) / ONE_VALUE_LENGTH;
44            BitSet {
45                data: vec![0; size],
46            }
47        }
48
49        pub fn new_from(value: u64) -> Self {
50            BitSet { data: vec![value] }
51        }
52
53        pub fn set(&mut self, index: usize, value: bool) {
54            let (data_index, bit_index) = get_bit_position(index);
55            assert!(self.data.len() > data_index);
56            if value {
57                self.data[data_index] |= 1u64 << bit_index;
58            } else {
59                let tmp = MAXIMUM ^ 1 << (bit_index as u64);
60                self.data[data_index] &= tmp;
61            }
62        }
63
64        pub fn get(&mut self, index: usize) -> bool {
65            let (data_index, bit_index) = get_bit_position(index);
66            assert!(self.data.len() > data_index);
67            self.data[data_index] & (1u64 << bit_index as u64) != 0
68        }
69
70        pub fn shift_left(&self, shift: usize) -> Self {
71            let mut next_data = Vec::new();
72            let prefix_empty_count = shift / ONE_VALUE_LENGTH;
73            let shift_count = (shift % ONE_VALUE_LENGTH) as u64;
74            for _ in 0..prefix_empty_count {
75                next_data.push(0);
76            }
77
78            let mut from_previous = 0;
79            let room = ONE_VALUE_LENGTH as u64 - shift_count;
80            for &data in self.data.iter() {
81                let overflow = (data >> room) << room;
82                let rest = data - overflow;
83                let value = (rest << shift_count) + from_previous;
84                assert!(value <= MAXIMUM);
85                next_data.push(value);
86                from_previous = overflow >> room;
87            }
88            if from_previous > 0 {
89                next_data.push(from_previous);
90            }
91            BitSet { data: next_data }
92        }
93    }
94}
95
96#[cfg(test)]
97mod test {
98    use super::bitset::*;
99
100    #[test]
101    fn test_set_bit() {
102        let n = 10;
103        let value = 717;
104        let mut bitset = BitSet::new(n);
105        for i in 0..n {
106            if value & (1 << i) != 0 {
107                bitset.set(i, true);
108            }
109        }
110
111        for i in 0..n {
112            if value & (1 << i) != 0 {
113                assert!(bitset.get(i));
114            } else {
115                assert!(!bitset.get(i));
116            }
117        }
118    }
119
120    #[test]
121    fn test_bitset_or() {
122        let mut value1 = 717;
123        let mut bitset1 = BitSet::new_from(value1);
124
125        let value2 = 127;
126        let bitset2 = BitSet::new_from(value2);
127
128        value1 |= value2;
129        bitset1 |= bitset2;
130
131        for i in 0..50usize {
132            if value1 & (1u64 << i as u64) != 0 {
133                assert!(bitset1.get(i));
134            } else {
135                assert!(!bitset1.get(i));
136            }
137        }
138    }
139
140    #[test]
141    fn test_bitset_shift_left() {
142        let value1 = 717;
143        let first_shift = 30;
144        let second_shift = 40;
145        let bitset1 = BitSet::new_from(value1) << (first_shift + second_shift);
146        let bitset2 = BitSet::new_from(value1 << first_shift as u64) << second_shift;
147        assert!(bitset1 == bitset2);
148    }
149}