1use std::io;
21
22use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
23
24#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
29pub struct BitVector {
30 data: Vec<u64>,
32 len: usize,
34}
35
36impl<'de> Deserialize<'de> for BitVector {
37 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38 where
39 D: Deserializer<'de>,
40 {
41 #[derive(serde::Deserialize)]
42 #[serde(field_identifier, rename_all = "lowercase")]
43 enum Field {
44 Data,
45 Len,
46 }
47
48 struct BitVectorVisitor;
49
50 impl<'de> Visitor<'de> for BitVectorVisitor {
51 type Value = BitVector;
52
53 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 formatter.write_str("struct BitVector with consistent data and len fields")
55 }
56
57 fn visit_seq<V>(self, mut seq: V) -> Result<BitVector, V::Error>
58 where
59 V: SeqAccess<'de>,
60 {
61 let data: Vec<u64> = seq
62 .next_element()?
63 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
64 let len: usize = seq
65 .next_element()?
66 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
67 validate_bitvec(len, &data).map_err(de::Error::custom)
68 }
69
70 fn visit_map<V>(self, mut map: V) -> Result<BitVector, V::Error>
71 where
72 V: MapAccess<'de>,
73 {
74 let mut data: Option<Vec<u64>> = None;
75 let mut len: Option<usize> = None;
76
77 while let Some(key) = map.next_key()? {
78 match key {
79 Field::Data => {
80 if data.is_some() {
81 return Err(de::Error::duplicate_field("data"));
82 }
83 data = Some(map.next_value()?);
84 }
85 Field::Len => {
86 if len.is_some() {
87 return Err(de::Error::duplicate_field("len"));
88 }
89 len = Some(map.next_value()?);
90 }
91 }
92 }
93
94 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
95 let len = len.ok_or_else(|| de::Error::missing_field("len"))?;
96 validate_bitvec(len, &data).map_err(de::Error::custom)
97 }
98 }
99
100 const FIELDS: &[&str] = &["data", "len"];
101 deserializer.deserialize_struct("BitVector", FIELDS, BitVectorVisitor)
102 }
103}
104
105fn validate_bitvec(len: usize, data: &[u64]) -> Result<BitVector, String> {
108 let expected_words = (len + 63) / 64;
109 if data.len() != expected_words {
110 return Err(format!(
111 "BitVector invariant violated: len={len} requires {expected_words} words, but data contains {} words",
112 data.len()
113 ));
114 }
115 Ok(BitVector {
116 data: data.to_vec(),
117 len,
118 })
119}
120
121impl BitVector {
122 #[must_use]
126 pub fn from_raw_parts(data: Vec<u64>, len: usize) -> Self {
127 Self { data, len }
128 }
129
130 #[must_use]
132 pub fn new() -> Self {
133 Self {
134 data: Vec::new(),
135 len: 0,
136 }
137 }
138
139 #[must_use]
141 pub fn with_capacity(bits: usize) -> Self {
142 let words = (bits + 63) / 64;
143 Self {
144 data: Vec::with_capacity(words),
145 len: 0,
146 }
147 }
148
149 #[must_use]
151 pub fn from_bools(bools: &[bool]) -> Self {
152 let num_words = (bools.len() + 63) / 64;
153 let mut data = vec![0u64; num_words];
154
155 for (i, &b) in bools.iter().enumerate() {
156 if b {
157 let word_idx = i / 64;
158 let bit_idx = i % 64;
159 data[word_idx] |= 1 << bit_idx;
160 }
161 }
162
163 Self {
164 data,
165 len: bools.len(),
166 }
167 }
168
169 #[must_use]
171 pub fn filled(len: usize, value: bool) -> Self {
172 let num_words = (len + 63) / 64;
173 let fill = if value { u64::MAX } else { 0 };
174 let data = vec![fill; num_words];
175
176 Self { data, len }
177 }
178
179 #[must_use]
181 pub fn zeros(len: usize) -> Self {
182 Self::filled(len, false)
183 }
184
185 #[must_use]
187 pub fn ones(len: usize) -> Self {
188 Self::filled(len, true)
189 }
190
191 #[must_use]
193 pub fn len(&self) -> usize {
194 self.len
195 }
196
197 #[must_use]
199 pub fn is_empty(&self) -> bool {
200 self.len == 0
201 }
202
203 #[must_use]
205 pub fn get(&self, index: usize) -> Option<bool> {
206 if index >= self.len {
207 return None;
208 }
209
210 let word_idx = index / 64;
211 let bit_idx = index % 64;
212 Some((self.data[word_idx] & (1 << bit_idx)) != 0)
213 }
214
215 pub fn set(&mut self, index: usize, value: bool) {
221 assert!(index < self.len, "Index out of bounds");
222
223 let word_idx = index / 64;
224 let bit_idx = index % 64;
225
226 if value {
227 self.data[word_idx] |= 1 << bit_idx;
228 } else {
229 self.data[word_idx] &= !(1 << bit_idx);
230 }
231 }
232
233 pub fn push(&mut self, value: bool) {
235 let word_idx = self.len / 64;
236 let bit_idx = self.len % 64;
237
238 if word_idx >= self.data.len() {
239 self.data.push(0);
240 }
241
242 if value {
243 self.data[word_idx] |= 1 << bit_idx;
244 }
245
246 self.len += 1;
247 }
248
249 #[must_use]
251 pub fn count_ones(&self) -> usize {
252 if self.is_empty() {
253 return 0;
254 }
255
256 let full_words = self.len / 64;
257 let remaining_bits = self.len % 64;
258
259 let mut count: usize = self.data[..full_words]
260 .iter()
261 .map(|&w| w.count_ones() as usize)
262 .sum();
263
264 if remaining_bits > 0 && full_words < self.data.len() {
265 let mask = (1u64 << remaining_bits) - 1;
266 count += (self.data[full_words] & mask).count_ones() as usize;
267 }
268
269 count
270 }
271
272 #[must_use]
274 pub fn count_zeros(&self) -> usize {
275 self.len - self.count_ones()
276 }
277
278 #[must_use]
284 pub fn to_bools(&self) -> Vec<bool> {
285 (0..self.len)
286 .map(|i| self.get(i).expect("index within len"))
287 .collect()
288 }
289
290 pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
296 (0..self.len).map(move |i| self.get(i).expect("index within len"))
297 }
298
299 pub fn ones_iter(&self) -> impl Iterator<Item = usize> + '_ {
305 (0..self.len).filter(move |&i| self.get(i).expect("index within len"))
306 }
307
308 pub fn zeros_iter(&self) -> impl Iterator<Item = usize> + '_ {
314 (0..self.len).filter(move |&i| !self.get(i).expect("index within len"))
315 }
316
317 #[must_use]
319 pub fn data(&self) -> &[u64] {
320 &self.data
321 }
322
323 #[must_use]
325 pub fn compression_ratio(&self) -> f64 {
326 if self.is_empty() {
327 return 1.0;
328 }
329
330 let original_size = self.len;
332 let compressed_size = self.data.len() * 8;
334
335 if compressed_size == 0 {
336 return 1.0;
337 }
338
339 original_size as f64 / compressed_size as f64
340 }
341
342 #[must_use]
346 pub fn and(&self, other: &Self) -> Self {
347 let len = self.len.min(other.len);
348 let num_words = (len + 63) / 64;
349
350 let data: Vec<u64> = self
351 .data
352 .iter()
353 .zip(&other.data)
354 .take(num_words)
355 .map(|(&a, &b)| a & b)
356 .collect();
357
358 Self { data, len }
359 }
360
361 #[must_use]
365 pub fn or(&self, other: &Self) -> Self {
366 let len = self.len.min(other.len);
367 let num_words = (len + 63) / 64;
368
369 let data: Vec<u64> = self
370 .data
371 .iter()
372 .zip(&other.data)
373 .take(num_words)
374 .map(|(&a, &b)| a | b)
375 .collect();
376
377 Self { data, len }
378 }
379
380 #[must_use]
382 pub fn not(&self) -> Self {
383 let data: Vec<u64> = self.data.iter().map(|&w| !w).collect();
384 Self {
385 data,
386 len: self.len,
387 }
388 }
389
390 #[must_use]
392 pub fn xor(&self, other: &Self) -> Self {
393 let len = self.len.min(other.len);
394 let num_words = (len + 63) / 64;
395
396 let data: Vec<u64> = self
397 .data
398 .iter()
399 .zip(&other.data)
400 .take(num_words)
401 .map(|(&a, &b)| a ^ b)
402 .collect();
403
404 Self { data, len }
405 }
406
407 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
413 let len_u32 = u32::try_from(self.len).map_err(|_| {
414 io::Error::new(
415 io::ErrorKind::InvalidInput,
416 format!(
417 "BitVector length {} exceeds u32::MAX, cannot serialize",
418 self.len
419 ),
420 )
421 })?;
422 let mut buf = Vec::with_capacity(4 + self.data.len() * 8);
423 buf.extend_from_slice(&len_u32.to_le_bytes());
424 for &word in &self.data {
425 buf.extend_from_slice(&word.to_le_bytes());
426 }
427 Ok(buf)
428 }
429
430 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
436 if bytes.len() < 4 {
437 return Err(io::Error::new(
438 io::ErrorKind::InvalidData,
439 "BitVector too short",
440 ));
441 }
442
443 let len = u32::from_le_bytes(
444 bytes[0..4]
445 .try_into()
446 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
447 ) as usize;
448 let num_words = (len + 63) / 64;
449
450 if bytes.len() < 4 + num_words * 8 {
451 return Err(io::Error::new(
452 io::ErrorKind::InvalidData,
453 "BitVector truncated",
454 ));
455 }
456
457 let mut data = Vec::with_capacity(num_words);
458 for i in 0..num_words {
459 let offset = 4 + i * 8;
460 let word = u64::from_le_bytes(
461 bytes[offset..offset + 8]
462 .try_into()
463 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
464 );
465 data.push(word);
466 }
467
468 Ok(Self { data, len })
469 }
470}
471
472impl Default for BitVector {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478impl FromIterator<bool> for BitVector {
479 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
480 let mut bitvec = BitVector::new();
481 for b in iter {
482 bitvec.push(b);
483 }
484 bitvec
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_bitvec_basic() {
494 let bools = vec![true, false, true, true, false, false, true, false];
495 let bitvec = BitVector::from_bools(&bools);
496
497 assert_eq!(bitvec.len(), 8);
498 for (i, &expected) in bools.iter().enumerate() {
499 assert_eq!(bitvec.get(i), Some(expected));
500 }
501 }
502
503 #[test]
504 fn test_bitvec_empty() {
505 let bitvec = BitVector::new();
506 assert!(bitvec.is_empty());
507 assert_eq!(bitvec.get(0), None);
508 }
509
510 #[test]
511 fn test_bitvec_push() {
512 let mut bitvec = BitVector::new();
513 bitvec.push(true);
514 bitvec.push(false);
515 bitvec.push(true);
516
517 assert_eq!(bitvec.len(), 3);
518 assert_eq!(bitvec.get(0), Some(true));
519 assert_eq!(bitvec.get(1), Some(false));
520 assert_eq!(bitvec.get(2), Some(true));
521 }
522
523 #[test]
524 fn test_bitvec_set() {
525 let mut bitvec = BitVector::zeros(8);
526
527 bitvec.set(0, true);
528 bitvec.set(3, true);
529 bitvec.set(7, true);
530
531 assert_eq!(bitvec.get(0), Some(true));
532 assert_eq!(bitvec.get(1), Some(false));
533 assert_eq!(bitvec.get(3), Some(true));
534 assert_eq!(bitvec.get(7), Some(true));
535 }
536
537 #[test]
538 fn test_bitvec_count() {
539 let bools = vec![true, false, true, true, false, false, true, false];
540 let bitvec = BitVector::from_bools(&bools);
541
542 assert_eq!(bitvec.count_ones(), 4);
543 assert_eq!(bitvec.count_zeros(), 4);
544 }
545
546 #[test]
547 fn test_bitvec_filled() {
548 let zeros = BitVector::zeros(100);
549 assert_eq!(zeros.count_ones(), 0);
550 assert_eq!(zeros.count_zeros(), 100);
551
552 let ones = BitVector::ones(100);
553 assert_eq!(ones.count_ones(), 100);
554 assert_eq!(ones.count_zeros(), 0);
555 }
556
557 #[test]
558 fn test_bitvec_to_bools() {
559 let original = vec![true, false, true, true, false];
560 let bitvec = BitVector::from_bools(&original);
561 let restored = bitvec.to_bools();
562 assert_eq!(original, restored);
563 }
564
565 #[test]
566 fn test_bitvec_large() {
567 let bools: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
569 let bitvec = BitVector::from_bools(&bools);
570
571 assert_eq!(bitvec.len(), 200);
572 for (i, &expected) in bools.iter().enumerate() {
573 assert_eq!(bitvec.get(i), Some(expected), "Mismatch at index {}", i);
574 }
575 }
576
577 #[test]
578 fn test_bitvec_and() {
579 let a = BitVector::from_bools(&[true, true, false, false]);
580 let b = BitVector::from_bools(&[true, false, true, false]);
581 let result = a.and(&b);
582
583 assert_eq!(result.to_bools(), vec![true, false, false, false]);
584 }
585
586 #[test]
587 fn test_bitvec_or() {
588 let a = BitVector::from_bools(&[true, true, false, false]);
589 let b = BitVector::from_bools(&[true, false, true, false]);
590 let result = a.or(&b);
591
592 assert_eq!(result.to_bools(), vec![true, true, true, false]);
593 }
594
595 #[test]
596 fn test_bitvec_not() {
597 let a = BitVector::from_bools(&[true, false, true, false]);
598 let result = a.not();
599
600 assert_eq!(result.get(0), Some(false));
602 assert_eq!(result.get(1), Some(true));
603 assert_eq!(result.get(2), Some(false));
604 assert_eq!(result.get(3), Some(true));
605 }
606
607 #[test]
608 fn test_bitvec_xor() {
609 let a = BitVector::from_bools(&[true, true, false, false]);
610 let b = BitVector::from_bools(&[true, false, true, false]);
611 let result = a.xor(&b);
612
613 assert_eq!(result.to_bools(), vec![false, true, true, false]);
614 }
615
616 #[test]
617 fn test_bitvec_serialization() {
618 let bools = vec![true, false, true, true, false, false, true, false];
619 let bitvec = BitVector::from_bools(&bools);
620 let bytes = bitvec.to_bytes().unwrap();
621 let restored = BitVector::from_bytes(&bytes).unwrap();
622 assert_eq!(bitvec, restored);
623 }
624
625 #[test]
626 fn test_bitvec_compression_ratio() {
627 let bitvec = BitVector::zeros(64);
628 let ratio = bitvec.compression_ratio();
629 assert!((ratio - 8.0).abs() < 0.1);
631 }
632
633 #[test]
634 fn test_bitvec_ones_iter() {
635 let bools = vec![true, false, true, true, false];
636 let bitvec = BitVector::from_bools(&bools);
637 let ones: Vec<usize> = bitvec.ones_iter().collect();
638 assert_eq!(ones, vec![0, 2, 3]);
639 }
640
641 #[test]
642 fn test_bitvec_zeros_iter() {
643 let bools = vec![true, false, true, true, false];
644 let bitvec = BitVector::from_bools(&bools);
645 let zeros: Vec<usize> = bitvec.zeros_iter().collect();
646 assert_eq!(zeros, vec![1, 4]);
647 }
648
649 #[test]
650 fn test_bitvec_from_iter() {
651 let bitvec: BitVector = vec![true, false, true].into_iter().collect();
652 assert_eq!(bitvec.len(), 3);
653 assert_eq!(bitvec.get(0), Some(true));
654 assert_eq!(bitvec.get(1), Some(false));
655 assert_eq!(bitvec.get(2), Some(true));
656 }
657
658 #[test]
659 fn test_bitvec_deserialize_roundtrip() {
660 let bools = vec![true, false, true, true, false, false, true, false];
661 let original = BitVector::from_bools(&bools);
662 let json = serde_json::to_string(&original).unwrap();
663 let restored: BitVector = serde_json::from_str(&json).unwrap();
664 assert_eq!(original, restored);
665 }
666
667 #[test]
668 fn test_bitvec_deserialize_invalid_len_too_large() {
669 let json = r#"{"data":[42],"len":200}"#;
671 let result: Result<BitVector, _> = serde_json::from_str(json);
672 assert!(result.is_err());
673 let err_msg = result.unwrap_err().to_string();
674 assert!(
675 err_msg.contains("invariant violated"),
676 "expected invariant error, got: {err_msg}"
677 );
678 }
679
680 #[test]
681 fn test_bitvec_deserialize_invalid_len_data_mismatch() {
682 let json = r#"{"data":[1,2,3],"len":10}"#;
684 let result: Result<BitVector, _> = serde_json::from_str(json);
685 assert!(result.is_err());
686 let err_msg = result.unwrap_err().to_string();
687 assert!(
688 err_msg.contains("invariant violated"),
689 "expected invariant error, got: {err_msg}"
690 );
691 }
692
693 #[test]
694 fn test_bitvec_deserialize_valid_edge_cases() {
695 let json = r#"{"data":[],"len":0}"#;
697 let bv: BitVector = serde_json::from_str(json).unwrap();
698 assert_eq!(bv.len(), 0);
699 assert!(bv.is_empty());
700
701 let json = r#"{"data":[1],"len":1}"#;
703 let bv: BitVector = serde_json::from_str(json).unwrap();
704 assert_eq!(bv.len(), 1);
705 assert_eq!(bv.get(0), Some(true));
706
707 let json = r#"{"data":[18446744073709551615],"len":64}"#;
709 let bv: BitVector = serde_json::from_str(json).unwrap();
710 assert_eq!(bv.len(), 64);
711 assert_eq!(bv.count_ones(), 64);
712 }
713}