derivre/
regex.rs

1use std::fmt::Debug;
2
3use crate::HashSet;
4use anyhow::Result;
5
6use crate::{
7    ast::{ExprRef, ExprSet, NextByte},
8    bytecompress::ByteCompressor,
9    deriv::DerivCache,
10    hashcons::VecHashCons,
11    nextbyte::NextByteCache,
12    pp::PrettyPrinter,
13    relevance::RelevanceCache,
14};
15
16const DEBUG: bool = false;
17
18macro_rules! debug {
19    ($($arg:tt)*) => {
20        if DEBUG {
21            eprintln!($($arg)*);
22        }
23    };
24}
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash)]
27pub struct StateID(u32);
28
29impl StateID {
30    // DEAD state corresponds to empty vector
31    pub const DEAD: StateID = StateID::new(0);
32    // MISSING state corresponds to yet not computed entries in the state table
33    pub const MISSING: StateID = StateID::new(1);
34
35    pub fn as_usize(&self) -> usize {
36        (self.0 >> 1) as usize
37    }
38
39    pub fn as_u32(&self) -> u32 {
40        self.0 >> 1
41    }
42
43    pub fn is_valid(&self) -> bool {
44        *self != Self::MISSING
45    }
46
47    #[inline(always)]
48    pub fn is_dead(&self) -> bool {
49        *self == Self::DEAD
50    }
51
52    #[inline(always)]
53    pub fn has_lowest_match(&self) -> bool {
54        (self.0 & 1) == 1
55    }
56
57    pub fn _set_lowest_match(self) -> Self {
58        Self(self.0 | 1)
59    }
60
61    pub const fn new(id: u32) -> Self {
62        Self(id << 1)
63    }
64
65    pub fn new_hash_cons() -> VecHashCons {
66        let mut rx_sets = VecHashCons::new();
67        let id = rx_sets.insert(&[]);
68        assert!(id == StateID::DEAD.as_u32());
69        let id = rx_sets.insert(&[ExprRef::INVALID.as_u32()]);
70        assert!(id == StateID::MISSING.as_u32());
71        rx_sets
72    }
73}
74
75impl Debug for StateID {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        if *self == StateID::DEAD {
78            write!(f, "StateID(DEAD)")
79        } else if *self == StateID::MISSING {
80            write!(f, "StateID(MISSING)")
81        } else {
82            write!(f, "StateID({},{})", self.0 >> 1, self.0 & 1)
83        }
84    }
85}
86
87#[derive(Clone)]
88pub struct AlphabetInfo {
89    mapping: [u8; 256],
90    size: usize,
91}
92
93#[derive(Clone)]
94pub struct Regex {
95    exprs: ExprSet,
96    deriv: DerivCache,
97    next_byte: NextByteCache,
98    relevance: RelevanceCache,
99    alpha: AlphabetInfo,
100    initial: StateID,
101    rx_sets: VecHashCons,
102    state_table: Vec<StateID>,
103    state_descs: Vec<StateDesc>,
104    num_transitions: usize,
105    num_ast_nodes: usize,
106    max_states: usize,
107}
108
109#[derive(Clone, Debug, Default)]
110struct StateDesc {
111    lookahead_len: Option<Option<usize>>,
112    next_byte: Option<NextByte>,
113}
114
115// public implementation
116impl Regex {
117    pub fn new(rx: &str) -> Result<Self> {
118        let parser = regex_syntax::ParserBuilder::new().build();
119        Self::new_with_parser(parser, rx)
120    }
121
122    pub fn new_with_parser(parser: regex_syntax::Parser, rx: &str) -> Result<Self> {
123        let mut exprset = ExprSet::new(256);
124        let rx = exprset.parse_expr(parser.clone(), rx, false)?;
125        Self::new_with_exprset(exprset, rx, u64::MAX)
126    }
127
128    pub fn alpha(&self) -> &AlphabetInfo {
129        &self.alpha
130    }
131
132    pub fn initial_state(&mut self) -> StateID {
133        self.initial
134    }
135
136    pub fn always_empty(&mut self) -> bool {
137        self.initial_state().is_dead()
138    }
139
140    pub fn is_accepting(&mut self, state: StateID) -> bool {
141        self.lookahead_len_for_state(state).is_some()
142    }
143
144    fn resolve(rx_sets: &VecHashCons, state: StateID) -> ExprRef {
145        ExprRef::new(rx_sets.get(state.as_u32())[0])
146    }
147
148    pub fn lookahead_len_for_state(&mut self, state: StateID) -> Option<usize> {
149        if state == StateID::DEAD || state == StateID::MISSING {
150            return None;
151        }
152        let desc = &mut self.state_descs[state.as_usize()];
153        if let Some(len) = desc.lookahead_len {
154            return len;
155        }
156        let expr = Self::resolve(&self.rx_sets, state);
157        let mut res = None;
158        if self.exprs.is_nullable(expr) {
159            res = Some(self.exprs.lookahead_len(expr).unwrap_or(0));
160        }
161        desc.lookahead_len = Some(res);
162        res
163    }
164
165    #[inline(always)]
166    pub fn transition(&mut self, state: StateID, b: u8) -> StateID {
167        let idx = self.alpha.map_state(state, b);
168        let new_state = self.state_table[idx];
169        if new_state != StateID::MISSING {
170            new_state
171        } else {
172            let new_state = self.transition_inner(state, b);
173            self.num_transitions += 1;
174            self.state_table[idx] = new_state;
175            new_state
176        }
177    }
178
179    pub fn transition_bytes(&mut self, state: StateID, bytes: &[u8]) -> StateID {
180        let mut state = state;
181        for &b in bytes {
182            state = self.transition(state, b);
183        }
184        state
185    }
186
187    pub fn is_match(&mut self, text: &str) -> bool {
188        self.lookahead_len(text).is_some()
189    }
190
191    pub fn is_match_bytes(&mut self, text: &[u8]) -> bool {
192        self.lookahead_len_bytes(text).is_some()
193    }
194
195    pub fn lookahead_len_bytes(&mut self, text: &[u8]) -> Option<usize> {
196        let mut state = self.initial_state();
197        for b in text {
198            let b = *b;
199            let new_state = self.transition(state, b);
200            debug!("b: {:?} --{:?}--> {:?}", state, b as char, new_state);
201            state = new_state;
202            if state == StateID::DEAD {
203                return None;
204            }
205        }
206        self.lookahead_len_for_state(state)
207    }
208
209    pub fn lookahead_len(&mut self, text: &str) -> Option<usize> {
210        self.lookahead_len_bytes(text.as_bytes())
211    }
212
213    /// Estimate the size of the regex tables in bytes.
214    pub fn num_bytes(&self) -> usize {
215        self.exprs.num_bytes()
216            + self.deriv.num_bytes()
217            + self.next_byte.num_bytes()
218            + self.state_descs.len() * 100
219            + self.state_table.len() * std::mem::size_of::<StateID>()
220            + self.rx_sets.num_bytes()
221    }
222
223    pub fn cost(&self) -> u64 {
224        self.exprs.cost()
225    }
226
227    /// Check if the there is only one transition out of state.
228    /// This is an approximation - see docs for NextByte.
229    pub fn next_byte(&mut self, state: StateID) -> NextByte {
230        if state == StateID::DEAD || state == StateID::MISSING {
231            return NextByte::Dead;
232        }
233
234        let desc = &mut self.state_descs[state.as_usize()];
235        if let Some(next_byte) = desc.next_byte {
236            return next_byte;
237        }
238
239        let e = Self::resolve(&self.rx_sets, state);
240        let next_byte = self.next_byte.next_byte(&self.exprs, e);
241        desc.next_byte = Some(next_byte);
242        next_byte
243    }
244
245    pub fn stats(&self) -> String {
246        format!(
247            "regexp: {} nodes (+ {} derived via {} derivatives), states: {}; transitions: {}; bytes: {}; alphabet size: {}",
248            self.num_ast_nodes,
249            self.exprs.len() - self.num_ast_nodes,
250            self.deriv.num_deriv,
251            self.state_descs.len(),
252            self.num_transitions,
253            self.num_bytes(),
254            self.alpha.len(),
255        )
256    }
257
258    pub fn dfa(&mut self) -> Vec<u8> {
259        let mut used = HashSet::default();
260        let mut designated_bytes = vec![];
261        for b in 0..=255 {
262            let m = self.alpha.map(b);
263            if !used.contains(&m) {
264                used.insert(m);
265                designated_bytes.push(b);
266            }
267        }
268
269        let mut stack = vec![self.initial_state()];
270        let mut visited = HashSet::default();
271        while let Some(state) = stack.pop() {
272            for b in &designated_bytes {
273                let new_state = self.transition(state, *b);
274                if !visited.contains(&new_state) {
275                    stack.push(new_state);
276                    visited.insert(new_state);
277                    assert!(visited.len() < 250);
278                }
279            }
280        }
281
282        assert!(!self.state_table.contains(&StateID::MISSING));
283        let mut res = self.alpha.mapping.to_vec();
284        res.extend(self.state_table.iter().map(|s| s.as_u32() as u8));
285        res
286    }
287
288    pub fn print_state_table(&self) {
289        for (state, row) in self.state_table.chunks(self.alpha.len()).enumerate() {
290            println!("state: {}", state);
291            for (b, &new_state) in row.iter().enumerate() {
292                println!("  s{:?} -> {:?}", b, new_state);
293            }
294        }
295    }
296}
297
298impl AlphabetInfo {
299    pub fn from_exprset(exprset: ExprSet, rx_list: &[ExprRef]) -> (Self, ExprSet, Vec<ExprRef>) {
300        assert!(exprset.alphabet_size == 256);
301
302        debug!("rx0: {}", exprset.expr_to_string_with_info(rx_list[0]));
303
304        let ((mut exprset, rx_list), mapping, alphabet_size) = if cfg!(feature = "compress") {
305            let mut compressor = ByteCompressor::new();
306            let cost0 = exprset.cost;
307            let (mut exprset, rx_list) = compressor.compress(exprset, rx_list);
308            exprset.cost += cost0;
309            exprset.set_pp(PrettyPrinter::new(
310                compressor.mapping.clone(),
311                compressor.alphabet_size,
312            ));
313            (
314                (exprset, rx_list),
315                compressor.mapping,
316                compressor.alphabet_size,
317            )
318        } else {
319            let alphabet_size = exprset.alphabet_size;
320            (
321                (exprset, rx_list.to_vec()),
322                (0..=255).collect(),
323                alphabet_size,
324            )
325        };
326
327        // disable expensive optimizations after initial construction
328        exprset.disable_optimizations();
329
330        debug!(
331            "compressed: {}",
332            exprset.expr_to_string_with_info(rx_list[0])
333        );
334
335        let alpha = AlphabetInfo {
336            mapping: mapping.try_into().unwrap(),
337            size: alphabet_size,
338        };
339        (alpha, exprset, rx_list.to_vec())
340    }
341
342    #[inline(always)]
343    pub fn map(&self, b: u8) -> usize {
344        if cfg!(feature = "compress") {
345            self.mapping[b as usize] as usize
346        } else {
347            b as usize
348        }
349    }
350
351    #[inline(always)]
352    pub fn map_state(&self, state: StateID, b: u8) -> usize {
353        if cfg!(feature = "compress") {
354            self.map(b) + state.as_usize() * self.len()
355        } else {
356            b as usize + state.as_usize() * 256
357        }
358    }
359
360    #[inline(always)]
361    pub fn len(&self) -> usize {
362        self.size
363    }
364
365    #[inline(always)]
366    pub fn is_empty(&self) -> bool {
367        self.size == 0
368    }
369
370    pub fn has_error(&self) -> bool {
371        self.size == 0
372    }
373
374    pub fn enter_error_state(&mut self) {
375        self.size = 0;
376    }
377}
378
379// private implementation
380impl Regex {
381    pub fn is_contained_in_prefixes(
382        exprset: ExprSet,
383        small: ExprRef,
384        big: ExprRef,
385        relevance_fuel: u64,
386    ) -> Result<bool> {
387        let (mut slf, rxes) = Self::prep_regex(exprset, &[small, big]);
388        let small = rxes[0];
389        let big = rxes[1];
390
391        slf.relevance.is_contained_in_prefixes(
392            &mut slf.exprs,
393            &mut slf.deriv,
394            small,
395            big,
396            relevance_fuel,
397            false,
398        )
399    }
400
401    fn prep_regex(exprset: ExprSet, top_rxs: &[ExprRef]) -> (Self, Vec<ExprRef>) {
402        let (alpha, exprset, rx_list) = AlphabetInfo::from_exprset(exprset, top_rxs);
403        let num_ast_nodes = exprset.len();
404
405        let rx_sets = StateID::new_hash_cons();
406
407        let mut slf = Regex {
408            deriv: DerivCache::new(),
409            next_byte: NextByteCache::new(),
410            relevance: RelevanceCache::new(),
411            exprs: exprset,
412            alpha,
413            rx_sets,
414            state_table: vec![],
415            state_descs: vec![],
416            num_transitions: 0,
417            num_ast_nodes,
418            initial: StateID::MISSING,
419            max_states: usize::MAX,
420        };
421
422        let desc = StateDesc {
423            lookahead_len: Some(None),
424            next_byte: Some(NextByte::Dead),
425        };
426
427        // DEAD
428        slf.append_state(desc.clone());
429        // also append state for the "MISSING"
430        slf.append_state(desc);
431        // in fact, transition from MISSING and DEAD should both lead to DEAD
432        slf.state_table.fill(StateID::DEAD);
433        assert!(!slf.alpha.is_empty());
434
435        (slf, rx_list)
436    }
437
438    pub(crate) fn new_with_exprset(
439        exprset: ExprSet,
440        top_rx: ExprRef,
441        relevance_fuel: u64,
442    ) -> Result<Self> {
443        let (mut r, top_rx) = Self::prep_regex(exprset, &[top_rx]);
444        let top_rx = top_rx[0];
445
446        if r.relevance
447            .is_non_empty_limited(&mut r.exprs, top_rx, relevance_fuel)?
448        {
449            r.initial = r.insert_state(top_rx);
450        } else {
451            r.initial = StateID::DEAD;
452        }
453
454        Ok(r)
455    }
456
457    fn append_state(&mut self, state_desc: StateDesc) {
458        let mut new_states = vec![StateID::MISSING; self.alpha.len()];
459        self.state_table.append(&mut new_states);
460        self.state_descs.push(state_desc);
461        if self.state_descs.len() >= self.max_states {
462            self.alpha.enter_error_state();
463        }
464    }
465
466    fn insert_state(&mut self, d: ExprRef) -> StateID {
467        let id = StateID::new(self.rx_sets.insert(&[d.as_u32()]));
468        if id.as_usize() >= self.state_descs.len() {
469            self.append_state(StateDesc::default());
470        }
471        id
472    }
473
474    fn transition_inner(&mut self, state: StateID, b: u8) -> StateID {
475        assert!(state.is_valid());
476
477        let e = Self::resolve(&self.rx_sets, state);
478        let d = self.deriv.derivative(&mut self.exprs, e, b);
479        if d == ExprRef::NO_MATCH {
480            StateID::DEAD
481        } else if self.relevance.is_non_empty(&mut self.exprs, d) {
482            self.insert_state(d)
483        } else {
484            StateID::DEAD
485        }
486    }
487}
488
489impl Debug for Regex {
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        write!(f, "Regex({})", self.stats())
492    }
493}