1#![doc = include_str!("../README.md")]
2#![no_std]
3#![cfg_attr(feature = "nightly", allow(internal_features), feature(core_intrinsics))]
4#![cfg_attr(feature = "nightly", feature(portable_simd))]
5
6#[cfg(not(feature = "nightly"))]
11#[allow(unused)]
12pub(crate) use core::convert::{identity as likely, identity as unlikely};
13#[cfg(feature = "nightly")]
14#[allow(unused)]
15pub(crate) use core::intrinsics::{likely, unlikely};
16
17const PAGE_SIZE: usize = 4096;
62
63#[inline(always)]
64unsafe fn same_page<const VECTOR_SIZE: usize>(slice: &[u8]) -> bool {
65 let address = slice.as_ptr() as usize;
66 let offset_within_page = address & (PAGE_SIZE - 1);
68 offset_within_page < PAGE_SIZE - VECTOR_SIZE
70}
71
72fn count_shared_reference(p: &[u8], q: &[u8]) -> usize {
74 p.iter().zip(q)
75 .take_while(|(x, y)| x == y).count()
76}
77
78#[cold]
79fn count_shared_cold(a: &[u8], b: &[u8]) -> usize {
80 count_shared_reference(a, b)
81}
82
83#[cfg(target_feature = "avx512f")]
84#[inline(always)]
85fn count_shared_avx512(p: &[u8], q: &[u8]) -> usize {
86 use core::arch::x86_64::*;
87 unsafe {
88 let pl = p.len();
89 let ql = q.len();
90 let max_shared = pl.min(ql);
91 if unlikely(max_shared == 0) { return 0 }
92 let m = (!(0u64 as __mmask64)) >> (64 - max_shared.min(64));
93 let pv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, p.as_ptr() as _);
94 let qv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, q.as_ptr() as _);
95 let ne = !_mm512_cmpeq_epi8_mask(pv, qv);
96 let count = _tzcnt_u64(ne);
97 if count != 64 || max_shared < 65 {
98 (count as usize).min(max_shared)
99 } else {
100 let new_len = max_shared-64;
101 64 + count_shared_avx512(core::slice::from_raw_parts(p.as_ptr().add(64), new_len), core::slice::from_raw_parts(q.as_ptr().add(64), new_len))
102 }
103 }
104}
105
106#[cfg(all(target_feature="avx2", not(miri)))]
107#[inline(always)]
108fn count_shared_avx2(p: &[u8], q: &[u8]) -> usize {
109 use core::arch::x86_64::*;
110 unsafe {
111 let pl = p.len();
112 let ql = q.len();
113 let max_shared = pl.min(ql);
114 if unlikely(max_shared == 0) { return 0 }
115 if likely(same_page::<32>(p) && same_page::<32>(q)) {
116 let pv = _mm256_loadu_si256(p.as_ptr() as _);
117 let qv = _mm256_loadu_si256(q.as_ptr() as _);
118 let ev = _mm256_cmpeq_epi8(pv, qv);
119 let ne = !(_mm256_movemask_epi8(ev) as u32);
120 let count = _tzcnt_u32(ne);
121 if count != 32 || max_shared < 33 {
122 (count as usize).min(max_shared)
123 } else {
124 let new_len = max_shared-32;
125 32 + count_shared_avx2(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
126 }
127 } else {
128 count_shared_cold(p, q)
129 }
130 }
131}
132
133#[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
134#[inline(always)]
135fn count_shared_neon(p: &[u8], q: &[u8]) -> usize {
136 use core::arch::aarch64::*;
137 unsafe {
138 let pl = p.len();
139 let ql = q.len();
140 let max_shared = pl.min(ql);
141 if unlikely(max_shared == 0) { return 0 }
142
143 if same_page::<16>(p) && same_page::<16>(q) {
144 let pv = vld1q_u8(p.as_ptr());
145 let qv = vld1q_u8(q.as_ptr());
146 let eq = vceqq_u8(pv, qv);
147
148 let mut bytes = [core::mem::MaybeUninit::<u8>::uninit(); 16];
161 vst1q_u8(bytes.as_mut_ptr().cast(), eq);
162 let scalar128 = u128::from_le_bytes(core::mem::transmute(bytes));
163 let count = scalar128.trailing_ones() / 8;
164
165 if count != 16 || max_shared < 17 {
166 (count as usize).min(max_shared)
167 } else {
168 let new_len = max_shared-16;
169 16 + count_shared_neon(core::slice::from_raw_parts(p.as_ptr().add(16), new_len), core::slice::from_raw_parts(q.as_ptr().add(16), new_len))
170 }
171 } else {
172 return count_shared_cold(p, q);
173 }
174 }
175}
176
177#[cfg(all(feature = "nightly", not(miri)))]
178#[inline(always)]
179fn count_shared_simd(p: &[u8], q: &[u8]) -> usize {
180 use core::simd::{u8x32, cmp::SimdPartialEq};
181 unsafe {
182 let pl = p.len();
183 let ql = q.len();
184 let max_shared = pl.min(ql);
185 if unlikely(max_shared == 0) { return 0 }
186 if same_page::<32>(p) && same_page::<32>(q) {
187 let mut p_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
188 core::ptr::copy_nonoverlapping(p.as_ptr().cast(), (&mut p_array).as_mut_ptr(), 32);
189 let pv = u8x32::from_array(core::mem::transmute(p_array));
190 let mut q_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
191 core::ptr::copy_nonoverlapping(q.as_ptr().cast(), (&mut q_array).as_mut_ptr(), 32);
192 let qv = u8x32::from_array(core::mem::transmute(q_array));
193 let ev = pv.simd_eq(qv);
194 let mask = ev.to_bitmask();
195 let count = mask.trailing_ones();
196 if count != 32 || max_shared < 33 {
197 (count as usize).min(max_shared)
198 } else {
199 let new_len = max_shared-32;
200 32 + count_shared_simd(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
201 }
202 } else {
203 return count_shared_cold(p, q);
204 }
205 }
206}
207
208#[inline]
227pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
228 #[cfg(all(target_feature="avx512f", not(miri)))]
229 {
230 count_shared_avx512(a, b)
231 }
232 #[cfg(all(target_feature="avx2", not(target_feature="avx512f"), not(miri)))]
233 {
234 count_shared_avx2(a, b)
235 }
236 #[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon", not(miri)))]
237 {
238 count_shared_neon(a, b)
239 }
240 #[cfg(all(feature = "nightly", target_arch = "aarch64", target_feature = "neon", not(miri)))]
241 {
242 count_shared_simd(a, b)
243 }
244 #[cfg(any(all(not(target_feature="avx2"), not(target_feature="neon")), miri))]
245 {
246 count_shared_reference(a, b)
247 }
248}
249
250#[test]
251fn find_prefix_overlap_test() {
252 let tests = [
253 ("12345", "67890", 0),
254 ("", "12300", 0),
255 ("12345", "", 0),
256 ("12345", "12300", 3),
257 ("123", "123000000", 3),
258 ("123456789012345678901234567890xxxx", "123456789012345678901234567890yy", 30),
259 ("123456789012345678901234567890123456789012345678901234567890xxxx", "123456789012345678901234567890123456789012345678901234567890yy", 60),
260 ("1234567890123456xxxx", "1234567890123456yyyyyyy", 16),
261 ("123456789012345xxxx", "123456789012345yyyyyyy", 15),
262 ("12345678901234567xxxx", "12345678901234567yyyyyyy", 17),
263 ("1234567890123456789012345678901xxxx", "1234567890123456789012345678901yy", 31),
264 ("12345678901234567890123456789012xxxx", "12345678901234567890123456789012yy", 32),
265 ("123456789012345678901234567890123xxxx", "123456789012345678901234567890123yy", 33),
266 ("123456789012345678901234567890123456789012345678901234567890123xxxx", "123456789012345678901234567890123456789012345678901234567890123yy", 63),
267 ("1234567890123456789012345678901234567890123456789012345678901234xxxx", "1234567890123456789012345678901234567890123456789012345678901234yy", 64),
268 ("12345678901234567890123456789012345678901234567890123456789012345xxxx", "12345678901234567890123456789012345678901234567890123456789012345yy", 65),
269 ];
270
271 for test in tests {
272 let overlap = find_prefix_overlap(test.0.as_bytes(), test.1.as_bytes());
273 assert_eq!(overlap, test.2);
274 }
275}
276
277#[inline(always)]
279pub fn starts_with(x: &[u8], y: &[u8]) -> bool {
280 if y.len() == 0 { return true }
281 if x.len() == 0 { return false }
282 if y.len() > x.len() { return false }
283 find_prefix_overlap(x, y) == y.len()
284}