sigmah/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2// Enabling SIMD feature means including portable_simd and AVX512(BW), this enables the mask register access for smaller sizes
3#![cfg_attr(feature = "simd", feature(portable_simd, avx512_target_feature))]
4// Unfortunately, we only need a subset of generic_const_exprs which the minimal implementation should suffice
5#![allow(incomplete_features)]
6#![feature(generic_const_exprs)]
7
8use crate::multiversion::{equal_then_find_second_position_simple, match_naive};
9use bitvec::prelude::*;
10use core::mem::transmute_copy;
11
12#[cfg(all(feature = "rayon", feature = "simd"))]
13use {
14    arrayvec::ArrayVec,
15    rayon::{
16        iter::{IndexedParallelIterator, IntoParallelIterator},
17        prelude::*,
18    },
19};
20
21#[cfg(feature = "simd")]
22use {
23    crate::multiversion::simd::{
24        equal_then_find_second_position_simd, match_simd_core, match_simd_select_core,
25    },
26    core::simd::{LaneCount, SupportedLaneCount},
27};
28
29use crate::utils::{pad_zeroes_slice_unchecked, simd::SimdBits};
30
31#[derive(Debug, Copy, Clone)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33#[repr(transparent)]
34pub struct SignatureMask<const N: usize>(BitArray<[u8; N.div_ceil(u8::BITS as usize)]>)
35where
36    [(); N.div_ceil(u8::BITS as usize)]:;
37
38impl<const N: usize> SignatureMask<N>
39where
40    [(); N.div_ceil(u8::BITS as usize)]:,
41{
42    pub const MAX: Self = Self({
43        let mut arr: BitArray<[u8; N.div_ceil(u8::BITS as usize)]> = BitArray::ZERO;
44        let mut i = 0;
45        while i < N {
46            const BITS: usize = u8::BITS as usize;
47            arr.data[i / BITS] |= 1 << (i % BITS);
48            i += 1;
49        }
50        arr
51    });
52
53    #[inline(always)]
54    pub fn is_exact(&self) -> bool {
55        let mut i = 0;
56        while i < N {
57            if !unsafe { *self.0.get_unchecked(i) } {
58                return false;
59            }
60            i += 1;
61        }
62        true
63    }
64}
65
66impl<const N: usize> SignatureMask<N>
67where
68    [(); N.div_ceil(u8::BITS as usize)]:,
69{
70    #[inline(always)]
71    pub const fn from_bool_array(pattern: [bool; N]) -> Self {
72        Self::from_bool_slice(&pattern)
73    }
74
75    #[inline(always)]
76    pub const fn from_bool_slice(pattern: &[bool; N]) -> Self {
77        Self(Self::from_bool_slice_to_bitarr(pattern))
78    }
79
80    #[inline(always)]
81    pub const fn from_bool_array_to_bitarr(
82        pattern: [bool; N],
83    ) -> BitArray<[u8; N.div_ceil(u8::BITS as usize)]> {
84        Self::from_bool_slice_to_bitarr(&pattern)
85    }
86
87    #[inline(always)]
88    pub const fn from_bool_slice_to_bitarr(
89        pattern: &[bool; N],
90    ) -> BitArray<[u8; N.div_ceil(u8::BITS as usize)]> {
91        let mut arr: BitArray<[u8; N.div_ceil(u8::BITS as usize)]> = BitArray::ZERO;
92        let mut i = 0;
93        while i < pattern.len() {
94            if pattern[i] {
95                const BITS: usize = u8::BITS as usize;
96                arr.data[i / BITS] |= 1 << (i % BITS);
97            }
98            i += 1;
99        }
100        arr
101    }
102
103    #[inline(always)]
104    pub const fn to_bool_array(&self) -> [bool; N] {
105        let mut arr: [bool; N] = [false; N];
106        let mut i = 0;
107        while i < N {
108            const BITS: usize = u8::BITS as usize;
109            let bit = 1 << (i % BITS);
110            arr[i] = (self.0.data[i / BITS] & bit) == bit;
111            i += 1;
112        }
113        arr
114    }
115}
116
117impl<const N: usize> SignatureMask<N>
118where
119    [(); N.div_ceil(u8::BITS as usize)]:,
120{
121    #[inline(always)]
122    pub const fn from_byte_array(pattern: [u8; N]) -> Self {
123        Self::from_byte_slice(&pattern)
124    }
125
126    #[inline(always)]
127    pub const fn from_byte_slice(pattern: &[u8; N]) -> Self {
128        match Self::try_from_byte_slice_to_bitarr(pattern) {
129            Ok(x) => Self(x),
130            Err(e) => panic!("{}", e),
131        }
132    }
133
134    #[inline(always)]
135    pub const fn try_from_byte_array_to_bitarr(
136        pattern: [u8; N],
137    ) -> Result<BitArray<[u8; N.div_ceil(u8::BITS as usize)]>, &'static str> {
138        Self::try_from_byte_slice_to_bitarr(&pattern)
139    }
140
141    #[inline(always)]
142    pub const fn try_from_byte_slice_to_bitarr(
143        pattern: &[u8; N],
144    ) -> Result<BitArray<[u8; N.div_ceil(u8::BITS as usize)]>, &'static str> {
145        let mut pattern_bool: [bool; N] = [false; N];
146        let mut i = 0;
147        while i < pattern.len() {
148            pattern_bool[i] = match pattern[i] {
149                b'x' => true,
150                b'?' => false,
151                _ => return Err("unknown character in pattern"),
152            };
153            i += 1;
154        }
155        Ok(Self::from_bool_slice_to_bitarr(&pattern_bool))
156    }
157
158    #[inline(always)]
159    pub const fn to_byte_array(&self) -> [u8; N] {
160        let mut arr: [u8; N] = [b'?'; N];
161        let mut i = 0;
162        while i < N {
163            const BITS: usize = u8::BITS as usize;
164            let bit = 1 << (i % BITS);
165            arr[i] = if (self.0.data[i / BITS] & bit) == bit {
166                b'x'
167            } else {
168                b'?'
169            };
170            i += 1;
171        }
172        arr
173    }
174}
175
176#[derive(Debug, Copy, Clone)]
177#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
178#[repr(C, align(1))]
179pub struct Signature<const N: usize>
180where
181    [(); N.div_ceil(u8::BITS as usize)]:,
182{
183    #[cfg_attr(feature = "serde", serde(with = "serde_big_array::BigArray"))]
184    pub pattern: [u8; N],
185    pub mask: SignatureMask<N>,
186}
187
188impl<const N: usize> Signature<N>
189where
190    [(); N.div_ceil(u8::BITS as usize)]:,
191{
192    #[inline(always)]
193    pub const fn from_pattern_mask(pattern: [u8; N], mask: [u8; N]) -> Self {
194        Self {
195            pattern,
196            mask: SignatureMask::from_byte_array(mask),
197        }
198    }
199
200    // Notice we cannot use From<([u8; N], [u8; N])> because it will break const guarantee
201    #[inline(always)]
202    pub const fn from_pattern_mask_tuple((pattern, mask): ([u8; N], [u8; N])) -> Self {
203        Self::from_pattern_mask(pattern, mask)
204    }
205
206    #[inline(always)]
207    pub const fn from_option_array(needle: [Option<u8>; N]) -> Self {
208        Self::from_option_slice(&needle)
209    }
210
211    #[inline(always)]
212    pub const fn from_option_slice(needle: &[Option<u8>; N]) -> Self {
213        unsafe { Self::from_option_slice_unchecked(needle) }
214    }
215
216    #[inline(always)]
217    pub const fn from_array_with_exact_match_mask(pattern: [u8; N]) -> Self {
218        Self {
219            pattern,
220            mask: SignatureMask::MAX,
221        }
222    }
223
224    #[inline(always)]
225    pub const fn from_slice_with_exact_match_mask(pattern: &[u8; N]) -> Self {
226        Self::from_array_with_exact_match_mask(*pattern)
227    }
228
229    #[inline(always)]
230    pub const unsafe fn from_option_slice_unchecked(needle: &[Option<u8>]) -> Self {
231        let mut pattern: [u8; N] = [0; N];
232        let mut mask: [bool; N] = [false; N];
233        let mut i = 0;
234        while i < needle.len() {
235            if let Some(x) = needle[i] {
236                pattern[i] = x;
237                mask[i] = true;
238            }
239            i += 1;
240        }
241        Self {
242            pattern,
243            mask: SignatureMask::from_bool_array(mask),
244        }
245    }
246}
247
248impl<const N: usize> Signature<N>
249where
250    [(); N.div_ceil(u8::BITS as usize)]:,
251    [(); N.div_ceil(u64::LANES)]:,
252    [(); N.div_ceil(u32::LANES)]:,
253    [(); N.div_ceil(u16::LANES)]:,
254    [(); N.div_ceil(u8::LANES)]:,
255{
256    pub fn scan<'a>(&self, haystack: &'a [u8]) -> Option<&'a [u8]> {
257        self.scan_inner(haystack, |chunk| self.match_best_effort(chunk))
258    }
259
260    pub fn scan_naive<'a>(&self, haystack: &'a [u8]) -> Option<&'a [u8]> {
261        self.scan_inner(haystack, |chunk| {
262            match_naive(chunk, &self.pattern, &self.mask.0)
263        })
264    }
265
266    #[inline]
267    fn scan_inner<'a>(
268        &self,
269        mut haystack: &'a [u8],
270        f: impl Fn(&[u8; N]) -> bool,
271    ) -> Option<&'a [u8]> {
272        let exact_match = self.mask.is_exact();
273        while !haystack.is_empty() {
274            let haystack_smaller_than_n = haystack.len() < N;
275
276            let window: &[u8; N] = unsafe {
277                if haystack_smaller_than_n {
278                    &pad_zeroes_slice_unchecked::<N>(haystack)
279                } else {
280                    transmute_copy(&haystack)
281                }
282            };
283
284            if f(window) {
285                return Some(haystack);
286            } else if exact_match && haystack_smaller_than_n {
287                // If we are having the mask to match for all, and the chunk is actually smaller than N, we are cooked anyway
288                return None;
289            }
290
291            // Since we are using a sliding window approach, we are safe to determine that we can either:
292            //   1. Skip to the first position of c for all c in window[1..] where c == window[0]
293            //   2. Skip this entire window
294            // The optimization is derived from the Z-Algorithm which constructs an array Z,
295            // where Z[i] represents the length of the longest substring starting from i which is also a prefix of the string.
296            // More formally, given first Z[0] is tombstone, then for i in 1..N:
297            //   Z[i] is the length of the longest substring starting at i that matches the prefix of S (i.e. memcmp(S[0:], S[i:])).
298            // Then we further simplify that to find the first position where Z[i] != 0, it to use the fact that if Z[i] > 0, it has to be a prefix of our pattern,
299            // so it is a potential search point. If all that is in the Z box are 0, then we are safe to assume all patterns are unique and need one-by-one brute force.
300            // Technically speaking, if we repeat this process to each shift of the window with respect to its mask position, we can obtain the Z-box algorithm as well.
301            // It is speculated that we can redefine the special window[0] prefix to a value of "w" and index "i" for any c for all i, c in window[1..] where i == first(for i, m in mask[1..] where m == true),
302            // and then do skip to the "i"th position of c for all c in window[1..] where c == w. For now I'm too lazy to investigate whether the proof is correct.
303            //
304            // If in SIMD manner, we can first take the first character, splat it to vector width and match it with the haystack window after first element,
305            // then do find-first-set and add 1 to cover for the real next position. It is always assumed the scanner will always go at least 1 byte ahead
306            let move_position =
307                if unsafe { *self.mask.0.get_unchecked(0) } && !haystack_smaller_than_n {
308                    self.equal_then_find_second_position(
309                        unsafe { *self.pattern.get_unchecked(0) },
310                        window,
311                    )
312                    .unwrap_or(N)
313                } else {
314                    1
315                };
316            haystack = unsafe { haystack.get_unchecked(move_position..) };
317        }
318        None
319    }
320
321    fn equal_then_find_second_position(&self, first: u8, window: &[u8; N]) -> Option<usize> {
322        #[cfg(feature = "simd")]
323        {
324            if N >= 64 {
325                equal_then_find_second_position_simd::<u64, N>(first, window)
326            } else if N >= 32 {
327                equal_then_find_second_position_simd::<u32, N>(first, window)
328            } else if N >= 16 {
329                equal_then_find_second_position_simd::<u16, N>(first, window)
330            } else if N >= 8 {
331                equal_then_find_second_position_simd::<u8, N>(first, window)
332            } else {
333                // for the lulz
334                equal_then_find_second_position_simple(first, window)
335            }
336        }
337
338        #[cfg(not(feature = "simd"))]
339        {
340            equal_then_find_second_position_simple(first, window)
341        }
342    }
343}
344
345impl<const N: usize> Signature<N>
346where
347    [(); N.div_ceil(u8::BITS as usize)]:,
348    [(); N.div_ceil(u64::LANES)]:,
349    [(); N.div_ceil(u32::LANES)]:,
350    [(); N.div_ceil(u16::LANES)]:,
351    [(); N.div_ceil(u8::LANES)]:,
352{
353    #[inline(always)]
354    pub fn match_best_effort(&self, chunk: &[u8; N]) -> bool {
355        #[cfg(feature = "simd")]
356        {
357            if N >= 64 {
358                self.match_simd::<u64>(chunk)
359            } else if N >= 32 {
360                self.match_simd::<u32>(chunk)
361            } else if N >= 16 {
362                self.match_simd::<u16>(chunk)
363            } else if N >= 8 {
364                self.match_simd::<u8>(chunk)
365            } else {
366                // for the lulz
367                self.match_naive(chunk)
368            }
369        }
370
371        #[cfg(not(feature = "simd"))]
372        {
373            self.match_naive(chunk)
374        }
375    }
376
377    #[inline(always)]
378    pub fn match_naive(&self, chunk: &[u8; N]) -> bool {
379        match_naive(chunk, &self.pattern, &self.mask.0)
380    }
381}
382
383#[cfg(feature = "simd")]
384impl<const N: usize> Signature<N>
385where
386    [(); N.div_ceil(u8::BITS as usize)]:,
387    [(); N.div_ceil(u64::LANES)]:,
388    [(); N.div_ceil(u32::LANES)]:,
389    [(); N.div_ceil(u16::LANES)]:,
390    [(); N.div_ceil(u8::LANES)]:,
391{
392    #[inline(always)]
393    pub fn scan_simd<'a, T: SimdBits>(&self, haystack: &'a [u8]) -> Option<&'a [u8]>
394    where
395        LaneCount<{ T::LANES }>: SupportedLaneCount,
396        [(); N.div_ceil(T::LANES)]:,
397    {
398        let f = |chunk: &[u8; N]| self.match_simd(chunk);
399        self.scan_inner(haystack, f)
400    }
401
402    #[inline(always)]
403    pub fn scan_simd_select<'a, T: SimdBits>(&self, haystack: &'a [u8]) -> Option<&'a [u8]>
404    where
405        LaneCount<{ T::LANES }>: SupportedLaneCount,
406        [(); N.div_ceil(T::LANES)]:,
407    {
408        let f = |chunk: &[u8; N]| self.match_simd_select(chunk);
409        self.scan_inner(haystack, f)
410    }
411
412    #[inline(always)]
413    pub fn match_simd<T: SimdBits>(&self, chunk: &[u8; N]) -> bool
414    where
415        LaneCount<{ T::LANES }>: SupportedLaneCount,
416        [(); N.div_ceil(T::LANES)]:,
417    {
418        #[cfg(feature = "rayon")]
419        {
420            self.match_simd_rayon_inner(chunk, match_simd_core)
421        }
422
423        #[cfg(not(feature = "rayon"))]
424        {
425            self.match_simd_simple_inner(chunk, match_simd_core)
426        }
427    }
428
429    #[inline(always)]
430    pub fn match_simd_select<T: SimdBits>(&self, chunk: &[u8; N]) -> bool
431    where
432        LaneCount<{ T::LANES }>: SupportedLaneCount,
433        [(); N.div_ceil(T::LANES)]:,
434    {
435        #[cfg(feature = "rayon")]
436        {
437            self.match_simd_rayon_inner(chunk, match_simd_select_core)
438        }
439
440        #[cfg(not(feature = "rayon"))]
441        {
442            self.match_simd_simple_inner(chunk, match_simd_select_core)
443        }
444    }
445
446    #[inline(always)]
447    pub fn match_simd_simple_inner<T: SimdBits>(
448        &self,
449        chunk: &[u8; N],
450        f: impl Fn(&[u8; T::LANES], &[u8; T::LANES], u64) -> bool + Sync,
451    ) -> bool
452    where
453        [(); T::LANES]:,
454    {
455        chunk
456            .chunks(T::LANES)
457            .zip(self.pattern.chunks(T::LANES))
458            .zip(
459                self.mask
460                    .0
461                    .chunks(T::LANES)
462                    .map(|mask| mask.iter_ones().fold(T::ZERO, |acc, x| acc | (T::ONE << x))),
463            )
464            .all(|((haystack, pattern), ref mask)| {
465                let haystack: &[u8; T::LANES] = unsafe {
466                    if haystack.len() < T::LANES {
467                        &pad_zeroes_slice_unchecked::<{ T::LANES }>(haystack)
468                    } else {
469                        transmute_copy(&haystack)
470                    }
471                };
472
473                let pattern: &[u8; T::LANES] = unsafe {
474                    if pattern.len() < T::LANES {
475                        &pad_zeroes_slice_unchecked::<{ T::LANES }>(pattern)
476                    } else {
477                        transmute_copy(&pattern)
478                    }
479                };
480                f(haystack, pattern, mask.to_u64())
481            })
482    }
483}
484
485#[cfg(all(feature = "simd", feature = "rayon"))]
486impl<const N: usize> Signature<N>
487where
488    [(); N.div_ceil(u8::BITS as usize)]:,
489{
490    #[inline(always)]
491    pub fn match_simd_rayon<T: SimdBits>(&self, chunk: &[u8; N]) -> bool
492    where
493        LaneCount<{ T::LANES }>: SupportedLaneCount,
494        [(); N.div_ceil(T::LANES)]:,
495    {
496        self.match_simd_rayon_inner(chunk, match_simd_core)
497    }
498
499    #[inline(always)]
500    pub fn match_simd_rayon_inner<T: SimdBits>(
501        &self,
502        chunk: &[u8; N],
503        f: impl Fn(&[u8; T::LANES], &[u8; T::LANES], u64) -> bool + Sync,
504    ) -> bool
505    where
506        [(); N.div_ceil(T::LANES)]:,
507    {
508        let chunks: ArrayVec<&[u8], { N.div_ceil(T::LANES) }> = chunk.chunks(T::LANES).collect();
509        let patterns: ArrayVec<&[u8], { N.div_ceil(T::LANES) }> =
510            self.pattern.chunks(T::LANES).collect();
511        let masks = self
512            .mask
513            .0
514            .chunks(T::LANES)
515            .map(|mask| mask.iter_ones().fold(T::ZERO, |acc, x| acc | (T::ONE << x)))
516            .collect::<ArrayVec<T, { N.div_ceil(T::LANES) }>>();
517
518        chunks
519            .into_par_iter()
520            .zip(patterns.into_par_iter())
521            .zip(masks.into_par_iter())
522            .all(|((&haystack, &pattern), mask)| {
523                let haystack: &[u8; T::LANES] = unsafe {
524                    if haystack.len() < T::LANES {
525                        &pad_zeroes_slice_unchecked::<{ T::LANES }>(haystack)
526                    } else {
527                        transmute_copy(&haystack)
528                    }
529                };
530
531                let pattern: &[u8; T::LANES] = unsafe {
532                    if pattern.len() < T::LANES {
533                        &pad_zeroes_slice_unchecked::<{ T::LANES }>(pattern)
534                    } else {
535                        transmute_copy(&pattern)
536                    }
537                };
538                f(haystack, pattern, mask.to_u64())
539            })
540    }
541}
542
543pub(crate) mod multiversion;
544pub(crate) mod utils;