1use itertools::iproduct;
2use regex_syntax::hir::{self, visit, Hir, HirKind, Visitor};
3use std::cell::Cell;
4use std::fmt::{Display, Formatter, Write};
5use std::str::Utf8Error;
6use std::{collections::BTreeSet, ops::Deref};
7
8#[derive(Clone, Debug)]
9pub enum Model {
10 All(Cell<usize>),
12 None(Cell<usize>),
14 Atom(Cell<usize>, String),
16 And(Cell<usize>, Vec<Model>),
18 Or(Cell<usize>, Vec<Model>),
20}
21use Model::{All, And, Atom, None, Or};
22
23impl std::hash::Hash for Model {
24 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
25 state.write_u8(self.op());
26 match self {
27 All(_) | None(_) => (),
28 Atom(_, s) => s.hash(state),
29 And(_, ps) | Or(_, ps) => {
30 state.write_usize(ps.len());
31 for p in ps {
32 state.write_usize(p.unique_id());
33 }
34 }
35 }
36 }
37}
38
39impl std::cmp::PartialEq for Model {
40 fn eq(&self, other: &Self) -> bool {
41 match (self, other) {
42 (All(_), All(_)) | (None(_), None(_)) => true,
43 (Atom(_, a), Atom(_, b)) => a == b,
44 (And(_, va), And(_, vb)) | (Or(_, va), Or(_, vb)) => {
45 va.len() == vb.len()
46 && std::iter::zip(va, vb).all(|(a, b)| a.unique_id() == b.unique_id())
47 }
48 _ => false,
49 }
50 }
51}
52impl Eq for Model {}
53
54impl From<String> for Model {
55 fn from(s: String) -> Self {
56 Atom(Cell::new(usize::MAX), s)
57 }
58}
59
60impl Display for Model {
61 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62 match &self {
63 All(_) => f.write_str(""),
64 None(_) => f.write_str("*no-matches*"),
65 Atom(_, s) => f.write_str(s),
66 And(_, subs) => {
67 for (i, s) in subs.iter().enumerate() {
68 if i != 0 {
69 f.write_char(' ')?;
70 }
71 write!(f, "{s}")?;
72 }
73 Ok(())
74 }
75 Or(_, subs) => {
76 f.write_char('(')?;
77 for (i, s) in subs.iter().enumerate() {
78 if i != 0 {
79 f.write_char('|')?;
80 }
81 write!(f, "{s}")?;
82 }
83 f.write_char(')')
84 }
85 }
86 }
87}
88
89#[derive(Debug)]
91pub enum Error {
92 FinalizationError,
94 EarlyStop,
96 DecodeError(Utf8Error),
98 ClassError(hir::ClassBytes),
100}
101impl Display for Error {
102 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{self:?}")
104 }
105}
106impl std::error::Error for Error {}
107impl From<Utf8Error> for Error {
108 fn from(value: Utf8Error) -> Self {
109 Error::DecodeError(value)
110 }
111}
112
113impl Model {
114 pub fn new(r: &Hir) -> Result<Self, Error> {
115 visit(r, InfoVisitor::default())
116 }
117
118 pub fn unique_id(&self) -> usize {
119 match self {
120 All(id) | None(id) | Atom(id, _) | And(id, _) | Or(id, _) => id.get(),
121 }
122 }
123 pub fn set_unique_id(&self, value: usize) {
124 match self {
125 All(id) | None(id) | Atom(id, _) | And(id, _) | Or(id, _) => id.set(value),
126 }
127 }
128
129 pub fn all() -> Self {
130 All(Cell::new(usize::MAX))
131 }
132
133 pub fn none() -> Self {
134 None(Cell::new(usize::MAX))
135 }
136
137 fn or_strings(strings: SSet) -> Self {
138 Model::Or(
139 Cell::new(usize::MAX),
140 simplify_string_set(strings).map(From::from).collect(),
141 )
142 }
143
144 fn op(&self) -> u8 {
145 match self {
146 All(_) => 0,
147 None(_) => 1,
148 Atom(_, _) => 2,
149 And(_, _) => 3,
150 Or(_, _) => 4,
151 }
152 }
153
154 fn simplify(self) -> Self {
156 match self {
157 And(uid, v) if v.is_empty() => All(uid),
158 Or(uid, v) if v.is_empty() => None(uid),
159 And(_, mut v) | Or(_, mut v) if v.len() == 1 => {
160 v.pop().expect("we checked the length").simplify()
161 }
162 s => s,
163 }
164 }
165
166 fn and(self, mut b: Self) -> Self {
170 let mut a = self.simplify();
171 b = b.simplify();
172
173 if a.op() > b.op() {
175 std::mem::swap(&mut a, &mut b);
176 }
177
178 a = match a {
180 All(..) => return b,
182 None(uid) => return None(uid),
184 a => a,
185 };
186
187 match (a, b) {
188 (And(unique_id, mut va), And(_, vb)) => {
190 va.extend(vb);
191 And(unique_id, va)
192 }
193 (And(unique_id, mut v), vv) | (vv, And(unique_id, mut v)) => {
195 v.push(vv);
196 And(unique_id, v)
197 }
198 (a, b) => And(Cell::new(usize::MAX), vec![a, b]),
199 }
200 }
201
202 fn or(self, mut b: Self) -> Self {
203 let mut a = self.simplify();
204 b = b.simplify();
205
206 if a.op() > b.op() {
208 std::mem::swap(&mut a, &mut b);
209 }
210
211 a = match a {
212 None(..) => return b,
214 All(uid) => return All(uid),
216 a => a,
217 };
218
219 match (a, b) {
220 (Or(unique_id, mut va), Or(_, vb)) => {
222 va.extend(vb);
223 Or(unique_id, va)
224 }
225 (Or(unique_id, mut v), vv) | (vv, Or(unique_id, mut v)) => {
227 v.push(vv);
228 Or(unique_id, v)
229 }
230 (a, b) => Or(Cell::new(usize::MAX), vec![a, b]),
231 }
232 }
233}
234
235#[derive(PartialEq, Eq, Debug, Clone)]
246struct LengthThenLex(pub String);
247impl Deref for LengthThenLex {
248 type Target = String;
249
250 fn deref(&self) -> &Self::Target {
251 &self.0
252 }
253}
254impl Ord for LengthThenLex {
255 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
256 self.0
257 .len()
258 .cmp(&other.0.len())
259 .then_with(|| self.0.cmp(&other.0))
260 }
261}
262impl PartialOrd for LengthThenLex {
263 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
264 Some(self.cmp(other))
265 }
266}
267type SSet = BTreeSet<LengthThenLex>;
268fn simplify_string_set(strings: SSet) -> impl Iterator<Item = String> {
269 let mut to_keep = vec![true; strings.len()];
270 let mut e = strings.iter().enumerate();
271 while let Some((i, s)) = e.next() {
272 if s.is_empty() || !to_keep[i] {
273 continue;
274 }
275
276 for (keep, (_, s2)) in to_keep[i..].iter_mut().skip(1).zip(e.clone()) {
277 if *keep && s2.len() > s.len() && s2.0.contains(&s.0) {
278 *keep = false;
279 }
280 }
281 }
282
283 std::iter::zip(to_keep, strings)
284 .filter(|v| v.0)
285 .map(|v| v.1 .0)
286}
287
288#[derive(Debug)]
291enum Info {
292 Match(Model),
293 Exact(SSet),
294}
295impl Info {
296 fn take_match(self) -> Model {
297 match self {
298 Self::Match(p) => p,
299 Self::Exact(s) => Model::or_strings(s),
300 }
301 }
302
303 fn into_exact(self) -> Option<SSet> {
304 match self {
305 Self::Exact(s) => Some(s),
306 Self::Match(_) => Option::None,
307 }
308 }
309}
310
311struct InfoVisitor {
312 stack: Vec<Info>,
313 max_visits: usize,
314}
315impl Default for InfoVisitor {
316 fn default() -> Self {
317 Self {
318 max_visits: 100_000,
319 stack: Vec::new(),
320 }
321 }
322}
323
324impl Visitor for InfoVisitor {
328 type Output = Model;
329 type Err = Error;
330
331 fn finish(mut self) -> Result<Self::Output, Self::Err> {
332 (self.stack.len() == 1)
333 .then_some(&mut self.stack)
334 .and_then(|s| s.pop())
335 .map(Info::take_match)
336 .ok_or(Error::FinalizationError)
337 }
338
339 fn visit_pre(&mut self, _hir: &Hir) -> Result<(), Self::Err> {
340 self.max_visits = self.max_visits.checked_sub(1).ok_or(Error::EarlyStop)?;
344
345 Ok(())
346 }
347
348 fn visit_post(&mut self, hir: &Hir) -> Result<(), Self::Err> {
349 match hir.kind() {
350 HirKind::Empty | HirKind::Look(_) => {
351 self.stack
352 .push(Info::Exact([LengthThenLex(String::new())].into()));
353 }
354 HirKind::Literal(hir::Literal(data)) => {
355 if data.is_empty() {
356 self.stack.push(Info::Match(Model::none()));
358 } else {
359 self.stack.push(Info::Exact(
364 [LengthThenLex(std::str::from_utf8(data)?.to_lowercase())].into(),
365 ));
366 }
367 }
368 HirKind::Class(cls) => {
369 let uc;
370 let c = match cls {
371 hir::Class::Unicode(c) => c,
372 hir::Class::Bytes(b) => {
373 uc = b
374 .to_unicode_class()
375 .ok_or_else(|| Error::ClassError(b.clone()))?;
376 &uc
377 }
378 };
379 self.stack
380 .push(if c.iter().map(|r| r.len()).sum::<usize>() > 10 {
381 Info::Match(Model::all())
382 } else {
383 Info::Exact(
384 c.iter()
385 .flat_map(|r| (r.start()..=r.end()))
386 .map(char::to_lowercase)
387 .map(String::from_iter)
388 .map(LengthThenLex)
389 .collect(),
390 )
391 });
392 }
393 HirKind::Repetition(hir::Repetition { min, .. }) => {
397 if *min == 0 {
398 self.stack.pop();
400 self.stack.push(Info::Match(Model::all()));
401 } else {
402 let arg = self
404 .stack
405 .pop()
406 .expect("a repetition to be associated with a pattern to repeat")
407 .take_match();
408 self.stack.push(Info::Match(arg));
409 }
410 }
411 HirKind::Capture(_) => (),
414 HirKind::Alternation(alt) => {
415 let topn = self.stack.len() - alt.len()..;
421 let infos = &mut self.stack[topn.clone()];
422
423 let matches =
424 topn.start + infos.iter().filter(|v| matches!(v, Info::Match(_))).count();
425 infos.sort_unstable_by_key(|v| match v {
429 Info::Match(_) => (false, 0),
430 Info::Exact(s) => (true, s.len()),
431 });
432 let exacts = self
434 .stack
435 .drain(matches..)
436 .rev()
437 .fold(BTreeSet::new(), |mut s, i| {
438 s.append(
439 &mut i
440 .into_exact()
441 .expect("the top `matches` records should be exacts"),
442 );
443 s
444 });
445 let mut matches = self
446 .stack
447 .drain(topn)
448 .map(Info::take_match)
449 .collect::<Vec<_>>();
450 self.stack.push(if matches.is_empty() {
451 Info::Exact(exacts)
452 } else {
453 if !exacts.is_empty() {
454 matches.push(Model::or_strings(exacts));
455 }
456 Info::Match(
457 matches
458 .into_iter()
459 .map(From::from)
460 .fold(Model::none(), Model::or),
461 )
462 });
463 }
464 HirKind::Concat(c) => {
468 let topn = self.stack.len() - c.len()..;
469
470 let mut result = Info::Match(Model::all());
472 let mut exacts = BTreeSet::new();
473 for info in self.stack.drain(topn) {
474 match info {
475 Info::Exact(set) if exacts.is_empty() => {
476 exacts = set;
477 }
478 Info::Exact(set) if set.len() * exacts.len() <= 16 => {
479 exacts = iproduct!(&exacts, &set)
483 .map(|(s, ss)| {
484 let mut r = String::with_capacity(s.len() + ss.len());
485 r.push_str(s);
486 r.push_str(ss);
487 LengthThenLex(r)
488 })
489 .collect();
490 }
491 i => {
492 let mut p = result.take_match();
495 if !exacts.is_empty() {
496 p = Model::and(p, Model::or_strings(std::mem::take(&mut exacts)));
497 }
498 p = Model::and(p, i.take_match());
499 result = Info::Match(p);
500 }
501 }
502 }
503
504 if exacts.is_empty() {
505 self.stack.push(result);
506 } else {
507 self.stack.push(Info::Match(Model::and(
508 result.take_match(),
509 Model::or_strings(exacts),
510 )));
511 }
512 }
513 }
514 Ok(())
515 }
516}