Skip to main content

arrow_buffer/util/
bit_mask.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utils for working with packed bit masks
19
20use crate::bit_util::ceil;
21
22/// Util function to set bits in a slice of bytes.
23///
24/// This will sets all bits on `write_data` in the range `[offset_write..offset_write+len]`
25/// to be equal to the bits in `data` in the range `[offset_read..offset_read+len]`
26/// returns the number of `0` bits `data[offset_read..offset_read+len]`
27/// `offset_write`, `offset_read`, and `len` are in terms of bits
28pub fn set_bits(
29    write_data: &mut [u8],
30    data: &[u8],
31    offset_write: usize,
32    offset_read: usize,
33    len: usize,
34) -> usize {
35    assert!(
36        offset_write
37            .checked_add(len)
38            .expect("operation will overflow write buffer")
39            <= write_data.len() * 8
40    );
41    assert!(
42        offset_read
43            .checked_add(len)
44            .expect("operation will overflow read buffer")
45            <= data.len() * 8
46    );
47    let mut null_count = 0;
48    let mut acc = 0;
49    while len > acc {
50        // SAFETY: the arguments to `set_upto_64bits` are within the valid range because
51        // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8
52        // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8
53        let (n, len_set) = unsafe {
54            set_upto_64bits(
55                write_data,
56                data,
57                offset_write + acc,
58                offset_read + acc,
59                len - acc,
60            )
61        };
62        null_count += n;
63        acc += len_set;
64    }
65
66    null_count
67}
68
69/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary.
70/// Returns a pair of the number of `0` bits and the number of bits set
71///
72/// # Safety
73/// The caller must ensure all arguments are within the valid range.
74#[inline]
75unsafe fn set_upto_64bits(
76    write_data: &mut [u8],
77    data: &[u8],
78    offset_write: usize,
79    offset_read: usize,
80    len: usize,
81) -> (usize, usize) {
82    let read_byte = offset_read / 8;
83    let read_shift = offset_read % 8;
84    let write_byte = offset_write / 8;
85    let write_shift = offset_write % 8;
86
87    if len >= 64 {
88        let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
89        if read_shift == 0 {
90            if write_shift == 0 {
91                // no shifting necessary
92                let len = 64;
93                let null_count = chunk.count_zeros() as usize;
94                unsafe { write_u64_bytes(write_data, write_byte, chunk) };
95                (null_count, len)
96            } else {
97                // only write shifting necessary
98                let len = 64 - write_shift;
99                let chunk = chunk << write_shift;
100                let null_count = len - chunk.count_ones() as usize;
101                unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
102                (null_count, len)
103            }
104        } else if write_shift == 0 {
105            // only read shifting necessary
106            let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0
107            let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask
108            let null_count = len - chunk.count_ones() as usize;
109            unsafe { write_u64_bytes(write_data, write_byte, chunk) };
110            (null_count, len)
111        } else {
112            let len = 64 - std::cmp::max(read_shift, write_shift);
113            let chunk = (chunk >> read_shift) << write_shift;
114            let null_count = len - chunk.count_ones() as usize;
115            unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
116            (null_count, len)
117        }
118    } else if len == 1 {
119        let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
120        unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
121        ((byte_chunk ^ 1) as usize, 1)
122    } else {
123        let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
124        let bytes = ceil(len + read_shift, 8);
125        // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len()
126        let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
127        let mask = u64::MAX >> (64 - len);
128        let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only
129        let chunk = chunk << write_shift; // shifting back to align with `write_data`
130        let null_count = len - chunk.count_ones() as usize;
131        let bytes = ceil(len + write_shift, 8);
132        for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
133            unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
134        }
135        (null_count, len)
136    }
137}
138
139/// # Safety
140/// The caller must ensure `data` has `offset..(offset + 8)` range, and `count <= 8`.
141#[inline]
142unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
143    debug_assert!(count <= 8);
144    let mut tmp: u64 = 0;
145    let src = unsafe { data.as_ptr().add(offset) };
146    unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) };
147    tmp
148}
149
150/// # Safety
151/// The caller must ensure `data` has `offset..(offset + 8)` range
152#[inline]
153unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
154    let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64;
155    unsafe { ptr.write_unaligned(chunk) };
156}
157
158/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk`
159/// instead of overwriting
160///
161/// # Safety
162/// The caller must ensure `data` has `offset..(offset + 8)` range
163#[inline]
164unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
165    let ptr = unsafe { data.as_mut_ptr().add(offset) };
166    let chunk = chunk | (unsafe { *ptr }) as u64;
167    unsafe { (ptr as *mut u64).write_unaligned(chunk) };
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::bit_util::{get_bit, set_bit, unset_bit};
174    use rand::prelude::StdRng;
175    use rand::{Rng, SeedableRng, TryRngCore};
176    use std::fmt::Display;
177
178    #[test]
179    fn test_set_bits_aligned() {
180        SetBitsTest {
181            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
182            data: vec![
183                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
184                0b10100101,
185            ],
186            offset_write: 8,
187            offset_read: 0,
188            len: 64,
189            expected_data: vec![
190                0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
191                0b11100111, 0b10100101, 0,
192            ],
193            expected_null_count: 24,
194        }
195        .verify();
196    }
197
198    #[test]
199    fn test_set_bits_unaligned_destination_start() {
200        SetBitsTest {
201            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
202            data: vec![
203                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
204                0b10100101,
205            ],
206            offset_write: 3,
207            offset_read: 0,
208            len: 64,
209            expected_data: vec![
210                0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110,
211                0b00101111, 0b00000101, 0b00000000,
212            ],
213            expected_null_count: 24,
214        }
215        .verify();
216    }
217
218    #[test]
219    fn test_set_bits_unaligned_destination_end() {
220        SetBitsTest {
221            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
222            data: vec![
223                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
224                0b10100101,
225            ],
226            offset_write: 8,
227            offset_read: 0,
228            len: 62,
229            expected_data: vec![
230                0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
231                0b11100111, 0b00100101, 0,
232            ],
233            expected_null_count: 23,
234        }
235        .verify();
236    }
237
238    #[test]
239    fn test_set_bits_unaligned() {
240        SetBitsTest {
241            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
242            data: vec![
243                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
244                0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101,
245                0b10011001, 0b11011011, 0b11101011, 0b11000011,
246            ],
247            offset_write: 3,
248            offset_read: 5,
249            len: 95,
250            expected_data: vec![
251                0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001,
252                0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001,
253            ],
254            expected_null_count: 35,
255        }
256        .verify();
257    }
258
259    #[test]
260    fn set_bits_fuzz() {
261        let mut rng = StdRng::seed_from_u64(42);
262        let mut data = SetBitsTest::new();
263        for _ in 0..100 {
264            data.regen(&mut rng);
265            data.verify();
266        }
267    }
268
269    #[derive(Debug, Default)]
270    struct SetBitsTest {
271        /// target write data
272        write_data: Vec<u8>,
273        /// source data
274        data: Vec<u8>,
275        offset_write: usize,
276        offset_read: usize,
277        len: usize,
278        /// the expected contents of write_data after the test
279        expected_data: Vec<u8>,
280        /// the expected number of nulls copied at the end of the test
281        expected_null_count: usize,
282    }
283
284    /// prints a byte slice as a binary string like "01010101 10101010"
285    struct BinaryFormatter<'a>(&'a [u8]);
286    impl Display for BinaryFormatter<'_> {
287        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288            for byte in self.0 {
289                write!(f, "{byte:08b} ")?;
290            }
291            write!(f, " ")?;
292            Ok(())
293        }
294    }
295
296    impl Display for SetBitsTest {
297        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298            writeln!(f, "SetBitsTest {{")?;
299            writeln!(f, "  write_data:    {}", BinaryFormatter(&self.write_data))?;
300            writeln!(f, "  data:          {}", BinaryFormatter(&self.data))?;
301            writeln!(
302                f,
303                "  expected_data: {}",
304                BinaryFormatter(&self.expected_data)
305            )?;
306            writeln!(f, "  offset_write: {}", self.offset_write)?;
307            writeln!(f, "  offset_read: {}", self.offset_read)?;
308            writeln!(f, "  len: {}", self.len)?;
309            writeln!(f, "  expected_null_count: {}", self.expected_null_count)?;
310            writeln!(f, "}}")
311        }
312    }
313
314    impl SetBitsTest {
315        /// create a new instance of FuzzData
316        fn new() -> Self {
317            Self::default()
318        }
319
320        /// Update this instance's fields with randomly selected values and expected data
321        fn regen(&mut self, rng: &mut StdRng) {
322            //  (read) data
323            // ------------------+-----------------+-------
324            // .. offset_read .. | data            | ...
325            // ------------------+-----------------+-------
326
327            // Write data
328            // -------------------+-----------------+-------
329            // .. offset_write .. | (data to write) | ...
330            // -------------------+-----------------+-------
331
332            // length of data to copy
333            let len = rng.random_range(0..=200);
334
335            // randomly pick where we will write to
336            let offset_write_bits = rng.random_range(0..=200);
337            let offset_write_bytes = if offset_write_bits % 8 == 0 {
338                offset_write_bits / 8
339            } else {
340                (offset_write_bits / 8) + 1
341            };
342            let extra_write_data_bytes = rng.random_range(0..=5); // ensure 0 shows up often
343
344            // randomly decide where we will read from
345            let extra_read_data_bytes = rng.random_range(0..=5); // make sure 0 shows up often
346            let offset_read_bits = rng.random_range(0..=200);
347            let offset_read_bytes = if offset_read_bits % 8 != 0 {
348                (offset_read_bits / 8) + 1
349            } else {
350                offset_read_bits / 8
351            };
352
353            // create space for writing
354            self.write_data.clear();
355            self.write_data
356                .resize(offset_write_bytes + len + extra_write_data_bytes, 0);
357
358            // interestingly set_bits seems to assume the output is already zeroed
359            // the fuzz tests fail when this is uncommented
360            //self.write_data.try_fill(rng).unwrap();
361            self.offset_write = offset_write_bits;
362
363            // make source data
364            self.data
365                .resize(offset_read_bytes + len + extra_read_data_bytes, 0);
366            // fill source data with random bytes
367            rng.try_fill_bytes(self.data.as_mut_slice()).unwrap();
368            self.offset_read = offset_read_bits;
369
370            self.len = len;
371
372            // generated expectated output (not efficient)
373            self.expected_data.resize(self.write_data.len(), 0);
374            self.expected_data.copy_from_slice(&self.write_data);
375
376            self.expected_null_count = 0;
377            for i in 0..self.len {
378                let bit = get_bit(&self.data, self.offset_read + i);
379                if bit {
380                    set_bit(&mut self.expected_data, self.offset_write + i);
381                } else {
382                    unset_bit(&mut self.expected_data, self.offset_write + i);
383                    self.expected_null_count += 1;
384                }
385            }
386        }
387
388        /// call set_bits with the given parameters and compare with the expected output
389        fn verify(&self) {
390            // call set_bits and compare
391            let mut actual = self.write_data.to_vec();
392            let null_count = set_bits(
393                &mut actual,
394                &self.data,
395                self.offset_write,
396                self.offset_read,
397                self.len,
398            );
399
400            assert_eq!(actual, self.expected_data, "self: {self}");
401            assert_eq!(null_count, self.expected_null_count, "self: {self}");
402        }
403    }
404
405    #[test]
406    fn test_set_upto_64bits() {
407        // len >= 64
408        let write_data: &mut [u8] = &mut [0; 9];
409        let data: &[u8] = &[
410            0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
411            0b00000001, 0b00000001,
412        ];
413        let offset_write = 1;
414        let offset_read = 0;
415        let len = 65;
416        let (n, len_set) =
417            unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
418        assert_eq!(n, 55);
419        assert_eq!(len_set, 63);
420        assert_eq!(
421            write_data,
422            &[
423                0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
424                0b00000010, 0b00000000
425            ]
426        );
427
428        // len = 1
429        let write_data: &mut [u8] = &mut [0b00000000];
430        let data: &[u8] = &[0b00000001];
431        let offset_write = 1;
432        let offset_read = 0;
433        let len = 1;
434        let (n, len_set) =
435            unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
436        assert_eq!(n, 0);
437        assert_eq!(len_set, 1);
438        assert_eq!(write_data, &[0b00000010]);
439    }
440
441    #[test]
442    #[should_panic(expected = "operation will overflow read buffer")]
443    fn test_overflow_read_buffer_bounds() {
444        // Tiny buffers so any huge computed index is out-of-bounds.
445        let data = [0u8; 1];
446        let mut write_data = [0u8; 1];
447
448        // Choose values so (offset_read + len) wraps to a small number in release builds.
449        // offset_read = usize::MAX - 7, len = 8 => wraps to 0.
450        // This can bypass `assert!(offset_read + len <= data.len() * 8)`.
451        let offset_write: usize = 0;
452        let offset_read: usize = usize::MAX - 7;
453        let len: usize = 8;
454
455        // should panic on bounds check overflow
456        let _nulls = set_bits(&mut write_data, &data, offset_write, offset_read, len);
457    }
458
459    #[test]
460    #[should_panic(expected = "operation will overflow write buffer")]
461    fn test_overflow_write_buffer_bounds() {
462        // Tiny buffers so any huge computed index is out-of-bounds.
463        let data = [0u8; 1];
464        let mut write_data = [0u8; 1];
465
466        // Choose values so (offset_write + len) wraps to a small number in release builds.
467        // offset_write = usize::MAX - 7, len = 8 => wraps to 0.
468        // This can bypass `assert!(offset_write + len <= write_data.len() * 8)`.
469        let offset_write: usize = usize::MAX - 7;
470        let offset_read: usize = 0;
471        let len: usize = 8;
472
473        // should panic on bounds check overflow
474        let _nulls = set_bits(&mut write_data, &data, offset_write, offset_read, len);
475    }
476}