Skip to main content

fast_slice_utils/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3#![cfg_attr(feature = "nightly", allow(internal_features), feature(core_intrinsics))]
4#![cfg_attr(feature = "nightly", feature(portable_simd))]
5
6// =======================================================================================
7// Various implementations of `find_prefix_overlap`
8// =======================================================================================
9
10#[cfg(not(feature = "nightly"))]
11#[allow(unused)]
12pub(crate) use core::convert::{identity as likely, identity as unlikely};
13#[cfg(feature = "nightly")]
14#[allow(unused)]
15pub(crate) use core::intrinsics::{likely, unlikely};
16
17// GOAT!  UGH!  It turned out scalar paths aren't enough faster to justify having them
18// Probably on account of the extra branching causing misprediction
19// This code should be deleted eventually, but maybe keep it for a while while we discuss
20//
21// /// Returns the number of characters shared between two slices
22// #[inline]
23// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
24//     let len = a.len().min(b.len());
25
26//     match len {
27//         0 => 0,
28//         1 => (unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } as usize),
29//         2 => { 
30//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u16) };
31//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u16) };
32//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
33//             let cnt = cmp.trailing_ones();
34//             cnt as usize / 8
35//         },
36//         3 | 4 | 5 | 6 | 7 | 8 => {
37//             //GOAT, we need to do a check to make sure we don't over-read a page
38//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u64) };
39//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u64) };
40//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
41//             let cnt = cmp.trailing_ones();
42//             let result = cnt as usize / 8;
43//             result.min(len)
44//         },
45//         _ => count_shared_neon(a, b),
46//     }
47// }
48
49// GOAT!  AGH!! Even this is much slower, even on the zipfian distribution where 70% of the pairs have 0 overlap!!!
50//
51// /// Returns the number of characters shared between two slices
52// #[inline]
53// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
54//     if a.len() != 0 && b.len() != 0 && unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } {
55//         count_shared_neon(a, b)
56//     } else {
57//         0
58//     }
59// }
60
61#[allow(unused)]
62const PAGE_SIZE: usize = 4096;
63
64#[allow(unused)]
65#[inline(always)]
66unsafe fn same_page<const VECTOR_SIZE: usize>(slice: &[u8]) -> bool {
67    let address = slice.as_ptr() as usize;
68    // Mask to keep only the last 12 bits
69    let offset_within_page = address & (PAGE_SIZE - 1);
70    // Check if the 16/32/64th byte from the current offset exceeds the page boundary
71    offset_within_page < PAGE_SIZE - VECTOR_SIZE
72}
73
74/// A simple reference implementation of `find_prefix_overlap` with no fanciness
75fn count_shared_reference(p: &[u8], q: &[u8]) -> usize {
76    p.iter().zip(q)
77        .take_while(|(x, y)| x == y).count()
78}
79
80#[allow(unused)]
81#[cold]
82fn count_shared_cold(a: &[u8], b: &[u8]) -> usize {
83    count_shared_reference(a, b)
84}
85
86#[cfg(all(target_feature = "avx512f", target_feature = "avx512bw"))]
87#[inline(always)]
88fn count_shared_avx512(p: &[u8], q: &[u8]) -> usize {
89    use core::arch::x86_64::*;
90    unsafe {
91        let pl = p.len();
92        let ql = q.len();
93        let max_shared = pl.min(ql);
94        if unlikely(max_shared == 0) { return 0 }
95        let m = (!(0u64 as __mmask64)) >> (64 - max_shared.min(64));
96        let pv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, p.as_ptr() as _);
97        let qv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, q.as_ptr() as _);
98        let ne = !_mm512_cmpeq_epi8_mask(pv, qv);
99        let count = ne.trailing_zeros();
100        if count != 64 || max_shared < 65 {
101            (count as usize).min(max_shared)
102        } else {
103            let new_len = max_shared-64;
104            64 + count_shared_avx512(core::slice::from_raw_parts(p.as_ptr().add(64), new_len), core::slice::from_raw_parts(q.as_ptr().add(64), new_len))
105        }
106    }
107}
108
109#[allow(unused)]
110#[cfg(target_feature = "avx2")]
111#[inline(always)]
112fn count_shared_avx2(p: &[u8], q: &[u8]) -> usize {
113    use core::arch::x86_64::*;
114    unsafe {
115        let pl = p.len();
116        let ql = q.len();
117        let max_shared = pl.min(ql);
118        if unlikely(max_shared == 0) { return 0 }
119
120        let use_simd = if cfg!(feature = "miri_safe") {
121            pl >= 32 && ql >= 32
122        } else {
123            same_page::<32>(p) && same_page::<32>(q)
124        };
125
126        if likely(use_simd) {
127            let pv = _mm256_loadu_si256(p.as_ptr() as _);
128            let qv = _mm256_loadu_si256(q.as_ptr() as _);
129            let ev = _mm256_cmpeq_epi8(pv, qv);
130            let ne = !(_mm256_movemask_epi8(ev) as u32);
131            let count = ne.trailing_zeros();
132            if count != 32 || max_shared < 33 {
133                (count as usize).min(max_shared)
134            } else {
135                let new_len = max_shared-32;
136                32 + count_shared_avx2(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
137            }
138        } else {
139            count_shared_cold(p, q)
140        }
141    }
142}
143
144#[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon"))]
145#[inline(always)]
146fn count_shared_neon(p: &[u8], q: &[u8]) -> usize {
147    use core::arch::aarch64::*;
148    unsafe {
149        let pl = p.len();
150        let ql = q.len();
151        let max_shared = pl.min(ql);
152        if unlikely(max_shared == 0) { return 0 }
153
154        let use_simd = if cfg!(feature = "miri_safe") {
155            pl >= 16 && ql >= 16
156        } else {
157            same_page::<16>(p) && same_page::<16>(q)
158        };
159
160        if use_simd {
161            let pv = vld1q_u8(p.as_ptr());
162            let qv = vld1q_u8(q.as_ptr());
163            let eq = vceqq_u8(pv, qv);
164
165            //UGH! There must be a better way to do this...
166            // let neg = vmvnq_u8(eq);
167            // let lo: u64 = vgetq_lane_u64(core::mem::transmute(neg), 0);
168            // let hi: u64 = vgetq_lane_u64(core::mem::transmute(neg), 1);
169            // let count = if lo != 0 {
170            //     lo.trailing_zeros()
171            // } else {
172            //     64 + hi.trailing_zeros()
173            // } / 8;
174
175            //UGH! This code is actually a bit faster than the commented out code above.
176            // I'm sure I'm just not familiar enough with the neon ISA
177            let mut bytes = [core::mem::MaybeUninit::<u8>::uninit(); 16];
178            vst1q_u8(bytes.as_mut_ptr().cast(), eq);
179            let scalar128 = u128::from_le_bytes(core::mem::transmute(bytes));
180            let count = scalar128.trailing_ones() / 8;
181
182            if count != 16 || max_shared < 17 {
183                (count as usize).min(max_shared)
184            } else {
185                let new_len = max_shared-16;
186                16 + count_shared_neon(core::slice::from_raw_parts(p.as_ptr().add(16), new_len), core::slice::from_raw_parts(q.as_ptr().add(16), new_len))
187            }
188        } else {
189            return count_shared_cold(p, q);
190        }
191    }
192}
193
194#[cfg(feature = "nightly")]
195#[inline(always)]
196fn count_shared_simd(p: &[u8], q: &[u8]) -> usize {
197    use core::simd::{u8x32, cmp::SimdPartialEq};
198    unsafe {
199        let pl = p.len();
200        let ql = q.len();
201        let max_shared = pl.min(ql);
202        if unlikely(max_shared == 0) { return 0 }
203
204        let use_simd = if cfg!(feature = "miri_safe") {
205            pl >= 32 && ql >= 32
206        } else {
207            same_page::<32>(p) && same_page::<32>(q)
208        };
209
210        if use_simd {
211            let mut p_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
212            core::ptr::copy_nonoverlapping(p.as_ptr().cast(), (&mut p_array).as_mut_ptr(), 32);
213            let pv = u8x32::from_array(core::mem::transmute(p_array));
214            let mut q_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
215            core::ptr::copy_nonoverlapping(q.as_ptr().cast(), (&mut q_array).as_mut_ptr(), 32);
216            let qv = u8x32::from_array(core::mem::transmute(q_array));
217            let ev = pv.simd_eq(qv);
218            let mask = ev.to_bitmask();
219            let count = mask.trailing_ones();
220            if count != 32 || max_shared < 33 {
221                (count as usize).min(max_shared)
222            } else {
223                let new_len = max_shared-32;
224                32 + count_shared_simd(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
225            }
226        } else {
227            return count_shared_cold(p, q);
228        }
229    }
230}
231
232/// Returns the number of initial characters shared between two slices
233///
234/// The fastest (as measured by us) implementation is exported based on the platform and features.
235///
236/// - **AVX-512**: AVX-512 (F + BW) intrinsics (x86_64, requires nightly)
237/// - **AVX2**: AVX2 intrinsics (x86_64)
238/// - **NEON**: NEON intrinsics (aarch64)
239/// - **Portable SIMD**: Portable SIMD (requires nightly)
240/// - **Reference**: Reference scalar implementation
241///
242/// | AVX-512 | AVX2 | NEON | nightly | miri_safe | Implementation    |
243/// |---------|------|------|---------|-----------|-------------------|
244/// | ✓       | -    | ✗    | -       | ✓         | **AVX-512**       |
245/// | ✗       | ✓    | ✗    | -       | ✓         | **AVX2**          |
246/// | ✗       | ✗    | ✓    | ✗       | ✓         | **NEON**          |
247/// | ✗       | ✗    | ✓    | ✓       | ✓         | **Portable SIMD** |
248/// | -       | -    | -    | -       | ✓         | **Reference**     |
249///
250#[inline]
251pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
252    #[cfg(all(target_feature="avx512f", target_feature="avx512bw"))]
253    {
254        count_shared_avx512(a, b)
255    }
256    #[cfg(all(target_feature="avx2", not(all(target_feature="avx512f", target_feature="avx512bw"))))]
257    {
258        count_shared_avx2(a, b)
259    }
260    #[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon"))]
261    {
262        count_shared_neon(a, b)
263    }
264    #[cfg(all(feature = "nightly", target_arch = "aarch64", target_feature = "neon"))]
265    {
266        count_shared_simd(a, b)
267    }
268    #[cfg(all(not(target_feature="avx2"), not(target_feature="neon")))]
269    {
270        count_shared_reference(a, b)
271    }
272}
273
274#[test]
275fn find_prefix_overlap_test() {
276    let tests = [
277        ("12345", "67890", 0),
278        ("", "12300", 0),
279        ("12345", "", 0),
280        ("12345", "12300", 3),
281        ("123", "123000000", 3),
282        ("123456789012345678901234567890xxxx", "123456789012345678901234567890yy", 30),
283        ("123456789012345678901234567890123456789012345678901234567890xxxx", "123456789012345678901234567890123456789012345678901234567890yy", 60),
284        ("1234567890123456xxxx", "1234567890123456yyyyyyy", 16),
285        ("123456789012345xxxx", "123456789012345yyyyyyy", 15),
286        ("12345678901234567xxxx", "12345678901234567yyyyyyy", 17),
287        ("1234567890123456789012345678901xxxx", "1234567890123456789012345678901yy", 31),
288        ("12345678901234567890123456789012xxxx", "12345678901234567890123456789012yy", 32),
289        ("123456789012345678901234567890123xxxx", "123456789012345678901234567890123yy", 33),
290        ("123456789012345678901234567890123456789012345678901234567890123xxxx", "123456789012345678901234567890123456789012345678901234567890123yy", 63),
291        ("1234567890123456789012345678901234567890123456789012345678901234xxxx", "1234567890123456789012345678901234567890123456789012345678901234yy", 64),
292        ("12345678901234567890123456789012345678901234567890123456789012345xxxx", "12345678901234567890123456789012345678901234567890123456789012345yy", 65),
293    ];
294
295    for test in tests {
296        let overlap = find_prefix_overlap(test.0.as_bytes(), test.1.as_bytes());
297        assert_eq!(overlap, test.2);
298    }
299}
300
301#[test]
302fn find_prefix_overlap_long_test() {
303    let a = [b'A'; 70];
304    let mut b = [b'A'; 70];
305    assert_eq!(find_prefix_overlap(&a, &b), 70);
306    b[69] = b'B';
307    assert_eq!(find_prefix_overlap(&a, &b), 69);
308    b[69] = b'A';
309    b[64] = b'B';
310    assert_eq!(find_prefix_overlap(&a, &b), 64);
311    b[64] = b'A';
312    b[0] = b'B';
313    assert_eq!(find_prefix_overlap(&a, &b), 0);
314}
315
316/// A faster replacement for the stdlib version of [`starts_with`](slice::starts_with)
317#[inline(always)]
318pub fn starts_with(x: &[u8], y: &[u8]) -> bool {
319    if y.len() == 0 { return true }
320    if x.len() == 0 { return false }
321    if y.len() > x.len() { return false }
322    find_prefix_overlap(x, y) == y.len()
323}