gnu_sort/
simd_compare.rs

1/// SIMD-accelerated comparison functions for ultra-fast string operations
2/// Uses vectorized instructions to process 32-64 bytes at once
3use std::cmp::Ordering;
4
5/// SIMD-accelerated string comparison
6pub struct SIMDCompare;
7
8impl SIMDCompare {
9    /// Vectorized byte comparison using SIMD when available
10    #[inline]
11    pub fn compare_bytes_simd(a: &[u8], b: &[u8]) -> Ordering {
12        // For small strings, use direct comparison
13        if a.len() <= 16 || b.len() <= 16 {
14            return a.cmp(b);
15        }
16
17        // Use SIMD for larger strings
18        #[cfg(target_arch = "x86_64")]
19        {
20            if is_x86_feature_detected!("avx2") {
21                return Self::compare_avx2(a, b);
22            } else if is_x86_feature_detected!("sse4.2") {
23                return Self::compare_sse42(a, b);
24            }
25        }
26
27        #[cfg(target_arch = "aarch64")]
28        {
29            if std::arch::is_aarch64_feature_detected!("neon") {
30                return Self::compare_neon(a, b);
31            }
32        }
33
34        // Fallback to standard comparison
35        a.cmp(b)
36    }
37
38    /// Vectorized case-insensitive comparison
39    #[inline]
40    pub fn compare_case_insensitive_simd(a: &[u8], b: &[u8]) -> Ordering {
41        let min_len = a.len().min(b.len());
42
43        // Process in chunks of 32 bytes for AVX2
44        #[cfg(target_arch = "x86_64")]
45        {
46            if is_x86_feature_detected!("avx2") && min_len >= 32 {
47                return Self::compare_case_insensitive_avx2(a, b);
48            }
49        }
50
51        // Fallback to byte-by-byte comparison
52        for i in 0..min_len {
53            let a_char = a[i].to_ascii_lowercase();
54            let b_char = b[i].to_ascii_lowercase();
55            match a_char.cmp(&b_char) {
56                Ordering::Equal => continue,
57                other => return other,
58            }
59        }
60        a.len().cmp(&b.len())
61    }
62
63    /// AVX2-accelerated byte comparison
64    #[cfg(target_arch = "x86_64")]
65    #[inline]
66    fn compare_avx2(a: &[u8], b: &[u8]) -> Ordering {
67        use std::arch::x86_64::*;
68
69        let min_len = a.len().min(b.len());
70        let chunk_size = 32; // AVX2 processes 32 bytes at once
71        let chunks = min_len / chunk_size;
72
73        // Safety check: ensure we have enough data for SIMD
74        if chunks == 0 {
75            return a.cmp(b); // Fallback to standard comparison
76        }
77
78        unsafe {
79            for i in 0..chunks {
80                let offset = i * chunk_size;
81
82                // Load 32 bytes from each array
83                // SAFETY: We use unaligned loads (_loadu) which are safe for any alignment
84                // The offset is guaranteed to be within bounds by the chunks calculation
85                let va = _mm256_loadu_si256(a.as_ptr().add(offset) as *const __m256i);
86                let vb = _mm256_loadu_si256(b.as_ptr().add(offset) as *const __m256i);
87
88                // Compare vectors
89                let cmp = _mm256_cmpeq_epi8(va, vb);
90                let mask = _mm256_movemask_epi8(cmp) as u32;
91
92                // If not all bytes are equal, find first difference
93                if mask != 0xFFFFFFFF {
94                    let diff_pos = (!mask).trailing_zeros() as usize;
95                    let abs_pos = offset + diff_pos;
96                    return a[abs_pos].cmp(&b[abs_pos]);
97                }
98            }
99        }
100
101        // Compare remaining bytes
102        let remaining_start = chunks * chunk_size;
103        for i in remaining_start..min_len {
104            match a[i].cmp(&b[i]) {
105                Ordering::Equal => continue,
106                other => return other,
107            }
108        }
109
110        a.len().cmp(&b.len())
111    }
112
113    /// SSE4.2-accelerated byte comparison
114    #[cfg(target_arch = "x86_64")]
115    #[inline]
116    fn compare_sse42(a: &[u8], b: &[u8]) -> Ordering {
117        use std::arch::x86_64::*;
118
119        let min_len = a.len().min(b.len());
120        let chunk_size = 16; // SSE processes 16 bytes at once
121        let chunks = min_len / chunk_size;
122
123        unsafe {
124            for i in 0..chunks {
125                let offset = i * chunk_size;
126
127                // Load 16 bytes from each array
128                let va = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
129                let vb = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
130
131                // Compare vectors
132                let cmp = _mm_cmpeq_epi8(va, vb);
133                let mask = _mm_movemask_epi8(cmp) as u16;
134
135                // If not all bytes are equal, find first difference
136                if mask != 0xFFFF {
137                    let diff_pos = (!mask).trailing_zeros() as usize;
138                    let abs_pos = offset + diff_pos;
139                    return a[abs_pos].cmp(&b[abs_pos]);
140                }
141            }
142        }
143
144        // Compare remaining bytes
145        let remaining_start = chunks * chunk_size;
146        for i in remaining_start..min_len {
147            match a[i].cmp(&b[i]) {
148                Ordering::Equal => continue,
149                other => return other,
150            }
151        }
152
153        a.len().cmp(&b.len())
154    }
155
156    /// ARM NEON-accelerated byte comparison
157    #[cfg(target_arch = "aarch64")]
158    #[inline]
159    fn compare_neon(a: &[u8], b: &[u8]) -> Ordering {
160        use std::arch::aarch64::*;
161
162        let min_len = a.len().min(b.len());
163        let chunk_size = 16; // NEON processes 16 bytes at once
164        let chunks = min_len / chunk_size;
165
166        unsafe {
167            for i in 0..chunks {
168                let offset = i * chunk_size;
169
170                // Load 16 bytes from each array
171                let va = vld1q_u8(a.as_ptr().add(offset));
172                let vb = vld1q_u8(b.as_ptr().add(offset));
173
174                // Compare vectors
175                let cmp = vceqq_u8(va, vb);
176
177                // Check if all lanes are equal
178                let all_equal = vminvq_u8(cmp) == 0xFF;
179                if !all_equal {
180                    // Find first difference
181                    for j in 0..16 {
182                        let pos = offset + j;
183                        if a[pos] != b[pos] {
184                            return a[pos].cmp(&b[pos]);
185                        }
186                    }
187                }
188            }
189        }
190
191        // Compare remaining bytes
192        let remaining_start = chunks * chunk_size;
193        for i in remaining_start..min_len {
194            match a[i].cmp(&b[i]) {
195                Ordering::Equal => continue,
196                other => return other,
197            }
198        }
199
200        a.len().cmp(&b.len())
201    }
202
203    /// AVX2-accelerated case-insensitive comparison
204    #[cfg(target_arch = "x86_64")]
205    #[inline]
206    fn compare_case_insensitive_avx2(a: &[u8], b: &[u8]) -> Ordering {
207        use std::arch::x86_64::*;
208
209        let min_len = a.len().min(b.len());
210        let chunk_size = 32;
211        let chunks = min_len / chunk_size;
212
213        unsafe {
214            // Broadcast constants for case conversion
215            let upper_a = _mm256_set1_epi8(b'A' as i8);
216            let upper_z = _mm256_set1_epi8(b'Z' as i8);
217            let case_diff = _mm256_set1_epi8(32);
218
219            for i in 0..chunks {
220                let offset = i * chunk_size;
221
222                // Load 32 bytes from each array
223                let mut va = _mm256_loadu_si256(a.as_ptr().add(offset) as *const __m256i);
224                let mut vb = _mm256_loadu_si256(b.as_ptr().add(offset) as *const __m256i);
225
226                // Convert to lowercase using SIMD
227                let a_is_upper = _mm256_and_si256(
228                    _mm256_cmpgt_epi8(va, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1))),
229                    _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), va),
230                );
231                let b_is_upper = _mm256_and_si256(
232                    _mm256_cmpgt_epi8(vb, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1))),
233                    _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), vb),
234                );
235
236                va = _mm256_add_epi8(va, _mm256_and_si256(a_is_upper, case_diff));
237                vb = _mm256_add_epi8(vb, _mm256_and_si256(b_is_upper, case_diff));
238
239                // Compare converted vectors
240                let cmp = _mm256_cmpeq_epi8(va, vb);
241                let mask = _mm256_movemask_epi8(cmp) as u32;
242
243                // If not all bytes are equal, find first difference
244                if mask != 0xFFFFFFFF {
245                    let diff_pos = (!mask).trailing_zeros() as usize;
246                    let abs_pos = offset + diff_pos;
247                    return a[abs_pos]
248                        .to_ascii_lowercase()
249                        .cmp(&b[abs_pos].to_ascii_lowercase());
250                }
251            }
252        }
253
254        // Compare remaining bytes
255        let remaining_start = chunks * chunk_size;
256        for i in remaining_start..min_len {
257            let a_char = a[i].to_ascii_lowercase();
258            let b_char = b[i].to_ascii_lowercase();
259            match a_char.cmp(&b_char) {
260                Ordering::Equal => continue,
261                other => return other,
262            }
263        }
264
265        a.len().cmp(&b.len())
266    }
267
268    /// Fast numeric comparison using SIMD digit detection
269    #[inline]
270    pub fn is_all_digits_simd(bytes: &[u8]) -> bool {
271        if bytes.is_empty() {
272            return true;
273        }
274
275        #[cfg(target_arch = "x86_64")]
276        {
277            if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
278                return Self::is_all_digits_avx2(bytes);
279            }
280        }
281
282        // Fallback
283        bytes.iter().all(|&b| b.is_ascii_digit())
284    }
285
286    /// AVX2-accelerated digit detection
287    #[cfg(target_arch = "x86_64")]
288    #[inline]
289    fn is_all_digits_avx2(bytes: &[u8]) -> bool {
290        use std::arch::x86_64::*;
291
292        let chunk_size = 32;
293        let chunks = bytes.len() / chunk_size;
294
295        unsafe {
296            let min_digit = _mm256_set1_epi8(b'0' as i8);
297            let max_digit = _mm256_set1_epi8(b'9' as i8);
298
299            for i in 0..chunks {
300                let offset = i * chunk_size;
301                let v = _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i);
302
303                // Check if all bytes are in range '0'..'9'
304                let ge_min = _mm256_cmpgt_epi8(v, _mm256_sub_epi8(min_digit, _mm256_set1_epi8(1)));
305                let le_max = _mm256_cmpgt_epi8(_mm256_add_epi8(max_digit, _mm256_set1_epi8(1)), v);
306                let is_digit = _mm256_and_si256(ge_min, le_max);
307
308                let mask = _mm256_movemask_epi8(is_digit) as u32;
309                if mask != 0xFFFFFFFF {
310                    return false;
311                }
312            }
313        }
314
315        // Check remaining bytes
316        let remaining_start = chunks * chunk_size;
317        bytes[remaining_start..].iter().all(|&b| b.is_ascii_digit())
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_simd_comparison() {
327        let str_a = b"hello world this is a test";
328        let str_b = b"hello world this is a different test";
329
330        let result = SIMDCompare::compare_bytes_simd(str_a, str_b);
331        let expected = str_a[..].cmp(&str_b[..]);
332
333        assert_eq!(result, expected);
334    }
335
336    #[test]
337    fn test_simd_case_insensitive() {
338        let a = b"Hello World";
339        let b = b"HELLO WORLD";
340
341        let result = SIMDCompare::compare_case_insensitive_simd(a, b);
342        assert_eq!(result, Ordering::Equal);
343    }
344
345    #[test]
346    fn test_simd_digit_detection() {
347        assert!(SIMDCompare::is_all_digits_simd(b"123456789"));
348        assert!(!SIMDCompare::is_all_digits_simd(b"123a456"));
349        assert!(SIMDCompare::is_all_digits_simd(b""));
350    }
351}