Skip to main content

grafeo_core/codec/
bitvec.rs

1//! Stores booleans as individual bits - 8x smaller than `Vec<bool>`.
2//!
3//! Use this when you're tracking lots of boolean flags (like "visited" markers
4//! in graph traversals, or null bitmaps). Backed by `Vec<u64>` so bitwise
5//! operations like AND/OR/XOR stay cache-friendly.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use grafeo_core::codec::bitvec::BitVector;
11//! let bools = vec![true, false, true, true, false, false, true, false];
12//! let bitvec = BitVector::from_bools(&bools);
13//! // Stored as: 0b01001101 (1 byte instead of 8)
14//!
15//! assert_eq!(bitvec.get(0), Some(true));
16//! assert_eq!(bitvec.get(1), Some(false));
17//! assert_eq!(bitvec.count_ones(), 4);
18//! ```
19
20use std::io;
21
22use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
23
24/// Stores booleans as individual bits - 8x smaller than `Vec<bool>`.
25///
26/// Supports bitwise operations ([`and`](Self::and), [`or`](Self::or),
27/// [`not`](Self::not)) for combining filter results efficiently.
28#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
29pub struct BitVector {
30    /// Packed bits (little-endian within each word).
31    data: Vec<u64>,
32    /// Number of bits stored.
33    len: usize,
34}
35
36impl<'de> Deserialize<'de> for BitVector {
37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38    where
39        D: Deserializer<'de>,
40    {
41        #[derive(serde::Deserialize)]
42        #[serde(field_identifier, rename_all = "lowercase")]
43        enum Field {
44            Data,
45            Len,
46        }
47
48        struct BitVectorVisitor;
49
50        impl<'de> Visitor<'de> for BitVectorVisitor {
51            type Value = BitVector;
52
53            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54                formatter.write_str("struct BitVector with consistent data and len fields")
55            }
56
57            fn visit_seq<V>(self, mut seq: V) -> Result<BitVector, V::Error>
58            where
59                V: SeqAccess<'de>,
60            {
61                let data: Vec<u64> = seq
62                    .next_element()?
63                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
64                let len: usize = seq
65                    .next_element()?
66                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
67                validate_bitvec(len, &data).map_err(de::Error::custom)
68            }
69
70            fn visit_map<V>(self, mut map: V) -> Result<BitVector, V::Error>
71            where
72                V: MapAccess<'de>,
73            {
74                let mut data: Option<Vec<u64>> = None;
75                let mut len: Option<usize> = None;
76
77                while let Some(key) = map.next_key()? {
78                    match key {
79                        Field::Data => {
80                            if data.is_some() {
81                                return Err(de::Error::duplicate_field("data"));
82                            }
83                            data = Some(map.next_value()?);
84                        }
85                        Field::Len => {
86                            if len.is_some() {
87                                return Err(de::Error::duplicate_field("len"));
88                            }
89                            len = Some(map.next_value()?);
90                        }
91                    }
92                }
93
94                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
95                let len = len.ok_or_else(|| de::Error::missing_field("len"))?;
96                validate_bitvec(len, &data).map_err(de::Error::custom)
97            }
98        }
99
100        const FIELDS: &[&str] = &["data", "len"];
101        deserializer.deserialize_struct("BitVector", FIELDS, BitVectorVisitor)
102    }
103}
104
105/// Validates that `len` and `data` are consistent, returning a valid
106/// `BitVector` or an error message.
107fn validate_bitvec(len: usize, data: &[u64]) -> Result<BitVector, String> {
108    let expected_words = (len + 63) / 64;
109    if data.len() != expected_words {
110        return Err(format!(
111            "BitVector invariant violated: len={len} requires {expected_words} words, but data contains {} words",
112            data.len()
113        ));
114    }
115    Ok(BitVector {
116        data: data.to_vec(),
117        len,
118    })
119}
120
121impl BitVector {
122    /// Reconstructs from pre-packed raw parts.
123    ///
124    /// Used by section deserialization. The caller ensures data consistency.
125    #[must_use]
126    pub fn from_raw_parts(data: Vec<u64>, len: usize) -> Self {
127        Self { data, len }
128    }
129
130    /// Creates an empty bit vector.
131    #[must_use]
132    pub fn new() -> Self {
133        Self {
134            data: Vec::new(),
135            len: 0,
136        }
137    }
138
139    /// Creates a bit vector with the specified capacity (in bits).
140    #[must_use]
141    pub fn with_capacity(bits: usize) -> Self {
142        let words = (bits + 63) / 64;
143        Self {
144            data: Vec::with_capacity(words),
145            len: 0,
146        }
147    }
148
149    /// Creates a bit vector from a slice of booleans.
150    #[must_use]
151    pub fn from_bools(bools: &[bool]) -> Self {
152        let num_words = (bools.len() + 63) / 64;
153        let mut data = vec![0u64; num_words];
154
155        for (i, &b) in bools.iter().enumerate() {
156            if b {
157                let word_idx = i / 64;
158                let bit_idx = i % 64;
159                data[word_idx] |= 1 << bit_idx;
160            }
161        }
162
163        Self {
164            data,
165            len: bools.len(),
166        }
167    }
168
169    /// Creates a bit vector with all bits set to the same value.
170    #[must_use]
171    pub fn filled(len: usize, value: bool) -> Self {
172        let num_words = (len + 63) / 64;
173        let fill = if value { u64::MAX } else { 0 };
174        let data = vec![fill; num_words];
175
176        Self { data, len }
177    }
178
179    /// Creates a bit vector with all bits set to false (0).
180    #[must_use]
181    pub fn zeros(len: usize) -> Self {
182        Self::filled(len, false)
183    }
184
185    /// Creates a bit vector with all bits set to true (1).
186    #[must_use]
187    pub fn ones(len: usize) -> Self {
188        Self::filled(len, true)
189    }
190
191    /// Returns the number of bits.
192    #[must_use]
193    pub fn len(&self) -> usize {
194        self.len
195    }
196
197    /// Returns whether the bit vector is empty.
198    #[must_use]
199    pub fn is_empty(&self) -> bool {
200        self.len == 0
201    }
202
203    /// Gets the bit at the given index.
204    #[must_use]
205    pub fn get(&self, index: usize) -> Option<bool> {
206        if index >= self.len {
207            return None;
208        }
209
210        let word_idx = index / 64;
211        let bit_idx = index % 64;
212        Some((self.data[word_idx] & (1 << bit_idx)) != 0)
213    }
214
215    /// Sets the bit at the given index.
216    ///
217    /// # Panics
218    ///
219    /// Panics if index >= len.
220    pub fn set(&mut self, index: usize, value: bool) {
221        assert!(index < self.len, "Index out of bounds");
222
223        let word_idx = index / 64;
224        let bit_idx = index % 64;
225
226        if value {
227            self.data[word_idx] |= 1 << bit_idx;
228        } else {
229            self.data[word_idx] &= !(1 << bit_idx);
230        }
231    }
232
233    /// Appends a bit to the end.
234    pub fn push(&mut self, value: bool) {
235        let word_idx = self.len / 64;
236        let bit_idx = self.len % 64;
237
238        if word_idx >= self.data.len() {
239            self.data.push(0);
240        }
241
242        if value {
243            self.data[word_idx] |= 1 << bit_idx;
244        }
245
246        self.len += 1;
247    }
248
249    /// Returns the number of bits set to true.
250    #[must_use]
251    pub fn count_ones(&self) -> usize {
252        if self.is_empty() {
253            return 0;
254        }
255
256        let full_words = self.len / 64;
257        let remaining_bits = self.len % 64;
258
259        let mut count: usize = self.data[..full_words]
260            .iter()
261            .map(|&w| w.count_ones() as usize)
262            .sum();
263
264        if remaining_bits > 0 && full_words < self.data.len() {
265            let mask = (1u64 << remaining_bits) - 1;
266            count += (self.data[full_words] & mask).count_ones() as usize;
267        }
268
269        count
270    }
271
272    /// Returns the number of bits set to false.
273    #[must_use]
274    pub fn count_zeros(&self) -> usize {
275        self.len - self.count_ones()
276    }
277
278    /// Converts back to a `Vec<bool>`.
279    ///
280    /// # Panics
281    ///
282    /// Panics if an internal index is out of bounds (invariant violation).
283    #[must_use]
284    pub fn to_bools(&self) -> Vec<bool> {
285        (0..self.len)
286            .map(|i| self.get(i).expect("index within len"))
287            .collect()
288    }
289
290    /// Returns an iterator over the bits.
291    ///
292    /// # Panics
293    ///
294    /// Panics if an internal index is out of bounds (invariant violation).
295    pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
296        (0..self.len).map(move |i| self.get(i).expect("index within len"))
297    }
298
299    /// Returns an iterator over indices where bits are true.
300    ///
301    /// # Panics
302    ///
303    /// Panics if an internal index is out of bounds (invariant violation).
304    pub fn ones_iter(&self) -> impl Iterator<Item = usize> + '_ {
305        (0..self.len).filter(move |&i| self.get(i).expect("index within len"))
306    }
307
308    /// Returns an iterator over indices where bits are false.
309    ///
310    /// # Panics
311    ///
312    /// Panics if an internal index is out of bounds (invariant violation).
313    pub fn zeros_iter(&self) -> impl Iterator<Item = usize> + '_ {
314        (0..self.len).filter(move |&i| !self.get(i).expect("index within len"))
315    }
316
317    /// Returns the raw data.
318    #[must_use]
319    pub fn data(&self) -> &[u64] {
320        &self.data
321    }
322
323    /// Returns the compression ratio (original bytes / compressed bytes).
324    #[must_use]
325    pub fn compression_ratio(&self) -> f64 {
326        if self.is_empty() {
327            return 1.0;
328        }
329
330        // Original: 1 byte per bool
331        let original_size = self.len;
332        // Compressed: ceil(len / 8) bytes
333        let compressed_size = self.data.len() * 8;
334
335        if compressed_size == 0 {
336            return 1.0;
337        }
338
339        original_size as f64 / compressed_size as f64
340    }
341
342    /// Performs bitwise AND with another bit vector.
343    ///
344    /// The result has the length of the shorter vector.
345    #[must_use]
346    pub fn and(&self, other: &Self) -> Self {
347        let len = self.len.min(other.len);
348        let num_words = (len + 63) / 64;
349
350        let data: Vec<u64> = self
351            .data
352            .iter()
353            .zip(&other.data)
354            .take(num_words)
355            .map(|(&a, &b)| a & b)
356            .collect();
357
358        Self { data, len }
359    }
360
361    /// Performs bitwise OR with another bit vector.
362    ///
363    /// The result has the length of the shorter vector.
364    #[must_use]
365    pub fn or(&self, other: &Self) -> Self {
366        let len = self.len.min(other.len);
367        let num_words = (len + 63) / 64;
368
369        let data: Vec<u64> = self
370            .data
371            .iter()
372            .zip(&other.data)
373            .take(num_words)
374            .map(|(&a, &b)| a | b)
375            .collect();
376
377        Self { data, len }
378    }
379
380    /// Performs bitwise NOT.
381    #[must_use]
382    pub fn not(&self) -> Self {
383        let data: Vec<u64> = self.data.iter().map(|&w| !w).collect();
384        Self {
385            data,
386            len: self.len,
387        }
388    }
389
390    /// Performs bitwise XOR with another bit vector.
391    #[must_use]
392    pub fn xor(&self, other: &Self) -> Self {
393        let len = self.len.min(other.len);
394        let num_words = (len + 63) / 64;
395
396        let data: Vec<u64> = self
397            .data
398            .iter()
399            .zip(&other.data)
400            .take(num_words)
401            .map(|(&a, &b)| a ^ b)
402            .collect();
403
404        Self { data, len }
405    }
406
407    /// Serializes to bytes.
408    ///
409    /// # Errors
410    ///
411    /// Returns `Err` if the bit-vector length exceeds `u32::MAX`.
412    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
413        let len_u32 = u32::try_from(self.len).map_err(|_| {
414            io::Error::new(
415                io::ErrorKind::InvalidInput,
416                format!(
417                    "BitVector length {} exceeds u32::MAX, cannot serialize",
418                    self.len
419                ),
420            )
421        })?;
422        let mut buf = Vec::with_capacity(4 + self.data.len() * 8);
423        buf.extend_from_slice(&len_u32.to_le_bytes());
424        for &word in &self.data {
425            buf.extend_from_slice(&word.to_le_bytes());
426        }
427        Ok(buf)
428    }
429
430    /// Deserializes from bytes.
431    ///
432    /// # Errors
433    ///
434    /// Returns `Err` if the byte slice is too short or contains invalid data.
435    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
436        if bytes.len() < 4 {
437            return Err(io::Error::new(
438                io::ErrorKind::InvalidData,
439                "BitVector too short",
440            ));
441        }
442
443        let len = u32::from_le_bytes(
444            bytes[0..4]
445                .try_into()
446                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
447        ) as usize;
448        let num_words = (len + 63) / 64;
449
450        if bytes.len() < 4 + num_words * 8 {
451            return Err(io::Error::new(
452                io::ErrorKind::InvalidData,
453                "BitVector truncated",
454            ));
455        }
456
457        let mut data = Vec::with_capacity(num_words);
458        for i in 0..num_words {
459            let offset = 4 + i * 8;
460            let word = u64::from_le_bytes(
461                bytes[offset..offset + 8]
462                    .try_into()
463                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
464            );
465            data.push(word);
466        }
467
468        Ok(Self { data, len })
469    }
470}
471
472impl Default for BitVector {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478impl FromIterator<bool> for BitVector {
479    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
480        let mut bitvec = BitVector::new();
481        for b in iter {
482            bitvec.push(b);
483        }
484        bitvec
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_bitvec_basic() {
494        let bools = vec![true, false, true, true, false, false, true, false];
495        let bitvec = BitVector::from_bools(&bools);
496
497        assert_eq!(bitvec.len(), 8);
498        for (i, &expected) in bools.iter().enumerate() {
499            assert_eq!(bitvec.get(i), Some(expected));
500        }
501    }
502
503    #[test]
504    fn test_bitvec_empty() {
505        let bitvec = BitVector::new();
506        assert!(bitvec.is_empty());
507        assert_eq!(bitvec.get(0), None);
508    }
509
510    #[test]
511    fn test_bitvec_push() {
512        let mut bitvec = BitVector::new();
513        bitvec.push(true);
514        bitvec.push(false);
515        bitvec.push(true);
516
517        assert_eq!(bitvec.len(), 3);
518        assert_eq!(bitvec.get(0), Some(true));
519        assert_eq!(bitvec.get(1), Some(false));
520        assert_eq!(bitvec.get(2), Some(true));
521    }
522
523    #[test]
524    fn test_bitvec_set() {
525        let mut bitvec = BitVector::zeros(8);
526
527        bitvec.set(0, true);
528        bitvec.set(3, true);
529        bitvec.set(7, true);
530
531        assert_eq!(bitvec.get(0), Some(true));
532        assert_eq!(bitvec.get(1), Some(false));
533        assert_eq!(bitvec.get(3), Some(true));
534        assert_eq!(bitvec.get(7), Some(true));
535    }
536
537    #[test]
538    fn test_bitvec_count() {
539        let bools = vec![true, false, true, true, false, false, true, false];
540        let bitvec = BitVector::from_bools(&bools);
541
542        assert_eq!(bitvec.count_ones(), 4);
543        assert_eq!(bitvec.count_zeros(), 4);
544    }
545
546    #[test]
547    fn test_bitvec_filled() {
548        let zeros = BitVector::zeros(100);
549        assert_eq!(zeros.count_ones(), 0);
550        assert_eq!(zeros.count_zeros(), 100);
551
552        let ones = BitVector::ones(100);
553        assert_eq!(ones.count_ones(), 100);
554        assert_eq!(ones.count_zeros(), 0);
555    }
556
557    #[test]
558    fn test_bitvec_to_bools() {
559        let original = vec![true, false, true, true, false];
560        let bitvec = BitVector::from_bools(&original);
561        let restored = bitvec.to_bools();
562        assert_eq!(original, restored);
563    }
564
565    #[test]
566    fn test_bitvec_large() {
567        // Test with more than 64 bits
568        let bools: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
569        let bitvec = BitVector::from_bools(&bools);
570
571        assert_eq!(bitvec.len(), 200);
572        for (i, &expected) in bools.iter().enumerate() {
573            assert_eq!(bitvec.get(i), Some(expected), "Mismatch at index {}", i);
574        }
575    }
576
577    #[test]
578    fn test_bitvec_and() {
579        let a = BitVector::from_bools(&[true, true, false, false]);
580        let b = BitVector::from_bools(&[true, false, true, false]);
581        let result = a.and(&b);
582
583        assert_eq!(result.to_bools(), vec![true, false, false, false]);
584    }
585
586    #[test]
587    fn test_bitvec_or() {
588        let a = BitVector::from_bools(&[true, true, false, false]);
589        let b = BitVector::from_bools(&[true, false, true, false]);
590        let result = a.or(&b);
591
592        assert_eq!(result.to_bools(), vec![true, true, true, false]);
593    }
594
595    #[test]
596    fn test_bitvec_not() {
597        let a = BitVector::from_bools(&[true, false, true, false]);
598        let result = a.not();
599
600        // Note: NOT inverts all bits in the word, so we check the relevant bits
601        assert_eq!(result.get(0), Some(false));
602        assert_eq!(result.get(1), Some(true));
603        assert_eq!(result.get(2), Some(false));
604        assert_eq!(result.get(3), Some(true));
605    }
606
607    #[test]
608    fn test_bitvec_xor() {
609        let a = BitVector::from_bools(&[true, true, false, false]);
610        let b = BitVector::from_bools(&[true, false, true, false]);
611        let result = a.xor(&b);
612
613        assert_eq!(result.to_bools(), vec![false, true, true, false]);
614    }
615
616    #[test]
617    fn test_bitvec_serialization() {
618        let bools = vec![true, false, true, true, false, false, true, false];
619        let bitvec = BitVector::from_bools(&bools);
620        let bytes = bitvec.to_bytes().unwrap();
621        let restored = BitVector::from_bytes(&bytes).unwrap();
622        assert_eq!(bitvec, restored);
623    }
624
625    #[test]
626    fn test_bitvec_compression_ratio() {
627        let bitvec = BitVector::zeros(64);
628        let ratio = bitvec.compression_ratio();
629        // 64 bools = 64 bytes original, 8 bytes compressed = 8x
630        assert!((ratio - 8.0).abs() < 0.1);
631    }
632
633    #[test]
634    fn test_bitvec_ones_iter() {
635        let bools = vec![true, false, true, true, false];
636        let bitvec = BitVector::from_bools(&bools);
637        let ones: Vec<usize> = bitvec.ones_iter().collect();
638        assert_eq!(ones, vec![0, 2, 3]);
639    }
640
641    #[test]
642    fn test_bitvec_zeros_iter() {
643        let bools = vec![true, false, true, true, false];
644        let bitvec = BitVector::from_bools(&bools);
645        let zeros: Vec<usize> = bitvec.zeros_iter().collect();
646        assert_eq!(zeros, vec![1, 4]);
647    }
648
649    #[test]
650    fn test_bitvec_from_iter() {
651        let bitvec: BitVector = vec![true, false, true].into_iter().collect();
652        assert_eq!(bitvec.len(), 3);
653        assert_eq!(bitvec.get(0), Some(true));
654        assert_eq!(bitvec.get(1), Some(false));
655        assert_eq!(bitvec.get(2), Some(true));
656    }
657
658    #[test]
659    fn test_bitvec_deserialize_roundtrip() {
660        let bools = vec![true, false, true, true, false, false, true, false];
661        let original = BitVector::from_bools(&bools);
662        let json = serde_json::to_string(&original).unwrap();
663        let restored: BitVector = serde_json::from_str(&json).unwrap();
664        assert_eq!(original, restored);
665    }
666
667    #[test]
668    fn test_bitvec_deserialize_invalid_len_too_large() {
669        // len=200 requires ceil(200/64) = 4 words, but we only provide 1
670        let json = r#"{"data":[42],"len":200}"#;
671        let result: Result<BitVector, _> = serde_json::from_str(json);
672        assert!(result.is_err());
673        let err_msg = result.unwrap_err().to_string();
674        assert!(
675            err_msg.contains("invariant violated"),
676            "expected invariant error, got: {err_msg}"
677        );
678    }
679
680    #[test]
681    fn test_bitvec_deserialize_invalid_len_data_mismatch() {
682        // len=10 requires ceil(10/64) = 1 word, but we provide 3
683        let json = r#"{"data":[1,2,3],"len":10}"#;
684        let result: Result<BitVector, _> = serde_json::from_str(json);
685        assert!(result.is_err());
686        let err_msg = result.unwrap_err().to_string();
687        assert!(
688            err_msg.contains("invariant violated"),
689            "expected invariant error, got: {err_msg}"
690        );
691    }
692
693    #[test]
694    fn test_bitvec_deserialize_valid_edge_cases() {
695        // len=0 with empty data
696        let json = r#"{"data":[],"len":0}"#;
697        let bv: BitVector = serde_json::from_str(json).unwrap();
698        assert_eq!(bv.len(), 0);
699        assert!(bv.is_empty());
700
701        // len=1 with one u64
702        let json = r#"{"data":[1],"len":1}"#;
703        let bv: BitVector = serde_json::from_str(json).unwrap();
704        assert_eq!(bv.len(), 1);
705        assert_eq!(bv.get(0), Some(true));
706
707        // len=64 with one u64 (exactly fills one word)
708        let json = r#"{"data":[18446744073709551615],"len":64}"#;
709        let bv: BitVector = serde_json::from_str(json).unwrap();
710        assert_eq!(bv.len(), 64);
711        assert_eq!(bv.count_ones(), 64);
712    }
713}