Skip to main content

malware_modeler/
bitarray.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! The malware modeler uses the presence of n-grams as features for the malware detection models
4//! it creates. But ultimately each feature is the presence or absense of an n-gram. So instead of
5//! wasting memory in integers for just one or zero, use each bit as a true or false, since the
6//! presence or absense of a feature n-gram is a binary situation.
7//!
8//! The only question is: "what size integer to use?"
9//! If many n-grams are expected, then a larger integer size makes sense. But for smaller feature
10//! sets a smaller integer size might be better. So this [`BitArray`] type allows for using any
11//! unsigned integer type as a backend, since the user knows their data best. Unsigned is best since
12//! we'll never need a negative representation, especially since we don't use the integer variable
13//! as an integer anyway.
14//!
15//! Imagine this bit array as a vector if booleans, but compressed to use the least amount of memory
16//! as possible since for our purposes, we're likely to have a lot of n-gram features (possibly in
17//! the millions).
18
19use std::fmt::Display;
20use std::ops::{
21    Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, Mul, Not, Rem, Shl,
22    Shr, Sub,
23};
24use std::str::FromStr;
25
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27
28/// Trait for integer types which can be used at bit storage
29pub trait BitStorage<Rhs = Self, Output = Self>:
30    Copy
31    + Default
32    + PartialEq
33    + Eq
34    + Display
35    + std::fmt::Debug
36    + std::hash::Hash
37    + Add<Rhs, Output = Output>
38    + Sub<Rhs, Output = Output>
39    + Mul<Rhs, Output = Output>
40    + Div<Rhs, Output = Output>
41    + Rem<Rhs, Output = Output>
42    + std::fmt::Binary
43    + Shl<usize, Output = Output>
44    + Shr<usize, Output = Output>
45    + BitAnd
46    + BitAnd<Output = Self>
47    + BitAndAssign
48    + BitOr
49    + BitOrAssign
50    + BitXor
51    + BitXorAssign
52    + Not<Output = Self>
53    + Sized
54    + Send
55    + Sync
56{
57    /// Number of bits used for this type
58    const BITS: usize;
59
60    /// One as this type as a convenience
61    const ONE: Self;
62}
63
64macro_rules! impl_bit_storage {
65    ($($t:ty),*) => {
66        $(impl BitStorage for $t {
67            const BITS: usize = <$t>::BITS as usize;
68            const ONE: Self = 1 as Self;
69        })*
70    };
71}
72
73impl_bit_storage!(u8, u16, u32, u64, u128, usize);
74
75/// Array of booleans as a large bit vector
76#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
77pub struct BitArray<T: BitStorage> {
78    /// Vector of integers holding the bits as integers.
79    data: Vec<T>,
80
81    /// Bits remaining in the last integer.
82    remaining: usize,
83}
84
85impl<T: BitStorage> Display for BitArray<T> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        if self.data.is_empty() {
88            write!(f, "")?;
89            return Ok(());
90        }
91
92        for bit in self {
93            write!(f, "{}", u8::from(bit))?;
94        }
95        Ok(())
96    }
97}
98
99impl<T: BitStorage> Serialize for BitArray<T> {
100    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
101        let string = format!("{self}");
102        serializer.serialize_str(&string)
103    }
104}
105
106impl<'de, T: BitStorage> Deserialize<'de> for BitArray<T> {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: Deserializer<'de>,
110    {
111        use serde::de::Error;
112
113        let string = String::deserialize(deserializer)?;
114        BitArray::<T>::from_str(&string).map_err(D::Error::custom)
115    }
116}
117
118impl<T: BitStorage> IntoIterator for BitArray<T> {
119    type Item = bool;
120    type IntoIter = BitArrayOwnedIterator<T>;
121
122    fn into_iter(self) -> Self::IntoIter {
123        Self::IntoIter {
124            array: self,
125            pos: 0,
126        }
127    }
128}
129
130impl<'a, T: BitStorage + 'a> IntoIterator for &'a BitArray<T> {
131    type Item = bool;
132    type IntoIter = BitArrayIterator<'a, T>;
133
134    fn into_iter(self) -> Self::IntoIter {
135        Self::IntoIter {
136            array: self,
137            pos: 0,
138        }
139    }
140}
141
142impl<T: BitStorage> BitArray<T> {
143    /// Create a new bit array where the size is the expected number of bits.
144    #[must_use]
145    pub fn new(size: usize) -> Self {
146        let elements = size.div_ceil(T::BITS);
147        let remaining = if size < T::BITS {
148            T::BITS - size
149        } else {
150            size % T::BITS
151        };
152
153        Self {
154            data: vec![T::default(); elements],
155            remaining,
156        }
157    }
158
159    /// Empty vector with some bytes allocated, where the size of the memory allocated
160    /// depends on the underlying integer size used.
161    #[must_use]
162    pub fn with_capacity(capacity: usize) -> Self {
163        Self {
164            data: Vec::with_capacity(capacity),
165            remaining: 0,
166        }
167    }
168
169    // TODO: add `from_bytes(bytes: &[u8])`
170
171    /// Number of bits in use
172    #[inline]
173    #[must_use]
174    pub const fn len(&self) -> usize {
175        self.data.len() * T::BITS - self.remaining
176    }
177
178    /// Returns true if the array contains no elements.
179    #[inline]
180    #[must_use]
181    pub const fn is_empty(&self) -> bool {
182        self.data.is_empty()
183    }
184
185    /// Indicates if none of the bits are set
186    #[inline]
187    #[must_use]
188    pub fn all_zeroes(&self) -> bool {
189        self.data.iter().all(|x| *x == T::default())
190    }
191
192    /// Number of bits available
193    #[inline]
194    #[must_use]
195    pub const fn capacity(&self) -> usize {
196        self.data.len() * T::BITS
197    }
198
199    /// Iterate over the bits in the array
200    #[must_use]
201    pub fn iter(&self) -> BitArrayIterator<'_, T> {
202        BitArrayIterator {
203            array: self,
204            pos: 0,
205        }
206    }
207
208    /// Get the bit value from a given index
209    ///
210    /// # Panic
211    /// This will panic if `index` points to a value beyond the number of bits stored in this array.
212    #[must_use]
213    pub fn get(&self, index: usize) -> bool
214    where
215        <T as BitAnd>::Output: PartialEq<T>,
216    {
217        let (block, bit) = Self::bit_pos(index);
218        ((self.data[block] >> bit) & T::ONE) == T::ONE
219    }
220
221    /// Set the bit value at a given index
222    ///
223    /// # Panic
224    /// This will panic if `index` points to a value beyond the number of bits stored in this array.
225    pub fn set(&mut self, index: usize, value: bool) {
226        let (block, bit) = Self::bit_pos(index);
227        if value {
228            self.data[block] |= T::ONE << bit;
229        } else {
230            self.data[block] &= !(T::ONE << bit);
231        }
232    }
233
234    /// Unset the bit at a given index
235    ///
236    /// # Panic
237    /// This will panic if `index` points to a value beyond the number of bits stored in this array.
238    pub fn unset(&mut self, index: usize) {
239        self.set(index, false);
240    }
241
242    /// Push a new bit onto the end of the array.
243    pub fn push(&mut self, value: bool) {
244        // All integers are used, so add another
245        if self.remaining == 0 || self.data.is_empty() {
246            self.data.push(T::default());
247            self.remaining = T::BITS;
248        }
249
250        let index = self.len();
251        self.remaining -= 1;
252        self.set(index, value);
253    }
254
255    /// Clear all bits
256    #[inline]
257    pub fn clear(&mut self) {
258        self.data.iter_mut().for_each(|x| *x = T::default());
259    }
260
261    /// Remove all data and set the size to zero
262    #[inline]
263    pub fn reset(&mut self) {
264        self.data.clear();
265        self.remaining = 0;
266    }
267
268    #[inline]
269    const fn bit_pos(idx: usize) -> (usize, usize) {
270        let block = idx / T::BITS;
271        let bit = idx % T::BITS;
272        (block, bit)
273    }
274}
275
276impl<T: BitStorage> FromStr for BitArray<T> {
277    type Err = String;
278
279    fn from_str(s: &str) -> Result<Self, Self::Err> {
280        let mut array = BitArray::<T>::with_capacity(s.len());
281        for (index, bit) in s.chars().enumerate() {
282            match bit {
283                '0' => array.push(false),
284                '1' => array.push(true),
285                _ => return Err(format!("Invalid bit value {bit} at index {index}")),
286            }
287        }
288        Ok(array)
289    }
290}
291
292/// Iterator over the bits in a bit array holding a reference to the array
293#[derive(Clone)]
294pub struct BitArrayIterator<'a, T: BitStorage> {
295    array: &'a BitArray<T>,
296    pos: usize,
297}
298
299impl<T: BitStorage> Iterator for BitArrayIterator<'_, T> {
300    type Item = bool;
301
302    fn next(&mut self) -> Option<Self::Item> {
303        if self.pos >= self.array.len() {
304            None
305        } else {
306            let value = self.array.get(self.pos);
307            self.pos += 1;
308            Some(value)
309        }
310    }
311}
312
313impl<T: BitStorage> BitArrayIterator<'_, T> {
314    /// Reset the iterator to the beginning
315    pub fn reset(&mut self) {
316        self.pos = 0;
317    }
318}
319
320/// Iterator over the bits in a bit array owning the array
321pub struct BitArrayOwnedIterator<T: BitStorage> {
322    array: BitArray<T>,
323    pos: usize,
324}
325
326impl<T: BitStorage> Iterator for BitArrayOwnedIterator<T> {
327    type Item = bool;
328
329    fn next(&mut self) -> Option<Self::Item> {
330        if self.pos >= self.array.len() {
331            None
332        } else {
333            let value = self.array.get(self.pos);
334            self.pos += 1;
335            Some(value)
336        }
337    }
338}
339
340impl<T: BitStorage> BitArrayOwnedIterator<T> {
341    /// Reset the iterator to the beginning
342    pub fn reset(&mut self) {
343        self.pos = 0;
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn bit_positioning() {
353        assert_eq!(BitArray::<u8>::bit_pos(0), (0, 0));
354        assert_eq!(BitArray::<u8>::bit_pos(1), (0, 1));
355        assert_eq!(BitArray::<u8>::bit_pos(2), (0, 2));
356        assert_eq!(BitArray::<u8>::bit_pos(3), (0, 3));
357        assert_eq!(BitArray::<u8>::bit_pos(4), (0, 4));
358        assert_eq!(BitArray::<u8>::bit_pos(12), (1, 4));
359    }
360
361    #[test]
362    fn building() {
363        let mut array = BitArray::<u16>::default();
364        eprintln!("Array default(): {array:?}");
365        assert_eq!("", format!("{array}"));
366        array.push(true);
367        assert!(!array.is_empty());
368        eprintln!("Array pushed true: {array:?}");
369        assert_eq!("1", format!("{array}"));
370        array.push(false);
371        assert_eq!("10", format!("{array}"));
372        array.push(true);
373        assert_eq!("101", format!("{array}"));
374        array.clear();
375        assert_eq!("000", format!("{array}"));
376
377        array.reset();
378        assert_eq!("", format!("{array}"));
379    }
380
381    #[test]
382    fn empty() {
383        let default = BitArray::<u64>::default();
384        assert!(default.is_empty());
385        assert_eq!(default.len(), 0);
386        assert_eq!(default.capacity(), 0);
387        assert_eq!("", format!("{default}"));
388    }
389
390    #[test]
391    fn with_one_integer_u64() {
392        let mut array = BitArray::<u64>::new(10);
393        assert_eq!(array.data.len(), 1); // One integer
394        assert_eq!(array.remaining, 54); // 54 bits available
395        assert!(array.all_zeroes());
396        assert_eq!(format!("{array}"), "0000000000");
397        let original_len = array.len();
398        assert_eq!(original_len, 10);
399
400        array.set(5, true);
401        assert_eq!(array.len(), original_len);
402        assert!(array.get(5));
403        assert_eq!(format!("{array}"), "0000010000");
404
405        array.set(6, true);
406        assert_eq!(array.len(), original_len);
407        assert!(array.get(5));
408        assert_eq!(format!("{array}"), "0000011000");
409
410        array.set(5, false);
411        assert!(!array.get(5));
412
413        array.set(6, true);
414        assert_eq!(array.len(), original_len);
415        assert!(array.get(6));
416        assert_eq!(format!("{array}"), "0000001000");
417
418        array.push(true);
419        assert_eq!(format!("{array}"), "00000010001");
420
421        let array_string = format!("{array}");
422        let array_from_string = BitArray::<u64>::from_str(&array_string).unwrap();
423        assert_eq!(array, array_from_string);
424
425        array.clear();
426        assert!(array.all_zeroes());
427        assert_eq!(array.len(), original_len + 1);
428        assert_eq!(format!("{array}"), "00000000000");
429    }
430
431    #[test]
432    fn with_one_integer_u32() {
433        let mut array = BitArray::<u32>::new(10);
434        assert_eq!(array.data.len(), 1); // One integer
435        assert_eq!(array.remaining, 22); // 22 bits available
436        assert!(array.all_zeroes());
437        let original_len = array.len();
438
439        array.set(5, true);
440        assert_eq!(format!("{array}"), "0000010000");
441        assert_eq!(array.len(), original_len);
442        assert!(array.get(5));
443
444        let array_string = format!("{array}");
445        let array_from_string = BitArray::<u32>::from_str(&array_string).unwrap();
446        assert_eq!(array, array_from_string);
447
448        array.set(5, false);
449        assert!(!array.get(5));
450
451        array.clear();
452        assert!(array.all_zeroes());
453        assert_eq!(array.len(), original_len);
454    }
455
456    #[test]
457    fn with_one_integer_u16() {
458        let mut array = BitArray::<u16>::new(10);
459        assert_eq!(array.data.len(), 1); // One integer
460        assert_eq!(array.remaining, 6); // Six bits available
461        assert!(array.all_zeroes());
462        let original_len = array.len();
463
464        array.set(5, true);
465        assert_eq!(format!("{array}"), "0000010000");
466        assert_eq!(array.len(), original_len);
467        assert!(array.get(5));
468
469        let array_string = format!("{array}");
470        let array_from_string = BitArray::<u16>::from_str(&array_string).unwrap();
471        assert_eq!(array, array_from_string);
472
473        array.set(5, false);
474        assert!(!array.get(5));
475
476        array.clear();
477        assert!(array.all_zeroes());
478        assert_eq!(array.len(), original_len);
479    }
480
481    #[test]
482    fn with_one_integer_u8() {
483        let mut array = BitArray::<u8>::new(6);
484        assert_eq!(array.data.len(), 1); // One integer
485        assert_eq!(array.remaining, 2); // Two bits on the second byte available
486        assert!(array.all_zeroes());
487        let original_len = array.len();
488
489        array.set(5, true);
490        assert_eq!(format!("{array}"), "000001");
491        assert_eq!(array.len(), original_len);
492        assert!(array.get(5));
493
494        let array_string = format!("{array}");
495        let array_from_string = BitArray::<u8>::from_str(&array_string).unwrap();
496        assert_eq!(array, array_from_string);
497
498        array.set(5, false);
499        assert!(!array.get(5));
500
501        array.clear();
502        assert!(array.all_zeroes());
503        assert_eq!(array.len(), original_len);
504    }
505
506    #[test]
507    fn with_several_integers_u64() {
508        let mut array = BitArray::<u64>::new(2001);
509        assert_eq!(array.data.len(), 32);
510        assert_eq!(array.remaining, 17);
511        println!("{array}");
512        let original_len = array.len();
513
514        array.set(5, true);
515        assert_eq!(array.len(), original_len);
516        assert!(array.get(5));
517
518        array.set(6, true);
519        assert_eq!(array.len(), original_len);
520        assert!(array.get(5));
521
522        array.set(5, false);
523        assert!(!array.get(5));
524
525        array.set(6, true);
526        assert_eq!(array.len(), original_len);
527        assert!(array.get(6));
528
529        let array_string = format!("{array}");
530        let array_from_string = BitArray::<u64>::from_str(&array_string).unwrap();
531        assert_eq!(array, array_from_string);
532
533        array.clear();
534        assert_eq!(array.len(), original_len);
535    }
536}