Skip to main content

grafeo_core/codec/
bitvec.rs

1//! Stores booleans as individual bits - 8x smaller than `Vec<bool>`.
2//!
3//! Use this when you're tracking lots of boolean flags (like "visited" markers
4//! in graph traversals, or null bitmaps). Backed by `Vec<u64>` so bitwise
5//! operations like AND/OR/XOR stay cache-friendly.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use grafeo_core::codec::bitvec::BitVector;
11//! let bools = vec![true, false, true, true, false, false, true, false];
12//! let bitvec = BitVector::from_bools(&bools);
13//! // Stored as: 0b01001101 (1 byte instead of 8)
14//!
15//! assert_eq!(bitvec.get(0), Some(true));
16//! assert_eq!(bitvec.get(1), Some(false));
17//! assert_eq!(bitvec.count_ones(), 4);
18//! ```
19
20use std::io;
21
22use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
23
24/// Stores booleans as individual bits - 8x smaller than `Vec<bool>`.
25///
26/// Supports bitwise operations ([`and`](Self::and), [`or`](Self::or),
27/// [`not`](Self::not)) for combining filter results efficiently.
28#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
29pub struct BitVector {
30    /// Packed bits (little-endian within each word).
31    data: Vec<u64>,
32    /// Number of bits stored.
33    len: usize,
34}
35
36impl<'de> Deserialize<'de> for BitVector {
37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38    where
39        D: Deserializer<'de>,
40    {
41        #[derive(serde::Deserialize)]
42        #[serde(field_identifier, rename_all = "lowercase")]
43        enum Field {
44            Data,
45            Len,
46        }
47
48        struct BitVectorVisitor;
49
50        impl<'de> Visitor<'de> for BitVectorVisitor {
51            type Value = BitVector;
52
53            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54                formatter.write_str("struct BitVector with consistent data and len fields")
55            }
56
57            fn visit_seq<V>(self, mut seq: V) -> Result<BitVector, V::Error>
58            where
59                V: SeqAccess<'de>,
60            {
61                let data: Vec<u64> = seq
62                    .next_element()?
63                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
64                let len: usize = seq
65                    .next_element()?
66                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
67                validate_bitvec(len, &data).map_err(de::Error::custom)
68            }
69
70            fn visit_map<V>(self, mut map: V) -> Result<BitVector, V::Error>
71            where
72                V: MapAccess<'de>,
73            {
74                let mut data: Option<Vec<u64>> = None;
75                let mut len: Option<usize> = None;
76
77                while let Some(key) = map.next_key()? {
78                    match key {
79                        Field::Data => {
80                            if data.is_some() {
81                                return Err(de::Error::duplicate_field("data"));
82                            }
83                            data = Some(map.next_value()?);
84                        }
85                        Field::Len => {
86                            if len.is_some() {
87                                return Err(de::Error::duplicate_field("len"));
88                            }
89                            len = Some(map.next_value()?);
90                        }
91                    }
92                }
93
94                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
95                let len = len.ok_or_else(|| de::Error::missing_field("len"))?;
96                validate_bitvec(len, &data).map_err(de::Error::custom)
97            }
98        }
99
100        const FIELDS: &[&str] = &["data", "len"];
101        deserializer.deserialize_struct("BitVector", FIELDS, BitVectorVisitor)
102    }
103}
104
105/// Validates that `len` and `data` are consistent, returning a valid
106/// `BitVector` or an error message.
107fn validate_bitvec(len: usize, data: &[u64]) -> Result<BitVector, String> {
108    let expected_words = (len + 63) / 64;
109    if data.len() != expected_words {
110        return Err(format!(
111            "BitVector invariant violated: len={len} requires {expected_words} words, but data contains {} words",
112            data.len()
113        ));
114    }
115    Ok(BitVector {
116        data: data.to_vec(),
117        len,
118    })
119}
120
121impl BitVector {
122    /// Creates an empty bit vector.
123    #[must_use]
124    pub fn new() -> Self {
125        Self {
126            data: Vec::new(),
127            len: 0,
128        }
129    }
130
131    /// Creates a bit vector with the specified capacity (in bits).
132    #[must_use]
133    pub fn with_capacity(bits: usize) -> Self {
134        let words = (bits + 63) / 64;
135        Self {
136            data: Vec::with_capacity(words),
137            len: 0,
138        }
139    }
140
141    /// Creates a bit vector from a slice of booleans.
142    #[must_use]
143    pub fn from_bools(bools: &[bool]) -> Self {
144        let num_words = (bools.len() + 63) / 64;
145        let mut data = vec![0u64; num_words];
146
147        for (i, &b) in bools.iter().enumerate() {
148            if b {
149                let word_idx = i / 64;
150                let bit_idx = i % 64;
151                data[word_idx] |= 1 << bit_idx;
152            }
153        }
154
155        Self {
156            data,
157            len: bools.len(),
158        }
159    }
160
161    /// Creates a bit vector with all bits set to the same value.
162    #[must_use]
163    pub fn filled(len: usize, value: bool) -> Self {
164        let num_words = (len + 63) / 64;
165        let fill = if value { u64::MAX } else { 0 };
166        let data = vec![fill; num_words];
167
168        Self { data, len }
169    }
170
171    /// Creates a bit vector with all bits set to false (0).
172    #[must_use]
173    pub fn zeros(len: usize) -> Self {
174        Self::filled(len, false)
175    }
176
177    /// Creates a bit vector with all bits set to true (1).
178    #[must_use]
179    pub fn ones(len: usize) -> Self {
180        Self::filled(len, true)
181    }
182
183    /// Returns the number of bits.
184    #[must_use]
185    pub fn len(&self) -> usize {
186        self.len
187    }
188
189    /// Returns whether the bit vector is empty.
190    #[must_use]
191    pub fn is_empty(&self) -> bool {
192        self.len == 0
193    }
194
195    /// Gets the bit at the given index.
196    #[must_use]
197    pub fn get(&self, index: usize) -> Option<bool> {
198        if index >= self.len {
199            return None;
200        }
201
202        let word_idx = index / 64;
203        let bit_idx = index % 64;
204        Some((self.data[word_idx] & (1 << bit_idx)) != 0)
205    }
206
207    /// Sets the bit at the given index.
208    ///
209    /// # Panics
210    ///
211    /// Panics if index >= len.
212    pub fn set(&mut self, index: usize, value: bool) {
213        assert!(index < self.len, "Index out of bounds");
214
215        let word_idx = index / 64;
216        let bit_idx = index % 64;
217
218        if value {
219            self.data[word_idx] |= 1 << bit_idx;
220        } else {
221            self.data[word_idx] &= !(1 << bit_idx);
222        }
223    }
224
225    /// Appends a bit to the end.
226    pub fn push(&mut self, value: bool) {
227        let word_idx = self.len / 64;
228        let bit_idx = self.len % 64;
229
230        if word_idx >= self.data.len() {
231            self.data.push(0);
232        }
233
234        if value {
235            self.data[word_idx] |= 1 << bit_idx;
236        }
237
238        self.len += 1;
239    }
240
241    /// Returns the number of bits set to true.
242    #[must_use]
243    pub fn count_ones(&self) -> usize {
244        if self.is_empty() {
245            return 0;
246        }
247
248        let full_words = self.len / 64;
249        let remaining_bits = self.len % 64;
250
251        let mut count: usize = self.data[..full_words]
252            .iter()
253            .map(|&w| w.count_ones() as usize)
254            .sum();
255
256        if remaining_bits > 0 && full_words < self.data.len() {
257            let mask = (1u64 << remaining_bits) - 1;
258            count += (self.data[full_words] & mask).count_ones() as usize;
259        }
260
261        count
262    }
263
264    /// Returns the number of bits set to false.
265    #[must_use]
266    pub fn count_zeros(&self) -> usize {
267        self.len - self.count_ones()
268    }
269
270    /// Converts back to a `Vec<bool>`.
271    ///
272    /// # Panics
273    ///
274    /// Panics if an internal index is out of bounds (invariant violation).
275    #[must_use]
276    pub fn to_bools(&self) -> Vec<bool> {
277        (0..self.len)
278            .map(|i| self.get(i).expect("index within len"))
279            .collect()
280    }
281
282    /// Returns an iterator over the bits.
283    ///
284    /// # Panics
285    ///
286    /// Panics if an internal index is out of bounds (invariant violation).
287    pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
288        (0..self.len).map(move |i| self.get(i).expect("index within len"))
289    }
290
291    /// Returns an iterator over indices where bits are true.
292    ///
293    /// # Panics
294    ///
295    /// Panics if an internal index is out of bounds (invariant violation).
296    pub fn ones_iter(&self) -> impl Iterator<Item = usize> + '_ {
297        (0..self.len).filter(move |&i| self.get(i).expect("index within len"))
298    }
299
300    /// Returns an iterator over indices where bits are false.
301    ///
302    /// # Panics
303    ///
304    /// Panics if an internal index is out of bounds (invariant violation).
305    pub fn zeros_iter(&self) -> impl Iterator<Item = usize> + '_ {
306        (0..self.len).filter(move |&i| !self.get(i).expect("index within len"))
307    }
308
309    /// Returns the raw data.
310    #[must_use]
311    pub fn data(&self) -> &[u64] {
312        &self.data
313    }
314
315    /// Returns the compression ratio (original bytes / compressed bytes).
316    #[must_use]
317    pub fn compression_ratio(&self) -> f64 {
318        if self.is_empty() {
319            return 1.0;
320        }
321
322        // Original: 1 byte per bool
323        let original_size = self.len;
324        // Compressed: ceil(len / 8) bytes
325        let compressed_size = self.data.len() * 8;
326
327        if compressed_size == 0 {
328            return 1.0;
329        }
330
331        original_size as f64 / compressed_size as f64
332    }
333
334    /// Performs bitwise AND with another bit vector.
335    ///
336    /// The result has the length of the shorter vector.
337    #[must_use]
338    pub fn and(&self, other: &Self) -> Self {
339        let len = self.len.min(other.len);
340        let num_words = (len + 63) / 64;
341
342        let data: Vec<u64> = self
343            .data
344            .iter()
345            .zip(&other.data)
346            .take(num_words)
347            .map(|(&a, &b)| a & b)
348            .collect();
349
350        Self { data, len }
351    }
352
353    /// Performs bitwise OR with another bit vector.
354    ///
355    /// The result has the length of the shorter vector.
356    #[must_use]
357    pub fn or(&self, other: &Self) -> Self {
358        let len = self.len.min(other.len);
359        let num_words = (len + 63) / 64;
360
361        let data: Vec<u64> = self
362            .data
363            .iter()
364            .zip(&other.data)
365            .take(num_words)
366            .map(|(&a, &b)| a | b)
367            .collect();
368
369        Self { data, len }
370    }
371
372    /// Performs bitwise NOT.
373    #[must_use]
374    pub fn not(&self) -> Self {
375        let data: Vec<u64> = self.data.iter().map(|&w| !w).collect();
376        Self {
377            data,
378            len: self.len,
379        }
380    }
381
382    /// Performs bitwise XOR with another bit vector.
383    #[must_use]
384    pub fn xor(&self, other: &Self) -> Self {
385        let len = self.len.min(other.len);
386        let num_words = (len + 63) / 64;
387
388        let data: Vec<u64> = self
389            .data
390            .iter()
391            .zip(&other.data)
392            .take(num_words)
393            .map(|(&a, &b)| a ^ b)
394            .collect();
395
396        Self { data, len }
397    }
398
399    /// Serializes to bytes.
400    pub fn to_bytes(&self) -> Vec<u8> {
401        let mut buf = Vec::with_capacity(4 + self.data.len() * 8);
402        buf.extend_from_slice(&(self.len as u32).to_le_bytes());
403        for &word in &self.data {
404            buf.extend_from_slice(&word.to_le_bytes());
405        }
406        buf
407    }
408
409    /// Deserializes from bytes.
410    ///
411    /// # Errors
412    ///
413    /// Returns `Err` if the byte slice is too short or contains invalid data.
414    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
415        if bytes.len() < 4 {
416            return Err(io::Error::new(
417                io::ErrorKind::InvalidData,
418                "BitVector too short",
419            ));
420        }
421
422        let len = u32::from_le_bytes(
423            bytes[0..4]
424                .try_into()
425                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
426        ) as usize;
427        let num_words = (len + 63) / 64;
428
429        if bytes.len() < 4 + num_words * 8 {
430            return Err(io::Error::new(
431                io::ErrorKind::InvalidData,
432                "BitVector truncated",
433            ));
434        }
435
436        let mut data = Vec::with_capacity(num_words);
437        for i in 0..num_words {
438            let offset = 4 + i * 8;
439            let word = u64::from_le_bytes(
440                bytes[offset..offset + 8]
441                    .try_into()
442                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
443            );
444            data.push(word);
445        }
446
447        Ok(Self { data, len })
448    }
449}
450
451impl Default for BitVector {
452    fn default() -> Self {
453        Self::new()
454    }
455}
456
457impl FromIterator<bool> for BitVector {
458    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
459        let mut bitvec = BitVector::new();
460        for b in iter {
461            bitvec.push(b);
462        }
463        bitvec
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_bitvec_basic() {
473        let bools = vec![true, false, true, true, false, false, true, false];
474        let bitvec = BitVector::from_bools(&bools);
475
476        assert_eq!(bitvec.len(), 8);
477        for (i, &expected) in bools.iter().enumerate() {
478            assert_eq!(bitvec.get(i), Some(expected));
479        }
480    }
481
482    #[test]
483    fn test_bitvec_empty() {
484        let bitvec = BitVector::new();
485        assert!(bitvec.is_empty());
486        assert_eq!(bitvec.get(0), None);
487    }
488
489    #[test]
490    fn test_bitvec_push() {
491        let mut bitvec = BitVector::new();
492        bitvec.push(true);
493        bitvec.push(false);
494        bitvec.push(true);
495
496        assert_eq!(bitvec.len(), 3);
497        assert_eq!(bitvec.get(0), Some(true));
498        assert_eq!(bitvec.get(1), Some(false));
499        assert_eq!(bitvec.get(2), Some(true));
500    }
501
502    #[test]
503    fn test_bitvec_set() {
504        let mut bitvec = BitVector::zeros(8);
505
506        bitvec.set(0, true);
507        bitvec.set(3, true);
508        bitvec.set(7, true);
509
510        assert_eq!(bitvec.get(0), Some(true));
511        assert_eq!(bitvec.get(1), Some(false));
512        assert_eq!(bitvec.get(3), Some(true));
513        assert_eq!(bitvec.get(7), Some(true));
514    }
515
516    #[test]
517    fn test_bitvec_count() {
518        let bools = vec![true, false, true, true, false, false, true, false];
519        let bitvec = BitVector::from_bools(&bools);
520
521        assert_eq!(bitvec.count_ones(), 4);
522        assert_eq!(bitvec.count_zeros(), 4);
523    }
524
525    #[test]
526    fn test_bitvec_filled() {
527        let zeros = BitVector::zeros(100);
528        assert_eq!(zeros.count_ones(), 0);
529        assert_eq!(zeros.count_zeros(), 100);
530
531        let ones = BitVector::ones(100);
532        assert_eq!(ones.count_ones(), 100);
533        assert_eq!(ones.count_zeros(), 0);
534    }
535
536    #[test]
537    fn test_bitvec_to_bools() {
538        let original = vec![true, false, true, true, false];
539        let bitvec = BitVector::from_bools(&original);
540        let restored = bitvec.to_bools();
541        assert_eq!(original, restored);
542    }
543
544    #[test]
545    fn test_bitvec_large() {
546        // Test with more than 64 bits
547        let bools: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
548        let bitvec = BitVector::from_bools(&bools);
549
550        assert_eq!(bitvec.len(), 200);
551        for (i, &expected) in bools.iter().enumerate() {
552            assert_eq!(bitvec.get(i), Some(expected), "Mismatch at index {}", i);
553        }
554    }
555
556    #[test]
557    fn test_bitvec_and() {
558        let a = BitVector::from_bools(&[true, true, false, false]);
559        let b = BitVector::from_bools(&[true, false, true, false]);
560        let result = a.and(&b);
561
562        assert_eq!(result.to_bools(), vec![true, false, false, false]);
563    }
564
565    #[test]
566    fn test_bitvec_or() {
567        let a = BitVector::from_bools(&[true, true, false, false]);
568        let b = BitVector::from_bools(&[true, false, true, false]);
569        let result = a.or(&b);
570
571        assert_eq!(result.to_bools(), vec![true, true, true, false]);
572    }
573
574    #[test]
575    fn test_bitvec_not() {
576        let a = BitVector::from_bools(&[true, false, true, false]);
577        let result = a.not();
578
579        // Note: NOT inverts all bits in the word, so we check the relevant bits
580        assert_eq!(result.get(0), Some(false));
581        assert_eq!(result.get(1), Some(true));
582        assert_eq!(result.get(2), Some(false));
583        assert_eq!(result.get(3), Some(true));
584    }
585
586    #[test]
587    fn test_bitvec_xor() {
588        let a = BitVector::from_bools(&[true, true, false, false]);
589        let b = BitVector::from_bools(&[true, false, true, false]);
590        let result = a.xor(&b);
591
592        assert_eq!(result.to_bools(), vec![false, true, true, false]);
593    }
594
595    #[test]
596    fn test_bitvec_serialization() {
597        let bools = vec![true, false, true, true, false, false, true, false];
598        let bitvec = BitVector::from_bools(&bools);
599        let bytes = bitvec.to_bytes();
600        let restored = BitVector::from_bytes(&bytes).unwrap();
601        assert_eq!(bitvec, restored);
602    }
603
604    #[test]
605    fn test_bitvec_compression_ratio() {
606        let bitvec = BitVector::zeros(64);
607        let ratio = bitvec.compression_ratio();
608        // 64 bools = 64 bytes original, 8 bytes compressed = 8x
609        assert!((ratio - 8.0).abs() < 0.1);
610    }
611
612    #[test]
613    fn test_bitvec_ones_iter() {
614        let bools = vec![true, false, true, true, false];
615        let bitvec = BitVector::from_bools(&bools);
616        let ones: Vec<usize> = bitvec.ones_iter().collect();
617        assert_eq!(ones, vec![0, 2, 3]);
618    }
619
620    #[test]
621    fn test_bitvec_zeros_iter() {
622        let bools = vec![true, false, true, true, false];
623        let bitvec = BitVector::from_bools(&bools);
624        let zeros: Vec<usize> = bitvec.zeros_iter().collect();
625        assert_eq!(zeros, vec![1, 4]);
626    }
627
628    #[test]
629    fn test_bitvec_from_iter() {
630        let bitvec: BitVector = vec![true, false, true].into_iter().collect();
631        assert_eq!(bitvec.len(), 3);
632        assert_eq!(bitvec.get(0), Some(true));
633        assert_eq!(bitvec.get(1), Some(false));
634        assert_eq!(bitvec.get(2), Some(true));
635    }
636
637    #[test]
638    fn test_bitvec_deserialize_roundtrip() {
639        let bools = vec![true, false, true, true, false, false, true, false];
640        let original = BitVector::from_bools(&bools);
641        let json = serde_json::to_string(&original).unwrap();
642        let restored: BitVector = serde_json::from_str(&json).unwrap();
643        assert_eq!(original, restored);
644    }
645
646    #[test]
647    fn test_bitvec_deserialize_invalid_len_too_large() {
648        // len=200 requires ceil(200/64) = 4 words, but we only provide 1
649        let json = r#"{"data":[42],"len":200}"#;
650        let result: Result<BitVector, _> = serde_json::from_str(json);
651        assert!(result.is_err());
652        let err_msg = result.unwrap_err().to_string();
653        assert!(
654            err_msg.contains("invariant violated"),
655            "expected invariant error, got: {err_msg}"
656        );
657    }
658
659    #[test]
660    fn test_bitvec_deserialize_invalid_len_data_mismatch() {
661        // len=10 requires ceil(10/64) = 1 word, but we provide 3
662        let json = r#"{"data":[1,2,3],"len":10}"#;
663        let result: Result<BitVector, _> = serde_json::from_str(json);
664        assert!(result.is_err());
665        let err_msg = result.unwrap_err().to_string();
666        assert!(
667            err_msg.contains("invariant violated"),
668            "expected invariant error, got: {err_msg}"
669        );
670    }
671
672    #[test]
673    fn test_bitvec_deserialize_valid_edge_cases() {
674        // len=0 with empty data
675        let json = r#"{"data":[],"len":0}"#;
676        let bv: BitVector = serde_json::from_str(json).unwrap();
677        assert_eq!(bv.len(), 0);
678        assert!(bv.is_empty());
679
680        // len=1 with one u64
681        let json = r#"{"data":[1],"len":1}"#;
682        let bv: BitVector = serde_json::from_str(json).unwrap();
683        assert_eq!(bv.len(), 1);
684        assert_eq!(bv.get(0), Some(true));
685
686        // len=64 with one u64 (exactly fills one word)
687        let json = r#"{"data":[18446744073709551615],"len":64}"#;
688        let bv: BitVector = serde_json::from_str(json).unwrap();
689        assert_eq!(bv.len(), 64);
690        assert_eq!(bv.count_ones(), 64);
691    }
692}