Skip to main content

pipa/regexp/
string_search.rs

1use crate::util::memchr::memmem;
2
3#[derive(Clone)]
4pub struct StringSearcher {
5    pattern: Vec<u8>,
6
7    skip_table: [usize; 256],
8
9    pattern_len: usize,
10}
11
12impl StringSearcher {
13    pub fn new(pattern: &str) -> Option<Self> {
14        let pattern = pattern.as_bytes();
15        let pattern_len = pattern.len();
16
17        if pattern_len == 0 {
18            return None;
19        }
20
21        if pattern_len < 3 {
22            return None;
23        }
24
25        let mut skip_table = [pattern_len; 256];
26
27        for i in 0..pattern_len - 1 {
28            skip_table[pattern[i] as usize] = pattern_len - 1 - i;
29        }
30
31        Some(Self {
32            pattern: pattern.to_vec(),
33            skip_table,
34            pattern_len,
35        })
36    }
37
38    #[inline(always)]
39    pub fn find(&self, text: &str) -> Option<usize> {
40        let text = text.as_bytes();
41        let text_len = text.len();
42        let pat = &self.pattern;
43        let pat_len = self.pattern_len;
44
45        if pat_len > text_len {
46            return None;
47        }
48
49        let mut pos = 0;
50        let max_pos = text_len - pat_len;
51
52        while pos <= max_pos {
53            let mut i = pat_len - 1;
54
55            while text[pos + i] == pat[i] {
56                if i == 0 {
57                    return Some(pos);
58                }
59                i -= 1;
60            }
61
62            let bad_char = text[pos + pat_len - 1];
63            pos += self.skip_table[bad_char as usize];
64        }
65
66        None
67    }
68
69    pub fn find_all(&self, text: &str) -> Vec<(usize, usize)> {
70        let mut matches = Vec::new();
71        let mut pos = 0;
72
73        while let Some(m) = self.find(&text[pos..]) {
74            let abs_pos = pos + m;
75            matches.push((abs_pos, abs_pos + self.pattern_len));
76            pos = abs_pos + self.pattern_len;
77
78            if pos >= text.len() {
79                break;
80            }
81        }
82
83        matches
84    }
85}
86
87pub fn fast_find(haystack: &str, needle: &str) -> Option<usize> {
88    if needle.is_empty() {
89        return Some(0);
90    }
91
92    if needle.len() > haystack.len() {
93        return None;
94    }
95
96    if needle.len() < 4 {
97        return memmem::find(haystack.as_bytes(), needle.as_bytes());
98    }
99
100    if let Some(searcher) = StringSearcher::new(needle) {
101        searcher.find(haystack)
102    } else {
103        memmem::find(haystack.as_bytes(), needle.as_bytes())
104    }
105}
106
107pub fn fast_find_with(searcher: &StringSearcher, haystack: &str) -> Option<usize> {
108    searcher.find(haystack)
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_bmh_basic() {
117        let searcher = StringSearcher::new("world").unwrap();
118        assert_eq!(searcher.find("hello world"), Some(6));
119    }
120
121    #[test]
122    fn test_bmh_not_found() {
123        let searcher = StringSearcher::new("xyz").unwrap();
124        assert_eq!(searcher.find("hello world"), None);
125    }
126
127    #[test]
128    fn test_bmh_multiple() {
129        let searcher = StringSearcher::new("abc").unwrap();
130        assert_eq!(searcher.find("abcabcabc"), Some(0));
131    }
132
133    #[test]
134    fn test_bmh_long_pattern() {
135        let text = "The quick brown fox jumps over the lazy dog. ";
136        let pattern = "jumps over";
137        let searcher = StringSearcher::new(pattern).unwrap();
138        assert_eq!(searcher.find(text), Some(20));
139    }
140
141    #[test]
142    fn test_fast_find_short() {
143        assert_eq!(fast_find("hello", "ll"), Some(2));
144        assert_eq!(fast_find("hello", "abc"), None);
145    }
146
147    #[test]
148    fn test_fast_find_long() {
149        let text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. ";
150        assert_eq!(fast_find(text, "consectetur"), Some(28));
151    }
152
153    #[test]
154    fn test_find_all() {
155        let searcher = StringSearcher::new("abc").unwrap();
156        let matches = searcher.find_all("abcxabcxabc");
157        assert_eq!(matches, vec![(0, 3), (4, 7), (8, 11)]);
158    }
159}