1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5 use std::marker::PhantomData;
6
7 use crate::ext::Pointer;
8
9 #[inline(always)]
10 fn get_for_offset(mask: u32) -> u32 {
11 #[cfg(target_endian = "big")]
12 {
13 mask.swap_bytes()
14 }
15 #[cfg(target_endian = "little")]
16 {
17 mask
18 }
19 }
20
21 #[inline(always)]
22 fn first_offset(mask: u32) -> usize {
23 get_for_offset(mask).trailing_zeros() as usize
24 }
25
26 #[inline(always)]
27 fn clear_least_significant_bit(mask: u32) -> u32 {
28 mask & (mask - 1)
29 }
30
31 pub mod sse2 {
32 use super::*;
33
34 use core::arch::x86_64::{
35 __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36 _mm_set1_epi8,
37 };
38
39 #[derive(Debug)]
40 pub struct SSE2Searcher {
41 n1: u8,
42 n2: u8,
43 n3: u8,
44 v1: __m128i,
45 v2: __m128i,
46 v3: __m128i,
47 }
48
49 impl SSE2Searcher {
50 #[inline]
51 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52 Self {
53 n1,
54 n2,
55 n3,
56 v1: _mm_set1_epi8(n1 as i8),
57 v2: _mm_set1_epi8(n2 as i8),
58 v3: _mm_set1_epi8(n3 as i8),
59 }
60 }
61
62 #[inline(always)]
63 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64 SSE2Indices::new(self, haystack)
65 }
66 }
67
68 #[derive(Debug)]
69 pub struct SSE2Indices<'s, 'h> {
70 searcher: &'s SSE2Searcher,
71 haystack: PhantomData<&'h [u8]>,
72 start: *const u8,
73 end: *const u8,
74 current: *const u8,
75 mask: u32,
76 }
77
78 impl<'s, 'h> SSE2Indices<'s, 'h> {
79 #[inline]
80 fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81 let ptr = haystack.as_ptr();
82
83 Self {
84 searcher,
85 haystack: PhantomData,
86 start: ptr,
87 end: ptr.wrapping_add(haystack.len()),
88 current: ptr,
89 mask: 0,
90 }
91 }
92 }
93
94 const SSE2_STEP: usize = 16;
95
96 impl SSE2Indices<'_, '_> {
97 #[inline]
98 pub unsafe fn _next_in_current_mask(&mut self) -> Option<usize> {
99 let mask = self.mask;
100 let current = self.current;
101
102 if mask != 0 {
103 let offset = current.sub(SSE2_STEP).add(first_offset(mask));
104 self.mask = clear_least_significant_bit(mask);
105 self.current = current;
106
107 Some(offset.distance(self.start))
108 } else {
109 None
110 }
111 }
112
113 pub unsafe fn next(&mut self) -> Option<usize> {
114 if self.start >= self.end {
115 return None;
116 }
117
118 let mut mask = self.mask;
119
120 let mut current = self.current;
121 let start = self.start;
122 let len = self.end.distance(start);
123 let v1 = self.searcher.v1;
124 let v2 = self.searcher.v2;
125 let v3 = self.searcher.v3;
126
127 'main: loop {
128 if mask != 0 {
130 let offset = current.sub(SSE2_STEP).add(first_offset(mask));
131 self.mask = clear_least_significant_bit(mask);
132 self.current = current;
133
134 return Some(offset.distance(start));
135 }
136
137 if len >= SSE2_STEP {
139 let vectorized_end = self.end.sub(SSE2_STEP);
140
141 while current <= vectorized_end {
142 let chunk = _mm_loadu_si128(current as *const __m128i);
143 let cmp1 = _mm_cmpeq_epi8(chunk, v1);
144 let cmp2 = _mm_cmpeq_epi8(chunk, v2);
145 let cmp3 = _mm_cmpeq_epi8(chunk, v3);
146 let cmp = _mm_or_si128(cmp1, cmp2);
147 let cmp = _mm_or_si128(cmp, cmp3);
148
149 mask = _mm_movemask_epi8(cmp) as u32;
150
151 current = current.add(SSE2_STEP);
152
153 if mask != 0 {
154 continue 'main;
155 }
156 }
157 }
158
159 while current < self.end {
161 if *current == self.searcher.n1
162 || *current == self.searcher.n2
163 || *current == self.searcher.n3
164 {
165 let offset = current.distance(start);
166 self.current = current.add(1);
167 return Some(offset);
168 }
169 current = current.add(1);
170 }
171
172 return None;
173 }
174 }
175 }
176 }
177}
178
179#[cfg(target_arch = "aarch64")]
180mod aarch64 {
181 use core::arch::aarch64::{
182 uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
183 vreinterpretq_u16_u8, vshrn_n_u16,
184 };
185 use std::marker::PhantomData;
186
187 use crate::ext::Pointer;
188
189 #[inline(always)]
190 unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
191 let asu16s = vreinterpretq_u16_u8(v);
192 let mask = vshrn_n_u16(asu16s, 4);
193 let asu64 = vreinterpret_u64_u8(mask);
194 let scalar64 = vget_lane_u64(asu64, 0);
195
196 scalar64 & 0x8888888888888888
197 }
198
199 #[inline(always)]
200 fn first_offset(mask: u64) -> usize {
201 (mask.trailing_zeros() >> 2) as usize
202 }
203
204 #[inline(always)]
205 fn clear_least_significant_bit(mask: u64) -> u64 {
206 mask & (mask - 1)
207 }
208
209 #[derive(Debug)]
210 pub struct NeonSearcher {
211 n1: u8,
212 n2: u8,
213 n3: u8,
214 v1: uint8x16_t,
215 v2: uint8x16_t,
216 v3: uint8x16_t,
217 }
218
219 impl NeonSearcher {
220 #[inline]
221 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
222 Self {
223 n1,
224 n2,
225 n3,
226 v1: vdupq_n_u8(n1),
227 v2: vdupq_n_u8(n2),
228 v3: vdupq_n_u8(n3),
229 }
230 }
231
232 #[inline(always)]
233 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
234 NeonIndices::new(self, haystack)
235 }
236 }
237
238 #[derive(Debug)]
239 pub struct NeonIndices<'s, 'h> {
240 searcher: &'s NeonSearcher,
241 haystack: PhantomData<&'h [u8]>,
242 start: *const u8,
243 end: *const u8,
244 current: *const u8,
245 mask: u64,
246 }
247
248 impl<'s, 'h> NeonIndices<'s, 'h> {
249 #[inline]
250 fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
251 let ptr = haystack.as_ptr();
252
253 Self {
254 searcher,
255 haystack: PhantomData,
256 start: ptr,
257 end: ptr.wrapping_add(haystack.len()),
258 current: ptr,
259 mask: 0,
260 }
261 }
262 }
263
264 const NEON_STEP: usize = 16;
265
266 impl NeonIndices<'_, '_> {
267 #[inline]
268 pub unsafe fn _next_in_current_mask(&mut self) -> Option<usize> {
269 let mask = self.mask;
270 let current = self.current;
271
272 if mask != 0 {
273 let offset = current.sub(NEON_STEP).add(first_offset(mask));
274 self.mask = clear_least_significant_bit(mask);
275 self.current = current;
276
277 Some(offset.distance(self.start))
278 } else {
279 None
280 }
281 }
282
283 pub unsafe fn next(&mut self) -> Option<usize> {
284 if self.start >= self.end {
285 return None;
286 }
287
288 let mut mask = self.mask;
289 let mut current = self.current;
290 let start = self.start;
291 let len = self.end.distance(start);
292 let v1 = self.searcher.v1;
293 let v2 = self.searcher.v2;
294 let v3 = self.searcher.v3;
295
296 'main: loop {
297 if mask != 0 {
299 let offset = current.sub(NEON_STEP).add(first_offset(mask));
300 self.mask = clear_least_significant_bit(mask);
301 self.current = current;
302
303 return Some(offset.distance(start));
304 }
305
306 if len >= NEON_STEP {
308 let vectorized_end = self.end.sub(NEON_STEP);
309
310 while current <= vectorized_end {
311 let chunk = vld1q_u8(current);
312 let cmp1 = vceqq_u8(chunk, v1);
313 let cmp2 = vceqq_u8(chunk, v2);
314 let cmp3 = vceqq_u8(chunk, v3);
315 let cmp = vorrq_u8(cmp1, cmp2);
316 let cmp = vorrq_u8(cmp, cmp3);
317
318 mask = neon_movemask(cmp);
319
320 current = current.add(NEON_STEP);
321
322 if mask != 0 {
323 continue 'main;
324 }
325 }
326 }
327
328 while current < self.end {
330 if *current == self.searcher.n1
331 || *current == self.searcher.n2
332 || *current == self.searcher.n3
333 {
334 let offset = current.distance(start);
335 self.current = current.add(1);
336 return Some(offset);
337 }
338 current = current.add(1);
339 }
340
341 return None;
342 }
343 }
344 }
345}
346
347pub fn searcher_simd_instructions() -> &'static str {
353 #[cfg(target_arch = "x86_64")]
354 {
355 "sse2"
356 }
357
358 #[cfg(target_arch = "aarch64")]
359 {
360 "neon"
361 }
362
363 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
364 {
365 "none"
366 }
367}
368
369#[derive(Debug)]
370pub struct Searcher {
371 #[cfg(target_arch = "x86_64")]
372 inner: x86_64::sse2::SSE2Searcher,
373
374 #[cfg(target_arch = "aarch64")]
375 inner: aarch64::NeonSearcher,
376
377 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
378 inner: memchr::arch::all::memchr::Three,
379}
380
381impl Searcher {
382 #[inline(always)]
383 pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
384 #[cfg(target_arch = "x86_64")]
385 {
386 unsafe {
387 Self {
388 inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
389 }
390 }
391 }
392
393 #[cfg(target_arch = "aarch64")]
394 {
395 unsafe {
396 Self {
397 inner: aarch64::NeonSearcher::new(n1, n2, n3),
398 }
399 }
400 }
401
402 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
403 {
404 Self {
405 inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
406 }
407 }
408 }
409
410 #[inline(always)]
411 pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
412 #[cfg(target_arch = "x86_64")]
413 {
414 Indices {
415 inner: self.inner.iter(haystack),
416 }
417 }
418
419 #[cfg(target_arch = "aarch64")]
420 {
421 Indices {
422 inner: self.inner.iter(haystack),
423 }
424 }
425
426 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
427 {
428 Indices {
429 inner: self.inner.iter(haystack),
430 }
431 }
432 }
433}
434
435#[derive(Debug)]
436pub struct Indices<'s, 'h> {
437 #[cfg(target_arch = "x86_64")]
438 inner: x86_64::sse2::SSE2Indices<'s, 'h>,
439
440 #[cfg(target_arch = "aarch64")]
441 inner: aarch64::NeonIndices<'s, 'h>,
442
443 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
444 inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
445}
446
447impl FusedIterator for Indices<'_, '_> {}
448
449impl Iterator for Indices<'_, '_> {
450 type Item = usize;
451
452 #[inline(always)]
453 fn next(&mut self) -> Option<Self::Item> {
454 #[cfg(target_arch = "x86_64")]
455 {
456 unsafe { self.inner.next() }
457 }
458
459 #[cfg(target_arch = "aarch64")]
460 {
461 unsafe { self.inner.next() }
462 }
463
464 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
465 {
466 self.inner.next()
467 }
468 }
469}
470
471impl Indices<'_, '_> {
472 #[inline(always)]
473 pub fn _next_in_current_mask(&mut self) -> Option<usize> {
474 #[cfg(target_arch = "x86_64")]
475 {
476 unsafe { self.inner._next_in_current_mask() }
477 }
478
479 #[cfg(target_arch = "aarch64")]
480 {
481 unsafe { self.inner._next_in_current_mask() }
482 }
483
484 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
485 {
486 None
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 use memchr::arch::all::memchr::Three;
496
497 static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
498 static TEST_STRING_OFFSETS: &[usize; 18] = &[
499 4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
500 ];
501
502 #[test]
503 fn test_scalar_searcher() {
504 fn split(haystack: &[u8]) -> Vec<usize> {
505 let searcher = Three::new(b',', b'"', b'\n');
506 searcher.iter(haystack).collect()
507 }
508
509 let offsets = split(TEST_STRING);
510 assert_eq!(offsets, TEST_STRING_OFFSETS);
511
512 assert!(split("b".repeat(75).as_bytes()).is_empty());
514
515 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
517
518 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
520
521 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
523
524 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
526 }
527
528 #[test]
529 fn test_searcher() {
530 fn split(haystack: &[u8]) -> Vec<usize> {
531 let searcher = Searcher::new(b',', b'"', b'\n');
532 searcher.search(haystack).collect()
533 }
534
535 let offsets = split(TEST_STRING);
536 assert_eq!(offsets, TEST_STRING_OFFSETS);
537
538 assert!(split("b".repeat(75).as_bytes()).is_empty());
540
541 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
543
544 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
546
547 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
549
550 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
552
553 let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
555 let complex_indices = split(complex);
556
557 assert!(complex_indices
558 .iter()
559 .copied()
560 .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
561
562 assert_eq!(
563 complex_indices,
564 Three::new(b',', b'\n', b'"')
565 .iter(complex)
566 .collect::<Vec<_>>()
567 );
568 }
569}