1use std::io;
21
22use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
23
24#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
29pub struct BitVector {
30 data: Vec<u64>,
32 len: usize,
34}
35
36impl<'de> Deserialize<'de> for BitVector {
37 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38 where
39 D: Deserializer<'de>,
40 {
41 #[derive(serde::Deserialize)]
42 #[serde(field_identifier, rename_all = "lowercase")]
43 enum Field {
44 Data,
45 Len,
46 }
47
48 struct BitVectorVisitor;
49
50 impl<'de> Visitor<'de> for BitVectorVisitor {
51 type Value = BitVector;
52
53 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 formatter.write_str("struct BitVector with consistent data and len fields")
55 }
56
57 fn visit_seq<V>(self, mut seq: V) -> Result<BitVector, V::Error>
58 where
59 V: SeqAccess<'de>,
60 {
61 let data: Vec<u64> = seq
62 .next_element()?
63 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
64 let len: usize = seq
65 .next_element()?
66 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
67 validate_bitvec(len, &data).map_err(de::Error::custom)
68 }
69
70 fn visit_map<V>(self, mut map: V) -> Result<BitVector, V::Error>
71 where
72 V: MapAccess<'de>,
73 {
74 let mut data: Option<Vec<u64>> = None;
75 let mut len: Option<usize> = None;
76
77 while let Some(key) = map.next_key()? {
78 match key {
79 Field::Data => {
80 if data.is_some() {
81 return Err(de::Error::duplicate_field("data"));
82 }
83 data = Some(map.next_value()?);
84 }
85 Field::Len => {
86 if len.is_some() {
87 return Err(de::Error::duplicate_field("len"));
88 }
89 len = Some(map.next_value()?);
90 }
91 }
92 }
93
94 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
95 let len = len.ok_or_else(|| de::Error::missing_field("len"))?;
96 validate_bitvec(len, &data).map_err(de::Error::custom)
97 }
98 }
99
100 const FIELDS: &[&str] = &["data", "len"];
101 deserializer.deserialize_struct("BitVector", FIELDS, BitVectorVisitor)
102 }
103}
104
105fn validate_bitvec(len: usize, data: &[u64]) -> Result<BitVector, String> {
108 let expected_words = (len + 63) / 64;
109 if data.len() != expected_words {
110 return Err(format!(
111 "BitVector invariant violated: len={len} requires {expected_words} words, but data contains {} words",
112 data.len()
113 ));
114 }
115 Ok(BitVector {
116 data: data.to_vec(),
117 len,
118 })
119}
120
121impl BitVector {
122 #[must_use]
124 pub fn new() -> Self {
125 Self {
126 data: Vec::new(),
127 len: 0,
128 }
129 }
130
131 #[must_use]
133 pub fn with_capacity(bits: usize) -> Self {
134 let words = (bits + 63) / 64;
135 Self {
136 data: Vec::with_capacity(words),
137 len: 0,
138 }
139 }
140
141 #[must_use]
143 pub fn from_bools(bools: &[bool]) -> Self {
144 let num_words = (bools.len() + 63) / 64;
145 let mut data = vec![0u64; num_words];
146
147 for (i, &b) in bools.iter().enumerate() {
148 if b {
149 let word_idx = i / 64;
150 let bit_idx = i % 64;
151 data[word_idx] |= 1 << bit_idx;
152 }
153 }
154
155 Self {
156 data,
157 len: bools.len(),
158 }
159 }
160
161 #[must_use]
163 pub fn filled(len: usize, value: bool) -> Self {
164 let num_words = (len + 63) / 64;
165 let fill = if value { u64::MAX } else { 0 };
166 let data = vec![fill; num_words];
167
168 Self { data, len }
169 }
170
171 #[must_use]
173 pub fn zeros(len: usize) -> Self {
174 Self::filled(len, false)
175 }
176
177 #[must_use]
179 pub fn ones(len: usize) -> Self {
180 Self::filled(len, true)
181 }
182
183 #[must_use]
185 pub fn len(&self) -> usize {
186 self.len
187 }
188
189 #[must_use]
191 pub fn is_empty(&self) -> bool {
192 self.len == 0
193 }
194
195 #[must_use]
197 pub fn get(&self, index: usize) -> Option<bool> {
198 if index >= self.len {
199 return None;
200 }
201
202 let word_idx = index / 64;
203 let bit_idx = index % 64;
204 Some((self.data[word_idx] & (1 << bit_idx)) != 0)
205 }
206
207 pub fn set(&mut self, index: usize, value: bool) {
213 assert!(index < self.len, "Index out of bounds");
214
215 let word_idx = index / 64;
216 let bit_idx = index % 64;
217
218 if value {
219 self.data[word_idx] |= 1 << bit_idx;
220 } else {
221 self.data[word_idx] &= !(1 << bit_idx);
222 }
223 }
224
225 pub fn push(&mut self, value: bool) {
227 let word_idx = self.len / 64;
228 let bit_idx = self.len % 64;
229
230 if word_idx >= self.data.len() {
231 self.data.push(0);
232 }
233
234 if value {
235 self.data[word_idx] |= 1 << bit_idx;
236 }
237
238 self.len += 1;
239 }
240
241 #[must_use]
243 pub fn count_ones(&self) -> usize {
244 if self.is_empty() {
245 return 0;
246 }
247
248 let full_words = self.len / 64;
249 let remaining_bits = self.len % 64;
250
251 let mut count: usize = self.data[..full_words]
252 .iter()
253 .map(|&w| w.count_ones() as usize)
254 .sum();
255
256 if remaining_bits > 0 && full_words < self.data.len() {
257 let mask = (1u64 << remaining_bits) - 1;
258 count += (self.data[full_words] & mask).count_ones() as usize;
259 }
260
261 count
262 }
263
264 #[must_use]
266 pub fn count_zeros(&self) -> usize {
267 self.len - self.count_ones()
268 }
269
270 #[must_use]
276 pub fn to_bools(&self) -> Vec<bool> {
277 (0..self.len)
278 .map(|i| self.get(i).expect("index within len"))
279 .collect()
280 }
281
282 pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
288 (0..self.len).map(move |i| self.get(i).expect("index within len"))
289 }
290
291 pub fn ones_iter(&self) -> impl Iterator<Item = usize> + '_ {
297 (0..self.len).filter(move |&i| self.get(i).expect("index within len"))
298 }
299
300 pub fn zeros_iter(&self) -> impl Iterator<Item = usize> + '_ {
306 (0..self.len).filter(move |&i| !self.get(i).expect("index within len"))
307 }
308
309 #[must_use]
311 pub fn data(&self) -> &[u64] {
312 &self.data
313 }
314
315 #[must_use]
317 pub fn compression_ratio(&self) -> f64 {
318 if self.is_empty() {
319 return 1.0;
320 }
321
322 let original_size = self.len;
324 let compressed_size = self.data.len() * 8;
326
327 if compressed_size == 0 {
328 return 1.0;
329 }
330
331 original_size as f64 / compressed_size as f64
332 }
333
334 #[must_use]
338 pub fn and(&self, other: &Self) -> Self {
339 let len = self.len.min(other.len);
340 let num_words = (len + 63) / 64;
341
342 let data: Vec<u64> = self
343 .data
344 .iter()
345 .zip(&other.data)
346 .take(num_words)
347 .map(|(&a, &b)| a & b)
348 .collect();
349
350 Self { data, len }
351 }
352
353 #[must_use]
357 pub fn or(&self, other: &Self) -> Self {
358 let len = self.len.min(other.len);
359 let num_words = (len + 63) / 64;
360
361 let data: Vec<u64> = self
362 .data
363 .iter()
364 .zip(&other.data)
365 .take(num_words)
366 .map(|(&a, &b)| a | b)
367 .collect();
368
369 Self { data, len }
370 }
371
372 #[must_use]
374 pub fn not(&self) -> Self {
375 let data: Vec<u64> = self.data.iter().map(|&w| !w).collect();
376 Self {
377 data,
378 len: self.len,
379 }
380 }
381
382 #[must_use]
384 pub fn xor(&self, other: &Self) -> Self {
385 let len = self.len.min(other.len);
386 let num_words = (len + 63) / 64;
387
388 let data: Vec<u64> = self
389 .data
390 .iter()
391 .zip(&other.data)
392 .take(num_words)
393 .map(|(&a, &b)| a ^ b)
394 .collect();
395
396 Self { data, len }
397 }
398
399 pub fn to_bytes(&self) -> Vec<u8> {
401 let mut buf = Vec::with_capacity(4 + self.data.len() * 8);
402 buf.extend_from_slice(&(self.len as u32).to_le_bytes());
403 for &word in &self.data {
404 buf.extend_from_slice(&word.to_le_bytes());
405 }
406 buf
407 }
408
409 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
415 if bytes.len() < 4 {
416 return Err(io::Error::new(
417 io::ErrorKind::InvalidData,
418 "BitVector too short",
419 ));
420 }
421
422 let len = u32::from_le_bytes(
423 bytes[0..4]
424 .try_into()
425 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
426 ) as usize;
427 let num_words = (len + 63) / 64;
428
429 if bytes.len() < 4 + num_words * 8 {
430 return Err(io::Error::new(
431 io::ErrorKind::InvalidData,
432 "BitVector truncated",
433 ));
434 }
435
436 let mut data = Vec::with_capacity(num_words);
437 for i in 0..num_words {
438 let offset = 4 + i * 8;
439 let word = u64::from_le_bytes(
440 bytes[offset..offset + 8]
441 .try_into()
442 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
443 );
444 data.push(word);
445 }
446
447 Ok(Self { data, len })
448 }
449}
450
451impl Default for BitVector {
452 fn default() -> Self {
453 Self::new()
454 }
455}
456
457impl FromIterator<bool> for BitVector {
458 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
459 let mut bitvec = BitVector::new();
460 for b in iter {
461 bitvec.push(b);
462 }
463 bitvec
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_bitvec_basic() {
473 let bools = vec![true, false, true, true, false, false, true, false];
474 let bitvec = BitVector::from_bools(&bools);
475
476 assert_eq!(bitvec.len(), 8);
477 for (i, &expected) in bools.iter().enumerate() {
478 assert_eq!(bitvec.get(i), Some(expected));
479 }
480 }
481
482 #[test]
483 fn test_bitvec_empty() {
484 let bitvec = BitVector::new();
485 assert!(bitvec.is_empty());
486 assert_eq!(bitvec.get(0), None);
487 }
488
489 #[test]
490 fn test_bitvec_push() {
491 let mut bitvec = BitVector::new();
492 bitvec.push(true);
493 bitvec.push(false);
494 bitvec.push(true);
495
496 assert_eq!(bitvec.len(), 3);
497 assert_eq!(bitvec.get(0), Some(true));
498 assert_eq!(bitvec.get(1), Some(false));
499 assert_eq!(bitvec.get(2), Some(true));
500 }
501
502 #[test]
503 fn test_bitvec_set() {
504 let mut bitvec = BitVector::zeros(8);
505
506 bitvec.set(0, true);
507 bitvec.set(3, true);
508 bitvec.set(7, true);
509
510 assert_eq!(bitvec.get(0), Some(true));
511 assert_eq!(bitvec.get(1), Some(false));
512 assert_eq!(bitvec.get(3), Some(true));
513 assert_eq!(bitvec.get(7), Some(true));
514 }
515
516 #[test]
517 fn test_bitvec_count() {
518 let bools = vec![true, false, true, true, false, false, true, false];
519 let bitvec = BitVector::from_bools(&bools);
520
521 assert_eq!(bitvec.count_ones(), 4);
522 assert_eq!(bitvec.count_zeros(), 4);
523 }
524
525 #[test]
526 fn test_bitvec_filled() {
527 let zeros = BitVector::zeros(100);
528 assert_eq!(zeros.count_ones(), 0);
529 assert_eq!(zeros.count_zeros(), 100);
530
531 let ones = BitVector::ones(100);
532 assert_eq!(ones.count_ones(), 100);
533 assert_eq!(ones.count_zeros(), 0);
534 }
535
536 #[test]
537 fn test_bitvec_to_bools() {
538 let original = vec![true, false, true, true, false];
539 let bitvec = BitVector::from_bools(&original);
540 let restored = bitvec.to_bools();
541 assert_eq!(original, restored);
542 }
543
544 #[test]
545 fn test_bitvec_large() {
546 let bools: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
548 let bitvec = BitVector::from_bools(&bools);
549
550 assert_eq!(bitvec.len(), 200);
551 for (i, &expected) in bools.iter().enumerate() {
552 assert_eq!(bitvec.get(i), Some(expected), "Mismatch at index {}", i);
553 }
554 }
555
556 #[test]
557 fn test_bitvec_and() {
558 let a = BitVector::from_bools(&[true, true, false, false]);
559 let b = BitVector::from_bools(&[true, false, true, false]);
560 let result = a.and(&b);
561
562 assert_eq!(result.to_bools(), vec![true, false, false, false]);
563 }
564
565 #[test]
566 fn test_bitvec_or() {
567 let a = BitVector::from_bools(&[true, true, false, false]);
568 let b = BitVector::from_bools(&[true, false, true, false]);
569 let result = a.or(&b);
570
571 assert_eq!(result.to_bools(), vec![true, true, true, false]);
572 }
573
574 #[test]
575 fn test_bitvec_not() {
576 let a = BitVector::from_bools(&[true, false, true, false]);
577 let result = a.not();
578
579 assert_eq!(result.get(0), Some(false));
581 assert_eq!(result.get(1), Some(true));
582 assert_eq!(result.get(2), Some(false));
583 assert_eq!(result.get(3), Some(true));
584 }
585
586 #[test]
587 fn test_bitvec_xor() {
588 let a = BitVector::from_bools(&[true, true, false, false]);
589 let b = BitVector::from_bools(&[true, false, true, false]);
590 let result = a.xor(&b);
591
592 assert_eq!(result.to_bools(), vec![false, true, true, false]);
593 }
594
595 #[test]
596 fn test_bitvec_serialization() {
597 let bools = vec![true, false, true, true, false, false, true, false];
598 let bitvec = BitVector::from_bools(&bools);
599 let bytes = bitvec.to_bytes();
600 let restored = BitVector::from_bytes(&bytes).unwrap();
601 assert_eq!(bitvec, restored);
602 }
603
604 #[test]
605 fn test_bitvec_compression_ratio() {
606 let bitvec = BitVector::zeros(64);
607 let ratio = bitvec.compression_ratio();
608 assert!((ratio - 8.0).abs() < 0.1);
610 }
611
612 #[test]
613 fn test_bitvec_ones_iter() {
614 let bools = vec![true, false, true, true, false];
615 let bitvec = BitVector::from_bools(&bools);
616 let ones: Vec<usize> = bitvec.ones_iter().collect();
617 assert_eq!(ones, vec![0, 2, 3]);
618 }
619
620 #[test]
621 fn test_bitvec_zeros_iter() {
622 let bools = vec![true, false, true, true, false];
623 let bitvec = BitVector::from_bools(&bools);
624 let zeros: Vec<usize> = bitvec.zeros_iter().collect();
625 assert_eq!(zeros, vec![1, 4]);
626 }
627
628 #[test]
629 fn test_bitvec_from_iter() {
630 let bitvec: BitVector = vec![true, false, true].into_iter().collect();
631 assert_eq!(bitvec.len(), 3);
632 assert_eq!(bitvec.get(0), Some(true));
633 assert_eq!(bitvec.get(1), Some(false));
634 assert_eq!(bitvec.get(2), Some(true));
635 }
636
637 #[test]
638 fn test_bitvec_deserialize_roundtrip() {
639 let bools = vec![true, false, true, true, false, false, true, false];
640 let original = BitVector::from_bools(&bools);
641 let json = serde_json::to_string(&original).unwrap();
642 let restored: BitVector = serde_json::from_str(&json).unwrap();
643 assert_eq!(original, restored);
644 }
645
646 #[test]
647 fn test_bitvec_deserialize_invalid_len_too_large() {
648 let json = r#"{"data":[42],"len":200}"#;
650 let result: Result<BitVector, _> = serde_json::from_str(json);
651 assert!(result.is_err());
652 let err_msg = result.unwrap_err().to_string();
653 assert!(
654 err_msg.contains("invariant violated"),
655 "expected invariant error, got: {err_msg}"
656 );
657 }
658
659 #[test]
660 fn test_bitvec_deserialize_invalid_len_data_mismatch() {
661 let json = r#"{"data":[1,2,3],"len":10}"#;
663 let result: Result<BitVector, _> = serde_json::from_str(json);
664 assert!(result.is_err());
665 let err_msg = result.unwrap_err().to_string();
666 assert!(
667 err_msg.contains("invariant violated"),
668 "expected invariant error, got: {err_msg}"
669 );
670 }
671
672 #[test]
673 fn test_bitvec_deserialize_valid_edge_cases() {
674 let json = r#"{"data":[],"len":0}"#;
676 let bv: BitVector = serde_json::from_str(json).unwrap();
677 assert_eq!(bv.len(), 0);
678 assert!(bv.is_empty());
679
680 let json = r#"{"data":[1],"len":1}"#;
682 let bv: BitVector = serde_json::from_str(json).unwrap();
683 assert_eq!(bv.len(), 1);
684 assert_eq!(bv.get(0), Some(true));
685
686 let json = r#"{"data":[18446744073709551615],"len":64}"#;
688 let bv: BitVector = serde_json::from_str(json).unwrap();
689 assert_eq!(bv.len(), 64);
690 assert_eq!(bv.count_ones(), 64);
691 }
692}