regex_filtered/
model.rs

1use itertools::iproduct;
2use regex_syntax::hir::{self, visit, Hir, HirKind, Visitor};
3use std::cell::Cell;
4use std::fmt::{Display, Formatter, Write};
5use std::str::Utf8Error;
6use std::{collections::BTreeSet, ops::Deref};
7
8#[derive(Clone, Debug)]
9pub enum Model {
10    /// Everything matches.
11    All(Cell<usize>),
12    /// Nothing matches.
13    None(Cell<usize>),
14    /// The string matches.
15    Atom(Cell<usize>, String),
16    /// All sub-filters must match.
17    And(Cell<usize>, Vec<Model>),
18    /// One sub-filter must match.
19    Or(Cell<usize>, Vec<Model>),
20}
21use Model::{All, And, Atom, None, Or};
22
23impl std::hash::Hash for Model {
24    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
25        state.write_u8(self.op());
26        match self {
27            All(_) | None(_) => (),
28            Atom(_, s) => s.hash(state),
29            And(_, ps) | Or(_, ps) => {
30                state.write_usize(ps.len());
31                for p in ps {
32                    state.write_usize(p.unique_id());
33                }
34            }
35        }
36    }
37}
38
39impl std::cmp::PartialEq for Model {
40    fn eq(&self, other: &Self) -> bool {
41        match (self, other) {
42            (All(_), All(_)) | (None(_), None(_)) => true,
43            (Atom(_, a), Atom(_, b)) => a == b,
44            (And(_, va), And(_, vb)) | (Or(_, va), Or(_, vb)) => {
45                va.len() == vb.len()
46                    && std::iter::zip(va, vb).all(|(a, b)| a.unique_id() == b.unique_id())
47            }
48            _ => false,
49        }
50    }
51}
52impl Eq for Model {}
53
54impl From<String> for Model {
55    fn from(s: String) -> Self {
56        Atom(Cell::new(usize::MAX), s)
57    }
58}
59
60impl Display for Model {
61    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        match &self {
63            All(_) => f.write_str(""),
64            None(_) => f.write_str("*no-matches*"),
65            Atom(_, s) => f.write_str(s),
66            And(_, subs) => {
67                for (i, s) in subs.iter().enumerate() {
68                    if i != 0 {
69                        f.write_char(' ')?;
70                    }
71                    write!(f, "{s}")?;
72                }
73                Ok(())
74            }
75            Or(_, subs) => {
76                f.write_char('(')?;
77                for (i, s) in subs.iter().enumerate() {
78                    if i != 0 {
79                        f.write_char('|')?;
80                    }
81                    write!(f, "{s}")?;
82                }
83                f.write_char(')')
84            }
85        }
86    }
87}
88
89/// Processing errors
90#[derive(Debug)]
91pub enum Error {
92    /// Processing missed or exceeded some of the stack
93    FinalizationError,
94    /// Processing reached HIR nodes limit
95    EarlyStop,
96    /// Literal was not a valid string
97    DecodeError(Utf8Error),
98    /// Non-decodable character class
99    ClassError(hir::ClassBytes),
100}
101impl Display for Error {
102    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{self:?}")
104    }
105}
106impl std::error::Error for Error {}
107impl From<Utf8Error> for Error {
108    fn from(value: Utf8Error) -> Self {
109        Error::DecodeError(value)
110    }
111}
112
113impl Model {
114    pub fn new(r: &Hir) -> Result<Self, Error> {
115        visit(r, InfoVisitor::default())
116    }
117
118    pub fn unique_id(&self) -> usize {
119        match self {
120            All(id) | None(id) | Atom(id, _) | And(id, _) | Or(id, _) => id.get(),
121        }
122    }
123    pub fn set_unique_id(&self, value: usize) {
124        match self {
125            All(id) | None(id) | Atom(id, _) | And(id, _) | Or(id, _) => id.set(value),
126        }
127    }
128
129    pub fn all() -> Self {
130        All(Cell::new(usize::MAX))
131    }
132
133    pub fn none() -> Self {
134        None(Cell::new(usize::MAX))
135    }
136
137    fn or_strings(strings: SSet) -> Self {
138        Model::Or(
139            Cell::new(usize::MAX),
140            simplify_string_set(strings).map(From::from).collect(),
141        )
142    }
143
144    fn op(&self) -> u8 {
145        match self {
146            All(_) => 0,
147            None(_) => 1,
148            Atom(_, _) => 2,
149            And(_, _) => 3,
150            Or(_, _) => 4,
151        }
152    }
153
154    /// Simplifies And and Or nodes
155    fn simplify(self) -> Self {
156        match self {
157            And(uid, v) if v.is_empty() => All(uid),
158            Or(uid, v) if v.is_empty() => None(uid),
159            And(_, mut v) | Or(_, mut v) if v.len() == 1 => {
160                v.pop().expect("we checked the length").simplify()
161            }
162            s => s,
163        }
164    }
165
166    // re2 merges those into separate functions but it only saves on
167    // the header and increases the branching complexity of the rest
168    // so y?
169    fn and(self, mut b: Self) -> Self {
170        let mut a = self.simplify();
171        b = b.simplify();
172
173        // Canonicalize: a->op <= b->op.
174        if a.op() > b.op() {
175            std::mem::swap(&mut a, &mut b);
176        }
177
178        // ALL and NONE are smallest opcodes.
179        a = match a {
180            // ALL and b = b
181            All(..) => return b,
182            // NONE and b = None
183            None(uid) => return None(uid),
184            a => a,
185        };
186
187        match (a, b) {
188            // If a and b match op, merge their contents.
189            (And(unique_id, mut va), And(_, vb)) => {
190                va.extend(vb);
191                And(unique_id, va)
192            }
193            // If a or b matches the operation, merge the other one in
194            (And(unique_id, mut v), vv) | (vv, And(unique_id, mut v)) => {
195                v.push(vv);
196                And(unique_id, v)
197            }
198            (a, b) => And(Cell::new(usize::MAX), vec![a, b]),
199        }
200    }
201
202    fn or(self, mut b: Self) -> Self {
203        let mut a = self.simplify();
204        b = b.simplify();
205
206        // Canonicalize: a->op <= b->op.
207        if a.op() > b.op() {
208            std::mem::swap(&mut a, &mut b);
209        }
210
211        a = match a {
212            // NONE or b = b
213            None(..) => return b,
214            // ALL or b = ALL
215            All(uid) => return All(uid),
216            a => a,
217        };
218
219        match (a, b) {
220            // If a and b match op, merge their contents.
221            (Or(unique_id, mut va), Or(_, vb)) => {
222                va.extend(vb);
223                Or(unique_id, va)
224            }
225            // If a or b matches the operation, merge the other one in
226            (Or(unique_id, mut v), vv) | (vv, Or(unique_id, mut v)) => {
227                v.push(vv);
228                Or(unique_id, v)
229            }
230            (a, b) => Or(Cell::new(usize::MAX), vec![a, b]),
231        }
232    }
233}
234
235// Necessary for simplify_string_set to work: the simplification
236// consists of removing every "superset" of an other string of the
237// set, that is any strings which contains an other (non-empty) string
238// of the set, because the smaller atom will already indicate that the
239// pattern is a candidate, so also matching the larger atom is useless
240//
241// In order to make the implementation simpler and more efficient,
242// visit the smaller strings first that way we only need to visit the
243// following siblings (larger strings which *might* contain the
244// current one).
245#[derive(PartialEq, Eq, Debug, Clone)]
246struct LengthThenLex(pub String);
247impl Deref for LengthThenLex {
248    type Target = String;
249
250    fn deref(&self) -> &Self::Target {
251        &self.0
252    }
253}
254impl Ord for LengthThenLex {
255    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
256        self.0
257            .len()
258            .cmp(&other.0.len())
259            .then_with(|| self.0.cmp(&other.0))
260    }
261}
262impl PartialOrd for LengthThenLex {
263    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
264        Some(self.cmp(other))
265    }
266}
267type SSet = BTreeSet<LengthThenLex>;
268fn simplify_string_set(strings: SSet) -> impl Iterator<Item = String> {
269    let mut to_keep = vec![true; strings.len()];
270    let mut e = strings.iter().enumerate();
271    while let Some((i, s)) = e.next() {
272        if s.is_empty() || !to_keep[i] {
273            continue;
274        }
275
276        for (keep, (_, s2)) in to_keep[i..].iter_mut().skip(1).zip(e.clone()) {
277            if *keep && s2.len() > s.len() && s2.0.contains(&s.0) {
278                *keep = false;
279            }
280        }
281    }
282
283    std::iter::zip(to_keep, strings)
284        .filter(|v| v.0)
285        .map(|v| v.1 .0)
286}
287
288/// Intermediate information about the set of strings a regex matches,
289/// used for the computation of a prefilter.
290#[derive(Debug)]
291enum Info {
292    Match(Model),
293    Exact(SSet),
294}
295impl Info {
296    fn take_match(self) -> Model {
297        match self {
298            Self::Match(p) => p,
299            Self::Exact(s) => Model::or_strings(s),
300        }
301    }
302
303    fn into_exact(self) -> Option<SSet> {
304        match self {
305            Self::Exact(s) => Some(s),
306            Self::Match(_) => Option::None,
307        }
308    }
309}
310
311struct InfoVisitor {
312    stack: Vec<Info>,
313    max_visits: usize,
314}
315impl Default for InfoVisitor {
316    fn default() -> Self {
317        Self {
318            max_visits: 100_000,
319            stack: Vec::new(),
320        }
321    }
322}
323
324// [`regex_syntax::hir::Visitor`] works pretty differently than
325// `re2::Regexp::Walker` as it does not return / merge anything, so we
326// need to merge down into the stack on post.
327impl Visitor for InfoVisitor {
328    type Output = Model;
329    type Err = Error;
330
331    fn finish(mut self) -> Result<Self::Output, Self::Err> {
332        (self.stack.len() == 1)
333            .then_some(&mut self.stack)
334            .and_then(|s| s.pop())
335            .map(Info::take_match)
336            .ok_or(Error::FinalizationError)
337    }
338
339    fn visit_pre(&mut self, _hir: &Hir) -> Result<(), Self::Err> {
340        // re2 sets `stopped_early` and calls `ShortVisit` but keeps
341        // on keeping on, not clear why & ultimately BuildInfo only
342        // cares about having stopped early
343        self.max_visits = self.max_visits.checked_sub(1).ok_or(Error::EarlyStop)?;
344
345        Ok(())
346    }
347
348    fn visit_post(&mut self, hir: &Hir) -> Result<(), Self::Err> {
349        match hir.kind() {
350            HirKind::Empty | HirKind::Look(_) => {
351                self.stack
352                    .push(Info::Exact([LengthThenLex(String::new())].into()));
353            }
354            HirKind::Literal(hir::Literal(data)) => {
355                if data.is_empty() {
356                    // NoMatch
357                    self.stack.push(Info::Match(Model::none()));
358                } else {
359                    // re2 does this weird as it performs a cross
360                    // product of individual characters, but as far as
361                    // I understand that's just a complicated way to
362                    // build a singleton set of the payload?
363                    self.stack.push(Info::Exact(
364                        [LengthThenLex(std::str::from_utf8(data)?.to_lowercase())].into(),
365                    ));
366                }
367            }
368            HirKind::Class(cls) => {
369                let uc;
370                let c = match cls {
371                    hir::Class::Unicode(c) => c,
372                    hir::Class::Bytes(b) => {
373                        uc = b
374                            .to_unicode_class()
375                            .ok_or_else(|| Error::ClassError(b.clone()))?;
376                        &uc
377                    }
378                };
379                self.stack
380                    .push(if c.iter().map(|r| r.len()).sum::<usize>() > 10 {
381                        Info::Match(Model::all())
382                    } else {
383                        Info::Exact(
384                            c.iter()
385                                .flat_map(|r| (r.start()..=r.end()))
386                                .map(char::to_lowercase)
387                                .map(String::from_iter)
388                                .map(LengthThenLex)
389                                .collect(),
390                        )
391                    });
392            }
393            // Apparently re2 and regex have inverse choices, re2
394            // normalises repetitions to */+/?, regex normalises
395            // everything to {a, b}, so this may or may make any sense
396            HirKind::Repetition(hir::Repetition { min, .. }) => {
397                if *min == 0 {
398                    // corresponds to */? (star/quest)
399                    self.stack.pop();
400                    self.stack.push(Info::Match(Model::all()));
401                } else {
402                    // corresponds to +
403                    let arg = self
404                        .stack
405                        .pop()
406                        .expect("a repetition to be associated with a pattern to repeat")
407                        .take_match();
408                    self.stack.push(Info::Match(arg));
409                }
410            }
411            // should just leave its child on the stack for whoever
412            // lives up
413            HirKind::Capture(_) => (),
414            HirKind::Alternation(alt) => {
415                // needs to pop alt.len() items from the stack, and if
416                // they're ``exact`` then just merge them, otherwise
417                // ``Prefilter::Or`` them
418
419                // sort the topn to have the exacts at the top, largest top
420                let topn = self.stack.len() - alt.len()..;
421                let infos = &mut self.stack[topn.clone()];
422
423                let matches =
424                    topn.start + infos.iter().filter(|v| matches!(v, Info::Match(_))).count();
425                // I think we can do that because we don't actually
426                // regex match so order should not matter question
427                // mark
428                infos.sort_unstable_by_key(|v| match v {
429                    Info::Match(_) => (false, 0),
430                    Info::Exact(s) => (true, s.len()),
431                });
432                // there are exact matches, merge them
433                let exacts = self
434                    .stack
435                    .drain(matches..)
436                    .rev()
437                    .fold(BTreeSet::new(), |mut s, i| {
438                        s.append(
439                            &mut i
440                                .into_exact()
441                                .expect("the top `matches` records should be exacts"),
442                        );
443                        s
444                    });
445                let mut matches = self
446                    .stack
447                    .drain(topn)
448                    .map(Info::take_match)
449                    .collect::<Vec<_>>();
450                self.stack.push(if matches.is_empty() {
451                    Info::Exact(exacts)
452                } else {
453                    if !exacts.is_empty() {
454                        matches.push(Model::or_strings(exacts));
455                    }
456                    Info::Match(
457                        matches
458                            .into_iter()
459                            .map(From::from)
460                            .fold(Model::none(), Model::or),
461                    )
462                });
463            }
464            // and this one gets really painful, like above we need to
465            // take the topn but unlike the above we can't reorder all
466            // our stuff around
467            HirKind::Concat(c) => {
468                let topn = self.stack.len() - c.len()..;
469
470                // ALL is the identity element of AND
471                let mut result = Info::Match(Model::all());
472                let mut exacts = BTreeSet::new();
473                for info in self.stack.drain(topn) {
474                    match info {
475                        Info::Exact(set) if exacts.is_empty() => {
476                            exacts = set;
477                        }
478                        Info::Exact(set) if set.len() * exacts.len() <= 16 => {
479                            // Not useful to consume the existing
480                            // `exacts` up-front, as each item has to
481                            // be splatted over `set`.
482                            exacts = iproduct!(&exacts, &set)
483                                .map(|(s, ss)| {
484                                    let mut r = String::with_capacity(s.len() + ss.len());
485                                    r.push_str(s);
486                                    r.push_str(ss);
487                                    LengthThenLex(r)
488                                })
489                                .collect();
490                        }
491                        i => {
492                            // here AND the combination of info,
493                            // exact, and the existing garbage
494                            let mut p = result.take_match();
495                            if !exacts.is_empty() {
496                                p = Model::and(p, Model::or_strings(std::mem::take(&mut exacts)));
497                            }
498                            p = Model::and(p, i.take_match());
499                            result = Info::Match(p);
500                        }
501                    }
502                }
503
504                if exacts.is_empty() {
505                    self.stack.push(result);
506                } else {
507                    self.stack.push(Info::Match(Model::and(
508                        result.take_match(),
509                        Model::or_strings(exacts),
510                    )));
511                }
512            }
513        }
514        Ok(())
515    }
516}