fast_slice_utils/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3#![cfg_attr(feature = "nightly", allow(internal_features), feature(core_intrinsics))]
4#![cfg_attr(feature = "nightly", feature(portable_simd))]
5
6// =======================================================================================
7// Various implementations of `find_prefix_overlap`
8// =======================================================================================
9
10#[cfg(not(feature = "nightly"))]
11#[allow(unused)]
12pub(crate) use core::convert::{identity as likely, identity as unlikely};
13#[cfg(feature = "nightly")]
14#[allow(unused)]
15pub(crate) use core::intrinsics::{likely, unlikely};
16
17// GOAT!  UGH!  It turned out scalar paths aren't enough faster to justify having them
18// Probably on account of the extra branching causing misprediction
19// This code should be deleted eventually, but maybe keep it for a while while we discuss
20//
21// /// Returns the number of characters shared between two slices
22// #[inline]
23// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
24//     let len = a.len().min(b.len());
25
26//     match len {
27//         0 => 0,
28//         1 => (unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } as usize),
29//         2 => { 
30//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u16) };
31//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u16) };
32//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
33//             let cnt = cmp.trailing_ones();
34//             cnt as usize / 8
35//         },
36//         3 | 4 | 5 | 6 | 7 | 8 => {
37//             //GOAT, we need to do a check to make sure we don't over-read a page
38//             let a_word = unsafe{ core::ptr::read_unaligned(a.as_ptr() as *const u64) };
39//             let b_word = unsafe{ core::ptr::read_unaligned(b.as_ptr() as *const u64) };
40//             let cmp = !(a_word ^ b_word); // equal bytes will be 0xFF
41//             let cnt = cmp.trailing_ones();
42//             let result = cnt as usize / 8;
43//             result.min(len)
44//         },
45//         _ => count_shared_neon(a, b),
46//     }
47// }
48
49// GOAT!  AGH!! Even this is much slower, even on the zipfian distribution where 70% of the pairs have 0 overlap!!!
50//
51// /// Returns the number of characters shared between two slices
52// #[inline]
53// pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
54//     if a.len() != 0 && b.len() != 0 && unsafe{ a.get_unchecked(0) == b.get_unchecked(0) } {
55//         count_shared_neon(a, b)
56//     } else {
57//         0
58//     }
59// }
60
61const PAGE_SIZE: usize = 4096;
62
63#[inline(always)]
64unsafe fn same_page<const VECTOR_SIZE: usize>(slice: &[u8]) -> bool {
65    let address = slice.as_ptr() as usize;
66    // Mask to keep only the last 12 bits
67    let offset_within_page = address & (PAGE_SIZE - 1);
68    // Check if the 16/32/64th byte from the current offset exceeds the page boundary
69    offset_within_page < PAGE_SIZE - VECTOR_SIZE
70}
71
72/// A simple reference implementation of `find_prefix_overlap` with no fanciness
73fn count_shared_reference(p: &[u8], q: &[u8]) -> usize {
74    p.iter().zip(q)
75        .take_while(|(x, y)| x == y).count()
76}
77
78#[cold]
79fn count_shared_cold(a: &[u8], b: &[u8]) -> usize {
80    count_shared_reference(a, b)
81}
82
83#[cfg(all(target_feature="avx2", not(miri)))]
84#[inline(always)]
85fn count_shared_avx2(p: &[u8], q: &[u8]) -> usize {
86    use core::arch::x86_64::*;
87    unsafe {
88        let pl = p.len();
89        let ql = q.len();
90        let max_shared = pl.min(ql);
91        if unlikely(max_shared == 0) { return 0 }
92        if likely(same_page::<32>(p) && same_page::<32>(q)) {
93            let pv = _mm256_loadu_si256(p.as_ptr() as _);
94            let qv = _mm256_loadu_si256(q.as_ptr() as _);
95            let ev = _mm256_cmpeq_epi8(pv, qv);
96            let ne = !(_mm256_movemask_epi8(ev) as u32);
97            let count = _tzcnt_u32(ne);
98            if count != 32 || max_shared < 33 {
99                (count as usize).min(max_shared)
100            } else {
101                let new_len = max_shared-32;
102                32 + count_shared_avx2(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
103            }
104        } else {
105            count_shared_cold(p, q)
106        }
107    }
108}
109
110#[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
111#[inline(always)]
112fn count_shared_neon(p: &[u8], q: &[u8]) -> usize {
113    use core::arch::aarch64::*;
114    unsafe {
115        let pl = p.len();
116        let ql = q.len();
117        let max_shared = pl.min(ql);
118        if unlikely(max_shared == 0) { return 0 }
119
120        if same_page::<16>(p) && same_page::<16>(q) {
121            let pv = vld1q_u8(p.as_ptr());
122            let qv = vld1q_u8(q.as_ptr());
123            let eq = vceqq_u8(pv, qv);
124
125            //UGH! There must be a better way to do this...
126            // let neg = vmvnq_u8(eq);
127            // let lo: u64 = vgetq_lane_u64(core::mem::transmute(neg), 0);
128            // let hi: u64 = vgetq_lane_u64(core::mem::transmute(neg), 1);
129            // let count = if lo != 0 {
130            //     lo.trailing_zeros()
131            // } else {
132            //     64 + hi.trailing_zeros()
133            // } / 8;
134
135            //UGH! This code is actually a bit faster than the commented out code above.
136            // I'm sure I'm just not familiar enough with the neon ISA
137            let mut bytes = [core::mem::MaybeUninit::<u8>::uninit(); 16];
138            vst1q_u8(bytes.as_mut_ptr().cast(), eq);
139            let scalar128 = u128::from_le_bytes(core::mem::transmute(bytes));
140            let count = scalar128.trailing_ones() / 8;
141
142            if count != 16 || max_shared < 17 {
143                (count as usize).min(max_shared)
144            } else {
145                let new_len = max_shared-16;
146                16 + count_shared_neon(core::slice::from_raw_parts(p.as_ptr().add(16), new_len), core::slice::from_raw_parts(q.as_ptr().add(16), new_len))
147            }
148        } else {
149            return count_shared_cold(p, q);
150        }
151    }
152}
153
154#[cfg(all(feature = "nightly", not(miri)))]
155#[inline(always)]
156fn count_shared_simd(p: &[u8], q: &[u8]) -> usize {
157    use core::simd::{u8x32, cmp::SimdPartialEq};
158    unsafe {
159        let pl = p.len();
160        let ql = q.len();
161        let max_shared = pl.min(ql);
162        if unlikely(max_shared == 0) { return 0 }
163        if same_page::<32>(p) && same_page::<32>(q) {
164            let mut p_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
165            core::ptr::copy_nonoverlapping(p.as_ptr().cast(), (&mut p_array).as_mut_ptr(), 32);
166            let pv = u8x32::from_array(core::mem::transmute(p_array));
167            let mut q_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
168            core::ptr::copy_nonoverlapping(q.as_ptr().cast(), (&mut q_array).as_mut_ptr(), 32);
169            let qv = u8x32::from_array(core::mem::transmute(q_array));
170            let ev = pv.simd_eq(qv);
171            let mask = ev.to_bitmask();
172            let count = mask.trailing_ones();
173            if count != 32 || max_shared < 33 {
174                (count as usize).min(max_shared)
175            } else {
176                let new_len = max_shared-32;
177                32 + count_shared_simd(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
178            }
179        } else {
180            return count_shared_cold(p, q);
181        }
182    }
183}
184
185/// Returns the number of initial characters shared between two slices
186///
187/// The fastest (as measured by us) implementation is exported based on the platform and features.
188///
189/// - **AVX2**: AVX2 intrinsics (x86_64)
190/// - **NEON**: NEON intrinsics (aarch64)
191/// - **Portable SIMD**: Portable SIMD (requires nightly)
192/// - **Reference**: Reference scalar implementation
193///
194/// | AVX-512 | AVX2 | NEON | nightly | miri | Implementation    |
195/// |---------|------|------|---------|------|-------------------|
196/// | ✓       | -    | ✗    | -       | ✗    | **AVX2**          |
197/// | -       | ✓    | ✗    | -       | ✗    | **AVX2**          |
198/// | ✗       | ✗    | ✓    | ✗       | ✗    | **NEON**          |
199/// | ✗       | ✗    | ✓    | ✓       | ✗    | **Portable SIMD** |
200/// | -       | -    | -    | -       | ✓    | **Reference**     |
201///
202#[inline]
203pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
204    #[cfg(all(target_feature="avx2", not(miri)))]
205    {
206        count_shared_avx2(a, b)
207    }
208    #[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
209    {
210        count_shared_neon(a, b)
211    }
212    #[cfg(all(feature = "nightly", target_arch = "aarch64", target_feature = "neon", not(miri)))]
213    {
214        count_shared_simd(a, b)
215    }
216    #[cfg(any(all(not(target_feature="avx2"), not(target_feature="neon")), miri))]
217    {
218        count_shared_reference(a, b)
219    }
220}
221
222#[test]
223fn find_prefix_overlap_test() {
224    let tests = [
225        ("12345", "67890", 0),
226        ("", "12300", 0),
227        ("12345", "", 0),
228        ("12345", "12300", 3),
229        ("123", "123000000", 3),
230        ("123456789012345678901234567890xxxx", "123456789012345678901234567890yy", 30),
231        ("123456789012345678901234567890123456789012345678901234567890xxxx", "123456789012345678901234567890123456789012345678901234567890yy", 60),
232        ("1234567890123456xxxx", "1234567890123456yyyyyyy", 16),
233        ("123456789012345xxxx", "123456789012345yyyyyyy", 15),
234        ("12345678901234567xxxx", "12345678901234567yyyyyyy", 17),
235        ("1234567890123456789012345678901xxxx", "1234567890123456789012345678901yy", 31),
236        ("12345678901234567890123456789012xxxx", "12345678901234567890123456789012yy", 32),
237        ("123456789012345678901234567890123xxxx", "123456789012345678901234567890123yy", 33),
238        ("123456789012345678901234567890123456789012345678901234567890123xxxx", "123456789012345678901234567890123456789012345678901234567890123yy", 63),
239        ("1234567890123456789012345678901234567890123456789012345678901234xxxx", "1234567890123456789012345678901234567890123456789012345678901234yy", 64),
240        ("12345678901234567890123456789012345678901234567890123456789012345xxxx", "12345678901234567890123456789012345678901234567890123456789012345yy", 65),
241    ];
242
243    for test in tests {
244        let overlap = find_prefix_overlap(test.0.as_bytes(), test.1.as_bytes());
245        assert_eq!(overlap, test.2);
246    }
247}
248
249/// A faster replacement for the stdlib version of [`starts_with`](slice::starts_with)
250#[inline(always)]
251pub fn starts_with(x: &[u8], y: &[u8]) -> bool {
252    if y.len() == 0 { return true }
253    if x.len() == 0 { return false }
254    if y.len() > x.len() { return false }
255    find_prefix_overlap(x, y) == y.len()
256}