Skip to main content

nodedb_query/simd_filter/
bitmask.rs

1// SPDX-License-Identifier: Apache-2.0
2
3// ---------------------------------------------------------------------------
4// Bitmask helpers
5// ---------------------------------------------------------------------------
6
7/// Count total set bits across a bitmask.
8#[inline]
9pub fn popcount(mask: &[u64]) -> u64 {
10    mask.iter().map(|w| w.count_ones() as u64).sum()
11}
12
13/// Bitwise AND of two equal-length bitmasks, SIMD-accelerated.
14pub fn bitmask_and(a: &[u64], b: &[u64]) -> Vec<u64> {
15    let len = a.len().min(b.len());
16    if len == 0 {
17        return Vec::new();
18    }
19    #[cfg(target_arch = "x86_64")]
20    {
21        if std::is_x86_feature_detected!("avx512f") {
22            return unsafe { bitmask_and_avx512(&a[..len], &b[..len]) };
23        }
24        if std::is_x86_feature_detected!("avx2") {
25            return unsafe { bitmask_and_avx2(&a[..len], &b[..len]) };
26        }
27    }
28    bitmask_and_scalar(&a[..len], &b[..len])
29}
30
31#[inline]
32fn bitmask_and_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
33    let mut out = vec![0u64; a.len()];
34    for i in 0..a.len() {
35        out[i] = a[i] & b[i];
36    }
37    out
38}
39
40#[cfg(target_arch = "x86_64")]
41#[target_feature(enable = "avx512f")]
42unsafe fn bitmask_and_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
43    use std::arch::x86_64::*;
44    unsafe {
45        let len = a.len();
46        let mut out = vec![0u64; len];
47        let chunks = len / 8;
48        let a_ptr = a.as_ptr() as *const i64;
49        let b_ptr = b.as_ptr() as *const i64;
50        let o_ptr = out.as_mut_ptr() as *mut i64;
51
52        for i in 0..chunks {
53            let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
54            let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
55            let vc = _mm512_and_si512(va, vb);
56            _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
57        }
58
59        // Scalar tail.
60        for i in (chunks * 8)..len {
61            out[i] = a[i] & b[i];
62        }
63
64        out
65    }
66}
67
68#[cfg(target_arch = "x86_64")]
69#[target_feature(enable = "avx2")]
70unsafe fn bitmask_and_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
71    use std::arch::x86_64::*;
72    unsafe {
73        let len = a.len();
74        let mut out = vec![0u64; len];
75        let chunks = len / 4;
76        let a_ptr = a.as_ptr() as *const i64;
77        let b_ptr = b.as_ptr() as *const i64;
78
79        for i in 0..chunks {
80            let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
81            let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
82            let vc = _mm256_and_si256(va, vb);
83            _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
84        }
85
86        // Scalar tail.
87        for i in (chunks * 4)..len {
88            out[i] = a[i] & b[i];
89        }
90
91        out
92    }
93}
94
95/// Bitwise OR of two bitmasks, SIMD-accelerated.
96/// If the slices differ in length, the longer tail is copied as-is.
97pub fn bitmask_or(a: &[u64], b: &[u64]) -> Vec<u64> {
98    let max_len = a.len().max(b.len());
99    let min_len = a.len().min(b.len());
100    if max_len == 0 {
101        return Vec::new();
102    }
103
104    // OR the overlapping prefix.
105    let prefix = if min_len == 0 {
106        Vec::new()
107    } else {
108        #[cfg(target_arch = "x86_64")]
109        {
110            if std::is_x86_feature_detected!("avx512f") {
111                unsafe { bitmask_or_avx512(&a[..min_len], &b[..min_len]) }
112            } else if std::is_x86_feature_detected!("avx2") {
113                unsafe { bitmask_or_avx2(&a[..min_len], &b[..min_len]) }
114            } else {
115                bitmask_or_scalar(&a[..min_len], &b[..min_len])
116            }
117        }
118        #[cfg(not(target_arch = "x86_64"))]
119        bitmask_or_scalar(&a[..min_len], &b[..min_len])
120    };
121
122    // Extend with the tail of the longer slice.
123    let mut out = prefix;
124    out.resize(max_len, 0u64);
125    if a.len() > min_len {
126        out[min_len..].copy_from_slice(&a[min_len..]);
127    } else if b.len() > min_len {
128        out[min_len..].copy_from_slice(&b[min_len..]);
129    }
130    out
131}
132
133#[inline]
134fn bitmask_or_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
135    let mut out = vec![0u64; a.len()];
136    for i in 0..a.len() {
137        out[i] = a[i] | b[i];
138    }
139    out
140}
141
142#[cfg(target_arch = "x86_64")]
143#[target_feature(enable = "avx512f")]
144unsafe fn bitmask_or_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
145    use std::arch::x86_64::*;
146    unsafe {
147        let len = a.len();
148        let mut out = vec![0u64; len];
149        let chunks = len / 8;
150        let a_ptr = a.as_ptr() as *const i64;
151        let b_ptr = b.as_ptr() as *const i64;
152        let o_ptr = out.as_mut_ptr() as *mut i64;
153
154        for i in 0..chunks {
155            let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
156            let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
157            let vc = _mm512_or_si512(va, vb);
158            _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
159        }
160
161        // Scalar tail.
162        for i in (chunks * 8)..len {
163            out[i] = a[i] | b[i];
164        }
165
166        out
167    }
168}
169
170#[cfg(target_arch = "x86_64")]
171#[target_feature(enable = "avx2")]
172unsafe fn bitmask_or_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
173    use std::arch::x86_64::*;
174    unsafe {
175        let len = a.len();
176        let mut out = vec![0u64; len];
177        let chunks = len / 4;
178        let a_ptr = a.as_ptr() as *const i64;
179        let b_ptr = b.as_ptr() as *const i64;
180
181        for i in 0..chunks {
182            let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
183            let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
184            let vc = _mm256_or_si256(va, vb);
185            _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
186        }
187
188        // Scalar tail.
189        for i in (chunks * 4)..len {
190            out[i] = a[i] | b[i];
191        }
192
193        out
194    }
195}
196
197/// Bitwise NOT of a bitmask (flips all bits up to `row_count`).
198pub fn bitmask_not(mask: &[u64], row_count: usize) -> Vec<u64> {
199    let words = row_count.div_ceil(64);
200    let mut out = vec![0u64; words];
201    for i in 0..mask.len().min(words) {
202        out[i] = !mask[i];
203    }
204    // Clear bits beyond row_count in the last word.
205    let tail = row_count % 64;
206    if tail > 0 && !out.is_empty() {
207        let last = out.len() - 1;
208        out[last] &= (1u64 << tail) - 1;
209    }
210    out
211}
212
213/// Create an all-ones bitmask for `row_count` rows.
214pub fn bitmask_all(row_count: usize) -> Vec<u64> {
215    let words = row_count.div_ceil(64);
216    let mut out = vec![u64::MAX; words];
217    let tail = row_count % 64;
218    if tail > 0 && !out.is_empty() {
219        let last = out.len() - 1;
220        out[last] = (1u64 << tail) - 1;
221    }
222    out
223}
224
225/// Expand bitmask to a selection vector of row indices.
226pub fn bitmask_to_indices(mask: &[u64]) -> Vec<u32> {
227    let ones: u64 = popcount(mask);
228    let mut out = Vec::with_capacity(ones as usize);
229    for (word_idx, &word) in mask.iter().enumerate() {
230        if word == 0 {
231            continue;
232        }
233        let base = (word_idx as u32) * 64;
234        let mut w = word;
235        while w != 0 {
236            let bit = w.trailing_zeros();
237            out.push(base + bit);
238            w &= w - 1; // clear lowest set bit
239        }
240    }
241    out
242}
243
244/// Number of u64 words needed for `row_count` bits.
245#[inline]
246pub fn words_for(row_count: usize) -> usize {
247    row_count.div_ceil(64)
248}