edit/simd/
memchr2.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! `memchr`, but with two needles.
5
6use std::ptr;
7
8/// `memchr`, but with two needles.
9///
10/// Returns the index of the first occurrence of either needle in the
11/// `haystack`. If no needle is found, `haystack.len()` is returned.
12/// `offset` specifies the index to start searching from.
13pub fn memchr2(needle1: u8, needle2: u8, haystack: &[u8], offset: usize) -> usize {
14    unsafe {
15        let beg = haystack.as_ptr();
16        let end = beg.add(haystack.len());
17        let it = beg.add(offset.min(haystack.len()));
18        let it = memchr2_raw(needle1, needle2, it, end);
19        it.offset_from_unsigned(beg)
20    }
21}
22
23unsafe fn memchr2_raw(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
24    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
25    return unsafe { MEMCHR2_DISPATCH(needle1, needle2, beg, end) };
26
27    #[cfg(target_arch = "aarch64")]
28    return unsafe { memchr2_neon(needle1, needle2, beg, end) };
29
30    #[allow(unreachable_code)]
31    return unsafe { memchr2_fallback(needle1, needle2, beg, end) };
32}
33
34unsafe fn memchr2_fallback(
35    needle1: u8,
36    needle2: u8,
37    mut beg: *const u8,
38    end: *const u8,
39) -> *const u8 {
40    unsafe {
41        while !ptr::eq(beg, end) {
42            let ch = *beg;
43            if ch == needle1 || ch == needle2 {
44                break;
45            }
46            beg = beg.add(1);
47        }
48        beg
49    }
50}
51
52// In order to make `memchr2_raw` slim and fast, we use a function pointer that updates
53// itself to the correct implementation on the first call. This reduces binary size.
54// It would also reduce branches if we had >2 implementations (a jump still needs to be predicted).
55// NOTE that this ONLY works if Control Flow Guard is disabled on Windows.
56#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
57static mut MEMCHR2_DISPATCH: unsafe fn(
58    needle1: u8,
59    needle2: u8,
60    beg: *const u8,
61    end: *const u8,
62) -> *const u8 = memchr2_dispatch;
63
64#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
65unsafe fn memchr2_dispatch(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
66    let func = if is_x86_feature_detected!("avx2") { memchr2_avx2 } else { memchr2_fallback };
67    unsafe { MEMCHR2_DISPATCH = func };
68    unsafe { func(needle1, needle2, beg, end) }
69}
70
71// FWIW, I found that adding support for AVX512 was not useful at the time,
72// as it only marginally improved file load performance by <5%.
73#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
74#[target_feature(enable = "avx2")]
75unsafe fn memchr2_avx2(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 {
76    unsafe {
77        #[cfg(target_arch = "x86")]
78        use std::arch::x86::*;
79        #[cfg(target_arch = "x86_64")]
80        use std::arch::x86_64::*;
81
82        let n1 = _mm256_set1_epi8(needle1 as i8);
83        let n2 = _mm256_set1_epi8(needle2 as i8);
84        let mut remaining = end.offset_from_unsigned(beg);
85
86        while remaining >= 32 {
87            let v = _mm256_loadu_si256(beg as *const _);
88            let a = _mm256_cmpeq_epi8(v, n1);
89            let b = _mm256_cmpeq_epi8(v, n2);
90            let c = _mm256_or_si256(a, b);
91            let m = _mm256_movemask_epi8(c) as u32;
92
93            if m != 0 {
94                return beg.add(m.trailing_zeros() as usize);
95            }
96
97            beg = beg.add(32);
98            remaining -= 32;
99        }
100
101        memchr2_fallback(needle1, needle2, beg, end)
102    }
103}
104
105#[cfg(target_arch = "aarch64")]
106unsafe fn memchr2_neon(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 {
107    unsafe {
108        use std::arch::aarch64::*;
109
110        if end.offset_from_unsigned(beg) >= 16 {
111            let n1 = vdupq_n_u8(needle1);
112            let n2 = vdupq_n_u8(needle2);
113
114            loop {
115                let v = vld1q_u8(beg as *const _);
116                let a = vceqq_u8(v, n1);
117                let b = vceqq_u8(v, n2);
118                let c = vorrq_u8(a, b);
119
120                // https://community.arm.com/arm-community-blogs/b/servers-and-cloud-computing-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
121                let m = vreinterpretq_u16_u8(c);
122                let m = vshrn_n_u16(m, 4);
123                let m = vreinterpret_u64_u8(m);
124                let m = vget_lane_u64(m, 0);
125
126                if m != 0 {
127                    return beg.add(m.trailing_zeros() as usize >> 2);
128                }
129
130                beg = beg.add(16);
131                if end.offset_from_unsigned(beg) < 16 {
132                    break;
133                }
134            }
135        }
136
137        memchr2_fallback(needle1, needle2, beg, end)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::slice;
144
145    use super::*;
146    use crate::sys;
147
148    #[test]
149    fn test_empty() {
150        assert_eq!(memchr2(b'a', b'b', b"", 0), 0);
151    }
152
153    #[test]
154    fn test_basic() {
155        let haystack = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
156        let haystack = &haystack[..43];
157
158        assert_eq!(memchr2(b'a', b'z', haystack, 0), 0);
159        assert_eq!(memchr2(b'p', b'q', haystack, 0), 15);
160        assert_eq!(memchr2(b'Q', b'Z', haystack, 0), 42);
161        assert_eq!(memchr2(b'0', b'9', haystack, 0), haystack.len());
162    }
163
164    // Test that it doesn't match before/after the start offset respectively.
165    #[test]
166    fn test_with_offset() {
167        let haystack = b"abcdefghabcdefghabcdefghabcdefghabcdefgh";
168
169        assert_eq!(memchr2(b'a', b'b', haystack, 0), 0);
170        assert_eq!(memchr2(b'a', b'b', haystack, 1), 1);
171        assert_eq!(memchr2(b'a', b'b', haystack, 2), 8);
172        assert_eq!(memchr2(b'a', b'b', haystack, 9), 9);
173        assert_eq!(memchr2(b'a', b'b', haystack, 16), 16);
174        assert_eq!(memchr2(b'a', b'b', haystack, 41), 40);
175    }
176
177    // Test memory access safety at page boundaries.
178    // The test is a success if it doesn't segfault.
179    #[test]
180    fn test_page_boundary() {
181        let page = unsafe {
182            const PAGE_SIZE: usize = 64 * 1024; // 64 KiB to cover many architectures.
183
184            // 3 pages: uncommitted, committed, uncommitted
185            let ptr = sys::virtual_reserve(PAGE_SIZE * 3).unwrap();
186            sys::virtual_commit(ptr.add(PAGE_SIZE), PAGE_SIZE).unwrap();
187            slice::from_raw_parts_mut(ptr.add(PAGE_SIZE).as_ptr(), PAGE_SIZE)
188        };
189
190        page.fill(b'a');
191
192        // Test if it seeks beyond the page boundary.
193        assert_eq!(memchr2(b'\0', b'\0', &page[page.len() - 40..], 0), 40);
194        // Test if it seeks before the page boundary for the masked/partial load.
195        assert_eq!(memchr2(b'\0', b'\0', &page[..10], 0), 10);
196    }
197}