1use std::{cmp::min, slice};
6
7#[cfg(target_arch = "x86")]
8use std::arch::x86 as target_arch;
9#[cfg(target_arch = "x86_64")]
10use std::arch::x86_64 as target_arch;
11
12use self::target_arch::{
13 __m128i, _mm_cmpestri, _mm_cmpestrm, _mm_extract_epi16, _mm_loadu_si128,
14 _SIDD_CMP_EQUAL_ORDERED,
15};
16
17include!(concat!(env!("OUT_DIR"), "/src/simd_macros.rs"));
18
19const BYTES_PER_OPERATION: usize = 16;
20
21union TransmuteToSimd {
22 simd: __m128i,
23 bytes: [u8; 16],
24}
25
26trait PackedCompareControl {
27 fn needle(&self) -> __m128i;
28 fn needle_len(&self) -> i32;
29}
30
31#[inline]
32#[target_feature(enable = "sse4.2")]
33unsafe fn find_small<C, const CONTROL_BYTE: i32>(packed: PackedCompare<C, CONTROL_BYTE>, haystack: &[u8]) -> Option<usize>
34where
35 C: PackedCompareControl,
36{
37 let mut tail = [0u8; 16];
38 core::ptr::copy_nonoverlapping(haystack.as_ptr(), tail.as_mut_ptr(), haystack.len());
39 let haystack = &tail[..haystack.len()];
40 debug_assert!(haystack.len() < ::std::i32::MAX as usize);
41 packed.cmpestri(haystack.as_ptr(), haystack.len() as i32)
42}
43
44#[inline]
54#[target_feature(enable = "sse4.2")]
55unsafe fn find<C, const CONTROL_BYTE: i32>(packed: PackedCompare<C, CONTROL_BYTE>, mut haystack: &[u8]) -> Option<usize>
56where
57 C: PackedCompareControl,
58{
59 if haystack.is_empty() {
62 return None;
63 }
64
65 if haystack.len() < 16 {
66 return find_small(packed, haystack);
67 }
68
69 let mut offset = 0;
70
71 if let Some(misaligned) = Misalignment::new(haystack) {
72 if let Some(location) = packed.cmpestrm(misaligned.leading, misaligned.leading_junk) {
73 if location < haystack.len() {
77 return Some(location);
78 }
79 }
80
81 haystack = &haystack[misaligned.bytes_until_alignment..];
82 offset += misaligned.bytes_until_alignment;
83 }
84
85 let n_complete_chunks = haystack.len() / BYTES_PER_OPERATION;
87
88 let mut haystack_ptr = haystack.as_ptr();
92 let mut chunk_offset = 0;
93 for _ in 0..n_complete_chunks {
94 if let Some(location) = packed.cmpestri(haystack_ptr, BYTES_PER_OPERATION as i32) {
95 return Some(offset + chunk_offset + location);
96 }
97
98 haystack_ptr = haystack_ptr.offset(BYTES_PER_OPERATION as isize);
99 chunk_offset += BYTES_PER_OPERATION;
100 }
101 haystack = &haystack[chunk_offset..];
102 offset += chunk_offset;
103
104 if haystack.is_empty() {
106 return None;
107 }
108
109 find_small(packed, haystack).map(|loc| loc + offset)
110}
111
112struct PackedCompare<T, const CONTROL_BYTE: i32>(T);
113impl<T, const CONTROL_BYTE: i32> PackedCompare<T, CONTROL_BYTE>
114where
115 T: PackedCompareControl,
116{
117 #[inline]
118 #[target_feature(enable = "sse4.2")]
119 unsafe fn cmpestrm(&self, haystack: &[u8], leading_junk: usize) -> Option<usize> {
120 let haystack = _mm_loadu_si128(haystack.as_ptr() as *const __m128i);
122
123 let mask = _mm_cmpestrm(
124 self.0.needle(),
125 self.0.needle_len(),
126 haystack,
127 BYTES_PER_OPERATION as i32,
128 CONTROL_BYTE,
129 );
130 let mask = _mm_extract_epi16(mask, 0) as u16;
131
132 if mask.trailing_zeros() < 16 {
133 let mut mask = mask;
134 mask >>= leading_junk;
141 if mask == 0 {
144 None
146 } else {
147 let first_match = mask.trailing_zeros() as usize;
148 debug_assert!(first_match < 16);
149 Some(first_match)
150 }
151 } else {
152 None
153 }
154 }
155
156 #[inline]
157 #[target_feature(enable = "sse4.2")]
158 unsafe fn cmpestri(&self, haystack: *const u8, haystack_len: i32) -> Option<usize> {
159 debug_assert!(
160 (1..=16).contains(&haystack_len),
161 "haystack_len was {}",
162 haystack_len,
163 );
164
165 let haystack = _mm_loadu_si128(haystack as *const __m128i);
167
168 let location = _mm_cmpestri(
169 self.0.needle(),
170 self.0.needle_len(),
171 haystack,
172 haystack_len,
173 CONTROL_BYTE,
174 );
175
176 if location < 16 {
177 Some(location as usize)
178 } else {
179 None
180 }
181 }
182}
183
184#[derive(Debug)]
185struct Misalignment<'a> {
186 leading: &'a [u8],
187 leading_junk: usize,
188 bytes_until_alignment: usize,
189}
190
191impl<'a> Misalignment<'a> {
192 #[inline]
205 fn new(haystack: &[u8]) -> Option<Self> {
206 let aligned_start = ((haystack.as_ptr() as usize) & !0xF) as *const u8;
207
208 if aligned_start == haystack.as_ptr() {
210 return None;
211 }
212
213 let aligned_end = unsafe { aligned_start.offset(BYTES_PER_OPERATION as isize) };
214
215 let leading_junk = haystack.as_ptr() as usize - aligned_start as usize;
216 let leading_len = min(haystack.len() + leading_junk, BYTES_PER_OPERATION);
217
218 let leading = unsafe { slice::from_raw_parts(aligned_start, leading_len) };
219
220 let bytes_until_alignment = if leading_len == BYTES_PER_OPERATION {
221 aligned_end as usize - haystack.as_ptr() as usize
222 } else {
223 haystack.len()
224 };
225
226 Some(Misalignment {
227 leading,
228 leading_junk,
229 bytes_until_alignment,
230 })
231 }
232}
233
234pub struct Bytes {
235 needle: __m128i,
236 needle_len: i32,
237}
238
239impl Bytes {
240 pub fn new(bytes: [u8; 16], needle_len: i32) -> Self {
241 Bytes {
242 needle: unsafe { TransmuteToSimd { bytes }.simd },
243 needle_len,
244 }
245 }
246
247 #[inline]
248 #[target_feature(enable = "sse4.2")]
249 pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
250 find(PackedCompare::<_, 0>(self), haystack)
251 }
252}
253
254impl<'b> PackedCompareControl for &'b Bytes {
255 fn needle(&self) -> __m128i {
256 self.needle
257 }
258 fn needle_len(&self) -> i32 {
259 self.needle_len
260 }
261}
262
263pub struct ByteSubstring<'a> {
264 complete_needle: &'a [u8],
265 needle: __m128i,
266 needle_len: i32,
267}
268
269impl<'a> ByteSubstring<'a> {
270 pub fn new(needle: &'a[u8]) -> Self {
271 use std::cmp;
272
273 let mut simd_needle = [0; 16];
274 let len = cmp::min(simd_needle.len(), needle.len());
275 simd_needle[..len].copy_from_slice(&needle[..len]);
276 ByteSubstring {
277 complete_needle: needle,
278 needle: unsafe { TransmuteToSimd { bytes: simd_needle }.simd },
279 needle_len: len as i32,
280 }
281 }
282
283 #[cfg(feature = "pattern")]
284 pub fn needle_len(&self) -> usize {
285 self.complete_needle.len()
286 }
287
288 #[inline]
289 #[target_feature(enable = "sse4.2")]
290 pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
291 let mut offset = 0;
292
293 while let Some(idx) = find(PackedCompare::<_, _SIDD_CMP_EQUAL_ORDERED>(self), &haystack[offset..]) {
294 let abs_offset = offset + idx;
295 if haystack[abs_offset..].starts_with(self.complete_needle) {
297 return Some(abs_offset);
298 }
299
300 offset += idx + 1;
302 }
303
304 None
305 }
306}
307
308impl<'a, 'b> PackedCompareControl for &'b ByteSubstring<'a> {
309 fn needle(&self) -> __m128i {
310 self.needle
311 }
312 fn needle_len(&self) -> i32 {
313 self.needle_len
314 }
315}
316
317#[cfg(test)]
318mod test {
319 use proptest::prelude::*;
320 use std::{fmt, str};
321 use memmap::MmapMut;
322 use region::Protection;
323
324 use super::*;
325
326 lazy_static! {
327 static ref SPACE: Bytes = simd_bytes!(b' ');
328 static ref XML_DELIM_3: Bytes = simd_bytes!(b'<', b'>', b'&');
329 static ref XML_DELIM_5: Bytes = simd_bytes!(b'<', b'>', b'&', b'\'', b'"');
330 }
331
332 trait SliceFindPolyfill<T> {
333 fn find_any(&self, needles: &[T]) -> Option<usize>;
334 fn find_seq(&self, needle: &[T]) -> Option<usize>;
335 }
336
337 impl<T> SliceFindPolyfill<T> for [T]
338 where
339 T: PartialEq,
340 {
341 fn find_any(&self, needles: &[T]) -> Option<usize> {
342 self.iter().position(|c| needles.contains(c))
343 }
344
345 fn find_seq(&self, needle: &[T]) -> Option<usize> {
346 (0..self.len()).find(|&l| self[l..].starts_with(needle))
347 }
348 }
349
350 struct Haystack {
351 data: Vec<u8>,
352 start: usize,
353 }
354
355 impl Haystack {
356 fn without_start(&self) -> &[u8] {
357 &self.data
358 }
359
360 fn with_start(&self) -> &[u8] {
361 &self.data[self.start..]
362 }
363 }
364
365 impl fmt::Debug for Haystack {
367 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
368 f.debug_struct("Haystack")
369 .field("data", &self.data)
370 .field("(addr)", &self.data.as_ptr())
371 .field("start", &self.start)
372 .finish()
373 }
374 }
375
376 fn haystack() -> BoxedStrategy<Haystack> {
380 any::<Vec<u8>>()
381 .prop_flat_map(|data| {
382 let len = 0..=data.len();
383 (Just(data), len)
384 })
385 .prop_map(|(data, start)| Haystack { data, start })
386 .boxed()
387 }
388
389 #[derive(Debug)]
390 struct Needle {
391 data: [u8; 16],
392 len: usize,
393 }
394
395 impl Needle {
396 fn as_slice(&self) -> &[u8] {
397 &self.data[..self.len]
398 }
399 }
400
401 fn needle() -> BoxedStrategy<Needle> {
403 (any::<[u8; 16]>(), 0..=16_usize)
404 .prop_map(|(data, len)| Needle { data, len })
405 .boxed()
406 }
407
408 proptest! {
409 #[test]
410 fn works_as_find_does_for_up_to_and_including_16_bytes(
411 (haystack, needle) in (haystack(), needle())
412 ) {
413 let haystack = haystack.without_start();
414
415 let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
416 let them = haystack.find_any(needle.as_slice());
417 assert_eq!(us, them);
418 }
419
420 #[test]
421 fn works_as_find_does_for_various_memory_offsets(
422 (needle, haystack) in (needle(), haystack())
423 ) {
424 let haystack = haystack.with_start();
425
426 let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
427 let them = haystack.find_any(needle.as_slice());
428 assert_eq!(us, them);
429 }
430 }
431
432 #[test]
433 fn can_search_for_null_bytes() {
434 unsafe {
435 let null = simd_bytes!(b'\0');
436 assert_eq!(Some(1), null.find(b"a\0"));
437 assert_eq!(Some(0), null.find(b"\0"));
438 assert_eq!(None, null.find(b""));
439 }
440 }
441
442 #[test]
443 fn can_search_in_null_bytes() {
444 unsafe {
445 let a = simd_bytes!(b'a');
446 assert_eq!(Some(1), a.find(b"\0a"));
447 assert_eq!(None, a.find(b"\0"));
448 }
449 }
450
451 #[test]
452 fn space_is_found() {
453 unsafe {
454 assert_eq!(Some(0), SPACE.find(b" "));
460 assert_eq!(Some(1), SPACE.find(b"0 "));
461 assert_eq!(Some(2), SPACE.find(b"01 "));
462 assert_eq!(Some(3), SPACE.find(b"012 "));
463 assert_eq!(Some(4), SPACE.find(b"0123 "));
464 assert_eq!(Some(5), SPACE.find(b"01234 "));
465 assert_eq!(Some(6), SPACE.find(b"012345 "));
466 assert_eq!(Some(7), SPACE.find(b"0123456 "));
467 assert_eq!(Some(8), SPACE.find(b"01234567 "));
468 assert_eq!(Some(9), SPACE.find(b"012345678 "));
469 assert_eq!(Some(10), SPACE.find(b"0123456789 "));
470 assert_eq!(Some(11), SPACE.find(b"0123456789A "));
471 assert_eq!(Some(12), SPACE.find(b"0123456789AB "));
472 assert_eq!(Some(13), SPACE.find(b"0123456789ABC "));
473 assert_eq!(Some(14), SPACE.find(b"0123456789ABCD "));
474 assert_eq!(Some(15), SPACE.find(b"0123456789ABCDE "));
475 assert_eq!(Some(16), SPACE.find(b"0123456789ABCDEF "));
476 assert_eq!(Some(17), SPACE.find(b"0123456789ABCDEFG "));
477 }
478 }
479
480 #[test]
481 fn space_not_found() {
482 unsafe {
483 assert_eq!(None, SPACE.find(b""));
489 assert_eq!(None, SPACE.find(b"0"));
490 assert_eq!(None, SPACE.find(b"01"));
491 assert_eq!(None, SPACE.find(b"012"));
492 assert_eq!(None, SPACE.find(b"0123"));
493 assert_eq!(None, SPACE.find(b"01234"));
494 assert_eq!(None, SPACE.find(b"012345"));
495 assert_eq!(None, SPACE.find(b"0123456"));
496 assert_eq!(None, SPACE.find(b"01234567"));
497 assert_eq!(None, SPACE.find(b"012345678"));
498 assert_eq!(None, SPACE.find(b"0123456789"));
499 assert_eq!(None, SPACE.find(b"0123456789A"));
500 assert_eq!(None, SPACE.find(b"0123456789AB"));
501 assert_eq!(None, SPACE.find(b"0123456789ABC"));
502 assert_eq!(None, SPACE.find(b"0123456789ABCD"));
503 assert_eq!(None, SPACE.find(b"0123456789ABCDE"));
504 assert_eq!(None, SPACE.find(b"0123456789ABCDEF"));
505 assert_eq!(None, SPACE.find(b"0123456789ABCDEFG"));
506 }
507 }
508
509 #[test]
510 fn works_on_nonaligned_beginnings() {
511 unsafe {
512 let s = b"0123456789ABCDEF ".to_vec();
517
518 assert_eq!(Some(16), SPACE.find(&s[0..]));
519 assert_eq!(Some(15), SPACE.find(&s[1..]));
520 assert_eq!(Some(14), SPACE.find(&s[2..]));
521 assert_eq!(Some(13), SPACE.find(&s[3..]));
522 assert_eq!(Some(12), SPACE.find(&s[4..]));
523 assert_eq!(Some(11), SPACE.find(&s[5..]));
524 assert_eq!(Some(10), SPACE.find(&s[6..]));
525 assert_eq!(Some(9), SPACE.find(&s[7..]));
526 assert_eq!(Some(8), SPACE.find(&s[8..]));
527 assert_eq!(Some(7), SPACE.find(&s[9..]));
528 assert_eq!(Some(6), SPACE.find(&s[10..]));
529 assert_eq!(Some(5), SPACE.find(&s[11..]));
530 assert_eq!(Some(4), SPACE.find(&s[12..]));
531 assert_eq!(Some(3), SPACE.find(&s[13..]));
532 assert_eq!(Some(2), SPACE.find(&s[14..]));
533 assert_eq!(Some(1), SPACE.find(&s[15..]));
534 assert_eq!(Some(0), SPACE.find(&s[16..]));
535 assert_eq!(None, SPACE.find(&s[17..]));
536 }
537 }
538
539 #[test]
540 fn misalignment_does_not_cause_a_false_positive_before_start() {
541 const AAAA: u8 = 0x01;
542
543 let needle = Needle {
544 data: [
545 AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
546 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
547 ],
548 len: 1,
549 };
550 let haystack = Haystack {
551 data: vec![
552 AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
553 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
554 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
555 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
556 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
557 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
558 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
559 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
560 0x00, 0x00,
561 ],
562 start: 1,
563 };
564
565 let haystack = haystack.with_start();
566
567 assert_ne!(0, (haystack.as_ptr() as usize) % 16);
569 assert!(haystack.len() > 64);
572
573 let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
574 assert_eq!(None, us);
575 }
576
577 #[test]
578 fn xml_delim_3_is_found() {
579 unsafe {
580 assert_eq!(Some(0), XML_DELIM_3.find(b"<"));
581 assert_eq!(Some(0), XML_DELIM_3.find(b">"));
582 assert_eq!(Some(0), XML_DELIM_3.find(b"&"));
583 assert_eq!(None, XML_DELIM_3.find(b""));
584 }
585 }
586
587 #[test]
588 fn xml_delim_5_is_found() {
589 unsafe {
590 assert_eq!(Some(0), XML_DELIM_5.find(b"<"));
591 assert_eq!(Some(0), XML_DELIM_5.find(b">"));
592 assert_eq!(Some(0), XML_DELIM_5.find(b"&"));
593 assert_eq!(Some(0), XML_DELIM_5.find(b"'"));
594 assert_eq!(Some(0), XML_DELIM_5.find(b"\""));
595 assert_eq!(None, XML_DELIM_5.find(b""));
596 }
597 }
598
599 proptest! {
600 #[test]
601 fn works_as_find_does_for_byte_substrings(
602 (needle, haystack) in (any::<Vec<u8>>(), any::<Vec<u8>>())
603 ) {
604 let us = unsafe {
605 let s = ByteSubstring::new(&needle);
606 s.find(&haystack)
607 };
608 let them = haystack.find_seq(&needle);
609 assert_eq!(us, them);
610 }
611 }
612
613 #[test]
614 fn byte_substring_is_found() {
615 unsafe {
616 let substr = ByteSubstring::new(b"zz");
617 assert_eq!(Some(0), substr.find(b"zz"));
618 assert_eq!(Some(1), substr.find(b"0zz"));
619 assert_eq!(Some(2), substr.find(b"01zz"));
620 assert_eq!(Some(3), substr.find(b"012zz"));
621 assert_eq!(Some(4), substr.find(b"0123zz"));
622 assert_eq!(Some(5), substr.find(b"01234zz"));
623 assert_eq!(Some(6), substr.find(b"012345zz"));
624 assert_eq!(Some(7), substr.find(b"0123456zz"));
625 assert_eq!(Some(8), substr.find(b"01234567zz"));
626 assert_eq!(Some(9), substr.find(b"012345678zz"));
627 assert_eq!(Some(10), substr.find(b"0123456789zz"));
628 assert_eq!(Some(11), substr.find(b"0123456789Azz"));
629 assert_eq!(Some(12), substr.find(b"0123456789ABzz"));
630 assert_eq!(Some(13), substr.find(b"0123456789ABCzz"));
631 assert_eq!(Some(14), substr.find(b"0123456789ABCDzz"));
632 assert_eq!(Some(15), substr.find(b"0123456789ABCDEzz"));
633 assert_eq!(Some(16), substr.find(b"0123456789ABCDEFzz"));
634 assert_eq!(Some(17), substr.find(b"0123456789ABCDEFGzz"));
635 }
636 }
637
638 #[test]
639 fn byte_substring_is_not_found() {
640 unsafe {
641 let substr = ByteSubstring::new(b"zz");
642 assert_eq!(None, substr.find(b""));
643 assert_eq!(None, substr.find(b"0"));
644 assert_eq!(None, substr.find(b"01"));
645 assert_eq!(None, substr.find(b"012"));
646 assert_eq!(None, substr.find(b"0123"));
647 assert_eq!(None, substr.find(b"01234"));
648 assert_eq!(None, substr.find(b"012345"));
649 assert_eq!(None, substr.find(b"0123456"));
650 assert_eq!(None, substr.find(b"01234567"));
651 assert_eq!(None, substr.find(b"012345678"));
652 assert_eq!(None, substr.find(b"0123456789"));
653 assert_eq!(None, substr.find(b"0123456789A"));
654 assert_eq!(None, substr.find(b"0123456789AB"));
655 assert_eq!(None, substr.find(b"0123456789ABC"));
656 assert_eq!(None, substr.find(b"0123456789ABCD"));
657 assert_eq!(None, substr.find(b"0123456789ABCDE"));
658 assert_eq!(None, substr.find(b"0123456789ABCDEF"));
659 assert_eq!(None, substr.find(b"0123456789ABCDEFG"));
660 }
661 }
662
663 #[test]
664 fn byte_substring_has_false_positive() {
665 unsafe {
666 let substr = ByteSubstring::new(b"ab");
672 assert_eq!(Some(16), substr.find(b"aaaaaaaaaaaaaaaaab"))
673 };
675 }
676
677 #[test]
678 fn byte_substring_needle_is_longer_than_16_bytes() {
679 unsafe {
680 let needle = b"0123456789abcdefg";
681 let haystack = b"0123456789abcdefgh";
682 assert_eq!(Some(0), ByteSubstring::new(needle).find(haystack));
683 }
684 }
685
686 fn with_guarded_string(value: &str, f: impl FnOnce(&str)) {
687 let page_size = region::page::size();
691 assert!(value.len() <= page_size);
692
693 let mut mmap = MmapMut::map_anon(2 * page_size).unwrap();
695
696 let (first_page, second_page) = mmap.split_at_mut(page_size);
697
698 unsafe {
701 region::protect(second_page.as_ptr(), page_size, Protection::NONE).unwrap();
702 }
703
704 let dest = &mut first_page[page_size - value.len()..];
706 dest.copy_from_slice(value.as_bytes());
707 f(unsafe { str::from_utf8_unchecked(dest) });
708 }
709
710 #[test]
711 fn works_at_page_boundary() {
712 with_guarded_string("0123456789abcdef", |text| {
723 let needle = simd_bytes!(b'f');
725
726 for offset in 0..text.len() {
728 let tail = &text[offset..];
729 unsafe {
730 assert_eq!(Some(tail.len() - 1), needle.find(tail.as_bytes()));
731 }
732 }
733 });
734 }
735
736 #[test]
737 fn does_not_access_memory_after_haystack_when_haystack_is_multiple_of_16_bytes_and_no_match() {
738 with_guarded_string("0123456789abcdef", |text| {
742 let needle = simd_bytes!(b'z');
744
745 unsafe {
746 assert_eq!(None, needle.find(text.as_bytes()));
747 }
748 });
749 }
750}