1use std::ptr;
7
8pub fn memchr2(needle1: u8, needle2: u8, haystack: &[u8], offset: usize) -> usize {
14 unsafe {
15 let beg = haystack.as_ptr();
16 let end = beg.add(haystack.len());
17 let it = beg.add(offset.min(haystack.len()));
18 let it = memchr2_raw(needle1, needle2, it, end);
19 it.offset_from_unsigned(beg)
20 }
21}
22
23unsafe fn memchr2_raw(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
24 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
25 return unsafe { MEMCHR2_DISPATCH(needle1, needle2, beg, end) };
26
27 #[cfg(target_arch = "aarch64")]
28 return unsafe { memchr2_neon(needle1, needle2, beg, end) };
29
30 #[allow(unreachable_code)]
31 return unsafe { memchr2_fallback(needle1, needle2, beg, end) };
32}
33
34unsafe fn memchr2_fallback(
35 needle1: u8,
36 needle2: u8,
37 mut beg: *const u8,
38 end: *const u8,
39) -> *const u8 {
40 unsafe {
41 while !ptr::eq(beg, end) {
42 let ch = *beg;
43 if ch == needle1 || ch == needle2 {
44 break;
45 }
46 beg = beg.add(1);
47 }
48 beg
49 }
50}
51
52#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
57static mut MEMCHR2_DISPATCH: unsafe fn(
58 needle1: u8,
59 needle2: u8,
60 beg: *const u8,
61 end: *const u8,
62) -> *const u8 = memchr2_dispatch;
63
64#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
65unsafe fn memchr2_dispatch(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
66 let func = if is_x86_feature_detected!("avx2") { memchr2_avx2 } else { memchr2_fallback };
67 unsafe { MEMCHR2_DISPATCH = func };
68 unsafe { func(needle1, needle2, beg, end) }
69}
70
71#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
74#[target_feature(enable = "avx2")]
75unsafe fn memchr2_avx2(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 {
76 unsafe {
77 #[cfg(target_arch = "x86")]
78 use std::arch::x86::*;
79 #[cfg(target_arch = "x86_64")]
80 use std::arch::x86_64::*;
81
82 let n1 = _mm256_set1_epi8(needle1 as i8);
83 let n2 = _mm256_set1_epi8(needle2 as i8);
84 let mut remaining = end.offset_from_unsigned(beg);
85
86 while remaining >= 32 {
87 let v = _mm256_loadu_si256(beg as *const _);
88 let a = _mm256_cmpeq_epi8(v, n1);
89 let b = _mm256_cmpeq_epi8(v, n2);
90 let c = _mm256_or_si256(a, b);
91 let m = _mm256_movemask_epi8(c) as u32;
92
93 if m != 0 {
94 return beg.add(m.trailing_zeros() as usize);
95 }
96
97 beg = beg.add(32);
98 remaining -= 32;
99 }
100
101 memchr2_fallback(needle1, needle2, beg, end)
102 }
103}
104
105#[cfg(target_arch = "aarch64")]
106unsafe fn memchr2_neon(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 {
107 unsafe {
108 use std::arch::aarch64::*;
109
110 if end.offset_from_unsigned(beg) >= 16 {
111 let n1 = vdupq_n_u8(needle1);
112 let n2 = vdupq_n_u8(needle2);
113
114 loop {
115 let v = vld1q_u8(beg as *const _);
116 let a = vceqq_u8(v, n1);
117 let b = vceqq_u8(v, n2);
118 let c = vorrq_u8(a, b);
119
120 let m = vreinterpretq_u16_u8(c);
122 let m = vshrn_n_u16(m, 4);
123 let m = vreinterpret_u64_u8(m);
124 let m = vget_lane_u64(m, 0);
125
126 if m != 0 {
127 return beg.add(m.trailing_zeros() as usize >> 2);
128 }
129
130 beg = beg.add(16);
131 if end.offset_from_unsigned(beg) < 16 {
132 break;
133 }
134 }
135 }
136
137 memchr2_fallback(needle1, needle2, beg, end)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::slice;
144
145 use super::*;
146 use crate::sys;
147
148 #[test]
149 fn test_empty() {
150 assert_eq!(memchr2(b'a', b'b', b"", 0), 0);
151 }
152
153 #[test]
154 fn test_basic() {
155 let haystack = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
156 let haystack = &haystack[..43];
157
158 assert_eq!(memchr2(b'a', b'z', haystack, 0), 0);
159 assert_eq!(memchr2(b'p', b'q', haystack, 0), 15);
160 assert_eq!(memchr2(b'Q', b'Z', haystack, 0), 42);
161 assert_eq!(memchr2(b'0', b'9', haystack, 0), haystack.len());
162 }
163
164 #[test]
166 fn test_with_offset() {
167 let haystack = b"abcdefghabcdefghabcdefghabcdefghabcdefgh";
168
169 assert_eq!(memchr2(b'a', b'b', haystack, 0), 0);
170 assert_eq!(memchr2(b'a', b'b', haystack, 1), 1);
171 assert_eq!(memchr2(b'a', b'b', haystack, 2), 8);
172 assert_eq!(memchr2(b'a', b'b', haystack, 9), 9);
173 assert_eq!(memchr2(b'a', b'b', haystack, 16), 16);
174 assert_eq!(memchr2(b'a', b'b', haystack, 41), 40);
175 }
176
177 #[test]
180 fn test_page_boundary() {
181 let page = unsafe {
182 const PAGE_SIZE: usize = 64 * 1024; let ptr = sys::virtual_reserve(PAGE_SIZE * 3).unwrap();
186 sys::virtual_commit(ptr.add(PAGE_SIZE), PAGE_SIZE).unwrap();
187 slice::from_raw_parts_mut(ptr.add(PAGE_SIZE).as_ptr(), PAGE_SIZE)
188 };
189
190 page.fill(b'a');
191
192 assert_eq!(memchr2(b'\0', b'\0', &page[page.len() - 40..], 0), 40);
194 assert_eq!(memchr2(b'\0', b'\0', &page[..10], 0), 10);
196 }
197}