fast_slice_utils/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3#![cfg_attr(feature = "nightly", allow(internal_features), feature(core_intrinsics))]
4#![cfg_attr(feature = "nightly", feature(portable_simd))]
5
6// =======================================================================================
7// Various implementations of `find_prefix_overlap`
8// =======================================================================================
9
10#[cfg(not(feature = "nightly"))]
11#[allow(unused)]
12pub(crate) use core::convert::{identity as likely, identity as unlikely};
13#[cfg(feature = "nightly")]
14#[allow(unused)]
15pub(crate) use core::intrinsics::{likely, unlikely};
16
17// GOAT!  UGH!  It turned out scalar paths aren't enough faster to justify having them
18// Probably on account of the extra branching causing misprediction
19// This code should be deleted eventually, but maybe keep it for a while while we discuss
20//
21// /// Returns the number of characters shared between two slices
22// #[inline]
23// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
24//     let len = a.len().min(b.len());
25
26//     match len {
27//         0 => 0,
28//         1 => (unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } as usize),
29//         2 => { 
30//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u16) };
31//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u16) };
32//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
33//             let cnt = cmp.trailing_ones();
34//             cnt as usize / 8
35//         },
36//         3 | 4 | 5 | 6 | 7 | 8 => {
37//             //GOAT, we need to do a check to make sure we don't over-read a page
38//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u64) };
39//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u64) };
40//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
41//             let cnt = cmp.trailing_ones();
42//             let result = cnt as usize / 8;
43//             result.min(len)
44//         },
45//         _ => count_shared_neon(a, b),
46//     }
47// }
48
49// GOAT!  AGH!! Even this is much slower, even on the zipfian distribution where 70% of the pairs have 0 overlap!!!
50//
51// /// Returns the number of characters shared between two slices
52// #[inline]
53// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
54//     if a.len() != 0 && b.len() != 0 && unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } {
55//         count_shared_neon(a, b)
56//     } else {
57//         0
58//     }
59// }
60
61const PAGE_SIZE: usize = 4096;
62
63#[inline(always)]
64unsafe fn same_page<const VECTOR_SIZE: usize>(slice: &[u8]) -> bool {
65    let address = slice.as_ptr() as usize;
66    // Mask to keep only the last 12 bits
67    let offset_within_page = address & (PAGE_SIZE - 1);
68    // Check if the 16/32/64th byte from the current offset exceeds the page boundary
69    offset_within_page < PAGE_SIZE - VECTOR_SIZE
70}
71
72/// A simple reference implementation of `find_prefix_overlap` with no fanciness
73fn count_shared_reference(p: &[u8], q: &[u8]) -> usize {
74    p.iter().zip(q)
75        .take_while(|(x, y)| x == y).count()
76}
77
78#[cold]
79fn count_shared_cold(a: &[u8], b: &[u8]) -> usize {
80    count_shared_reference(a, b)
81}
82
83#[cfg(target_feature = "avx512f")]
84#[inline(always)]
85fn count_shared_avx512(p: &[u8], q: &[u8]) -> usize {
86    use core::arch::x86_64::*;
87    unsafe {
88        let pl = p.len();
89        let ql = q.len();
90        let max_shared = pl.min(ql);
91        if unlikely(max_shared == 0) { return 0 }
92        let m = (!(0u64 as __mmask64)) >> (64 - max_shared.min(64));
93        let pv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, p.as_ptr() as _);
94        let qv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, q.as_ptr() as _);
95        let ne = !_mm512_cmpeq_epi8_mask(pv, qv);
96        let count = _tzcnt_u64(ne);
97        if count != 64 || max_shared < 65 {
98            (count as usize).min(max_shared)
99        } else {
100            let new_len = max_shared-64;
101            64 + count_shared_avx512(core::slice::from_raw_parts(p.as_ptr().add(64), new_len), core::slice::from_raw_parts(q.as_ptr().add(64), new_len))
102        }
103    }
104}
105
106#[cfg(all(target_feature="avx2", not(miri)))]
107#[inline(always)]
108fn count_shared_avx2(p: &[u8], q: &[u8]) -> usize {
109    use core::arch::x86_64::*;
110    unsafe {
111        let pl = p.len();
112        let ql = q.len();
113        let max_shared = pl.min(ql);
114        if unlikely(max_shared == 0) { return 0 }
115        if likely(same_page::<32>(p) && same_page::<32>(q)) {
116            let pv = _mm256_loadu_si256(p.as_ptr() as _);
117            let qv = _mm256_loadu_si256(q.as_ptr() as _);
118            let ev = _mm256_cmpeq_epi8(pv, qv);
119            let ne = !(_mm256_movemask_epi8(ev) as u32);
120            let count = _tzcnt_u32(ne);
121            if count != 32 || max_shared < 33 {
122                (count as usize).min(max_shared)
123            } else {
124                let new_len = max_shared-32;
125                32 + count_shared_avx2(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
126            }
127        } else {
128            count_shared_cold(p, q)
129        }
130    }
131}
132
133#[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
134#[inline(always)]
135fn count_shared_neon(p: &[u8], q: &[u8]) -> usize {
136    use core::arch::aarch64::*;
137    unsafe {
138        let pl = p.len();
139        let ql = q.len();
140        let max_shared = pl.min(ql);
141        if unlikely(max_shared == 0) { return 0 }
142
143        if same_page::<16>(p) && same_page::<16>(q) {
144            let pv = vld1q_u8(p.as_ptr());
145            let qv = vld1q_u8(q.as_ptr());
146            let eq = vceqq_u8(pv, qv);
147
148            //UGH! There must be a better way to do this...
149            // let neg = vmvnq_u8(eq);
150            // let lo: u64 = vgetq_lane_u64(core::mem::transmute(neg), 0);
151            // let hi: u64 = vgetq_lane_u64(core::mem::transmute(neg), 1);
152            // let count = if lo != 0 {
153            //     lo.trailing_zeros()
154            // } else {
155            //     64 + hi.trailing_zeros()
156            // } / 8;
157
158            //UGH! This code is actually a bit faster than the commented out code above.
159            // I'm sure I'm just not familiar enough with the neon ISA
160            let mut bytes = [core::mem::MaybeUninit::<u8>::uninit(); 16];
161            vst1q_u8(bytes.as_mut_ptr().cast(), eq);
162            let scalar128 = u128::from_le_bytes(core::mem::transmute(bytes));
163            let count = scalar128.trailing_ones() / 8;
164
165            if count != 16 || max_shared < 17 {
166                (count as usize).min(max_shared)
167            } else {
168                let new_len = max_shared-16;
169                16 + count_shared_neon(core::slice::from_raw_parts(p.as_ptr().add(16), new_len), core::slice::from_raw_parts(q.as_ptr().add(16), new_len))
170            }
171        } else {
172            return count_shared_cold(p, q);
173        }
174    }
175}
176
177#[cfg(all(feature = "nightly", not(miri)))]
178#[inline(always)]
179fn count_shared_simd(p: &[u8], q: &[u8]) -> usize {
180    use core::simd::{u8x32, cmp::SimdPartialEq};
181    unsafe {
182        let pl = p.len();
183        let ql = q.len();
184        let max_shared = pl.min(ql);
185        if unlikely(max_shared == 0) { return 0 }
186        if same_page::<32>(p) && same_page::<32>(q) {
187            let mut p_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
188            core::ptr::copy_nonoverlapping(p.as_ptr().cast(), (&mut p_array).as_mut_ptr(), 32);
189            let pv = u8x32::from_array(core::mem::transmute(p_array));
190            let mut q_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
191            core::ptr::copy_nonoverlapping(q.as_ptr().cast(), (&mut q_array).as_mut_ptr(), 32);
192            let qv = u8x32::from_array(core::mem::transmute(q_array));
193            let ev = pv.simd_eq(qv);
194            let mask = ev.to_bitmask();
195            let count = mask.trailing_ones();
196            if count != 32 || max_shared < 33 {
197                (count as usize).min(max_shared)
198            } else {
199                let new_len = max_shared-32;
200                32 + count_shared_simd(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
201            }
202        } else {
203            return count_shared_cold(p, q);
204        }
205    }
206}
207
208/// Returns the number of initial characters shared between two slices
209///
210/// The fastest (as measured by us) implementation is exported based on the platform and features.
211///
212/// - **AVX-512**: AVX-512 intrinsics (x86_64, requires nightly)
213/// - **AVX2**: AVX2 intrinsics (x86_64)
214/// - **NEON**: NEON intrinsics (aarch64)
215/// - **Portable SIMD**: Portable SIMD (requires nightly)
216/// - **Reference**: Reference scalar implementation
217///
218/// | AVX-512 | AVX2 | NEON | nightly | miri | Implementation    |
219/// |---------|------|------|---------|------|-------------------|
220/// | ✓       | -    | ✗    | -       | ✗    | **AVX-512**       |
221/// | ✗       | ✓    | ✗    | -       | ✗    | **AVX2**          |
222/// | ✗       | ✗    | ✓    | ✗       | ✗    | **NEON**          |
223/// | ✗       | ✗    | ✓    | ✓       | ✗    | **Portable SIMD** |
224/// | -       | -    | -    | -       | ✓    | **Reference**     |
225///
226#[inline]
227pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
228    #[cfg(all(target_feature="avx512f", not(miri)))]
229    {
230        count_shared_avx512(a, b)
231    }
232    #[cfg(all(target_feature="avx2", not(target_feature="avx512f"), not(miri)))]
233    {
234        count_shared_avx2(a, b)
235    }
236    #[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
237    {
238        count_shared_neon(a, b)
239    }
240    #[cfg(all(feature = "nightly", target_arch = "aarch64", target_feature = "neon", not(miri)))]
241    {
242        count_shared_simd(a, b)
243    }
244    #[cfg(any(all(not(target_feature="avx2"), not(target_feature="neon")), miri))]
245    {
246        count_shared_reference(a, b)
247    }
248}
249
250#[test]
251fn find_prefix_overlap_test() {
252    let tests = [
253        ("12345", "67890", 0),
254        ("", "12300", 0),
255        ("12345", "", 0),
256        ("12345", "12300", 3),
257        ("123", "123000000", 3),
258        ("123456789012345678901234567890xxxx", "123456789012345678901234567890yy", 30),
259        ("123456789012345678901234567890123456789012345678901234567890xxxx", "123456789012345678901234567890123456789012345678901234567890yy", 60),
260        ("1234567890123456xxxx", "1234567890123456yyyyyyy", 16),
261        ("123456789012345xxxx", "123456789012345yyyyyyy", 15),
262        ("12345678901234567xxxx", "12345678901234567yyyyyyy", 17),
263        ("1234567890123456789012345678901xxxx", "1234567890123456789012345678901yy", 31),
264        ("12345678901234567890123456789012xxxx", "12345678901234567890123456789012yy", 32),
265        ("123456789012345678901234567890123xxxx", "123456789012345678901234567890123yy", 33),
266        ("123456789012345678901234567890123456789012345678901234567890123xxxx", "123456789012345678901234567890123456789012345678901234567890123yy", 63),
267        ("1234567890123456789012345678901234567890123456789012345678901234xxxx", "1234567890123456789012345678901234567890123456789012345678901234yy", 64),
268        ("12345678901234567890123456789012345678901234567890123456789012345xxxx", "12345678901234567890123456789012345678901234567890123456789012345yy", 65),
269    ];
270
271    for test in tests {
272        let overlap = find_prefix_overlap(test.0.as_bytes(), test.1.as_bytes());
273        assert_eq!(overlap, test.2);
274    }
275}
276
277/// A faster replacement for the stdlib version of [`starts_with`](slice::starts_with)
278#[inline(always)]
279pub fn starts_with(x: &[u8], y: &[u8]) -> bool {
280    if y.len() == 0 { return true }
281    if x.len() == 0 { return false }
282    if y.len() > x.len() { return false }
283    find_prefix_overlap(x, y) == y.len()
284}