regex_utils/
nfa.rs

1#![allow(clippy::result_large_err)]
2
3use regex_automata::{
4    nfa::thompson::{BuildError, State, NFA},
5    util::{look::Look, primitives::StateID},
6};
7use tinyvec::TinyVec;
8
9/// For Look/Union/BinaryUnion/Capture/Fail/Match: meaningless (should be empty)
10/// For ByteRange: indicates the current byte
11/// For Sparse: indicates the current byte for each ByteRange
12/// For Dense: indicates the current byte (0..=255)
13type SearchRange = TinyVec<[u16; 12]>;
14
15/// `NfaIter` will produce every possible string value that will match with the given nfa regex.
16///
17/// # Note
18///
19/// Regexes can be infinite (eg `a*`). Either use this iterator lazily, or limit the number
20/// of iterations.
21pub struct NfaIter {
22    // the graph to search
23    pub(crate) regex: NFA,
24    // the start node of the graph
25    start: StateID,
26    start_range: SearchRange,
27    // the max depth we currently want to search
28    depth: usize,
29    // the max depth observed in the graph
30    max_depth: usize,
31    // (state, search_range, byte depth, search depth)
32    // the search_range is used differently depending on what state we are exploring
33    stack: Vec<(StateID, SearchRange, usize, usize)>,
34    // the current path
35    str: Vec<u8>,
36}
37
38impl From<NFA> for NfaIter {
39    fn from(nfa: NFA) -> Self {
40        // anchored because if we didn't anchor our search we would have an infinite amount of prefixes that were valid
41        // and that isn't very interesting
42        let start = nfa.start_anchored();
43        let start_range = range_for(nfa.state(start));
44
45        Self {
46            regex: nfa,
47            stack: vec![(start, start_range.clone(), 0, 0)],
48            start,
49            start_range,
50            depth: 0,
51            max_depth: 0,
52            str: vec![],
53        }
54    }
55}
56
57fn range_for(s: &State) -> SearchRange {
58    match s {
59        State::ByteRange { trans } => tinyvec::tiny_vec![trans.start as u16],
60        State::Sparse(s) => s
61            .transitions
62            .iter()
63            .map(|trans| trans.start as u16)
64            .collect(),
65        State::Dense(_) => tinyvec::tiny_vec![0],
66        State::Look { .. } => tinyvec::tiny_vec![],
67        State::Union { .. } => tinyvec::tiny_vec![],
68        State::BinaryUnion { .. } => tinyvec::tiny_vec![],
69        State::Capture { .. } => tinyvec::tiny_vec![],
70        State::Fail => tinyvec::tiny_vec![],
71        State::Match { .. } => tinyvec::tiny_vec![],
72    }
73}
74
75impl NfaIter {
76    /// Parse the given regular expression using a default configuration and
77    /// return the corresponding `NfaIter`.
78    ///
79    /// If you want a non-default configuration, then use the
80    /// [`thompson::Compiler`](regex_automata::nfa::thompson::Compiler) to set your own configuration.
81    ///
82    /// See [`NFA`] for details
83    pub fn new(pattern: &str) -> Result<Self, BuildError> {
84        NFA::compiler().build(pattern).map(Self::from)
85    }
86
87    /// Parse the given regular expressions using a default configuration and
88    /// return the corresponding multi-`NfaIter`.
89    ///
90    /// If you want a non-default configuration, then use the
91    /// [`thompson::Compiler`](regex_automata::nfa::thompson::Compiler) to set your own configuration.
92    ///
93    /// See [`NFA`] for details
94    pub fn new_many<P: AsRef<str>>(patterns: &[P]) -> Result<Self, BuildError> {
95        NFA::compiler().build_many(patterns).map(Self::from)
96    }
97
98    fn range_for(&self, s: StateID) -> SearchRange {
99        range_for(self.regex.state(s))
100    }
101
102    /// Get the next matching string ref from this regex iterator
103    pub fn borrow_next(&mut self) -> Option<&[u8]> {
104        loop {
105            let Some((current, range, byte_depth, depth)) = self.stack.pop() else {
106                // we didn't get any deeper. no more search space
107                if self.max_depth < self.depth {
108                    break None;
109                }
110
111                self.depth += 1;
112                self.stack.clear();
113                self.stack.push((self.start, self.start_range.clone(), 0, 0));
114                continue;
115            };
116
117            // update recorded max depth
118            self.max_depth = usize::max(self.max_depth, depth);
119            self.str.truncate(byte_depth);
120
121            let state = self.regex.state(current);
122
123            // check we can explore deeper
124            if depth < self.depth {
125                match state {
126                    State::ByteRange { trans } => {
127                        // make sure we revisit this state
128                        if (range[0] as u8) < trans.end {
129                            self.stack.push((
130                                current,
131                                tinyvec::tiny_vec![range[0] + 1],
132                                byte_depth,
133                                depth,
134                            ));
135                        }
136                        self.str.push(range[0] as u8);
137                        self.stack.push((
138                            trans.next,
139                            self.range_for(trans.next),
140                            byte_depth + 1,
141                            depth + 1,
142                        ));
143                    }
144                    State::Sparse(s) => {
145                        for (i, &r) in range.iter().enumerate() {
146                            let t = s.transitions[i];
147                            if r <= t.end as u16 {
148                                // make sure we revisit this state
149                                let mut new_range = range.clone();
150                                new_range[i] += 1;
151                                self.stack.push((current, new_range, byte_depth, depth));
152
153                                self.str.push(r as u8);
154                                // add the new state
155                                self.stack.push((
156                                    t.next,
157                                    self.range_for(t.next),
158                                    byte_depth + 1,
159                                    depth + 1,
160                                ));
161                                break;
162                            }
163                        }
164                    }
165                    State::Dense(d) => {
166                        // make sure we revisit this state
167                        if range[0] < 255 {
168                            self.stack.push((
169                                current,
170                                tinyvec::tiny_vec![range[0] + 1],
171                                byte_depth,
172                                depth,
173                            ));
174                        }
175                        self.str.push(range[0] as u8);
176                        self.stack.push((
177                            d.transitions[range[0] as usize],
178                            self.range_for(d.transitions[range[0] as usize]),
179                            byte_depth + 1,
180                            depth + 1,
181                        ));
182                    }
183                    State::Look { look, next } => {
184                        let should = match look {
185                            Look::Start if byte_depth == 0 => true,
186                            Look::StartLF
187                                if byte_depth == 0 || self.str[byte_depth - 1] == b'\n' =>
188                            {
189                                true
190                            }
191                            Look::StartCRLF
192                                if byte_depth == 0
193                                    || self.str[byte_depth - 1] == b'\n'
194                                    || self.str[byte_depth - 1] == b'\r' =>
195                            {
196                                true
197                            }
198                            Look::End => true,
199                            Look::EndLF => true,
200                            Look::EndCRLF => true,
201                            Look::WordAscii => todo!(),
202                            Look::WordAsciiNegate => todo!(),
203                            Look::WordUnicode => todo!(),
204                            Look::WordUnicodeNegate => todo!(),
205                            _ => false,
206                        };
207                        if should {
208                            self.stack
209                                .push((*next, self.range_for(*next), byte_depth, depth + 1));
210                        }
211                    }
212                    State::Union { alternates } => {
213                        // same byte_depth because we matched no bytes
214                        for &alt in alternates.iter().rev() {
215                            self.stack
216                                .push((alt, self.range_for(alt), byte_depth, depth + 1));
217                        }
218                    }
219                    State::BinaryUnion { alt1, alt2 } => {
220                        // same byte_depth because we matched no bytes
221                        for &alt in [alt1, alt2].into_iter().rev() {
222                            self.stack
223                                .push((alt, self.range_for(alt), byte_depth, depth + 1));
224                        }
225                    }
226                    State::Capture { next, .. } => {
227                        // same byte_depth because we matched no bytes
228                        self.stack
229                            .push((*next, self.range_for(*next), byte_depth, depth + 1));
230                    }
231                    State::Fail => {}
232                    State::Match { .. } => {}
233                }
234            } else {
235                // test that this state is final
236                if matches!(state, State::Match { .. }) {
237                    break Some(&self.str);
238                }
239            }
240        }
241    }
242}
243
244impl Iterator for NfaIter {
245    type Item = Vec<u8>;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        self.borrow_next().map(ToOwned::to_owned)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use std::collections::HashSet;
255
256    use super::*;
257
258    #[test]
259    fn set() {
260        let iter = NfaIter::new(r"b|(a)?|cc").unwrap();
261
262        let x: Vec<Vec<u8>> = iter.collect();
263        assert_eq!(
264            x,
265            [b"b".to_vec(), b"".to_vec(), b"cc".to_vec(), b"a".to_vec(),]
266        );
267    }
268
269    #[test]
270    fn finite() {
271        let nfa = NFA::new(r"[0-1]{4}-[0-1]{2}-[0-1]{2}").unwrap();
272
273        // finite regex has finite iteration depth
274        // and no repeats
275        let x: HashSet<Vec<u8>> = NfaIter::from(nfa).collect();
276        assert_eq!(x.len(), 256);
277        for y in x {
278            assert_eq!(y.len(), 10);
279        }
280    }
281
282    #[test]
283    fn repeated() {
284        let nfa = NFA::new(r"a+(0|1)").unwrap();
285
286        // infinite regex iterates over all cases
287        let x: Vec<Vec<u8>> = NfaIter::from(nfa).take(20).collect();
288        let y = [
289            b"a0".to_vec(),
290            b"a1".to_vec(),
291            b"aa0".to_vec(),
292            b"aa1".to_vec(),
293            b"aaa0".to_vec(),
294            b"aaa1".to_vec(),
295            b"aaaa0".to_vec(),
296            b"aaaa1".to_vec(),
297            b"aaaaa0".to_vec(),
298            b"aaaaa1".to_vec(),
299            b"aaaaaa0".to_vec(),
300            b"aaaaaa1".to_vec(),
301            b"aaaaaaa0".to_vec(),
302            b"aaaaaaa1".to_vec(),
303            b"aaaaaaaa0".to_vec(),
304            b"aaaaaaaa1".to_vec(),
305            b"aaaaaaaaa0".to_vec(),
306            b"aaaaaaaaa1".to_vec(),
307            b"aaaaaaaaaa0".to_vec(),
308            b"aaaaaaaaaa1".to_vec(),
309        ];
310        assert_eq!(x, y);
311    }
312
313    #[test]
314    fn complex() {
315        let nfa = NFA::new(r"(a+|b+)*").unwrap();
316
317        // infinite regex iterates over all cases
318        let x: Vec<Vec<u8>> = NfaIter::from(nfa).take(13).collect();
319        let y = [
320            b"".to_vec(),
321            b"a".to_vec(),
322            b"b".to_vec(),
323            b"aa".to_vec(),
324            b"bb".to_vec(),
325            b"aaa".to_vec(),
326            b"bbb".to_vec(),
327            b"aaaa".to_vec(),
328            // technically a different path
329            b"aa".to_vec(),
330            b"ab".to_vec(),
331            b"bbbb".to_vec(),
332            b"ba".to_vec(),
333            // technically a different path
334            b"bb".to_vec(),
335        ];
336        assert_eq!(x, y);
337    }
338
339    #[test]
340    fn many() {
341        let search = NfaIter::new_many(&["[0-1]+", "^[a-b]+"]).unwrap();
342        let x: Vec<Vec<u8>> = search.take(12).collect();
343        let y = [
344            b"0".to_vec(),
345            b"1".to_vec(),
346            b"a".to_vec(),
347            b"b".to_vec(),
348            b"00".to_vec(),
349            b"01".to_vec(),
350            b"10".to_vec(),
351            b"11".to_vec(),
352            b"aa".to_vec(),
353            b"ab".to_vec(),
354            b"ba".to_vec(),
355            b"bb".to_vec(),
356        ];
357        assert_eq!(x, y);
358    }
359}