nodedb_query/simd_filter/
bitmask.rs1#[inline]
9pub fn popcount(mask: &[u64]) -> u64 {
10 mask.iter().map(|w| w.count_ones() as u64).sum()
11}
12
13pub fn bitmask_and(a: &[u64], b: &[u64]) -> Vec<u64> {
15 let len = a.len().min(b.len());
16 if len == 0 {
17 return Vec::new();
18 }
19 #[cfg(target_arch = "x86_64")]
20 {
21 if std::is_x86_feature_detected!("avx512f") {
22 return unsafe { bitmask_and_avx512(&a[..len], &b[..len]) };
23 }
24 if std::is_x86_feature_detected!("avx2") {
25 return unsafe { bitmask_and_avx2(&a[..len], &b[..len]) };
26 }
27 }
28 bitmask_and_scalar(&a[..len], &b[..len])
29}
30
31#[inline]
32fn bitmask_and_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
33 let mut out = vec![0u64; a.len()];
34 for i in 0..a.len() {
35 out[i] = a[i] & b[i];
36 }
37 out
38}
39
40#[cfg(target_arch = "x86_64")]
41#[target_feature(enable = "avx512f")]
42unsafe fn bitmask_and_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
43 use std::arch::x86_64::*;
44 unsafe {
45 let len = a.len();
46 let mut out = vec![0u64; len];
47 let chunks = len / 8;
48 let a_ptr = a.as_ptr() as *const i64;
49 let b_ptr = b.as_ptr() as *const i64;
50 let o_ptr = out.as_mut_ptr() as *mut i64;
51
52 for i in 0..chunks {
53 let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
54 let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
55 let vc = _mm512_and_si512(va, vb);
56 _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
57 }
58
59 for i in (chunks * 8)..len {
61 out[i] = a[i] & b[i];
62 }
63
64 out
65 }
66}
67
68#[cfg(target_arch = "x86_64")]
69#[target_feature(enable = "avx2")]
70unsafe fn bitmask_and_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
71 use std::arch::x86_64::*;
72 unsafe {
73 let len = a.len();
74 let mut out = vec![0u64; len];
75 let chunks = len / 4;
76 let a_ptr = a.as_ptr() as *const i64;
77 let b_ptr = b.as_ptr() as *const i64;
78
79 for i in 0..chunks {
80 let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
81 let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
82 let vc = _mm256_and_si256(va, vb);
83 _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
84 }
85
86 for i in (chunks * 4)..len {
88 out[i] = a[i] & b[i];
89 }
90
91 out
92 }
93}
94
95pub fn bitmask_or(a: &[u64], b: &[u64]) -> Vec<u64> {
98 let max_len = a.len().max(b.len());
99 let min_len = a.len().min(b.len());
100 if max_len == 0 {
101 return Vec::new();
102 }
103
104 let prefix = if min_len == 0 {
106 Vec::new()
107 } else {
108 #[cfg(target_arch = "x86_64")]
109 {
110 if std::is_x86_feature_detected!("avx512f") {
111 unsafe { bitmask_or_avx512(&a[..min_len], &b[..min_len]) }
112 } else if std::is_x86_feature_detected!("avx2") {
113 unsafe { bitmask_or_avx2(&a[..min_len], &b[..min_len]) }
114 } else {
115 bitmask_or_scalar(&a[..min_len], &b[..min_len])
116 }
117 }
118 #[cfg(not(target_arch = "x86_64"))]
119 bitmask_or_scalar(&a[..min_len], &b[..min_len])
120 };
121
122 let mut out = prefix;
124 out.resize(max_len, 0u64);
125 if a.len() > min_len {
126 out[min_len..].copy_from_slice(&a[min_len..]);
127 } else if b.len() > min_len {
128 out[min_len..].copy_from_slice(&b[min_len..]);
129 }
130 out
131}
132
133#[inline]
134fn bitmask_or_scalar(a: &[u64], b: &[u64]) -> Vec<u64> {
135 let mut out = vec![0u64; a.len()];
136 for i in 0..a.len() {
137 out[i] = a[i] | b[i];
138 }
139 out
140}
141
142#[cfg(target_arch = "x86_64")]
143#[target_feature(enable = "avx512f")]
144unsafe fn bitmask_or_avx512(a: &[u64], b: &[u64]) -> Vec<u64> {
145 use std::arch::x86_64::*;
146 unsafe {
147 let len = a.len();
148 let mut out = vec![0u64; len];
149 let chunks = len / 8;
150 let a_ptr = a.as_ptr() as *const i64;
151 let b_ptr = b.as_ptr() as *const i64;
152 let o_ptr = out.as_mut_ptr() as *mut i64;
153
154 for i in 0..chunks {
155 let va = _mm512_loadu_si512(a_ptr.add(i * 8).cast());
156 let vb = _mm512_loadu_si512(b_ptr.add(i * 8).cast());
157 let vc = _mm512_or_si512(va, vb);
158 _mm512_storeu_si512(o_ptr.add(i * 8).cast(), vc);
159 }
160
161 for i in (chunks * 8)..len {
163 out[i] = a[i] | b[i];
164 }
165
166 out
167 }
168}
169
170#[cfg(target_arch = "x86_64")]
171#[target_feature(enable = "avx2")]
172unsafe fn bitmask_or_avx2(a: &[u64], b: &[u64]) -> Vec<u64> {
173 use std::arch::x86_64::*;
174 unsafe {
175 let len = a.len();
176 let mut out = vec![0u64; len];
177 let chunks = len / 4;
178 let a_ptr = a.as_ptr() as *const i64;
179 let b_ptr = b.as_ptr() as *const i64;
180
181 for i in 0..chunks {
182 let va = _mm256_loadu_si256(a_ptr.add(i * 4).cast());
183 let vb = _mm256_loadu_si256(b_ptr.add(i * 4).cast());
184 let vc = _mm256_or_si256(va, vb);
185 _mm256_storeu_si256(out.as_mut_ptr().add(i * 4).cast(), vc);
186 }
187
188 for i in (chunks * 4)..len {
190 out[i] = a[i] | b[i];
191 }
192
193 out
194 }
195}
196
197pub fn bitmask_not(mask: &[u64], row_count: usize) -> Vec<u64> {
199 let words = row_count.div_ceil(64);
200 let mut out = vec![0u64; words];
201 for i in 0..mask.len().min(words) {
202 out[i] = !mask[i];
203 }
204 let tail = row_count % 64;
206 if tail > 0 && !out.is_empty() {
207 let last = out.len() - 1;
208 out[last] &= (1u64 << tail) - 1;
209 }
210 out
211}
212
213pub fn bitmask_all(row_count: usize) -> Vec<u64> {
215 let words = row_count.div_ceil(64);
216 let mut out = vec![u64::MAX; words];
217 let tail = row_count % 64;
218 if tail > 0 && !out.is_empty() {
219 let last = out.len() - 1;
220 out[last] = (1u64 << tail) - 1;
221 }
222 out
223}
224
225pub fn bitmask_to_indices(mask: &[u64]) -> Vec<u32> {
227 let ones: u64 = popcount(mask);
228 let mut out = Vec::with_capacity(ones as usize);
229 for (word_idx, &word) in mask.iter().enumerate() {
230 if word == 0 {
231 continue;
232 }
233 let base = (word_idx as u32) * 64;
234 let mut w = word;
235 while w != 0 {
236 let bit = w.trailing_zeros();
237 out.push(base + bit);
238 w &= w - 1; }
240 }
241 out
242}
243
244#[inline]
246pub fn words_for(row_count: usize) -> usize {
247 row_count.div_ceil(64)
248}