jetscii/
simd.rs

1// # Warning
2//
3// Everything in this module assumes that the SSE 4.2 feature is available.
4
5use std::{cmp::min, slice};
6
7#[cfg(target_arch = "x86")]
8use std::arch::x86 as target_arch;
9#[cfg(target_arch = "x86_64")]
10use std::arch::x86_64 as target_arch;
11
12use self::target_arch::{
13    __m128i, _mm_cmpestri, _mm_cmpestrm, _mm_extract_epi16, _mm_loadu_si128,
14    _SIDD_CMP_EQUAL_ORDERED,
15};
16
17include!(concat!(env!("OUT_DIR"), "/src/simd_macros.rs"));
18
19const BYTES_PER_OPERATION: usize = 16;
20
21union TransmuteToSimd {
22    simd: __m128i,
23    bytes: [u8; 16],
24}
25
26trait PackedCompareControl {
27    fn needle(&self) -> __m128i;
28    fn needle_len(&self) -> i32;
29}
30
31#[inline]
32#[target_feature(enable = "sse4.2")]
33unsafe fn find_small<C, const CONTROL_BYTE: i32>(packed: PackedCompare<C, CONTROL_BYTE>, haystack: &[u8]) -> Option<usize>
34where
35    C: PackedCompareControl,
36{
37    let mut tail = [0u8; 16];
38    core::ptr::copy_nonoverlapping(haystack.as_ptr(), tail.as_mut_ptr(), haystack.len());
39    let haystack = &tail[..haystack.len()];
40    debug_assert!(haystack.len() < ::std::i32::MAX as usize);
41    packed.cmpestri(haystack.as_ptr(), haystack.len() as i32)
42}
43
44/// The PCMPxSTRx instructions always read 16 bytes worth of
45/// data. Although the instructions handle unaligned memory access
46/// just fine, they might attempt to read off the end of a page
47/// and into a protected area.
48///
49/// To handle this case, we read in 16-byte aligned chunks with
50/// respect to the *end* of the byte slice. This makes the
51/// complicated part in searching the leftover bytes at the
52/// beginning of the byte slice.
53#[inline]
54#[target_feature(enable = "sse4.2")]
55unsafe fn find<C, const CONTROL_BYTE: i32>(packed: PackedCompare<C, CONTROL_BYTE>, mut haystack: &[u8]) -> Option<usize>
56where
57    C: PackedCompareControl,
58{
59    // FIXME: EXPLAIN SAFETY
60
61    if haystack.is_empty() {
62        return None;
63    }
64
65    if haystack.len() < 16 {
66        return find_small(packed, haystack);
67    }
68
69    let mut offset = 0;
70
71    if let Some(misaligned) = Misalignment::new(haystack) {
72        if let Some(location) = packed.cmpestrm(misaligned.leading, misaligned.leading_junk) {
73            // Since the masking operation covers an entire
74            // 16-byte chunk, we have to see if the match occurred
75            // somewhere *after* our data
76            if location < haystack.len() {
77                return Some(location);
78            }
79        }
80
81        haystack = &haystack[misaligned.bytes_until_alignment..];
82        offset += misaligned.bytes_until_alignment;
83    }
84
85    // TODO: try removing the 16-byte loop and check the disasm
86    let n_complete_chunks = haystack.len() / BYTES_PER_OPERATION;
87
88    // Getting the pointer once before the loop avoids the
89    // overhead of manipulating the length of the slice inside the
90    // loop.
91    let mut haystack_ptr = haystack.as_ptr();
92    let mut chunk_offset = 0;
93    for _ in 0..n_complete_chunks {
94        if let Some(location) = packed.cmpestri(haystack_ptr, BYTES_PER_OPERATION as i32) {
95            return Some(offset + chunk_offset + location);
96        }
97
98        haystack_ptr = haystack_ptr.offset(BYTES_PER_OPERATION as isize);
99        chunk_offset += BYTES_PER_OPERATION;
100    }
101    haystack = &haystack[chunk_offset..];
102    offset += chunk_offset;
103
104    // No data left to search
105    if haystack.is_empty() {
106        return None;
107    }
108
109    find_small(packed, haystack).map(|loc| loc + offset)
110}
111
112struct PackedCompare<T, const CONTROL_BYTE: i32>(T);
113impl<T, const CONTROL_BYTE: i32> PackedCompare<T, CONTROL_BYTE>
114where
115    T: PackedCompareControl,
116{
117    #[inline]
118    #[target_feature(enable = "sse4.2")]
119    unsafe fn cmpestrm(&self, haystack: &[u8], leading_junk: usize) -> Option<usize> {
120        // TODO: document why this is ok
121        let haystack = _mm_loadu_si128(haystack.as_ptr() as *const __m128i);
122
123        let mask = _mm_cmpestrm(
124            self.0.needle(),
125            self.0.needle_len(),
126            haystack,
127            BYTES_PER_OPERATION as i32,
128            CONTROL_BYTE,
129        );
130        let mask = _mm_extract_epi16(mask, 0) as u16;
131
132        if mask.trailing_zeros() < 16 {
133            let mut mask = mask;
134            // Byte: 7 6 5 4 3 2 1 0
135            // Str : &[0, 1, 2, 3, ...]
136            //
137            // Bit-0 corresponds to Str-0; shifting to the right
138            // removes the parts of the string that don't belong to
139            // us.
140            mask >>= leading_junk;
141            // The first 1, starting from Bit-0 and going to Bit-7,
142            // denotes the position of the first match.
143            if mask == 0 {
144                // All of our matches were before the slice started
145                None
146            } else {
147                let first_match = mask.trailing_zeros() as usize;
148                debug_assert!(first_match < 16);
149                Some(first_match)
150            }
151        } else {
152            None
153        }
154    }
155
156    #[inline]
157    #[target_feature(enable = "sse4.2")]
158    unsafe fn cmpestri(&self, haystack: *const u8, haystack_len: i32) -> Option<usize> {
159        debug_assert!(
160            (1..=16).contains(&haystack_len),
161            "haystack_len was {}",
162            haystack_len,
163        );
164
165        // TODO: document why this is ok
166        let haystack = _mm_loadu_si128(haystack as *const __m128i);
167
168        let location = _mm_cmpestri(
169            self.0.needle(),
170            self.0.needle_len(),
171            haystack,
172            haystack_len,
173            CONTROL_BYTE,
174        );
175
176        if location < 16 {
177            Some(location as usize)
178        } else {
179            None
180        }
181    }
182}
183
184#[derive(Debug)]
185struct Misalignment<'a> {
186    leading: &'a [u8],
187    leading_junk: usize,
188    bytes_until_alignment: usize,
189}
190
191impl<'a> Misalignment<'a> {
192    /// # Cases
193    ///
194    /// 0123456789ABCDEF
195    /// |--|                < 1.
196    ///       |--|          < 2.
197    ///             |--|    < 3.
198    ///             |----|  < 4.
199    ///
200    /// 1. The input slice is aligned.
201    /// 2. The input slice is unaligned and is completely within the 16-byte chunk.
202    /// 3. The input slice is unaligned and touches the boundary of the 16-byte chunk.
203    /// 4. The input slice is unaligned and crosses the boundary of the 16-byte chunk.
204    #[inline]
205    fn new(haystack: &[u8]) -> Option<Self> {
206        let aligned_start = ((haystack.as_ptr() as usize) & !0xF) as *const u8;
207
208        // If we are already aligned, there's nothing to do
209        if aligned_start == haystack.as_ptr() {
210            return None;
211        }
212
213        let aligned_end = unsafe { aligned_start.offset(BYTES_PER_OPERATION as isize) };
214
215        let leading_junk = haystack.as_ptr() as usize - aligned_start as usize;
216        let leading_len = min(haystack.len() + leading_junk, BYTES_PER_OPERATION);
217
218        let leading = unsafe { slice::from_raw_parts(aligned_start, leading_len) };
219
220        let bytes_until_alignment = if leading_len == BYTES_PER_OPERATION {
221            aligned_end as usize - haystack.as_ptr() as usize
222        } else {
223            haystack.len()
224        };
225
226        Some(Misalignment {
227            leading,
228            leading_junk,
229            bytes_until_alignment,
230        })
231    }
232}
233
234pub struct Bytes {
235    needle: __m128i,
236    needle_len: i32,
237}
238
239impl Bytes {
240    pub /* const */ fn new(bytes: [u8; 16], needle_len: i32) -> Self {
241        Bytes {
242            needle: unsafe { TransmuteToSimd { bytes }.simd },
243            needle_len,
244        }
245    }
246
247    #[inline]
248    #[target_feature(enable = "sse4.2")]
249    pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
250        find(PackedCompare::<_, 0>(self), haystack)
251    }
252}
253
254impl<'b> PackedCompareControl for &'b Bytes {
255    fn needle(&self) -> __m128i {
256        self.needle
257    }
258    fn needle_len(&self) -> i32 {
259        self.needle_len
260    }
261}
262
263pub struct ByteSubstring<'a> {
264    complete_needle: &'a [u8],
265    needle: __m128i,
266    needle_len: i32,
267}
268
269impl<'a> ByteSubstring<'a> {
270    pub /* const */ fn new(needle: &'a[u8]) -> Self {
271        use std::cmp;
272
273        let mut simd_needle = [0; 16];
274        let len = cmp::min(simd_needle.len(), needle.len());
275        simd_needle[..len].copy_from_slice(&needle[..len]);
276        ByteSubstring {
277            complete_needle: needle,
278            needle: unsafe { TransmuteToSimd { bytes: simd_needle }.simd },
279            needle_len: len as i32,
280        }
281    }
282
283    #[cfg(feature = "pattern")]
284    pub fn needle_len(&self) -> usize {
285        self.complete_needle.len()
286    }
287
288    #[inline]
289    #[target_feature(enable = "sse4.2")]
290    pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
291        let mut offset = 0;
292
293        while let Some(idx) = find(PackedCompare::<_, _SIDD_CMP_EQUAL_ORDERED>(self), &haystack[offset..]) {
294            let abs_offset = offset + idx;
295            // Found a match, but is it really?
296            if haystack[abs_offset..].starts_with(self.complete_needle) {
297                return Some(abs_offset);
298            }
299
300            // Skip past this false positive
301            offset += idx + 1;
302        }
303
304        None
305    }
306}
307
308impl<'a, 'b> PackedCompareControl for &'b ByteSubstring<'a> {
309    fn needle(&self) -> __m128i {
310        self.needle
311    }
312    fn needle_len(&self) -> i32 {
313        self.needle_len
314    }
315}
316
317#[cfg(test)]
318mod test {
319    use proptest::prelude::*;
320    use std::{fmt, str};
321    use memmap::MmapMut;
322    use region::Protection;
323
324    use super::*;
325
326    lazy_static! {
327        static ref SPACE: Bytes = simd_bytes!(b' ');
328        static ref XML_DELIM_3: Bytes = simd_bytes!(b'<', b'>', b'&');
329        static ref XML_DELIM_5: Bytes = simd_bytes!(b'<', b'>', b'&', b'\'', b'"');
330    }
331
332    trait SliceFindPolyfill<T> {
333        fn find_any(&self, needles: &[T]) -> Option<usize>;
334        fn find_seq(&self, needle: &[T]) -> Option<usize>;
335    }
336
337    impl<T> SliceFindPolyfill<T> for [T]
338    where
339        T: PartialEq,
340    {
341        fn find_any(&self, needles: &[T]) -> Option<usize> {
342            self.iter().position(|c| needles.contains(c))
343        }
344
345        fn find_seq(&self, needle: &[T]) -> Option<usize> {
346            (0..self.len()).find(|&l| self[l..].starts_with(needle))
347        }
348    }
349
350    struct Haystack {
351        data: Vec<u8>,
352        start: usize,
353    }
354
355    impl Haystack {
356        fn without_start(&self) -> &[u8] {
357            &self.data
358        }
359
360        fn with_start(&self) -> &[u8] {
361            &self.data[self.start..]
362        }
363    }
364
365    // Knowing the address of the data can be important
366    impl fmt::Debug for Haystack {
367        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
368            f.debug_struct("Haystack")
369                .field("data", &self.data)
370                .field("(addr)", &self.data.as_ptr())
371                .field("start", &self.start)
372                .finish()
373        }
374    }
375
376    /// Creates a set of bytes and an offset inside them. Allows
377    /// checking arbitrary memory offsets, not just where the
378    /// allocator placed a value.
379    fn haystack() -> BoxedStrategy<Haystack> {
380        any::<Vec<u8>>()
381            .prop_flat_map(|data| {
382                let len = 0..=data.len();
383                (Just(data), len)
384            })
385            .prop_map(|(data, start)| Haystack { data, start })
386            .boxed()
387    }
388
389    #[derive(Debug)]
390    struct Needle {
391        data: [u8; 16],
392        len: usize,
393    }
394
395    impl Needle {
396        fn as_slice(&self) -> &[u8] {
397            &self.data[..self.len]
398        }
399    }
400
401    /// Creates an array and the number of valid values
402    fn needle() -> BoxedStrategy<Needle> {
403        (any::<[u8; 16]>(), 0..=16_usize)
404            .prop_map(|(data, len)| Needle { data, len })
405            .boxed()
406    }
407
408    proptest! {
409        #[test]
410        fn works_as_find_does_for_up_to_and_including_16_bytes(
411            (haystack, needle) in (haystack(), needle())
412        ) {
413           let haystack = haystack.without_start();
414
415           let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
416           let them = haystack.find_any(needle.as_slice());
417           assert_eq!(us, them);
418        }
419
420        #[test]
421        fn works_as_find_does_for_various_memory_offsets(
422            (needle, haystack) in (needle(), haystack())
423        ) {
424            let haystack = haystack.with_start();
425
426            let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
427            let them = haystack.find_any(needle.as_slice());
428            assert_eq!(us, them);
429        }
430    }
431
432    #[test]
433    fn can_search_for_null_bytes() {
434        unsafe {
435            let null = simd_bytes!(b'\0');
436            assert_eq!(Some(1), null.find(b"a\0"));
437            assert_eq!(Some(0), null.find(b"\0"));
438            assert_eq!(None, null.find(b""));
439        }
440    }
441
442    #[test]
443    fn can_search_in_null_bytes() {
444        unsafe {
445            let a = simd_bytes!(b'a');
446            assert_eq!(Some(1), a.find(b"\0a"));
447            assert_eq!(None, a.find(b"\0"));
448        }
449    }
450
451    #[test]
452    fn space_is_found() {
453        unsafe {
454            // Since the algorithm operates on 16-byte chunks, it's
455            // important to cover tests around that boundary. Since 16
456            // isn't that big of a number, we might as well do all of
457            // them.
458
459            assert_eq!(Some(0), SPACE.find(b" "));
460            assert_eq!(Some(1), SPACE.find(b"0 "));
461            assert_eq!(Some(2), SPACE.find(b"01 "));
462            assert_eq!(Some(3), SPACE.find(b"012 "));
463            assert_eq!(Some(4), SPACE.find(b"0123 "));
464            assert_eq!(Some(5), SPACE.find(b"01234 "));
465            assert_eq!(Some(6), SPACE.find(b"012345 "));
466            assert_eq!(Some(7), SPACE.find(b"0123456 "));
467            assert_eq!(Some(8), SPACE.find(b"01234567 "));
468            assert_eq!(Some(9), SPACE.find(b"012345678 "));
469            assert_eq!(Some(10), SPACE.find(b"0123456789 "));
470            assert_eq!(Some(11), SPACE.find(b"0123456789A "));
471            assert_eq!(Some(12), SPACE.find(b"0123456789AB "));
472            assert_eq!(Some(13), SPACE.find(b"0123456789ABC "));
473            assert_eq!(Some(14), SPACE.find(b"0123456789ABCD "));
474            assert_eq!(Some(15), SPACE.find(b"0123456789ABCDE "));
475            assert_eq!(Some(16), SPACE.find(b"0123456789ABCDEF "));
476            assert_eq!(Some(17), SPACE.find(b"0123456789ABCDEFG "));
477        }
478    }
479
480    #[test]
481    fn space_not_found() {
482        unsafe {
483            // Since the algorithm operates on 16-byte chunks, it's
484            // important to cover tests around that boundary. Since 16
485            // isn't that big of a number, we might as well do all of
486            // them.
487
488            assert_eq!(None, SPACE.find(b""));
489            assert_eq!(None, SPACE.find(b"0"));
490            assert_eq!(None, SPACE.find(b"01"));
491            assert_eq!(None, SPACE.find(b"012"));
492            assert_eq!(None, SPACE.find(b"0123"));
493            assert_eq!(None, SPACE.find(b"01234"));
494            assert_eq!(None, SPACE.find(b"012345"));
495            assert_eq!(None, SPACE.find(b"0123456"));
496            assert_eq!(None, SPACE.find(b"01234567"));
497            assert_eq!(None, SPACE.find(b"012345678"));
498            assert_eq!(None, SPACE.find(b"0123456789"));
499            assert_eq!(None, SPACE.find(b"0123456789A"));
500            assert_eq!(None, SPACE.find(b"0123456789AB"));
501            assert_eq!(None, SPACE.find(b"0123456789ABC"));
502            assert_eq!(None, SPACE.find(b"0123456789ABCD"));
503            assert_eq!(None, SPACE.find(b"0123456789ABCDE"));
504            assert_eq!(None, SPACE.find(b"0123456789ABCDEF"));
505            assert_eq!(None, SPACE.find(b"0123456789ABCDEFG"));
506        }
507    }
508
509    #[test]
510    fn works_on_nonaligned_beginnings() {
511        unsafe {
512            // We have special code for strings that don't lie on 16-byte
513            // boundaries. Since allocation seems to happen on these
514            // boundaries by default, let's walk around a bit.
515
516            let s = b"0123456789ABCDEF ".to_vec();
517
518            assert_eq!(Some(16), SPACE.find(&s[0..]));
519            assert_eq!(Some(15), SPACE.find(&s[1..]));
520            assert_eq!(Some(14), SPACE.find(&s[2..]));
521            assert_eq!(Some(13), SPACE.find(&s[3..]));
522            assert_eq!(Some(12), SPACE.find(&s[4..]));
523            assert_eq!(Some(11), SPACE.find(&s[5..]));
524            assert_eq!(Some(10), SPACE.find(&s[6..]));
525            assert_eq!(Some(9), SPACE.find(&s[7..]));
526            assert_eq!(Some(8), SPACE.find(&s[8..]));
527            assert_eq!(Some(7), SPACE.find(&s[9..]));
528            assert_eq!(Some(6), SPACE.find(&s[10..]));
529            assert_eq!(Some(5), SPACE.find(&s[11..]));
530            assert_eq!(Some(4), SPACE.find(&s[12..]));
531            assert_eq!(Some(3), SPACE.find(&s[13..]));
532            assert_eq!(Some(2), SPACE.find(&s[14..]));
533            assert_eq!(Some(1), SPACE.find(&s[15..]));
534            assert_eq!(Some(0), SPACE.find(&s[16..]));
535            assert_eq!(None, SPACE.find(&s[17..]));
536        }
537    }
538
539    #[test]
540    fn misalignment_does_not_cause_a_false_positive_before_start() {
541        const AAAA: u8 = 0x01;
542
543        let needle = Needle {
544            data: [
545                AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
546                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
547            ],
548            len: 1,
549        };
550        let haystack = Haystack {
551            data: vec![
552                AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
553                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
554                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
555                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
556                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
557                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
558                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
559                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
560                0x00, 0x00,
561            ],
562            start: 1,
563        };
564
565        let haystack = haystack.with_start();
566
567        // Needs to trigger the misalignment code
568        assert_ne!(0, (haystack.as_ptr() as usize) % 16);
569        // There are 64 bits in the mask and we check to make sure the
570        // result is less than the haystack
571        assert!(haystack.len() > 64);
572
573        let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
574        assert_eq!(None, us);
575    }
576
577    #[test]
578    fn xml_delim_3_is_found() {
579        unsafe {
580            assert_eq!(Some(0), XML_DELIM_3.find(b"<"));
581            assert_eq!(Some(0), XML_DELIM_3.find(b">"));
582            assert_eq!(Some(0), XML_DELIM_3.find(b"&"));
583            assert_eq!(None, XML_DELIM_3.find(b""));
584        }
585    }
586
587    #[test]
588    fn xml_delim_5_is_found() {
589        unsafe {
590            assert_eq!(Some(0), XML_DELIM_5.find(b"<"));
591            assert_eq!(Some(0), XML_DELIM_5.find(b">"));
592            assert_eq!(Some(0), XML_DELIM_5.find(b"&"));
593            assert_eq!(Some(0), XML_DELIM_5.find(b"'"));
594            assert_eq!(Some(0), XML_DELIM_5.find(b"\""));
595            assert_eq!(None, XML_DELIM_5.find(b""));
596        }
597    }
598
599    proptest! {
600        #[test]
601        fn works_as_find_does_for_byte_substrings(
602            (needle, haystack) in (any::<Vec<u8>>(), any::<Vec<u8>>())
603        ) {
604            let us = unsafe {
605                let s = ByteSubstring::new(&needle);
606                s.find(&haystack)
607            };
608            let them = haystack.find_seq(&needle);
609            assert_eq!(us, them);
610        }
611    }
612
613    #[test]
614    fn byte_substring_is_found() {
615        unsafe {
616            let substr = ByteSubstring::new(b"zz");
617            assert_eq!(Some(0), substr.find(b"zz"));
618            assert_eq!(Some(1), substr.find(b"0zz"));
619            assert_eq!(Some(2), substr.find(b"01zz"));
620            assert_eq!(Some(3), substr.find(b"012zz"));
621            assert_eq!(Some(4), substr.find(b"0123zz"));
622            assert_eq!(Some(5), substr.find(b"01234zz"));
623            assert_eq!(Some(6), substr.find(b"012345zz"));
624            assert_eq!(Some(7), substr.find(b"0123456zz"));
625            assert_eq!(Some(8), substr.find(b"01234567zz"));
626            assert_eq!(Some(9), substr.find(b"012345678zz"));
627            assert_eq!(Some(10), substr.find(b"0123456789zz"));
628            assert_eq!(Some(11), substr.find(b"0123456789Azz"));
629            assert_eq!(Some(12), substr.find(b"0123456789ABzz"));
630            assert_eq!(Some(13), substr.find(b"0123456789ABCzz"));
631            assert_eq!(Some(14), substr.find(b"0123456789ABCDzz"));
632            assert_eq!(Some(15), substr.find(b"0123456789ABCDEzz"));
633            assert_eq!(Some(16), substr.find(b"0123456789ABCDEFzz"));
634            assert_eq!(Some(17), substr.find(b"0123456789ABCDEFGzz"));
635        }
636    }
637
638    #[test]
639    fn byte_substring_is_not_found() {
640        unsafe {
641            let substr = ByteSubstring::new(b"zz");
642            assert_eq!(None, substr.find(b""));
643            assert_eq!(None, substr.find(b"0"));
644            assert_eq!(None, substr.find(b"01"));
645            assert_eq!(None, substr.find(b"012"));
646            assert_eq!(None, substr.find(b"0123"));
647            assert_eq!(None, substr.find(b"01234"));
648            assert_eq!(None, substr.find(b"012345"));
649            assert_eq!(None, substr.find(b"0123456"));
650            assert_eq!(None, substr.find(b"01234567"));
651            assert_eq!(None, substr.find(b"012345678"));
652            assert_eq!(None, substr.find(b"0123456789"));
653            assert_eq!(None, substr.find(b"0123456789A"));
654            assert_eq!(None, substr.find(b"0123456789AB"));
655            assert_eq!(None, substr.find(b"0123456789ABC"));
656            assert_eq!(None, substr.find(b"0123456789ABCD"));
657            assert_eq!(None, substr.find(b"0123456789ABCDE"));
658            assert_eq!(None, substr.find(b"0123456789ABCDEF"));
659            assert_eq!(None, substr.find(b"0123456789ABCDEFG"));
660        }
661    }
662
663    #[test]
664    fn byte_substring_has_false_positive() {
665        unsafe {
666            // The PCMPESTRI instruction will mark the "a" before "ab" as
667            // a match because it cannot look beyond the 16 byte window
668            // of the haystack. We need to double-check any match to
669            // ensure it completely matches.
670
671            let substr = ByteSubstring::new(b"ab");
672            assert_eq!(Some(16), substr.find(b"aaaaaaaaaaaaaaaaab"))
673            //   this "a" is a false positive ~~~~~~~~~~~~~~~^
674        };
675    }
676
677    #[test]
678    fn byte_substring_needle_is_longer_than_16_bytes() {
679        unsafe {
680            let needle = b"0123456789abcdefg";
681            let haystack = b"0123456789abcdefgh";
682            assert_eq!(Some(0), ByteSubstring::new(needle).find(haystack));
683        }
684    }
685
686    fn with_guarded_string(value: &str, f: impl FnOnce(&str)) {
687        // Allocate a string that ends directly before a
688        // read-protected page.
689
690        let page_size = region::page::size();
691        assert!(value.len() <= page_size);
692
693        // Map two rw-accessible pages of anonymous memory
694        let mut mmap = MmapMut::map_anon(2 * page_size).unwrap();
695
696        let (first_page, second_page) = mmap.split_at_mut(page_size);
697
698        // Prohibit any access to the second page, so that any attempt
699        // to read or write it would trigger a segfault
700        unsafe {
701            region::protect(second_page.as_ptr(), page_size, Protection::NONE).unwrap();
702        }
703
704        // Copy bytes to the end of the first page
705        let dest = &mut first_page[page_size - value.len()..];
706        dest.copy_from_slice(value.as_bytes());
707        f(unsafe { str::from_utf8_unchecked(dest) });
708    }
709
710    #[test]
711    fn works_at_page_boundary() {
712        // PCMPxSTRx instructions are known to read 16 bytes at a
713        // time. This behaviour may cause accidental segfaults by
714        // reading past the page boundary.
715        //
716        // For now, this test failing crashes the whole test
717        // suite. This could be fixed by setting a custom signal
718        // handler, though Rust lacks such facilities at the moment.
719
720        // Allocate a 16-byte string at page boundary.  To verify this
721        // test, set protect=false to prevent segfaults.
722        with_guarded_string("0123456789abcdef", |text| {
723            // Will search for the last char
724            let needle = simd_bytes!(b'f');
725
726            // Check all suffixes of our 16-byte string
727            for offset in 0..text.len() {
728                let tail = &text[offset..];
729                unsafe {
730                    assert_eq!(Some(tail.len() - 1), needle.find(tail.as_bytes()));
731                }
732            }
733        });
734    }
735
736    #[test]
737    fn does_not_access_memory_after_haystack_when_haystack_is_multiple_of_16_bytes_and_no_match() {
738        // For now, this test failing crashes the whole test
739        // suite. This could be fixed by setting a custom signal
740        // handler, though Rust lacks such facilities at the moment.
741        with_guarded_string("0123456789abcdef", |text| {
742            // Will search for a char not present
743            let needle = simd_bytes!(b'z');
744
745            unsafe {
746                assert_eq!(None, needle.find(text.as_bytes()));
747            }
748        });
749    }
750}