Skip to main content

kyu_storage/
null_mask.rs

1//! Packed u64 bitset for null tracking.
2//!
3//! Matches Kuzu's `NullMask`: bit=1 means NULL, bit=0 means non-null.
4//! Each u64 entry tracks 64 values.
5
6const BITS_PER_ENTRY_LOG2: u32 = 6;
7const BITS_PER_ENTRY: u64 = 1 << BITS_PER_ENTRY_LOG2;
8const NO_NULL_ENTRY: u64 = 0;
9const ALL_NULL_ENTRY: u64 = !0u64;
10
11/// Packed null bitmask where bit=1 means NULL, bit=0 means non-null.
12#[derive(Clone, Debug)]
13pub struct NullMask {
14    data: Vec<u64>,
15    may_contain_nulls: bool,
16}
17
18impl NullMask {
19    /// Create a new null mask with all values set to non-null.
20    pub fn new(capacity: u64) -> Self {
21        let num_entries = num_entries_for(capacity);
22        Self {
23            data: vec![NO_NULL_ENTRY; num_entries],
24            may_contain_nulls: false,
25        }
26    }
27
28    /// Create a null mask with all values set to null.
29    pub fn all_null(capacity: u64) -> Self {
30        let num_entries = num_entries_for(capacity);
31        Self {
32            data: vec![ALL_NULL_ENTRY; num_entries],
33            may_contain_nulls: true,
34        }
35    }
36
37    /// Fast-path: returns `true` if no nulls are guaranteed absent.
38    #[inline]
39    pub fn has_no_nulls_guarantee(&self) -> bool {
40        !self.may_contain_nulls
41    }
42
43    /// Check whether a specific position is null.
44    #[inline]
45    pub fn is_null(&self, pos: u64) -> bool {
46        let (entry_idx, bit_idx) = entry_and_bit(pos);
47        (self.data[entry_idx] >> bit_idx) & 1 != 0
48    }
49
50    /// Set a specific position to null or non-null.
51    #[inline]
52    pub fn set_null(&mut self, pos: u64, is_null: bool) {
53        let (entry_idx, bit_idx) = entry_and_bit(pos);
54        if is_null {
55            self.data[entry_idx] |= 1u64 << bit_idx;
56            self.may_contain_nulls = true;
57        } else {
58            self.data[entry_idx] &= !(1u64 << bit_idx);
59        }
60    }
61
62    /// Set all positions to non-null.
63    pub fn set_all_non_null(&mut self) {
64        if !self.may_contain_nulls {
65            return;
66        }
67        self.data.fill(NO_NULL_ENTRY);
68        self.may_contain_nulls = false;
69    }
70
71    /// Set all positions to null.
72    pub fn set_all_null(&mut self) {
73        self.data.fill(ALL_NULL_ENTRY);
74        self.may_contain_nulls = true;
75    }
76
77    /// Set a range of bits to null or non-null.
78    pub fn set_null_range(&mut self, offset: u64, count: u64, is_null: bool) {
79        if count == 0 {
80            return;
81        }
82        if is_null {
83            self.may_contain_nulls = true;
84        }
85        let fill = if is_null {
86            ALL_NULL_ENTRY
87        } else {
88            NO_NULL_ENTRY
89        };
90        let end = offset + count;
91        let (start_entry, start_bit) = entry_and_bit(offset);
92        let (end_entry, end_bit) = entry_and_bit(end);
93
94        if start_entry == end_entry {
95            let mask = lower_mask(count) << start_bit;
96            if is_null {
97                self.data[start_entry] |= mask;
98            } else {
99                self.data[start_entry] &= !mask;
100            }
101            return;
102        }
103
104        // First partial entry
105        if start_bit != 0 {
106            let mask = ALL_NULL_ENTRY << start_bit;
107            if is_null {
108                self.data[start_entry] |= mask;
109            } else {
110                self.data[start_entry] &= !mask;
111            }
112        } else {
113            self.data[start_entry] = fill;
114        }
115
116        // Full entries in the middle
117        for entry in &mut self.data[start_entry + if start_bit != 0 { 1 } else { 0 }..end_entry] {
118            *entry = fill;
119        }
120
121        // Last partial entry
122        if end_bit != 0 {
123            let mask = lower_mask(end_bit as u64);
124            if is_null {
125                self.data[end_entry] |= mask;
126            } else {
127                self.data[end_entry] &= !mask;
128            }
129        }
130    }
131
132    /// Count the number of null bits set in the mask.
133    pub fn count_nulls(&self) -> u64 {
134        if !self.may_contain_nulls {
135            return 0;
136        }
137        self.data.iter().map(|e| e.count_ones() as u64).sum()
138    }
139
140    /// Copy null bits from another mask.
141    /// Returns `true` if any null bit was copied.
142    pub fn copy_from(
143        &mut self,
144        src: &NullMask,
145        src_offset: u64,
146        dst_offset: u64,
147        count: u64,
148    ) -> bool {
149        if count == 0 {
150            return false;
151        }
152        if src.has_no_nulls_guarantee() {
153            self.set_null_range(dst_offset, count, false);
154            return false;
155        }
156
157        let mut any_null = false;
158        for i in 0..count {
159            let is_null = src.is_null(src_offset + i);
160            self.set_null(dst_offset + i, is_null);
161            any_null |= is_null;
162        }
163        any_null
164    }
165
166    /// Resize the mask to a new capacity. New bits are set to non-null.
167    pub fn resize(&mut self, new_capacity: u64) {
168        let new_num_entries = num_entries_for(new_capacity);
169        self.data.resize(new_num_entries, NO_NULL_ENTRY);
170    }
171
172    /// Number of u64 entries in the backing buffer.
173    pub fn num_entries(&self) -> usize {
174        self.data.len()
175    }
176
177    /// Total number of bits the mask can track.
178    pub fn capacity(&self) -> u64 {
179        self.data.len() as u64 * BITS_PER_ENTRY
180    }
181
182    /// Access the raw data slice (read-only).
183    #[inline]
184    pub fn data(&self) -> &[u64] {
185        &self.data
186    }
187
188    /// Construct from a raw u64 vec (used by JIT output).
189    pub fn from_raw(data: Vec<u64>, capacity: u64) -> Self {
190        let may_contain_nulls = data.iter().any(|&w| w != 0);
191        let mut mask = Self {
192            data,
193            may_contain_nulls,
194        };
195        let needed = num_entries_for(capacity);
196        mask.data.resize(needed, NO_NULL_ENTRY);
197        mask
198    }
199}
200
201#[inline]
202fn entry_and_bit(pos: u64) -> (usize, u32) {
203    let entry = (pos >> BITS_PER_ENTRY_LOG2) as usize;
204    let bit = (pos & (BITS_PER_ENTRY - 1)) as u32;
205    (entry, bit)
206}
207
208#[inline]
209fn num_entries_for(capacity: u64) -> usize {
210    capacity.div_ceil(BITS_PER_ENTRY) as usize
211}
212
213#[inline]
214fn lower_mask(count: u64) -> u64 {
215    if count >= 64 {
216        ALL_NULL_ENTRY
217    } else {
218        (1u64 << count) - 1
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn new_mask_all_non_null() {
228        let mask = NullMask::new(100);
229        assert!(mask.has_no_nulls_guarantee());
230        for i in 0..100 {
231            assert!(!mask.is_null(i));
232        }
233    }
234
235    #[test]
236    fn all_null_mask() {
237        let mask = NullMask::all_null(100);
238        assert!(!mask.has_no_nulls_guarantee());
239        for i in 0..100 {
240            assert!(mask.is_null(i));
241        }
242    }
243
244    #[test]
245    fn set_and_check_null() {
246        let mut mask = NullMask::new(128);
247        mask.set_null(0, true);
248        mask.set_null(63, true);
249        mask.set_null(64, true);
250        mask.set_null(127, true);
251
252        assert!(mask.is_null(0));
253        assert!(mask.is_null(63));
254        assert!(mask.is_null(64));
255        assert!(mask.is_null(127));
256        assert!(!mask.is_null(1));
257        assert!(!mask.is_null(62));
258        assert!(!mask.is_null(65));
259    }
260
261    #[test]
262    fn set_null_then_clear() {
263        let mut mask = NullMask::new(64);
264        mask.set_null(10, true);
265        assert!(mask.is_null(10));
266        mask.set_null(10, false);
267        assert!(!mask.is_null(10));
268    }
269
270    #[test]
271    fn count_nulls_empty() {
272        let mask = NullMask::new(256);
273        assert_eq!(mask.count_nulls(), 0);
274    }
275
276    #[test]
277    fn count_nulls_some() {
278        let mut mask = NullMask::new(256);
279        mask.set_null(0, true);
280        mask.set_null(100, true);
281        mask.set_null(255, true);
282        assert_eq!(mask.count_nulls(), 3);
283    }
284
285    #[test]
286    fn set_null_range_within_single_entry() {
287        let mut mask = NullMask::new(64);
288        mask.set_null_range(4, 8, true);
289        for i in 0..64 {
290            assert_eq!(mask.is_null(i), (4..12).contains(&i), "pos {i}");
291        }
292    }
293
294    #[test]
295    fn set_null_range_across_entries() {
296        let mut mask = NullMask::new(256);
297        mask.set_null_range(60, 10, true);
298        for i in 0..256 {
299            assert_eq!(mask.is_null(i), (60..70).contains(&i), "pos {i}");
300        }
301    }
302
303    #[test]
304    fn set_null_range_full_entries() {
305        let mut mask = NullMask::new(256);
306        mask.set_null_range(0, 256, true);
307        assert_eq!(mask.count_nulls(), 256);
308    }
309
310    #[test]
311    fn set_null_range_clear() {
312        let mut mask = NullMask::all_null(128);
313        mask.set_null_range(10, 20, false);
314        for i in 0..128 {
315            assert_eq!(mask.is_null(i), !(10..30).contains(&i), "pos {i}");
316        }
317    }
318
319    #[test]
320    fn copy_from_basic() {
321        let mut src = NullMask::new(64);
322        src.set_null(5, true);
323        src.set_null(10, true);
324
325        let mut dst = NullMask::new(64);
326        let any_null = dst.copy_from(&src, 0, 0, 64);
327        assert!(any_null);
328        assert!(dst.is_null(5));
329        assert!(dst.is_null(10));
330        assert!(!dst.is_null(0));
331    }
332
333    #[test]
334    fn copy_from_with_offset() {
335        let mut src = NullMask::new(64);
336        src.set_null(0, true);
337
338        let mut dst = NullMask::new(128);
339        dst.copy_from(&src, 0, 64, 1);
340        assert!(dst.is_null(64));
341        assert!(!dst.is_null(0));
342    }
343
344    #[test]
345    fn copy_from_no_nulls_source() {
346        let src = NullMask::new(64);
347        let mut dst = NullMask::all_null(64);
348        let any_null = dst.copy_from(&src, 0, 0, 64);
349        assert!(!any_null);
350        for i in 0..64 {
351            assert!(!dst.is_null(i));
352        }
353    }
354
355    #[test]
356    fn resize_grow() {
357        let mut mask = NullMask::new(64);
358        mask.set_null(0, true);
359        mask.resize(256);
360        assert!(mask.is_null(0));
361        assert!(!mask.is_null(128));
362        assert_eq!(mask.num_entries(), 4);
363    }
364
365    #[test]
366    fn resize_shrink() {
367        let mut mask = NullMask::new(256);
368        mask.set_null(0, true);
369        mask.resize(64);
370        assert!(mask.is_null(0));
371        assert_eq!(mask.num_entries(), 1);
372    }
373
374    #[test]
375    fn set_all_non_null() {
376        let mut mask = NullMask::all_null(128);
377        assert_eq!(mask.count_nulls(), 128);
378        mask.set_all_non_null();
379        assert!(mask.has_no_nulls_guarantee());
380        assert_eq!(mask.count_nulls(), 0);
381    }
382
383    #[test]
384    fn set_all_null() {
385        let mut mask = NullMask::new(128);
386        mask.set_all_null();
387        assert_eq!(mask.count_nulls(), 128);
388    }
389}