1#![doc = include_str!("../README.md")]
2#![no_std]
3#![cfg_attr(feature = "nightly", allow(internal_features), feature(core_intrinsics))]
4#![cfg_attr(feature = "nightly", feature(portable_simd))]
5
6#[cfg(not(feature = "nightly"))]
11#[allow(unused)]
12pub(crate) use core::convert::{identity as likely, identity as unlikely};
13#[cfg(feature = "nightly")]
14#[allow(unused)]
15pub(crate) use core::intrinsics::{likely, unlikely};
16
17#[allow(unused)]
62const PAGE_SIZE: usize = 4096;
63
64#[allow(unused)]
65#[inline(always)]
66unsafe fn same_page<const VECTOR_SIZE: usize>(slice: &[u8]) -> bool {
67 let address = slice.as_ptr() as usize;
68 let offset_within_page = address & (PAGE_SIZE - 1);
70 offset_within_page < PAGE_SIZE - VECTOR_SIZE
72}
73
74fn count_shared_reference(p: &[u8], q: &[u8]) -> usize {
76 p.iter().zip(q)
77 .take_while(|(x, y)| x == y).count()
78}
79
80#[allow(unused)]
81#[cold]
82fn count_shared_cold(a: &[u8], b: &[u8]) -> usize {
83 count_shared_reference(a, b)
84}
85
86#[cfg(all(target_feature = "avx512f", target_feature = "avx512bw"))]
87#[inline(always)]
88fn count_shared_avx512(p: &[u8], q: &[u8]) -> usize {
89 use core::arch::x86_64::*;
90 unsafe {
91 let pl = p.len();
92 let ql = q.len();
93 let max_shared = pl.min(ql);
94 if unlikely(max_shared == 0) { return 0 }
95 let m = (!(0u64 as __mmask64)) >> (64 - max_shared.min(64));
96 let pv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, p.as_ptr() as _);
97 let qv = _mm512_mask_loadu_epi8(_mm512_setzero_si512(), m, q.as_ptr() as _);
98 let ne = !_mm512_cmpeq_epi8_mask(pv, qv);
99 let count = ne.trailing_zeros();
100 if count != 64 || max_shared < 65 {
101 (count as usize).min(max_shared)
102 } else {
103 let new_len = max_shared-64;
104 64 + count_shared_avx512(core::slice::from_raw_parts(p.as_ptr().add(64), new_len), core::slice::from_raw_parts(q.as_ptr().add(64), new_len))
105 }
106 }
107}
108
109#[allow(unused)]
110#[cfg(target_feature = "avx2")]
111#[inline(always)]
112fn count_shared_avx2(p: &[u8], q: &[u8]) -> usize {
113 use core::arch::x86_64::*;
114 unsafe {
115 let pl = p.len();
116 let ql = q.len();
117 let max_shared = pl.min(ql);
118 if unlikely(max_shared == 0) { return 0 }
119
120 let use_simd = if cfg!(feature = "miri_safe") {
121 pl >= 32 && ql >= 32
122 } else {
123 same_page::<32>(p) && same_page::<32>(q)
124 };
125
126 if likely(use_simd) {
127 let pv = _mm256_loadu_si256(p.as_ptr() as _);
128 let qv = _mm256_loadu_si256(q.as_ptr() as _);
129 let ev = _mm256_cmpeq_epi8(pv, qv);
130 let ne = !(_mm256_movemask_epi8(ev) as u32);
131 let count = ne.trailing_zeros();
132 if count != 32 || max_shared < 33 {
133 (count as usize).min(max_shared)
134 } else {
135 let new_len = max_shared-32;
136 32 + count_shared_avx2(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
137 }
138 } else {
139 count_shared_cold(p, q)
140 }
141 }
142}
143
144#[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon"))]
145#[inline(always)]
146fn count_shared_neon(p: &[u8], q: &[u8]) -> usize {
147 use core::arch::aarch64::*;
148 unsafe {
149 let pl = p.len();
150 let ql = q.len();
151 let max_shared = pl.min(ql);
152 if unlikely(max_shared == 0) { return 0 }
153
154 let use_simd = if cfg!(feature = "miri_safe") {
155 pl >= 16 && ql >= 16
156 } else {
157 same_page::<16>(p) && same_page::<16>(q)
158 };
159
160 if use_simd {
161 let pv = vld1q_u8(p.as_ptr());
162 let qv = vld1q_u8(q.as_ptr());
163 let eq = vceqq_u8(pv, qv);
164
165 let mut bytes = [core::mem::MaybeUninit::<u8>::uninit(); 16];
178 vst1q_u8(bytes.as_mut_ptr().cast(), eq);
179 let scalar128 = u128::from_le_bytes(core::mem::transmute(bytes));
180 let count = scalar128.trailing_ones() / 8;
181
182 if count != 16 || max_shared < 17 {
183 (count as usize).min(max_shared)
184 } else {
185 let new_len = max_shared-16;
186 16 + count_shared_neon(core::slice::from_raw_parts(p.as_ptr().add(16), new_len), core::slice::from_raw_parts(q.as_ptr().add(16), new_len))
187 }
188 } else {
189 return count_shared_cold(p, q);
190 }
191 }
192}
193
194#[cfg(feature = "nightly")]
195#[inline(always)]
196fn count_shared_simd(p: &[u8], q: &[u8]) -> usize {
197 use core::simd::{u8x32, cmp::SimdPartialEq};
198 unsafe {
199 let pl = p.len();
200 let ql = q.len();
201 let max_shared = pl.min(ql);
202 if unlikely(max_shared == 0) { return 0 }
203
204 let use_simd = if cfg!(feature = "miri_safe") {
205 pl >= 32 && ql >= 32
206 } else {
207 same_page::<32>(p) && same_page::<32>(q)
208 };
209
210 if use_simd {
211 let mut p_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
212 core::ptr::copy_nonoverlapping(p.as_ptr().cast(), (&mut p_array).as_mut_ptr(), 32);
213 let pv = u8x32::from_array(core::mem::transmute(p_array));
214 let mut q_array = [core::mem::MaybeUninit::<u8>::uninit(); 32];
215 core::ptr::copy_nonoverlapping(q.as_ptr().cast(), (&mut q_array).as_mut_ptr(), 32);
216 let qv = u8x32::from_array(core::mem::transmute(q_array));
217 let ev = pv.simd_eq(qv);
218 let mask = ev.to_bitmask();
219 let count = mask.trailing_ones();
220 if count != 32 || max_shared < 33 {
221 (count as usize).min(max_shared)
222 } else {
223 let new_len = max_shared-32;
224 32 + count_shared_simd(core::slice::from_raw_parts(p.as_ptr().add(32), new_len), core::slice::from_raw_parts(q.as_ptr().add(32), new_len))
225 }
226 } else {
227 return count_shared_cold(p, q);
228 }
229 }
230}
231
232#[inline]
251pub fn find_prefix_overlap(a: &[u8], b: &[u8]) -> usize {
252 #[cfg(all(target_feature="avx512f", target_feature="avx512bw"))]
253 {
254 count_shared_avx512(a, b)
255 }
256 #[cfg(all(target_feature="avx2", not(all(target_feature="avx512f", target_feature="avx512bw"))))]
257 {
258 count_shared_avx2(a, b)
259 }
260 #[cfg(all(not(feature = "nightly"), target_arch = "aarch64", target_feature = "neon"))]
261 {
262 count_shared_neon(a, b)
263 }
264 #[cfg(all(feature = "nightly", target_arch = "aarch64", target_feature = "neon"))]
265 {
266 count_shared_simd(a, b)
267 }
268 #[cfg(all(not(target_feature="avx2"), not(target_feature="neon")))]
269 {
270 count_shared_reference(a, b)
271 }
272}
273
274#[test]
275fn find_prefix_overlap_test() {
276 let tests = [
277 ("12345", "67890", 0),
278 ("", "12300", 0),
279 ("12345", "", 0),
280 ("12345", "12300", 3),
281 ("123", "123000000", 3),
282 ("123456789012345678901234567890xxxx", "123456789012345678901234567890yy", 30),
283 ("123456789012345678901234567890123456789012345678901234567890xxxx", "123456789012345678901234567890123456789012345678901234567890yy", 60),
284 ("1234567890123456xxxx", "1234567890123456yyyyyyy", 16),
285 ("123456789012345xxxx", "123456789012345yyyyyyy", 15),
286 ("12345678901234567xxxx", "12345678901234567yyyyyyy", 17),
287 ("1234567890123456789012345678901xxxx", "1234567890123456789012345678901yy", 31),
288 ("12345678901234567890123456789012xxxx", "12345678901234567890123456789012yy", 32),
289 ("123456789012345678901234567890123xxxx", "123456789012345678901234567890123yy", 33),
290 ("123456789012345678901234567890123456789012345678901234567890123xxxx", "123456789012345678901234567890123456789012345678901234567890123yy", 63),
291 ("1234567890123456789012345678901234567890123456789012345678901234xxxx", "1234567890123456789012345678901234567890123456789012345678901234yy", 64),
292 ("12345678901234567890123456789012345678901234567890123456789012345xxxx", "12345678901234567890123456789012345678901234567890123456789012345yy", 65),
293 ];
294
295 for test in tests {
296 let overlap = find_prefix_overlap(test.0.as_bytes(), test.1.as_bytes());
297 assert_eq!(overlap, test.2);
298 }
299}
300
301#[test]
302fn find_prefix_overlap_long_test() {
303 let a = [b'A'; 70];
304 let mut b = [b'A'; 70];
305 assert_eq!(find_prefix_overlap(&a, &b), 70);
306 b[69] = b'B';
307 assert_eq!(find_prefix_overlap(&a, &b), 69);
308 b[69] = b'A';
309 b[64] = b'B';
310 assert_eq!(find_prefix_overlap(&a, &b), 64);
311 b[64] = b'A';
312 b[0] = b'B';
313 assert_eq!(find_prefix_overlap(&a, &b), 0);
314}
315
316#[inline(always)]
318pub fn starts_with(x: &[u8], y: &[u8]) -> bool {
319 if y.len() == 0 { return true }
320 if x.len() == 0 { return false }
321 if y.len() > x.len() { return false }
322 find_prefix_overlap(x, y) == y.len()
323}