twoway/
pcmp.rs

1//! SSE4.2 (pcmpestri) accelerated substring search
2//!
3//! Using the two way substring search algorithm.
4// wssm word size string matching<br>
5// wslm word size lexicographical maximum suffix
6//
7
8#![allow(dead_code)]
9
10extern crate unchecked_index;
11extern crate memchr;
12
13use std::cmp;
14use std::mem;
15use std::iter::Zip;
16
17use self::unchecked_index::get_unchecked;
18
19use TwoWaySearcher;
20
21fn zip<I, J>(i: I, j: J) -> Zip<I::IntoIter, J::IntoIter>
22    where I: IntoIterator,
23          J: IntoIterator
24{
25    i.into_iter().zip(j)
26}
27
28#[cfg(target_arch = "x86")]
29use std::arch::x86::*;
30
31#[cfg(target_arch = "x86_64")]
32use std::arch::x86_64::*;
33
34/// `pcmpestri`
35///
36/// “Packed compare explicit length strings (return index)”
37///
38/// PCMPESTRI xmm1, xmm2/m128, imm8
39///
40/// Return value: least index for start of (partial) match, (16 if no match).
41///
42/// Mask: `text` can be at at any point in valid memory, as long as `text_len`
43/// bytes are readable.
44#[target_feature(enable = "sse4.2")]
45unsafe fn pcmpestri_16_mask(text: *const u8, offset: usize, text_len: usize,
46                       needle: __m128i, needle_len: usize) -> u32 {
47    //debug_assert!(text_len + offset <= text.len()); // saturates at 16
48    //debug_assert!(needle_len <= 16); // saturates at 16
49    let text = mask_load(text.offset(offset as _) as *const _, text_len);
50    _mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
51}
52
53/// `pcmpestri`
54///
55/// “Packed compare explicit length strings (return index)”
56///
57/// PCMPESTRI xmm1, xmm2/m128, imm8
58///
59/// Return value: least index for start of (partial) match, (16 if no match).
60///
61/// No mask: `text` must be at least 16 bytes from the end of a memory region.
62#[target_feature(enable = "sse4.2")]
63unsafe fn pcmpestri_16_nomask(text: *const u8, offset: usize, text_len: usize,
64                       needle: __m128i, needle_len: usize) -> u32 {
65    //debug_assert!(text_len + offset <= text.len()); // saturates at 16
66    //debug_assert!(needle_len <= 16); // saturates at 16
67    let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
68    _mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
69}
70
71/// `pcmpestrm`
72///
73/// “Packed compare explicit length strings (return mask)”
74///
75/// PCMPESTRM xmm1, xmm2/m128, imm8
76///
77/// Return value: bitmask in the 16 lsb of the return value.
78#[target_feature(enable = "sse4.2")]
79unsafe fn pcmpestrm_eq_each(text: *const u8, offset: usize, text_len: usize,
80                            needle: *const u8, noffset: usize, needle_len: usize) -> u64 {
81    // NOTE: text *must* be readable for 16 bytes
82    // NOTE: needle *must* be readable for 16 bytes
83    //debug_assert!(text_len + offset <= text.len()); // saturates at 16
84    //debug_assert!(needle_len <= 16); // saturates at 16
85    let needle = _mm_loadu_si128(needle.offset(noffset as _) as *const _);
86    let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
87    let mask = _mm_cmpestrm(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_EACH);
88
89    #[cfg(target_arch = "x86")] {
90        _mm_extract_epi32(mask, 0) as u64 | (_mm_extract_epi32(mask, 1) as (u64) << 32)
91    }
92
93    #[cfg(target_arch = "x86_64")] {
94        _mm_extract_epi64(mask, 0) as _
95    }
96}
97
98
99/// Search for first possible match of `pat` -- might be just a byte
100/// Return `(pos, length)` length of match
101#[cfg(test)]
102fn first_start_of_match(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
103    // not safe for text that is non aligned and ends at page boundary
104    let patl = pat.len();
105    assert!(patl <= 16);
106    unsafe { first_start_of_match_mask(text, pat.len(), pat128(pat)) }
107}
108
109/// Safe wrapper around pcmpestri to find first match of `p` in `text`.
110/// safe to search unaligned for first start of match
111///
112/// the end of text an be close (within 16 bytes) of a page boundary
113#[target_feature(enable = "sse4.2")]
114unsafe fn first_start_of_match_mask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
115    let tp = text.as_ptr();
116    debug_assert!(pat_len <= 16);
117
118    let mut offset = 0;
119
120    while text.len() >= offset + pat_len {
121        let tlen = text.len() - offset;
122        let ret = pcmpestri_16_mask(tp, offset, tlen, p, pat_len) as usize;
123        if ret == 16 {
124            offset += 16;
125        } else {
126            let match_len = cmp::min(pat_len, 16 - ret);
127            return Some((offset + ret, match_len));
128        }
129    }
130
131    None
132}
133
134
135/// Safe wrapper around pcmpestri to find first match of `p` in `text`.
136/// safe to search unaligned for first start of match
137///
138/// unsafe because the end of text must not be close (within 16 bytes) of a page boundary
139#[target_feature(enable = "sse4.2")]
140unsafe fn first_start_of_match_nomask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
141    let tp = text.as_ptr();
142    debug_assert!(pat_len <= 16);
143    debug_assert!(pat_len <= text.len());
144
145    let mut offset = 0;
146
147    while text.len() - pat_len >= offset {
148        let tlen = text.len() - offset;
149        let ret = pcmpestri_16_nomask(tp, offset, tlen, p, pat_len) as usize;
150        if ret == 16 {
151            offset += 16;
152        } else {
153            let match_len = cmp::min(pat_len, 16 - ret);
154            return Some((offset + ret, match_len));
155        }
156    }
157
158    None
159}
160
161#[test]
162fn test_first_start_of_match() {
163    let text = b"abc";
164    let longer = "longer text and so on";
165    assert_eq!(first_start_of_match(text, b"d"), None);
166    assert_eq!(first_start_of_match(text, b"c"), Some((2, 1)));
167    assert_eq!(first_start_of_match(text, b"abc"), Some((0, 3)));
168    assert_eq!(first_start_of_match(text, b"T"), None);
169    assert_eq!(first_start_of_match(text, b"\0text"), None);
170    assert_eq!(first_start_of_match(text, b"\0"), None);
171
172    // test all windows
173    for wsz in 1..17 {
174        for window in longer.as_bytes().windows(wsz) {
175            let str_find = longer.find(::std::str::from_utf8(window).unwrap());
176            assert!(str_find.is_some());
177            let first_start = first_start_of_match(longer.as_bytes(), window);
178            assert!(first_start.is_some());
179            let (pos, len) = first_start.unwrap();
180            assert!(len <= wsz);
181            assert!(len == wsz && Some(pos) == str_find
182                    || pos <= str_find.unwrap());
183        }
184    }
185}
186
187fn find_2byte_pat(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
188    debug_assert!(text.len() >= pat.len());
189    debug_assert!(pat.len() == 2);
190    // Search for the second byte of the pattern, not the first, better for
191    // scripts where we have two-byte encoded codepoints (the first byte will
192    // repeat much more often than the second).
193    let mut off = 1;
194    while let Some(i) = memchr::memchr(pat[1], &text[off..]) {
195        match text.get(off + i - 1) {
196            None => break,
197            Some(&c) if c == pat[0] => return Some((off + i - 1, off + i + 1)),
198            _ => off += i + 1,
199        }
200
201    }
202    None
203}
204
205/// Simd text search optimized for short patterns (<= 8 bytes)
206#[target_feature(enable = "sse4.2")]
207unsafe fn find_short_pat(text: &[u8], pat: &[u8]) -> Option<usize> {
208    debug_assert!(pat.len() <= 8);
209    /*
210    if pat.len() == 2 {
211        return find_2byte_pat(text, pat);
212    }
213    */
214    let r = pat128(pat);
215
216    // safe part of text -- everything but the last 16 bytes
217    let safetext = &text[..cmp::max(text.len(), 16) - 16];
218
219    let mut pos = 0;
220    'search: loop {
221        if pos + pat.len() > safetext.len() {
222            break;
223        }
224        // find the next occurence
225        match first_start_of_match_nomask(&safetext[pos..], pat.len(), r) {
226            None => {
227                pos = cmp::max(pos, safetext.len() - pat.len());
228                break // no matches
229            }
230            Some((mpos, mlen)) => {
231                pos += mpos;
232                if mlen < pat.len() {
233                    if pos > text.len() - pat.len() {
234                        return None;
235                    }
236                    for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
237                        if a != b {
238                            pos += 1;
239                            continue 'search;
240                        }
241                    }
242                }
243
244                return Some(pos);
245            }
246        }
247    }
248
249    'tail: loop {
250        if pos > text.len() - pat.len() {
251            return None;
252        }
253        // find the next occurence
254        match first_start_of_match_mask(&text[pos..], pat.len(), r) {
255            None => return None, // no matches
256            Some((mpos, mlen)) => {
257                pos += mpos;
258                if mlen < pat.len() {
259                    if pos > text.len() - pat.len() {
260                        return None;
261                    }
262                    for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
263                        if a != b {
264                            pos += 1;
265                            continue 'tail;
266                        }
267                    }
268                }
269
270                return Some(pos);
271            }
272        }
273    }
274}
275
276/// `is_supported` checks whether necessary SSE 4.2 feature is supported on current CPU.
277pub fn is_supported() -> bool {
278    #[cfg(feature = "use_std")]
279    return is_x86_feature_detected!("sse4.2");
280    #[cfg(not(feature = "use_std"))]
281    return cfg!(target_feature = "sse4.2");
282}
283
284/// `find` finds the first ocurrence of `pattern` in the `text`.
285///
286/// This is the SSE42 accelerated version.
287pub fn find(text: &[u8], pattern: &[u8]) -> Option<usize> {
288    assert!(is_supported());
289
290    if pattern.is_empty() {
291        return Some(0);
292    } else if text.len() < pattern.len() {
293        return None;
294    } else if pattern.len() == 1 {
295        return memchr::memchr(pattern[0], text);
296    } else {
297        unsafe { find_inner(text, pattern) }
298    }
299}
300
301#[target_feature(enable = "sse4.2")]
302pub(crate) unsafe fn find_inner(text: &[u8], pat: &[u8]) -> Option<usize> {
303    if pat.len() <= 6 {
304        return find_short_pat(text, pat);
305    }
306
307    // real two way algorithm
308    //
309
310    // `memory` is the number of bytes of the left half that we already know
311    let (crit_pos, mut period) = TwoWaySearcher::crit_params(pat);
312    let mut memory;
313
314    if &pat[..crit_pos] == &pat[period.. period + crit_pos] {
315        memory = 0; // use memory
316    } else {
317        memory = !0; // !0 means memory is unused
318        // approximation to the true period
319        period = cmp::max(crit_pos, pat.len() - crit_pos) + 1;
320    }
321
322    //println!("pat: {:?}, crit={}, period={}", pat, crit_pos, period);
323    let (left, right) = pat.split_at(crit_pos);
324    let (right16, _right17) = right.split_at(cmp::min(16, right.len()));
325    assert!(right.len() != 0);
326
327    let r = pat128(right);
328
329    // safe part of text -- everything but the last 16 bytes
330    let safetext = &text[..cmp::max(text.len(), 16) - 16];
331
332    let mut pos = 0;
333    if memory == !0 {
334        // Long period case -- no memory, period is an approximation
335        'search: loop {
336            if pos + pat.len() > safetext.len() {
337                break;
338            }
339            // find the next occurence of the right half
340            let start = crit_pos;
341            match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
342                None => {
343                    pos = cmp::max(pos, safetext.len() - pat.len());
344                    break // no matches
345                }
346                Some((mpos, mlen)) => {
347                    pos += mpos;
348                    let mut pfxlen = mlen;
349                    if pfxlen < right.len() {
350                        pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
351                    }
352                    if pfxlen != right.len() {
353                        // partial match
354                        // skip by the number of bytes matched
355                        pos += pfxlen + 1;
356                        continue 'search;
357                    } else {
358                        // matches right part
359                    }
360                }
361            }
362
363            // See if the left part of the needle matches
364            // XXX: Original algorithm compares from right to left here
365            if left != &text[pos..pos + left.len()] {
366                pos += period;
367                continue 'search;
368            }
369
370            return Some(pos);
371        }
372    } else {
373        // Short period case -- use memory, true period
374        'search_memory: loop {
375            if pos + pat.len() > safetext.len() {
376                break;
377            }
378            // find the next occurence of the right half
379            //println!("memory trace pos={}, memory={}", pos, memory);
380            let mut pfxlen = if memory == 0 {
381                let start = crit_pos;
382                match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
383                    None => {
384                        pos = cmp::max(pos, safetext.len() - pat.len());
385                        break // no matches
386                    }
387                    Some((mpos, mlen)) => {
388                        pos += mpos;
389                        mlen
390                    }
391                }
392            } else {
393                memory - crit_pos
394            };
395            if pfxlen < right.len() {
396                pfxlen += shared_prefix_inner(&text[pos + crit_pos + pfxlen..], &right[pfxlen..]);
397            }
398            if pfxlen != right.len() {
399                // partial match
400                // skip by the number of bytes matched
401                pos += pfxlen + 1;
402                memory = 0;
403                continue 'search_memory;
404            } else {
405                // matches right part
406            }
407
408            // See if the left part of the needle matches
409            // XXX: Original algorithm compares from right to left here
410            if memory <= left.len() && &left[memory..] != &text[pos + memory..pos + left.len()] {
411                pos += period;
412                memory = pat.len() - period;
413                continue 'search_memory;
414            }
415
416            return Some(pos);
417        }
418    }
419
420    // no memory used for final part
421    'tail: loop {
422        if pos > text.len() - pat.len() {
423            return None;
424        }
425        // find the next occurence of the right half
426        let start = crit_pos;
427        match first_start_of_match_mask(&text[pos + start..], right16.len(), r) {
428            None => return None,
429            Some((mpos, mlen)) => {
430                pos += mpos;
431                let mut pfxlen = mlen;
432                if pfxlen < right.len() {
433                    pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
434                }
435                if pfxlen != right.len() {
436                    // partial match
437                    // skip by the number of bytes matched
438                    pos += pfxlen + 1;
439                    continue 'tail;
440
441                } else {
442                    // matches right part
443                }
444            }
445        }
446
447        // See if the left part of the needle matches
448        // XXX: Original algorithm compares from right to left here
449        if left != &text[pos..pos + left.len()] {
450            pos += period;
451            continue 'tail;
452        }
453
454        return Some(pos);
455    }
456}
457
458#[test]
459fn test_find() {
460    let text = b"abc";
461    assert_eq!(find(text, b"d"), None);
462    assert_eq!(find(text, b"c"), Some(2));
463
464    let longer = "longer text and so on, a bit more";
465
466    // test all windows
467    for wsz in 1..longer.len() {
468        for window in longer.as_bytes().windows(wsz) {
469            let str_find = longer.find(::std::str::from_utf8(window).unwrap());
470            assert!(str_find.is_some());
471            assert_eq!(find(longer.as_bytes(), window), str_find, "{:?} {:?}",
472                       longer, ::std::str::from_utf8(window));
473        }
474    }
475
476    let pat = b"ger text and so on";
477    assert!(pat.len() > 16);
478    assert_eq!(Some(3), find(longer.as_bytes(), pat));
479
480    // test short period case
481
482    let text = "cbabababcbabababab";
483    let n = "abababab";
484    assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
485
486    // memoized case -- this is tricky
487    let text = "cbababababababababababababababab";
488    let n = "abababab";
489    assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
490
491}
492
493/// Load the first 16 bytes of `pat` into a SIMD vector.
494#[inline(always)]
495fn pat128(pat: &[u8]) -> __m128i {
496    unsafe {
497        mask_load(pat.as_ptr() as *const _, pat.len())
498    }
499}
500
501/// Load the first len bytes (maximum 16) from ptr into a vector, safely
502#[inline(always)]
503unsafe fn mask_load(ptr: *const u8, mut len: usize) -> __m128i {
504    let mut data: __m128i = _mm_setzero_si128();
505    len = cmp::min(len, mem::size_of_val(&data));
506
507    ::std::ptr::copy_nonoverlapping(ptr, &mut data as *mut _ as _, len);
508    return data;
509}
510
511/// Find longest shared prefix, return its length
512///
513/// Alignment safe: works for any text, pat.
514pub fn shared_prefix(text: &[u8], pat: &[u8]) -> usize {
515    assert!(is_supported());
516
517    unsafe { shared_prefix_inner(text, pat) }
518}
519
520#[target_feature(enable = "sse4.2")]
521unsafe fn shared_prefix_inner(text: &[u8], pat: &[u8]) -> usize {
522    let tp = text.as_ptr();
523    let tlen = text.len();
524    let pp = pat.as_ptr();
525    let plen = pat.len();
526    let len = cmp::min(tlen, plen);
527
528    // TODO: do non-aligned prefix manually too(?) aligned text or pat..
529    // all but the end we can process with pcmpestrm
530    let initial_part = len.saturating_sub(16);
531    let mut prefix_len = 0;
532    let mut offset = 0;
533    while offset < initial_part {
534        let initial_tail = initial_part - offset;
535        let mask = pcmpestrm_eq_each(tp, offset, initial_tail, pp, offset, initial_tail);
536        // find zero in the first 16 bits
537        if mask != 0xffff {
538            let first_bit_set = (mask ^ 0xffff).trailing_zeros() as usize;
539            prefix_len += first_bit_set;
540            return prefix_len;
541        } else {
542            prefix_len += cmp::min(initial_tail, 16);
543        }
544        offset += 16;
545    }
546    // so one block left, the last (up to) 16 bytes
547    // unchecked slicing .. we don't want panics in this function
548    let text_suffix = get_unchecked(text, prefix_len..len);
549    let pat_suffix = get_unchecked(pat, prefix_len..len);
550    for (&a, &b) in zip(text_suffix, pat_suffix) {
551        if a != b {
552            break;
553        }
554        prefix_len += 1;
555    }
556
557    prefix_len
558}
559
560#[test]
561fn test_prefixlen() {
562    let text_long  = b"0123456789abcdefeffect";
563    let text_long2 = b"9123456789abcdefeffect";
564    let text_long3 = b"0123456789abcdefgffect";
565    let plen = shared_prefix(text_long, text_long);
566    assert_eq!(plen, text_long.len());
567    let plen = shared_prefix(b"abcd", b"abc");
568    assert_eq!(plen, 3);
569    let plen = shared_prefix(b"abcd", b"abcf");
570    assert_eq!(plen, 3);
571    assert_eq!(0, shared_prefix(text_long, text_long2));
572    assert_eq!(0, shared_prefix(text_long, &text_long[1..]));
573    assert_eq!(16, shared_prefix(text_long, text_long3));
574
575    for i in 0..text_long.len() + 1 {
576        assert_eq!(text_long.len() - i, shared_prefix(&text_long[i..], &text_long[i..]));
577    }
578
579    let l1 = [7u8; 1024];
580    let mut l2 = [7u8; 1024];
581    let off = 1000;
582    l2[off] = 0;
583    for i in 0..off {
584        let plen = shared_prefix(&l1[i..], &l2[i..]);
585        assert_eq!(plen, off - i);
586    }
587}