1use std::fmt::Debug;
2
3use crate::HashSet;
4use anyhow::Result;
5
6use crate::{
7 ast::{ExprRef, ExprSet, NextByte},
8 bytecompress::ByteCompressor,
9 deriv::DerivCache,
10 hashcons::VecHashCons,
11 nextbyte::NextByteCache,
12 pp::PrettyPrinter,
13 relevance::RelevanceCache,
14};
15
16const DEBUG: bool = false;
17
18macro_rules! debug {
19 ($($arg:tt)*) => {
20 if DEBUG {
21 eprintln!($($arg)*);
22 }
23 };
24}
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash)]
27pub struct StateID(u32);
28
29impl StateID {
30 pub const DEAD: StateID = StateID::new(0);
32 pub const MISSING: StateID = StateID::new(1);
34
35 pub fn as_usize(&self) -> usize {
36 (self.0 >> 1) as usize
37 }
38
39 pub fn as_u32(&self) -> u32 {
40 self.0 >> 1
41 }
42
43 pub fn is_valid(&self) -> bool {
44 *self != Self::MISSING
45 }
46
47 #[inline(always)]
48 pub fn is_dead(&self) -> bool {
49 *self == Self::DEAD
50 }
51
52 #[inline(always)]
53 pub fn has_lowest_match(&self) -> bool {
54 (self.0 & 1) == 1
55 }
56
57 pub fn _set_lowest_match(self) -> Self {
58 Self(self.0 | 1)
59 }
60
61 pub const fn new(id: u32) -> Self {
62 Self(id << 1)
63 }
64
65 pub fn new_hash_cons() -> VecHashCons {
66 let mut rx_sets = VecHashCons::new();
67 let id = rx_sets.insert(&[]);
68 assert!(id == StateID::DEAD.as_u32());
69 let id = rx_sets.insert(&[ExprRef::INVALID.as_u32()]);
70 assert!(id == StateID::MISSING.as_u32());
71 rx_sets
72 }
73}
74
75impl Debug for StateID {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 if *self == StateID::DEAD {
78 write!(f, "StateID(DEAD)")
79 } else if *self == StateID::MISSING {
80 write!(f, "StateID(MISSING)")
81 } else {
82 write!(f, "StateID({},{})", self.0 >> 1, self.0 & 1)
83 }
84 }
85}
86
87#[derive(Clone)]
88pub struct AlphabetInfo {
89 mapping: [u8; 256],
90 size: usize,
91}
92
93#[derive(Clone)]
94pub struct Regex {
95 exprs: ExprSet,
96 deriv: DerivCache,
97 next_byte: NextByteCache,
98 relevance: RelevanceCache,
99 alpha: AlphabetInfo,
100 initial: StateID,
101 rx_sets: VecHashCons,
102 state_table: Vec<StateID>,
103 state_descs: Vec<StateDesc>,
104 num_transitions: usize,
105 num_ast_nodes: usize,
106 max_states: usize,
107}
108
109#[derive(Clone, Debug, Default)]
110struct StateDesc {
111 lookahead_len: Option<Option<usize>>,
112 next_byte: Option<NextByte>,
113}
114
115impl Regex {
117 pub fn new(rx: &str) -> Result<Self> {
118 let parser = regex_syntax::ParserBuilder::new().build();
119 Self::new_with_parser(parser, rx)
120 }
121
122 pub fn new_with_parser(parser: regex_syntax::Parser, rx: &str) -> Result<Self> {
123 let mut exprset = ExprSet::new(256);
124 let rx = exprset.parse_expr(parser.clone(), rx, false)?;
125 Self::new_with_exprset(exprset, rx, u64::MAX)
126 }
127
128 pub fn alpha(&self) -> &AlphabetInfo {
129 &self.alpha
130 }
131
132 pub fn initial_state(&mut self) -> StateID {
133 self.initial
134 }
135
136 pub fn always_empty(&mut self) -> bool {
137 self.initial_state().is_dead()
138 }
139
140 pub fn is_accepting(&mut self, state: StateID) -> bool {
141 self.lookahead_len_for_state(state).is_some()
142 }
143
144 fn resolve(rx_sets: &VecHashCons, state: StateID) -> ExprRef {
145 ExprRef::new(rx_sets.get(state.as_u32())[0])
146 }
147
148 pub fn lookahead_len_for_state(&mut self, state: StateID) -> Option<usize> {
149 if state == StateID::DEAD || state == StateID::MISSING {
150 return None;
151 }
152 let desc = &mut self.state_descs[state.as_usize()];
153 if let Some(len) = desc.lookahead_len {
154 return len;
155 }
156 let expr = Self::resolve(&self.rx_sets, state);
157 let mut res = None;
158 if self.exprs.is_nullable(expr) {
159 res = Some(self.exprs.lookahead_len(expr).unwrap_or(0));
160 }
161 desc.lookahead_len = Some(res);
162 res
163 }
164
165 #[inline(always)]
166 pub fn transition(&mut self, state: StateID, b: u8) -> StateID {
167 let idx = self.alpha.map_state(state, b);
168 let new_state = self.state_table[idx];
169 if new_state != StateID::MISSING {
170 new_state
171 } else {
172 let new_state = self.transition_inner(state, b);
173 self.num_transitions += 1;
174 self.state_table[idx] = new_state;
175 new_state
176 }
177 }
178
179 pub fn transition_bytes(&mut self, state: StateID, bytes: &[u8]) -> StateID {
180 let mut state = state;
181 for &b in bytes {
182 state = self.transition(state, b);
183 }
184 state
185 }
186
187 pub fn is_match(&mut self, text: &str) -> bool {
188 self.lookahead_len(text).is_some()
189 }
190
191 pub fn is_match_bytes(&mut self, text: &[u8]) -> bool {
192 self.lookahead_len_bytes(text).is_some()
193 }
194
195 pub fn lookahead_len_bytes(&mut self, text: &[u8]) -> Option<usize> {
196 let mut state = self.initial_state();
197 for b in text {
198 let b = *b;
199 let new_state = self.transition(state, b);
200 debug!("b: {:?} --{:?}--> {:?}", state, b as char, new_state);
201 state = new_state;
202 if state == StateID::DEAD {
203 return None;
204 }
205 }
206 self.lookahead_len_for_state(state)
207 }
208
209 pub fn lookahead_len(&mut self, text: &str) -> Option<usize> {
210 self.lookahead_len_bytes(text.as_bytes())
211 }
212
213 pub fn num_bytes(&self) -> usize {
215 self.exprs.num_bytes()
216 + self.deriv.num_bytes()
217 + self.next_byte.num_bytes()
218 + self.state_descs.len() * 100
219 + self.state_table.len() * std::mem::size_of::<StateID>()
220 + self.rx_sets.num_bytes()
221 }
222
223 pub fn cost(&self) -> u64 {
224 self.exprs.cost()
225 }
226
227 pub fn next_byte(&mut self, state: StateID) -> NextByte {
230 if state == StateID::DEAD || state == StateID::MISSING {
231 return NextByte::Dead;
232 }
233
234 let desc = &mut self.state_descs[state.as_usize()];
235 if let Some(next_byte) = desc.next_byte {
236 return next_byte;
237 }
238
239 let e = Self::resolve(&self.rx_sets, state);
240 let next_byte = self.next_byte.next_byte(&self.exprs, e);
241 desc.next_byte = Some(next_byte);
242 next_byte
243 }
244
245 pub fn stats(&self) -> String {
246 format!(
247 "regexp: {} nodes (+ {} derived via {} derivatives), states: {}; transitions: {}; bytes: {}; alphabet size: {}",
248 self.num_ast_nodes,
249 self.exprs.len() - self.num_ast_nodes,
250 self.deriv.num_deriv,
251 self.state_descs.len(),
252 self.num_transitions,
253 self.num_bytes(),
254 self.alpha.len(),
255 )
256 }
257
258 pub fn dfa(&mut self) -> Vec<u8> {
259 let mut used = HashSet::default();
260 let mut designated_bytes = vec![];
261 for b in 0..=255 {
262 let m = self.alpha.map(b);
263 if !used.contains(&m) {
264 used.insert(m);
265 designated_bytes.push(b);
266 }
267 }
268
269 let mut stack = vec![self.initial_state()];
270 let mut visited = HashSet::default();
271 while let Some(state) = stack.pop() {
272 for b in &designated_bytes {
273 let new_state = self.transition(state, *b);
274 if !visited.contains(&new_state) {
275 stack.push(new_state);
276 visited.insert(new_state);
277 assert!(visited.len() < 250);
278 }
279 }
280 }
281
282 assert!(!self.state_table.contains(&StateID::MISSING));
283 let mut res = self.alpha.mapping.to_vec();
284 res.extend(self.state_table.iter().map(|s| s.as_u32() as u8));
285 res
286 }
287
288 pub fn print_state_table(&self) {
289 for (state, row) in self.state_table.chunks(self.alpha.len()).enumerate() {
290 println!("state: {}", state);
291 for (b, &new_state) in row.iter().enumerate() {
292 println!(" s{:?} -> {:?}", b, new_state);
293 }
294 }
295 }
296}
297
298impl AlphabetInfo {
299 pub fn from_exprset(exprset: ExprSet, rx_list: &[ExprRef]) -> (Self, ExprSet, Vec<ExprRef>) {
300 assert!(exprset.alphabet_size == 256);
301
302 debug!("rx0: {}", exprset.expr_to_string_with_info(rx_list[0]));
303
304 let ((mut exprset, rx_list), mapping, alphabet_size) = if cfg!(feature = "compress") {
305 let mut compressor = ByteCompressor::new();
306 let cost0 = exprset.cost;
307 let (mut exprset, rx_list) = compressor.compress(exprset, rx_list);
308 exprset.cost += cost0;
309 exprset.set_pp(PrettyPrinter::new(
310 compressor.mapping.clone(),
311 compressor.alphabet_size,
312 ));
313 (
314 (exprset, rx_list),
315 compressor.mapping,
316 compressor.alphabet_size,
317 )
318 } else {
319 let alphabet_size = exprset.alphabet_size;
320 (
321 (exprset, rx_list.to_vec()),
322 (0..=255).collect(),
323 alphabet_size,
324 )
325 };
326
327 exprset.disable_optimizations();
329
330 debug!(
331 "compressed: {}",
332 exprset.expr_to_string_with_info(rx_list[0])
333 );
334
335 let alpha = AlphabetInfo {
336 mapping: mapping.try_into().unwrap(),
337 size: alphabet_size,
338 };
339 (alpha, exprset, rx_list.to_vec())
340 }
341
342 #[inline(always)]
343 pub fn map(&self, b: u8) -> usize {
344 if cfg!(feature = "compress") {
345 self.mapping[b as usize] as usize
346 } else {
347 b as usize
348 }
349 }
350
351 #[inline(always)]
352 pub fn map_state(&self, state: StateID, b: u8) -> usize {
353 if cfg!(feature = "compress") {
354 self.map(b) + state.as_usize() * self.len()
355 } else {
356 b as usize + state.as_usize() * 256
357 }
358 }
359
360 #[inline(always)]
361 pub fn len(&self) -> usize {
362 self.size
363 }
364
365 #[inline(always)]
366 pub fn is_empty(&self) -> bool {
367 self.size == 0
368 }
369
370 pub fn has_error(&self) -> bool {
371 self.size == 0
372 }
373
374 pub fn enter_error_state(&mut self) {
375 self.size = 0;
376 }
377}
378
379impl Regex {
381 pub fn is_contained_in_prefixes(
382 exprset: ExprSet,
383 small: ExprRef,
384 big: ExprRef,
385 relevance_fuel: u64,
386 ) -> Result<bool> {
387 let (mut slf, rxes) = Self::prep_regex(exprset, &[small, big]);
388 let small = rxes[0];
389 let big = rxes[1];
390
391 slf.relevance.is_contained_in_prefixes(
392 &mut slf.exprs,
393 &mut slf.deriv,
394 small,
395 big,
396 relevance_fuel,
397 false,
398 )
399 }
400
401 fn prep_regex(exprset: ExprSet, top_rxs: &[ExprRef]) -> (Self, Vec<ExprRef>) {
402 let (alpha, exprset, rx_list) = AlphabetInfo::from_exprset(exprset, top_rxs);
403 let num_ast_nodes = exprset.len();
404
405 let rx_sets = StateID::new_hash_cons();
406
407 let mut slf = Regex {
408 deriv: DerivCache::new(),
409 next_byte: NextByteCache::new(),
410 relevance: RelevanceCache::new(),
411 exprs: exprset,
412 alpha,
413 rx_sets,
414 state_table: vec![],
415 state_descs: vec![],
416 num_transitions: 0,
417 num_ast_nodes,
418 initial: StateID::MISSING,
419 max_states: usize::MAX,
420 };
421
422 let desc = StateDesc {
423 lookahead_len: Some(None),
424 next_byte: Some(NextByte::Dead),
425 };
426
427 slf.append_state(desc.clone());
429 slf.append_state(desc);
431 slf.state_table.fill(StateID::DEAD);
433 assert!(!slf.alpha.is_empty());
434
435 (slf, rx_list)
436 }
437
438 pub(crate) fn new_with_exprset(
439 exprset: ExprSet,
440 top_rx: ExprRef,
441 relevance_fuel: u64,
442 ) -> Result<Self> {
443 let (mut r, top_rx) = Self::prep_regex(exprset, &[top_rx]);
444 let top_rx = top_rx[0];
445
446 if r.relevance
447 .is_non_empty_limited(&mut r.exprs, top_rx, relevance_fuel)?
448 {
449 r.initial = r.insert_state(top_rx);
450 } else {
451 r.initial = StateID::DEAD;
452 }
453
454 Ok(r)
455 }
456
457 fn append_state(&mut self, state_desc: StateDesc) {
458 let mut new_states = vec![StateID::MISSING; self.alpha.len()];
459 self.state_table.append(&mut new_states);
460 self.state_descs.push(state_desc);
461 if self.state_descs.len() >= self.max_states {
462 self.alpha.enter_error_state();
463 }
464 }
465
466 fn insert_state(&mut self, d: ExprRef) -> StateID {
467 let id = StateID::new(self.rx_sets.insert(&[d.as_u32()]));
468 if id.as_usize() >= self.state_descs.len() {
469 self.append_state(StateDesc::default());
470 }
471 id
472 }
473
474 fn transition_inner(&mut self, state: StateID, b: u8) -> StateID {
475 assert!(state.is_valid());
476
477 let e = Self::resolve(&self.rx_sets, state);
478 let d = self.deriv.derivative(&mut self.exprs, e, b);
479 if d == ExprRef::NO_MATCH {
480 StateID::DEAD
481 } else if self.relevance.is_non_empty(&mut self.exprs, d) {
482 self.insert_state(d)
483 } else {
484 StateID::DEAD
485 }
486 }
487}
488
489impl Debug for Regex {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 write!(f, "Regex({})", self.stats())
492 }
493}