aoc_companion/
bitset.rs

1use std::{fmt::Display, mem};
2
3#[derive(Debug, Clone, Hash, PartialEq, Eq)]
4pub struct BitSet {
5    data: Vec<usize>,
6    len: usize,
7}
8
9impl BitSet {
10    const BUCKET_SIZE: usize = mem::size_of::<usize>() * 8; // in bits
11
12    #[must_use]
13    pub fn new() -> Self {
14        Self {
15            data: Vec::new(),
16            len: 0,
17        }
18    }
19
20    #[must_use]
21    pub fn with_capacity(cap: usize) -> Self {
22        Self {
23            data: Vec::with_capacity(cap / Self::BUCKET_SIZE),
24            len: 0,
25        }
26    }
27
28    #[must_use]
29    pub fn all_false(cap: usize) -> Self {
30        Self {
31            data: vec![0; cap / Self::BUCKET_SIZE + 1],
32            len: cap,
33        }
34    }
35
36    #[must_use]
37    pub fn all_true(cap: usize) -> Self {
38        Self {
39            data: vec![usize::MAX; cap / Self::BUCKET_SIZE],
40            len: cap,
41        }
42    }
43
44    pub fn insert(&mut self, value: bool) {
45        if self.len % Self::BUCKET_SIZE == 0 {
46            self.data.push(0);
47        }
48
49        if value {
50            self.data[self.len / Self::BUCKET_SIZE] |= 1 << (self.len % Self::BUCKET_SIZE);
51        }
52
53        self.len += 1;
54    }
55
56    pub fn toggle(&mut self, index: usize) {
57        self.data[index / Self::BUCKET_SIZE] ^= 1 << (index % Self::BUCKET_SIZE);
58    }
59
60    #[must_use]
61    pub fn get(&self, index: usize) -> bool {
62        self.data[index / Self::BUCKET_SIZE] & (1 << (index % Self::BUCKET_SIZE)) != 0
63    }
64
65    #[must_use]
66    pub fn len(&self) -> usize {
67        self.len
68    }
69
70    #[must_use]
71    pub fn is_empty(&self) -> bool {
72        self.len == 0
73    }
74
75    #[must_use]
76    pub fn iter(&self) -> BitSetIter {
77        BitSetIter {
78            bitset: self,
79            idx: 0,
80        }
81    }
82
83    #[must_use]
84    pub fn iter_true(&self) -> BitSetTrueIter {
85        BitSetTrueIter {
86            bitset: self,
87            idx: 0,
88        }
89    }
90}
91
92impl Default for BitSet {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98pub struct BitSetIter<'a> {
99    bitset: &'a BitSet,
100    idx: usize,
101}
102
103impl<'a> Iterator for BitSetIter<'a> {
104    type Item = bool;
105
106    fn next(&mut self) -> Option<Self::Item> {
107        if self.idx >= self.bitset.len {
108            return None;
109        }
110
111        let result = self.bitset.get(self.idx);
112        self.idx += 1;
113        Some(result)
114    }
115}
116
117impl Display for BitSet {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        for i in 0..self.len {
120            if self.get(i) {
121                write!(f, "1")?;
122            } else {
123                write!(f, "0")?;
124            }
125        }
126
127        Ok(())
128    }
129}
130
131pub struct BitSetTrueIter<'a> {
132    bitset: &'a BitSet,
133    idx: usize,
134}
135
136impl<'a> Iterator for BitSetTrueIter<'a> {
137    type Item = usize;
138
139    fn next(&mut self) -> Option<Self::Item> {
140        while self.idx < self.bitset.len {
141            if self.bitset.get(self.idx) {
142                let result = self.idx;
143                self.idx += 1;
144                return Some(result);
145            }
146
147            self.idx += 1;
148        }
149
150        None
151    }
152}
153
154#[cfg(test)]
155mod test {
156    use super::*;
157
158    #[test]
159    fn test_bitset() {
160        let mut bs = BitSet::all_false(10);
161        assert!(bs.iter().all(|b| !b));
162        for i in 0..100 {
163            assert!(!bs.get(i));
164            bs.toggle(i);
165        }
166
167        for i in 0..100 {
168            assert!(bs.get(i));
169            bs.toggle(i);
170        }
171
172        assert!(bs.iter().all(|b| b));
173    }
174
175    #[test]
176    fn test_bitset_true() {
177        let mut bs = BitSet::all_false(10);
178        assert!(bs.iter_true().next().is_none());
179        for i in 0..100 {
180            assert!(!bs.get(i));
181            bs.toggle(i);
182        }
183
184        for i in 0..100 {
185            assert!(bs.get(i));
186            bs.toggle(i);
187        }
188
189        assert!(bs.iter_true().next().is_none());
190    }
191}