1use std::cmp::Ordering;
4
5pub struct SIMDCompare;
7
8impl SIMDCompare {
9 #[inline]
11 pub fn compare_bytes_simd(a: &[u8], b: &[u8]) -> Ordering {
12 if a.len() <= 16 || b.len() <= 16 {
14 return a.cmp(b);
15 }
16
17 #[cfg(target_arch = "x86_64")]
19 {
20 if is_x86_feature_detected!("avx2") {
21 return Self::compare_avx2(a, b);
22 } else if is_x86_feature_detected!("sse4.2") {
23 return Self::compare_sse42(a, b);
24 }
25 }
26
27 #[cfg(target_arch = "aarch64")]
28 {
29 if std::arch::is_aarch64_feature_detected!("neon") {
30 return Self::compare_neon(a, b);
31 }
32 }
33
34 a.cmp(b)
36 }
37
38 #[inline]
40 pub fn compare_case_insensitive_simd(a: &[u8], b: &[u8]) -> Ordering {
41 let min_len = a.len().min(b.len());
42
43 #[cfg(target_arch = "x86_64")]
45 {
46 if is_x86_feature_detected!("avx2") && min_len >= 32 {
47 return Self::compare_case_insensitive_avx2(a, b);
48 }
49 }
50
51 for i in 0..min_len {
53 let a_char = a[i].to_ascii_lowercase();
54 let b_char = b[i].to_ascii_lowercase();
55 match a_char.cmp(&b_char) {
56 Ordering::Equal => continue,
57 other => return other,
58 }
59 }
60 a.len().cmp(&b.len())
61 }
62
63 #[cfg(target_arch = "x86_64")]
65 #[inline]
66 fn compare_avx2(a: &[u8], b: &[u8]) -> Ordering {
67 use std::arch::x86_64::*;
68
69 let min_len = a.len().min(b.len());
70 let chunk_size = 32; let chunks = min_len / chunk_size;
72
73 if chunks == 0 {
75 return a.cmp(b); }
77
78 unsafe {
79 for i in 0..chunks {
80 let offset = i * chunk_size;
81
82 let va = _mm256_loadu_si256(a.as_ptr().add(offset) as *const __m256i);
86 let vb = _mm256_loadu_si256(b.as_ptr().add(offset) as *const __m256i);
87
88 let cmp = _mm256_cmpeq_epi8(va, vb);
90 let mask = _mm256_movemask_epi8(cmp) as u32;
91
92 if mask != 0xFFFFFFFF {
94 let diff_pos = (!mask).trailing_zeros() as usize;
95 let abs_pos = offset + diff_pos;
96 return a[abs_pos].cmp(&b[abs_pos]);
97 }
98 }
99 }
100
101 let remaining_start = chunks * chunk_size;
103 for i in remaining_start..min_len {
104 match a[i].cmp(&b[i]) {
105 Ordering::Equal => continue,
106 other => return other,
107 }
108 }
109
110 a.len().cmp(&b.len())
111 }
112
113 #[cfg(target_arch = "x86_64")]
115 #[inline]
116 fn compare_sse42(a: &[u8], b: &[u8]) -> Ordering {
117 use std::arch::x86_64::*;
118
119 let min_len = a.len().min(b.len());
120 let chunk_size = 16; let chunks = min_len / chunk_size;
122
123 unsafe {
124 for i in 0..chunks {
125 let offset = i * chunk_size;
126
127 let va = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
129 let vb = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
130
131 let cmp = _mm_cmpeq_epi8(va, vb);
133 let mask = _mm_movemask_epi8(cmp) as u16;
134
135 if mask != 0xFFFF {
137 let diff_pos = (!mask).trailing_zeros() as usize;
138 let abs_pos = offset + diff_pos;
139 return a[abs_pos].cmp(&b[abs_pos]);
140 }
141 }
142 }
143
144 let remaining_start = chunks * chunk_size;
146 for i in remaining_start..min_len {
147 match a[i].cmp(&b[i]) {
148 Ordering::Equal => continue,
149 other => return other,
150 }
151 }
152
153 a.len().cmp(&b.len())
154 }
155
156 #[cfg(target_arch = "aarch64")]
158 #[inline]
159 fn compare_neon(a: &[u8], b: &[u8]) -> Ordering {
160 use std::arch::aarch64::*;
161
162 let min_len = a.len().min(b.len());
163 let chunk_size = 16; let chunks = min_len / chunk_size;
165
166 unsafe {
167 for i in 0..chunks {
168 let offset = i * chunk_size;
169
170 let va = vld1q_u8(a.as_ptr().add(offset));
172 let vb = vld1q_u8(b.as_ptr().add(offset));
173
174 let cmp = vceqq_u8(va, vb);
176
177 let all_equal = vminvq_u8(cmp) == 0xFF;
179 if !all_equal {
180 for j in 0..16 {
182 let pos = offset + j;
183 if a[pos] != b[pos] {
184 return a[pos].cmp(&b[pos]);
185 }
186 }
187 }
188 }
189 }
190
191 let remaining_start = chunks * chunk_size;
193 for i in remaining_start..min_len {
194 match a[i].cmp(&b[i]) {
195 Ordering::Equal => continue,
196 other => return other,
197 }
198 }
199
200 a.len().cmp(&b.len())
201 }
202
203 #[cfg(target_arch = "x86_64")]
205 #[inline]
206 fn compare_case_insensitive_avx2(a: &[u8], b: &[u8]) -> Ordering {
207 use std::arch::x86_64::*;
208
209 let min_len = a.len().min(b.len());
210 let chunk_size = 32;
211 let chunks = min_len / chunk_size;
212
213 unsafe {
214 let upper_a = _mm256_set1_epi8(b'A' as i8);
216 let upper_z = _mm256_set1_epi8(b'Z' as i8);
217 let case_diff = _mm256_set1_epi8(32);
218
219 for i in 0..chunks {
220 let offset = i * chunk_size;
221
222 let mut va = _mm256_loadu_si256(a.as_ptr().add(offset) as *const __m256i);
224 let mut vb = _mm256_loadu_si256(b.as_ptr().add(offset) as *const __m256i);
225
226 let a_is_upper = _mm256_and_si256(
228 _mm256_cmpgt_epi8(va, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1))),
229 _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), va),
230 );
231 let b_is_upper = _mm256_and_si256(
232 _mm256_cmpgt_epi8(vb, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1))),
233 _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), vb),
234 );
235
236 va = _mm256_add_epi8(va, _mm256_and_si256(a_is_upper, case_diff));
237 vb = _mm256_add_epi8(vb, _mm256_and_si256(b_is_upper, case_diff));
238
239 let cmp = _mm256_cmpeq_epi8(va, vb);
241 let mask = _mm256_movemask_epi8(cmp) as u32;
242
243 if mask != 0xFFFFFFFF {
245 let diff_pos = (!mask).trailing_zeros() as usize;
246 let abs_pos = offset + diff_pos;
247 return a[abs_pos]
248 .to_ascii_lowercase()
249 .cmp(&b[abs_pos].to_ascii_lowercase());
250 }
251 }
252 }
253
254 let remaining_start = chunks * chunk_size;
256 for i in remaining_start..min_len {
257 let a_char = a[i].to_ascii_lowercase();
258 let b_char = b[i].to_ascii_lowercase();
259 match a_char.cmp(&b_char) {
260 Ordering::Equal => continue,
261 other => return other,
262 }
263 }
264
265 a.len().cmp(&b.len())
266 }
267
268 #[inline]
270 pub fn is_all_digits_simd(bytes: &[u8]) -> bool {
271 if bytes.is_empty() {
272 return true;
273 }
274
275 #[cfg(target_arch = "x86_64")]
276 {
277 if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
278 return Self::is_all_digits_avx2(bytes);
279 }
280 }
281
282 bytes.iter().all(|&b| b.is_ascii_digit())
284 }
285
286 #[cfg(target_arch = "x86_64")]
288 #[inline]
289 fn is_all_digits_avx2(bytes: &[u8]) -> bool {
290 use std::arch::x86_64::*;
291
292 let chunk_size = 32;
293 let chunks = bytes.len() / chunk_size;
294
295 unsafe {
296 let min_digit = _mm256_set1_epi8(b'0' as i8);
297 let max_digit = _mm256_set1_epi8(b'9' as i8);
298
299 for i in 0..chunks {
300 let offset = i * chunk_size;
301 let v = _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i);
302
303 let ge_min = _mm256_cmpgt_epi8(v, _mm256_sub_epi8(min_digit, _mm256_set1_epi8(1)));
305 let le_max = _mm256_cmpgt_epi8(_mm256_add_epi8(max_digit, _mm256_set1_epi8(1)), v);
306 let is_digit = _mm256_and_si256(ge_min, le_max);
307
308 let mask = _mm256_movemask_epi8(is_digit) as u32;
309 if mask != 0xFFFFFFFF {
310 return false;
311 }
312 }
313 }
314
315 let remaining_start = chunks * chunk_size;
317 bytes[remaining_start..].iter().all(|&b| b.is_ascii_digit())
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_simd_comparison() {
327 let str_a = b"hello world this is a test";
328 let str_b = b"hello world this is a different test";
329
330 let result = SIMDCompare::compare_bytes_simd(str_a, str_b);
331 let expected = str_a[..].cmp(&str_b[..]);
332
333 assert_eq!(result, expected);
334 }
335
336 #[test]
337 fn test_simd_case_insensitive() {
338 let a = b"Hello World";
339 let b = b"HELLO WORLD";
340
341 let result = SIMDCompare::compare_case_insensitive_simd(a, b);
342 assert_eq!(result, Ordering::Equal);
343 }
344
345 #[test]
346 fn test_simd_digit_detection() {
347 assert!(SIMDCompare::is_all_digits_simd(b"123456789"));
348 assert!(!SIMDCompare::is_all_digits_simd(b"123a456"));
349 assert!(SIMDCompare::is_all_digits_simd(b""));
350 }
351}