nodedb_query/simd_filter/
bitmask.rs1#[inline]
7pub fn popcount(mask: &[u64]) -> u64 {
8 mask.iter().map(|w| w.count_ones() as u64).sum()
9}
10
11pub fn bitmask_and(a: &[u64], b: &[u64]) -> Vec<u64> {
13 let len = a.len().min(b.len());
14 if len == 0 {
15 return Vec::new();
16 }
17 #[cfg(target_arch = "x86_64")]
18 {
19 if std::is_x86_feature_detected!("avx512f") {
20 return unsafe { bitmask_and_avx512(&a[..len], &b[..len]) };
21 }
22 if std::is_x86_feature_detected!("avx2") {
23 return unsafe { bitmask_and_avx2(&a[..len], &b[..len]) };
24 }
25 }
26 bitmask_and_scalar(&a[..len], &b[..len])
27}
28
29#[inline]
30fn bitmask_and_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
31 let mut out = vec![0u64; a.len()];
32 for i in 0..a.len() {
33 out[i] = a[i] & b[i];
34 }
35 out
36}
37
38#[cfg(target_arch = "x86_64")]
39#[target_feature(enable = "avx512f")]
40unsafe fn bitmask_and_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
41 use std::arch::x86_64::*;
42 unsafe {
43 let len = a.len();
44 let mut out = vec![0u64; len];
45 let chunks = len / 8;
46 let a_ptr = a.as_ptr() as *const i64;
47 let b_ptr = b.as_ptr() as *const i64;
48 let o_ptr = out.as_mut_ptr() as *mut i64;
49
50 for i in 0..chunks {
51 let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
52 let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
53 let vc = _mm512_and_si512(va, vb);
54 _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
55 }
56
57 for i in (chunks * 8)..len {
59 out[i] = a[i] & b[i];
60 }
61
62 out
63 }
64}
65
66#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68unsafe fn bitmask_and_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
69 use std::arch::x86_64::*;
70 unsafe {
71 let len = a.len();
72 let mut out = vec![0u64; len];
73 let chunks = len / 4;
74 let a_ptr = a.as_ptr() as *const i64;
75 let b_ptr = b.as_ptr() as *const i64;
76
77 for i in 0..chunks {
78 let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
79 let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
80 let vc = _mm256_and_si256(va, vb);
81 _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
82 }
83
84 for i in (chunks * 4)..len {
86 out[i] = a[i] & b[i];
87 }
88
89 out
90 }
91}
92
93pub fn bitmask_or(a: &[u64], b: &[u64]) -> Vec<u64> {
96 let max_len = a.len().max(b.len());
97 let min_len = a.len().min(b.len());
98 if max_len == 0 {
99 return Vec::new();
100 }
101
102 let prefix = if min_len == 0 {
104 Vec::new()
105 } else {
106 #[cfg(target_arch = "x86_64")]
107 {
108 if std::is_x86_feature_detected!("avx512f") {
109 unsafe { bitmask_or_avx512(&a[..min_len], &b[..min_len]) }
110 } else if std::is_x86_feature_detected!("avx2") {
111 unsafe { bitmask_or_avx2(&a[..min_len], &b[..min_len]) }
112 } else {
113 bitmask_or_scalar(&a[..min_len], &b[..min_len])
114 }
115 }
116 #[cfg(not(target_arch = "x86_64"))]
117 bitmask_or_scalar(&a[..min_len], &b[..min_len])
118 };
119
120 let mut out = prefix;
122 out.resize(max_len, 0u64);
123 if a.len() > min_len {
124 out[min_len..].copy_from_slice(&a[min_len..]);
125 } else if b.len() > min_len {
126 out[min_len..].copy_from_slice(&b[min_len..]);
127 }
128 out
129}
130
131#[inline]
132fn bitmask_or_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
133 let mut out = vec![0u64; a.len()];
134 for i in 0..a.len() {
135 out[i] = a[i] | b[i];
136 }
137 out
138}
139
140#[cfg(target_arch = "x86_64")]
141#[target_feature(enable = "avx512f")]
142unsafe fn bitmask_or_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
143 use std::arch::x86_64::*;
144 unsafe {
145 let len = a.len();
146 let mut out = vec![0u64; len];
147 let chunks = len / 8;
148 let a_ptr = a.as_ptr() as *const i64;
149 let b_ptr = b.as_ptr() as *const i64;
150 let o_ptr = out.as_mut_ptr() as *mut i64;
151
152 for i in 0..chunks {
153 let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
154 let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
155 let vc = _mm512_or_si512(va, vb);
156 _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
157 }
158
159 for i in (chunks * 8)..len {
161 out[i] = a[i] | b[i];
162 }
163
164 out
165 }
166}
167
168#[cfg(target_arch = "x86_64")]
169#[target_feature(enable = "avx2")]
170unsafe fn bitmask_or_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
171 use std::arch::x86_64::*;
172 unsafe {
173 let len = a.len();
174 let mut out = vec![0u64; len];
175 let chunks = len / 4;
176 let a_ptr = a.as_ptr() as *const i64;
177 let b_ptr = b.as_ptr() as *const i64;
178
179 for i in 0..chunks {
180 let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
181 let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
182 let vc = _mm256_or_si256(va, vb);
183 _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
184 }
185
186 for i in (chunks * 4)..len {
188 out[i] = a[i] | b[i];
189 }
190
191 out
192 }
193}
194
195pub fn bitmask_not(mask: &[u64], row_count: usize) -> Vec<u64> {
197 let words = row_count.div_ceil(64);
198 let mut out = vec![0u64; words];
199 for i in 0..mask.len().min(words) {
200 out[i] = !mask[i];
201 }
202 let tail = row_count % 64;
204 if tail > 0 && !out.is_empty() {
205 let last = out.len() - 1;
206 out[last] &= (1u64 << tail) - 1;
207 }
208 out
209}
210
211pub fn bitmask_all(row_count: usize) -> Vec<u64> {
213 let words = row_count.div_ceil(64);
214 let mut out = vec![u64::MAX; words];
215 let tail = row_count % 64;
216 if tail > 0 && !out.is_empty() {
217 let last = out.len() - 1;
218 out[last] = (1u64 << tail) - 1;
219 }
220 out
221}
222
223pub fn bitmask_to_indices(mask: &[u64]) -> Vec<u32> {
225 let ones: u64 = popcount(mask);
226 let mut out = Vec::with_capacity(ones as usize);
227 for (word_idx, &word) in mask.iter().enumerate() {
228 if word == 0 {
229 continue;
230 }
231 let base = (word_idx as u32) * 64;
232 let mut w = word;
233 while w != 0 {
234 let bit = w.trailing_zeros();
235 out.push(base + bit);
236 w &= w - 1; }
238 }
239 out
240}
241
242#[inline]
244pub fn words_for(row_count: usize) -> usize {
245 row_count.div_ceil(64)
246}