1use crate::{Decode, Encode, EncodedSize, Error, Result, VarInt};
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct BitSet {
14 data: Vec<i64>,
15}
16
17impl BitSet {
18 pub fn new() -> Self {
20 Self { data: Vec::new() }
21 }
22
23 pub fn from_longs(data: Vec<i64>) -> Self {
29 Self { data }
30 }
31
32 pub fn len(&self) -> usize {
35 self.data.len() * 64
36 }
37
38 pub fn is_empty(&self) -> bool {
40 self.data.is_empty()
41 }
42
43 pub fn get(&self, index: usize) -> bool {
48 let word = index / 64;
49 let bit = index % 64;
50 if word >= self.data.len() {
51 return false;
52 }
53 (self.data[word] >> bit) & 1 != 0
54 }
55
56 pub fn set(&mut self, index: usize, value: bool) {
63 let word = index / 64;
64 let bit = index % 64;
65
66 if value {
67 if word >= self.data.len() {
68 self.data.resize(word + 1, 0);
69 }
70 self.data[word] |= 1i64 << bit;
71 } else if word < self.data.len() {
72 self.data[word] &= !(1i64 << bit);
73 }
74 }
75}
76
77impl Default for BitSet {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl Encode for BitSet {
89 fn encode(&self, buf: &mut Vec<u8>) -> Result<()> {
91 VarInt(self.data.len() as i32).encode(buf)?;
92 for &word in &self.data {
93 word.encode(buf)?;
94 }
95 Ok(())
96 }
97}
98
99impl Decode for BitSet {
105 fn decode(buf: &mut &[u8]) -> Result<Self> {
110 let len = VarInt::decode(buf)?.0;
111 if len < 0 {
112 return Err(Error::InvalidData(format!("negative BitSet length: {len}")));
113 }
114 let len = len as usize;
115 let mut data = Vec::with_capacity(len);
116 for _ in 0..len {
117 data.push(i64::decode(buf)?);
118 }
119 Ok(Self { data })
120 }
121}
122
123impl EncodedSize for BitSet {
128 fn encoded_size(&self) -> usize {
130 VarInt(self.data.len() as i32).encoded_size() + self.data.len() * 8
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 fn roundtrip(bs: &BitSet) {
139 let mut buf = Vec::with_capacity(bs.encoded_size());
140 bs.encode(&mut buf).unwrap();
141 assert_eq!(buf.len(), bs.encoded_size());
142
143 let mut cursor = buf.as_slice();
144 let decoded = BitSet::decode(&mut cursor).unwrap();
145 assert!(cursor.is_empty());
146 assert_eq!(decoded, *bs);
147 }
148
149 #[test]
152 fn new_is_empty() {
153 let bs = BitSet::new();
154 assert!(bs.is_empty());
155 assert_eq!(bs.len(), 0);
156 }
157
158 #[test]
159 fn default_is_empty() {
160 let bs = BitSet::default();
161 assert!(bs.is_empty());
162 }
163
164 #[test]
165 fn from_longs() {
166 let bs = BitSet::from_longs(vec![0xFF, 0x00]);
167 assert_eq!(bs.len(), 128);
168 assert!(!bs.is_empty());
169 }
170
171 #[test]
174 fn get_out_of_range() {
175 let bs = BitSet::new();
176 assert!(!bs.get(0));
177 assert!(!bs.get(1000));
178 }
179
180 #[test]
181 fn set_and_get() {
182 let mut bs = BitSet::new();
183 bs.set(0, true);
184 assert!(bs.get(0));
185 assert!(!bs.get(1));
186 }
187
188 #[test]
189 fn set_high_bit() {
190 let mut bs = BitSet::new();
191 bs.set(200, true);
192 assert!(bs.get(200));
193 assert!(!bs.get(199));
194 assert!(!bs.get(201));
195 assert_eq!(bs.len(), 256);
197 }
198
199 #[test]
200 fn clear_bit() {
201 let mut bs = BitSet::new();
202 bs.set(5, true);
203 assert!(bs.get(5));
204 bs.set(5, false);
205 assert!(!bs.get(5));
206 }
207
208 #[test]
209 fn clear_out_of_range_is_noop() {
210 let mut bs = BitSet::new();
211 bs.set(1000, false);
212 assert!(bs.is_empty());
213 }
214
215 #[test]
216 fn word_boundary() {
217 let mut bs = BitSet::new();
218 bs.set(63, true);
219 bs.set(64, true);
220 assert!(bs.get(63));
221 assert!(bs.get(64));
222 assert!(!bs.get(62));
223 assert!(!bs.get(65));
224 }
225
226 #[test]
229 fn roundtrip_empty() {
230 roundtrip(&BitSet::new());
231 }
232
233 #[test]
234 fn roundtrip_single_word() {
235 let mut bs = BitSet::new();
236 bs.set(0, true);
237 bs.set(7, true);
238 bs.set(63, true);
239 roundtrip(&bs);
240 }
241
242 #[test]
243 fn roundtrip_multiple_words() {
244 let mut bs = BitSet::new();
245 bs.set(0, true);
246 bs.set(64, true);
247 bs.set(128, true);
248 roundtrip(&bs);
249 }
250
251 #[test]
252 fn roundtrip_from_longs() {
253 let bs = BitSet::from_longs(vec![0x0102030405060708, -1]);
254 roundtrip(&bs);
255 }
256
257 #[test]
258 fn empty_encodes_as_varint_zero() {
259 let bs = BitSet::new();
260 let mut buf = Vec::new();
261 bs.encode(&mut buf).unwrap();
262 assert_eq!(buf, [0x00]);
263 }
264
265 #[test]
266 fn encoded_size_empty() {
267 assert_eq!(BitSet::new().encoded_size(), 1);
268 }
269
270 #[test]
271 fn encoded_size_one_word() {
272 let bs = BitSet::from_longs(vec![1]);
273 assert_eq!(bs.encoded_size(), 9);
275 }
276
277 #[test]
278 fn negative_length_decode() {
279 let mut buf = Vec::new();
280 VarInt(-1).encode(&mut buf).unwrap();
281 let mut cursor = buf.as_slice();
282 assert!(matches!(
283 BitSet::decode(&mut cursor),
284 Err(Error::InvalidData(_))
285 ));
286 }
287
288 #[test]
289 fn truncated_buffer() {
290 let mut buf = Vec::new();
291 VarInt(2).encode(&mut buf).unwrap();
292 buf.extend_from_slice(&[0u8; 8]);
294 let mut cursor = buf.as_slice();
295 assert!(matches!(
296 BitSet::decode(&mut cursor),
297 Err(Error::BufferUnderflow { .. })
298 ));
299 }
300
301 mod proptests {
302 use super::*;
303 use proptest::prelude::*;
304
305 proptest! {
306 #[test]
307 fn bitset_roundtrip(data in proptest::collection::vec(any::<i64>(), 0..10)) {
308 let bs = BitSet::from_longs(data);
309 roundtrip(&bs);
310 }
311 }
312 }
313}