1#![allow(clippy::result_large_err)]
2
3use regex_automata::{
4 nfa::thompson::{BuildError, State, NFA},
5 util::{look::Look, primitives::StateID},
6};
7use tinyvec::TinyVec;
8
9type SearchRange = TinyVec<[u16; 12]>;
14
15pub struct NfaIter {
22 pub(crate) regex: NFA,
24 start: StateID,
26 start_range: SearchRange,
27 depth: usize,
29 max_depth: usize,
31 stack: Vec<(StateID, SearchRange, usize, usize)>,
34 str: Vec<u8>,
36}
37
38impl From<NFA> for NfaIter {
39 fn from(nfa: NFA) -> Self {
40 let start = nfa.start_anchored();
43 let start_range = range_for(nfa.state(start));
44
45 Self {
46 regex: nfa,
47 stack: vec![(start, start_range.clone(), 0, 0)],
48 start,
49 start_range,
50 depth: 0,
51 max_depth: 0,
52 str: vec![],
53 }
54 }
55}
56
57fn range_for(s: &State) -> SearchRange {
58 match s {
59 State::ByteRange { trans } => tinyvec::tiny_vec![trans.start as u16],
60 State::Sparse(s) => s
61 .transitions
62 .iter()
63 .map(|trans| trans.start as u16)
64 .collect(),
65 State::Dense(_) => tinyvec::tiny_vec![0],
66 State::Look { .. } => tinyvec::tiny_vec![],
67 State::Union { .. } => tinyvec::tiny_vec![],
68 State::BinaryUnion { .. } => tinyvec::tiny_vec![],
69 State::Capture { .. } => tinyvec::tiny_vec![],
70 State::Fail => tinyvec::tiny_vec![],
71 State::Match { .. } => tinyvec::tiny_vec![],
72 }
73}
74
75impl NfaIter {
76 pub fn new(pattern: &str) -> Result<Self, BuildError> {
84 NFA::compiler().build(pattern).map(Self::from)
85 }
86
87 pub fn new_many<P: AsRef<str>>(patterns: &[P]) -> Result<Self, BuildError> {
95 NFA::compiler().build_many(patterns).map(Self::from)
96 }
97
98 fn range_for(&self, s: StateID) -> SearchRange {
99 range_for(self.regex.state(s))
100 }
101
102 pub fn borrow_next(&mut self) -> Option<&[u8]> {
104 loop {
105 let Some((current, range, byte_depth, depth)) = self.stack.pop() else {
106 if self.max_depth < self.depth {
108 break None;
109 }
110
111 self.depth += 1;
112 self.stack.clear();
113 self.stack.push((self.start, self.start_range.clone(), 0, 0));
114 continue;
115 };
116
117 self.max_depth = usize::max(self.max_depth, depth);
119 self.str.truncate(byte_depth);
120
121 let state = self.regex.state(current);
122
123 if depth < self.depth {
125 match state {
126 State::ByteRange { trans } => {
127 if (range[0] as u8) < trans.end {
129 self.stack.push((
130 current,
131 tinyvec::tiny_vec![range[0] + 1],
132 byte_depth,
133 depth,
134 ));
135 }
136 self.str.push(range[0] as u8);
137 self.stack.push((
138 trans.next,
139 self.range_for(trans.next),
140 byte_depth + 1,
141 depth + 1,
142 ));
143 }
144 State::Sparse(s) => {
145 for (i, &r) in range.iter().enumerate() {
146 let t = s.transitions[i];
147 if r <= t.end as u16 {
148 let mut new_range = range.clone();
150 new_range[i] += 1;
151 self.stack.push((current, new_range, byte_depth, depth));
152
153 self.str.push(r as u8);
154 self.stack.push((
156 t.next,
157 self.range_for(t.next),
158 byte_depth + 1,
159 depth + 1,
160 ));
161 break;
162 }
163 }
164 }
165 State::Dense(d) => {
166 if range[0] < 255 {
168 self.stack.push((
169 current,
170 tinyvec::tiny_vec![range[0] + 1],
171 byte_depth,
172 depth,
173 ));
174 }
175 self.str.push(range[0] as u8);
176 self.stack.push((
177 d.transitions[range[0] as usize],
178 self.range_for(d.transitions[range[0] as usize]),
179 byte_depth + 1,
180 depth + 1,
181 ));
182 }
183 State::Look { look, next } => {
184 let should = match look {
185 Look::Start if byte_depth == 0 => true,
186 Look::StartLF
187 if byte_depth == 0 || self.str[byte_depth - 1] == b'\n' =>
188 {
189 true
190 }
191 Look::StartCRLF
192 if byte_depth == 0
193 || self.str[byte_depth - 1] == b'\n'
194 || self.str[byte_depth - 1] == b'\r' =>
195 {
196 true
197 }
198 Look::End => true,
199 Look::EndLF => true,
200 Look::EndCRLF => true,
201 Look::WordAscii => todo!(),
202 Look::WordAsciiNegate => todo!(),
203 Look::WordUnicode => todo!(),
204 Look::WordUnicodeNegate => todo!(),
205 _ => false,
206 };
207 if should {
208 self.stack
209 .push((*next, self.range_for(*next), byte_depth, depth + 1));
210 }
211 }
212 State::Union { alternates } => {
213 for &alt in alternates.iter().rev() {
215 self.stack
216 .push((alt, self.range_for(alt), byte_depth, depth + 1));
217 }
218 }
219 State::BinaryUnion { alt1, alt2 } => {
220 for &alt in [alt1, alt2].into_iter().rev() {
222 self.stack
223 .push((alt, self.range_for(alt), byte_depth, depth + 1));
224 }
225 }
226 State::Capture { next, .. } => {
227 self.stack
229 .push((*next, self.range_for(*next), byte_depth, depth + 1));
230 }
231 State::Fail => {}
232 State::Match { .. } => {}
233 }
234 } else {
235 if matches!(state, State::Match { .. }) {
237 break Some(&self.str);
238 }
239 }
240 }
241 }
242}
243
244impl Iterator for NfaIter {
245 type Item = Vec<u8>;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 self.borrow_next().map(ToOwned::to_owned)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use std::collections::HashSet;
255
256 use super::*;
257
258 #[test]
259 fn set() {
260 let iter = NfaIter::new(r"b|(a)?|cc").unwrap();
261
262 let x: Vec<Vec<u8>> = iter.collect();
263 assert_eq!(
264 x,
265 [b"b".to_vec(), b"".to_vec(), b"cc".to_vec(), b"a".to_vec(),]
266 );
267 }
268
269 #[test]
270 fn finite() {
271 let nfa = NFA::new(r"[0-1]{4}-[0-1]{2}-[0-1]{2}").unwrap();
272
273 let x: HashSet<Vec<u8>> = NfaIter::from(nfa).collect();
276 assert_eq!(x.len(), 256);
277 for y in x {
278 assert_eq!(y.len(), 10);
279 }
280 }
281
282 #[test]
283 fn repeated() {
284 let nfa = NFA::new(r"a+(0|1)").unwrap();
285
286 let x: Vec<Vec<u8>> = NfaIter::from(nfa).take(20).collect();
288 let y = [
289 b"a0".to_vec(),
290 b"a1".to_vec(),
291 b"aa0".to_vec(),
292 b"aa1".to_vec(),
293 b"aaa0".to_vec(),
294 b"aaa1".to_vec(),
295 b"aaaa0".to_vec(),
296 b"aaaa1".to_vec(),
297 b"aaaaa0".to_vec(),
298 b"aaaaa1".to_vec(),
299 b"aaaaaa0".to_vec(),
300 b"aaaaaa1".to_vec(),
301 b"aaaaaaa0".to_vec(),
302 b"aaaaaaa1".to_vec(),
303 b"aaaaaaaa0".to_vec(),
304 b"aaaaaaaa1".to_vec(),
305 b"aaaaaaaaa0".to_vec(),
306 b"aaaaaaaaa1".to_vec(),
307 b"aaaaaaaaaa0".to_vec(),
308 b"aaaaaaaaaa1".to_vec(),
309 ];
310 assert_eq!(x, y);
311 }
312
313 #[test]
314 fn complex() {
315 let nfa = NFA::new(r"(a+|b+)*").unwrap();
316
317 let x: Vec<Vec<u8>> = NfaIter::from(nfa).take(13).collect();
319 let y = [
320 b"".to_vec(),
321 b"a".to_vec(),
322 b"b".to_vec(),
323 b"aa".to_vec(),
324 b"bb".to_vec(),
325 b"aaa".to_vec(),
326 b"bbb".to_vec(),
327 b"aaaa".to_vec(),
328 b"aa".to_vec(),
330 b"ab".to_vec(),
331 b"bbbb".to_vec(),
332 b"ba".to_vec(),
333 b"bb".to_vec(),
335 ];
336 assert_eq!(x, y);
337 }
338
339 #[test]
340 fn many() {
341 let search = NfaIter::new_many(&["[0-1]+", "^[a-b]+"]).unwrap();
342 let x: Vec<Vec<u8>> = search.take(12).collect();
343 let y = [
344 b"0".to_vec(),
345 b"1".to_vec(),
346 b"a".to_vec(),
347 b"b".to_vec(),
348 b"00".to_vec(),
349 b"01".to_vec(),
350 b"10".to_vec(),
351 b"11".to_vec(),
352 b"aa".to_vec(),
353 b"ab".to_vec(),
354 b"ba".to_vec(),
355 b"bb".to_vec(),
356 ];
357 assert_eq!(x, y);
358 }
359}