simd_csv/
searcher.rs

1#[cfg(target_arch = "x86_64")]
2mod x86_64 {
3    use std::marker::PhantomData;
4
5    use crate::ext::Pointer;
6
7    #[inline(always)]
8    fn get_for_offset(mask: u32) -> u32 {
9        #[cfg(target_endian = "big")]
10        {
11            mask.swap_bytes()
12        }
13        #[cfg(target_endian = "little")]
14        {
15            mask
16        }
17    }
18
19    #[inline(always)]
20    fn first_offset(mask: u32) -> usize {
21        get_for_offset(mask).trailing_zeros() as usize
22    }
23
24    #[inline(always)]
25    fn clear_least_significant_bit(mask: u32) -> u32 {
26        mask & (mask - 1)
27    }
28
29    pub mod sse2 {
30        use super::*;
31
32        use core::arch::x86_64::{
33            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
34            _mm_set1_epi8,
35        };
36
37        #[derive(Debug)]
38        pub struct SSE2Searcher {
39            n1: u8,
40            n2: u8,
41            n3: u8,
42            v1: __m128i,
43            v2: __m128i,
44            v3: __m128i,
45        }
46
47        impl SSE2Searcher {
48            #[inline]
49            pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
50                Self {
51                    n1,
52                    n2,
53                    n3,
54                    v1: _mm_set1_epi8(n1 as i8),
55                    v2: _mm_set1_epi8(n2 as i8),
56                    v3: _mm_set1_epi8(n3 as i8),
57                }
58            }
59
60            #[inline(always)]
61            pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
62                SSE2Indices::new(self, haystack)
63            }
64        }
65
66        #[derive(Debug)]
67        pub struct SSE2Indices<'s, 'h> {
68            searcher: &'s SSE2Searcher,
69            haystack: PhantomData<&'h [u8]>,
70            start: *const u8,
71            end: *const u8,
72            current: *const u8,
73            mask: u32,
74        }
75
76        impl<'s, 'h> SSE2Indices<'s, 'h> {
77            #[inline]
78            fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
79                let ptr = haystack.as_ptr();
80
81                Self {
82                    searcher,
83                    haystack: PhantomData,
84                    start: ptr,
85                    end: ptr.wrapping_add(haystack.len()),
86                    current: ptr,
87                    mask: 0,
88                }
89            }
90        }
91
92        const SSE2_STEP: usize = 16;
93
94        impl<'s, 'h> SSE2Indices<'s, 'h> {
95            pub unsafe fn next(&mut self) -> Option<usize> {
96                if self.start >= self.end {
97                    return None;
98                }
99
100                let mut mask = self.mask;
101                let vectorized_end = self.end.sub(SSE2_STEP);
102                let mut current = self.current;
103                let start = self.start;
104                let v1 = self.searcher.v1;
105                let v2 = self.searcher.v2;
106                let v3 = self.searcher.v3;
107
108                'main: loop {
109                    // Processing current move mask
110                    if mask != 0 {
111                        let offset = current.sub(SSE2_STEP).add(first_offset(mask));
112                        self.mask = clear_least_significant_bit(mask);
113                        self.current = current;
114
115                        return Some(offset.distance(start));
116                    }
117
118                    // Main loop of unaligned loads
119                    while current <= vectorized_end {
120                        let chunk = _mm_loadu_si128(current as *const __m128i);
121                        let cmp1 = _mm_cmpeq_epi8(chunk, v1);
122                        let cmp2 = _mm_cmpeq_epi8(chunk, v2);
123                        let cmp3 = _mm_cmpeq_epi8(chunk, v3);
124                        let cmp = _mm_or_si128(cmp1, cmp2);
125                        let cmp = _mm_or_si128(cmp, cmp3);
126
127                        mask = _mm_movemask_epi8(cmp) as u32;
128
129                        current = current.add(SSE2_STEP);
130
131                        if mask != 0 {
132                            continue 'main;
133                        }
134                    }
135
136                    // Processing remaining bytes linearly
137                    while current < self.end {
138                        if *current == self.searcher.n1
139                            || *current == self.searcher.n2
140                            || *current == self.searcher.n3
141                        {
142                            let offset = current.distance(start);
143                            self.current = current.add(1);
144                            return Some(offset);
145                        }
146                        current = current.add(1);
147                    }
148
149                    return None;
150                }
151            }
152        }
153    }
154}
155
156#[cfg(target_arch = "aarch64")]
157mod aarch64 {
158    use core::arch::aarch64::{
159        uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
160        vreinterpretq_u16_u8, vshrn_n_u16,
161    };
162    use std::marker::PhantomData;
163
164    use crate::ext::Pointer;
165
166    #[inline(always)]
167    unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
168        let asu16s = vreinterpretq_u16_u8(v);
169        let mask = vshrn_n_u16(asu16s, 4);
170        let asu64 = vreinterpret_u64_u8(mask);
171        let scalar64 = vget_lane_u64(asu64, 0);
172
173        scalar64 & 0x8888888888888888
174    }
175
176    #[inline(always)]
177    fn first_offset(mask: u64) -> usize {
178        (mask.trailing_zeros() >> 2) as usize
179    }
180
181    #[inline(always)]
182    fn clear_least_significant_bit(mask: u64) -> u64 {
183        mask & (mask - 1)
184    }
185
186    #[derive(Debug)]
187    pub struct NeonSearcher {
188        n1: u8,
189        n2: u8,
190        n3: u8,
191        v1: uint8x16_t,
192        v2: uint8x16_t,
193        v3: uint8x16_t,
194    }
195
196    impl NeonSearcher {
197        #[inline]
198        pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
199            Self {
200                n1,
201                n2,
202                n3,
203                v1: vdupq_n_u8(n1),
204                v2: vdupq_n_u8(n2),
205                v3: vdupq_n_u8(n3),
206            }
207        }
208
209        #[inline(always)]
210        pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
211            NeonIndices::new(self, haystack)
212        }
213    }
214
215    #[derive(Debug)]
216    pub struct NeonIndices<'s, 'h> {
217        searcher: &'s NeonSearcher,
218        haystack: PhantomData<&'h [u8]>,
219        start: *const u8,
220        end: *const u8,
221        current: *const u8,
222        mask: u64,
223    }
224
225    impl<'s, 'h> NeonIndices<'s, 'h> {
226        #[inline]
227        fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
228            let ptr = haystack.as_ptr();
229
230            Self {
231                searcher,
232                haystack: PhantomData,
233                start: ptr,
234                end: ptr.wrapping_add(haystack.len()),
235                current: ptr,
236                mask: 0,
237            }
238        }
239    }
240
241    const SSE2_STEP: usize = 16;
242
243    impl<'s, 'h> NeonIndices<'s, 'h> {
244        pub unsafe fn next(&mut self) -> Option<usize> {
245            if self.start >= self.end {
246                return None;
247            }
248
249            let mut mask = self.mask;
250            let vectorized_end = self.end.sub(SSE2_STEP);
251            let mut current = self.current;
252            let start = self.start;
253            let v1 = self.searcher.v1;
254            let v2 = self.searcher.v2;
255            let v3 = self.searcher.v3;
256
257            'main: loop {
258                // Processing current move mask
259                if mask != 0 {
260                    let offset = current.sub(SSE2_STEP).add(first_offset(mask));
261                    self.mask = clear_least_significant_bit(mask);
262                    self.current = current;
263
264                    return Some(offset.distance(start));
265                }
266
267                // Main loop of unaligned loads
268                while current <= vectorized_end {
269                    let chunk = vld1q_u8(current);
270                    let cmp1 = vceqq_u8(chunk, v1);
271                    let cmp2 = vceqq_u8(chunk, v2);
272                    let cmp3 = vceqq_u8(chunk, v3);
273                    let cmp = vorrq_u8(cmp1, cmp2);
274                    let cmp = vorrq_u8(cmp, cmp3);
275
276                    mask = neon_movemask(cmp);
277
278                    current = current.add(SSE2_STEP);
279
280                    if mask != 0 {
281                        continue 'main;
282                    }
283                }
284
285                // Processing remaining bytes linearly
286                while current < self.end {
287                    if *current == self.searcher.n1
288                        || *current == self.searcher.n2
289                        || *current == self.searcher.n3
290                    {
291                        let offset = current.distance(start);
292                        self.current = current.add(1);
293                        return Some(offset);
294                    }
295                    current = current.add(1);
296                }
297
298                return None;
299            }
300        }
301    }
302}
303
304#[derive(Debug)]
305pub struct Searcher {
306    #[cfg(target_arch = "x86_64")]
307    inner: x86_64::sse2::SSE2Searcher,
308
309    #[cfg(target_arch = "aarch64")]
310    inner: aarch64::NeonSearcher,
311
312    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
313    inner: memchr::arch::all::memchr::Three,
314}
315
316impl Searcher {
317    pub fn leveraged_simd_instructions() -> &'static str {
318        #[cfg(target_arch = "x86_64")]
319        {
320            "sse2"
321        }
322
323        #[cfg(target_arch = "aarch64")]
324        {
325            "neon"
326        }
327
328        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
329        {
330            "none"
331        }
332    }
333
334    #[inline(always)]
335    pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
336        #[cfg(target_arch = "x86_64")]
337        {
338            unsafe {
339                Self {
340                    inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
341                }
342            }
343        }
344
345        #[cfg(target_arch = "aarch64")]
346        {
347            unsafe {
348                Self {
349                    inner: aarch64::NeonSearcher::new(n1, n2, n3),
350                }
351            }
352        }
353
354        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
355        {
356            Self {
357                inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
358            }
359        }
360    }
361
362    #[inline(always)]
363    pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
364        #[cfg(target_arch = "x86_64")]
365        {
366            Indices {
367                inner: self.inner.iter(haystack),
368            }
369        }
370
371        #[cfg(target_arch = "aarch64")]
372        {
373            Indices {
374                inner: self.inner.iter(haystack),
375            }
376        }
377
378        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
379        {
380            Indices {
381                inner: self.inner.iter(haystack),
382            }
383        }
384    }
385}
386
387#[derive(Debug)]
388pub struct Indices<'s, 'h> {
389    #[cfg(target_arch = "x86_64")]
390    inner: x86_64::sse2::SSE2Indices<'s, 'h>,
391
392    #[cfg(target_arch = "aarch64")]
393    inner: aarch64::NeonIndices<'s, 'h>,
394
395    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
396    inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
397}
398
399impl<'s, 'h> Iterator for Indices<'s, 'h> {
400    type Item = usize;
401
402    #[inline(always)]
403    fn next(&mut self) -> Option<Self::Item> {
404        #[cfg(target_arch = "x86_64")]
405        {
406            unsafe { self.inner.next() }
407        }
408
409        #[cfg(target_arch = "aarch64")]
410        {
411            unsafe { self.inner.next() }
412        }
413
414        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
415        {
416            self.inner.next()
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    use memchr::arch::all::memchr::Three;
426
427    static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
428    static TEST_STRING_OFFSETS: &[usize; 18] = &[
429        4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
430    ];
431
432    #[test]
433    fn test_scalar_searcher() {
434        fn split(haystack: &[u8]) -> Vec<usize> {
435            let searcher = Three::new(b',', b'"', b'\n');
436            searcher.iter(haystack).collect()
437        }
438
439        let offsets = split(TEST_STRING);
440        assert_eq!(offsets, TEST_STRING_OFFSETS);
441
442        // Not found at all
443        assert!(split("b".repeat(75).as_bytes()).is_empty());
444
445        // Regular
446        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
447
448        // Exactly 64
449        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
450
451        // Less than 32
452        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
453
454        // Less than 16
455        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
456    }
457
458    #[test]
459    fn test_searcher() {
460        fn split(haystack: &[u8]) -> Vec<usize> {
461            let searcher = Searcher::new(b',', b'"', b'\n');
462            searcher.search(haystack).collect()
463        }
464
465        let offsets = split(TEST_STRING);
466        assert_eq!(offsets, TEST_STRING_OFFSETS);
467
468        // Not found at all
469        assert!(split("b".repeat(75).as_bytes()).is_empty());
470
471        // Regular
472        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
473
474        // Exactly 64
475        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
476
477        // Less than 32
478        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
479
480        // Less than 16
481        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
482
483        // Complex input
484        let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
485        let complex_indices = split(complex);
486
487        assert!(complex_indices
488            .iter()
489            .copied()
490            .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
491
492        assert_eq!(
493            complex_indices,
494            Three::new(b',', b'\n', b'"')
495                .iter(complex)
496                .collect::<Vec<_>>()
497        );
498    }
499}