Skip to main content

nodedb_query/simd_filter/
bitmask.rs

1// ---------------------------------------------------------------------------
2// Bitmask helpers
3// ---------------------------------------------------------------------------
4
5/// Count total set bits across a bitmask.
6#[inline]
7pub fn popcount(mask: &[u64]) -> u64 {
8    mask.iter().map(|w| w.count_ones() as u64).sum()
9}
10
11/// Bitwise AND of two equal-length bitmasks, SIMD-accelerated.
12pub fn bitmask_and(a: &[u64], b: &[u64]) -> Vec<u64> {
13    let len = a.len().min(b.len());
14    if len == 0 {
15        return Vec::new();
16    }
17    #[cfg(target_arch = "x86_64")]
18    {
19        if std::is_x86_feature_detected!("avx512f") {
20            return unsafe { bitmask_and_avx512(&a[..len], &b[..len]) };
21        }
22        if std::is_x86_feature_detected!("avx2") {
23            return unsafe { bitmask_and_avx2(&a[..len], &b[..len]) };
24        }
25    }
26    bitmask_and_scalar(&a[..len], &b[..len])
27}
28
29#[inline]
30fn bitmask_and_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
31    let mut out = vec![0u64; a.len()];
32    for i in 0..a.len() {
33        out[i] = a[i] & b[i];
34    }
35    out
36}
37
38#[cfg(target_arch = "x86_64")]
39#[target_feature(enable = "avx512f")]
40unsafe fn bitmask_and_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
41    use std::arch::x86_64::*;
42    unsafe {
43        let len = a.len();
44        let mut out = vec![0u64; len];
45        let chunks = len / 8;
46        let a_ptr = a.as_ptr() as *const i64;
47        let b_ptr = b.as_ptr() as *const i64;
48        let o_ptr = out.as_mut_ptr() as *mut i64;
49
50        for i in 0..chunks {
51            let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
52            let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
53            let vc = _mm512_and_si512(va, vb);
54            _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
55        }
56
57        // Scalar tail.
58        for i in (chunks * 8)..len {
59            out[i] = a[i] & b[i];
60        }
61
62        out
63    }
64}
65
66#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68unsafe fn bitmask_and_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
69    use std::arch::x86_64::*;
70    unsafe {
71        let len = a.len();
72        let mut out = vec![0u64; len];
73        let chunks = len / 4;
74        let a_ptr = a.as_ptr() as *const i64;
75        let b_ptr = b.as_ptr() as *const i64;
76
77        for i in 0..chunks {
78            let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
79            let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
80            let vc = _mm256_and_si256(va, vb);
81            _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
82        }
83
84        // Scalar tail.
85        for i in (chunks * 4)..len {
86            out[i] = a[i] & b[i];
87        }
88
89        out
90    }
91}
92
93/// Bitwise OR of two bitmasks, SIMD-accelerated.
94/// If the slices differ in length, the longer tail is copied as-is.
95pub fn bitmask_or(a: &[u64], b: &[u64]) -> Vec<u64> {
96    let max_len = a.len().max(b.len());
97    let min_len = a.len().min(b.len());
98    if max_len == 0 {
99        return Vec::new();
100    }
101
102    // OR the overlapping prefix.
103    let prefix = if min_len == 0 {
104        Vec::new()
105    } else {
106        #[cfg(target_arch = "x86_64")]
107        {
108            if std::is_x86_feature_detected!("avx512f") {
109                unsafe { bitmask_or_avx512(&a[..min_len], &b[..min_len]) }
110            } else if std::is_x86_feature_detected!("avx2") {
111                unsafe { bitmask_or_avx2(&a[..min_len], &b[..min_len]) }
112            } else {
113                bitmask_or_scalar(&a[..min_len], &b[..min_len])
114            }
115        }
116        #[cfg(not(target_arch = "x86_64"))]
117        bitmask_or_scalar(&a[..min_len], &b[..min_len])
118    };
119
120    // Extend with the tail of the longer slice.
121    let mut out = prefix;
122    out.resize(max_len, 0u64);
123    if a.len() > min_len {
124        out[min_len..].copy_from_slice(&a[min_len..]);
125    } else if b.len() > min_len {
126        out[min_len..].copy_from_slice(&b[min_len..]);
127    }
128    out
129}
130
131#[inline]
132fn bitmask_or_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
133    let mut out = vec![0u64; a.len()];
134    for i in 0..a.len() {
135        out[i] = a[i] | b[i];
136    }
137    out
138}
139
140#[cfg(target_arch = "x86_64")]
141#[target_feature(enable = "avx512f")]
142unsafe fn bitmask_or_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
143    use std::arch::x86_64::*;
144    unsafe {
145        let len = a.len();
146        let mut out = vec![0u64; len];
147        let chunks = len / 8;
148        let a_ptr = a.as_ptr() as *const i64;
149        let b_ptr = b.as_ptr() as *const i64;
150        let o_ptr = out.as_mut_ptr() as *mut i64;
151
152        for i in 0..chunks {
153            let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
154            let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
155            let vc = _mm512_or_si512(va, vb);
156            _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
157        }
158
159        // Scalar tail.
160        for i in (chunks * 8)..len {
161            out[i] = a[i] | b[i];
162        }
163
164        out
165    }
166}
167
168#[cfg(target_arch = "x86_64")]
169#[target_feature(enable = "avx2")]
170unsafe fn bitmask_or_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
171    use std::arch::x86_64::*;
172    unsafe {
173        let len = a.len();
174        let mut out = vec![0u64; len];
175        let chunks = len / 4;
176        let a_ptr = a.as_ptr() as *const i64;
177        let b_ptr = b.as_ptr() as *const i64;
178
179        for i in 0..chunks {
180            let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
181            let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
182            let vc = _mm256_or_si256(va, vb);
183            _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
184        }
185
186        // Scalar tail.
187        for i in (chunks * 4)..len {
188            out[i] = a[i] | b[i];
189        }
190
191        out
192    }
193}
194
195/// Bitwise NOT of a bitmask (flips all bits up to `row_count`).
196pub fn bitmask_not(mask: &[u64], row_count: usize) -> Vec<u64> {
197    let words = row_count.div_ceil(64);
198    let mut out = vec![0u64; words];
199    for i in 0..mask.len().min(words) {
200        out[i] = !mask[i];
201    }
202    // Clear bits beyond row_count in the last word.
203    let tail = row_count % 64;
204    if tail > 0 && !out.is_empty() {
205        let last = out.len() - 1;
206        out[last] &= (1u64 << tail) - 1;
207    }
208    out
209}
210
211/// Create an all-ones bitmask for `row_count` rows.
212pub fn bitmask_all(row_count: usize) -> Vec<u64> {
213    let words = row_count.div_ceil(64);
214    let mut out = vec![u64::MAX; words];
215    let tail = row_count % 64;
216    if tail > 0 && !out.is_empty() {
217        let last = out.len() - 1;
218        out[last] = (1u64 << tail) - 1;
219    }
220    out
221}
222
223/// Expand bitmask to a selection vector of row indices.
224pub fn bitmask_to_indices(mask: &[u64]) -> Vec<u32> {
225    let ones: u64 = popcount(mask);
226    let mut out = Vec::with_capacity(ones as usize);
227    for (word_idx, &word) in mask.iter().enumerate() {
228        if word == 0 {
229            continue;
230        }
231        let base = (word_idx as u32) * 64;
232        let mut w = word;
233        while w != 0 {
234            let bit = w.trailing_zeros();
235            out.push(base + bit);
236            w &= w - 1; // clear lowest set bit
237        }
238    }
239    out
240}
241
242/// Number of u64 words needed for `row_count` bits.
243#[inline]
244pub fn words_for(row_count: usize) -> usize {
245    row_count.div_ceil(64)
246}