beans/regex/
matching.rs

1use crate::lexer::TerminalId;
2use newty::newty;
3use serde::{Deserialize, Serialize};
4use unbounded_interval_tree::interval_tree::IntervalTree;
5
6#[cfg(test)]
7mod tests {
8    use super::super::parsing::tests::compile;
9    use super::*;
10    #[test]
11    fn groups() {
12        // Test the regexp (a+)(b+)
13        let (program, nb_groups) = compile("(a+)(b+)", TerminalId(0)).unwrap();
14        let Match {
15            char_pos: end,
16            id: idx,
17            groups: results,
18        } = find(&program, "aabbb", nb_groups, &Allowed::All).unwrap();
19        assert_eq!(idx, TerminalId(0));
20        assert_eq!(end, 5);
21        assert_eq!(results, vec![Some(0), Some(2), Some(2), Some(5)]);
22    }
23
24    #[test]
25    fn chars() {
26        let (program, nb_groups) = compile("ab", TerminalId(0)).unwrap();
27        let Match {
28            char_pos: end,
29            id: idx,
30            groups: results,
31        } = find(&program, "abb", nb_groups, &Allowed::All).unwrap();
32        assert_eq!(idx, TerminalId(0));
33        assert_eq!(end, 2);
34        assert_eq!(results, vec![]);
35    }
36
37    #[test]
38    fn multiline_comments() {
39        let (program, nb_groups) = compile(r"/\*([^*]|\*[^/])*\*/", TerminalId(0)).unwrap();
40        let text1 = "/* hello, world */#and other stuff";
41        let text2 = "/* hello,\nworld */#and other stuff";
42        let text3 = "/* unicode éèàç */#and other stuff";
43        let Match {
44            char_pos: end, id, ..
45        } = find(&program, text1, nb_groups, &Allowed::All).unwrap();
46        assert_eq!(id, TerminalId(0));
47        assert_eq!(end, 18);
48        assert_eq!(text1.chars().nth(end).unwrap(), '#');
49        let Match {
50            char_pos: end, id, ..
51        } = find(&program, text2, nb_groups, &Allowed::All).unwrap();
52        assert_eq!(id, TerminalId(0));
53        assert_eq!(end, 18);
54        assert_eq!(text2.chars().nth(end).unwrap(), '#');
55        let Match {
56            char_pos: end, id, ..
57        } = find(&program, text3, nb_groups, &Allowed::All).unwrap();
58        assert_eq!(id, TerminalId(0));
59        assert_eq!(end, 18);
60        assert_eq!(text2.chars().nth(end).unwrap(), '#');
61    }
62
63    #[test]
64    fn escaped() {
65        let escaped = vec![
66            (
67                r"\w",
68                vec![
69                    ("a", true),
70                    ("A", true),
71                    ("0", true),
72                    ("_", true),
73                    ("%", false),
74                    ("'", false),
75                ],
76            ),
77            (r"a\b", vec![("a", true), ("ab", false)]),
78            (
79                r".\b.",
80                vec![("a ", true), (" a", true), ("  ", false), ("aa", false)],
81            ),
82        ];
83        for (regex, tests) in escaped {
84            let (program, _) = compile(regex, TerminalId(0)).unwrap();
85            for (string, result) in tests {
86                assert_eq!(find(&program, string, 0, &Allowed::All).is_some(), result);
87            }
88        }
89    }
90
91    #[test]
92    fn greedy() {
93        let (program, nb_groups) = compile("(a+)(a+)", TerminalId(0)).unwrap();
94        let Match {
95            char_pos: end,
96            id: idx,
97            groups: results,
98        } = find(&program, "aaaa", nb_groups, &Allowed::All).unwrap();
99        assert_eq!(end, 4);
100        assert_eq!(idx, TerminalId(0));
101        assert_eq!(results, vec![Some(0), Some(3), Some(3), Some(4)]);
102    }
103
104    #[test]
105    fn partial() {
106        let (program, nb_groups) = compile("a+", TerminalId(0)).unwrap();
107        let Match {
108            char_pos: end,
109            id: idx,
110            groups: results,
111        } = find(&program, "aaabcd", nb_groups, &Allowed::All).unwrap();
112        assert_eq!(end, 3);
113        assert_eq!(idx, TerminalId(0));
114        assert_eq!(results, Vec::new());
115    }
116}
117
118newty! {
119    pub id InstructionPointer
120    impl {
121    pub fn incr(&self) -> Self {
122        Self(self.0+1)
123    }
124    }
125}
126
127/// # Summary
128///
129/// `Instruction` represents an instruction of the VM.
130///
131/// # Variants
132///
133/// `Switch(ips: Vec<(usize, usize)>)`: fork the current thread once for each `(id, ip)` in `ips`,
134///                                         and set the instruction pointer of each new thread to `ip`, but only
135///                                         if the regex `id` is allowed, or if `ignored`.
136/// `Save(reg: usize)`: save the current location in register `reg`
137/// `Split(ip1: usize, ip2: usize)`: fork the current thread, and set the instruction pointer
138///                                of both thread respectivly to `ip1` and `ip2`
139/// `Char(chr: char)`: match `chr` at the current location, or stop the thread if it doesn't match
140/// `Jump(ip: usize)`: set the instruction pointer of the current thread to `ip`
141/// `Match(id: usize)`: stop the thread and record it as a successful match of the regex `id`
142/// `WordChar`: match /[A-Za-z0-9_]/ at the current location, or stop the thread if it doesn't match
143/// `WordBoundary`: match a word boundary (meaning, the end of the beginning of a word)
144/// `CharacterClass(
145///      class: IntervalTree<char>,
146///      negated: bool
147///  )`: match any character inside `class`, or outside if `negated`.
148///     **Warning**: this instruction is not constant time complexity,
149///              since it is actually a shorthand for many instructions at once.
150///              It is however (much) more efficient than if those instructions
151///              were executed indipendently.
152/// `Any`: match any character at the current location
153#[cfg_attr(test, derive(PartialEq))]
154#[derive(Debug, Serialize, Deserialize)]
155pub enum Instruction {
156    Switch(Vec<(TerminalId, InstructionPointer)>),
157    Save(usize),
158    Split(InstructionPointer, InstructionPointer),
159    Char(char),
160    Jump(InstructionPointer),
161    Match(TerminalId),
162    WordChar,
163    Digit,
164    WordBoundary,
165    Whitespace,
166    CharacterClass(IntervalTree<char>, bool),
167    EOF,
168    Any,
169}
170
171/// # Summary
172/// Set the allowed rules for the regex engine. It can either be `Allowed::All`, to allow
173/// all rules, or `Allowed::Some(rules: FixedBitSet)`, in which case only `rules` are allowed.
174#[derive(Debug)]
175pub enum Allowed {
176    All,
177    Some(AllowedTerminals),
178}
179
180impl Allowed {
181    pub fn contains(&self, i: TerminalId) -> bool {
182        match self {
183            Allowed::All => true,
184            Allowed::Some(allowed) => allowed.contains(i),
185        }
186    }
187}
188
189/// # Summary
190///
191/// Represents a `Match`, as returned by a parse of an input by a regex.
192///
193/// # Definition
194///
195/// `Match` is `pos: usize, id: usize, groups: Vec<Option<usize>>`, where
196/// `pos` is the end position (exclusive) of the match (the start position is 0, since all matches are anchored matches).
197/// `id` is the id of the regex that led to a match. If many had, one has been selected according to the priority rules.
198/// `groups` is a vector of all the groups defined by the regex that led to the match. The groups `i` owns items `2*i`
199///   and `2*i+1`, respectivly the start position (inclusive) and the end position (exclusive) of the match.
200pub struct Match {
201    pub char_pos: usize,
202    pub id: TerminalId,
203    pub groups: Vec<Option<usize>>,
204}
205
206// /// # Summary
207// ///
208// /// A way to referencing a `Program` without have an explicit `&Vec<_>` (which clippy doesn't like), but instead a slice referece.
209// pub type ProgramRef<'a> = &'a [Instruction];
210
211newty! {
212    #[derive(Serialize, Deserialize)]
213    #[cfg_attr(test, derive(PartialEq))]
214    pub vec Program (Instruction) [InstructionPointer]
215    impl {
216    pub fn len_ip(&self) -> InstructionPointer {
217            InstructionPointer(self.len())
218    }
219    }
220}
221
222newty! {
223    pub slice ProgramSlice (Instruction) [InstructionPointer]
224    of Program
225}
226
227newty! {
228    set DoneThreads [InstructionPointer]
229}
230
231newty! {
232    pub set AllowedTerminals [TerminalId]
233}
234
235/// # Summary
236///
237/// `ThreadList` is a `Thread` *priority queue* that doesn't accept twice the same `Thread`.
238/// Currently, the priority is the one of a LIFO queue (also called a stack).
239/// This may change in the future.
240///
241/// # Methods
242///
243/// `new`: create a new `ThreadList`
244/// `add`: insert a new `Thread`
245/// `get`: pop a thread
246/// `from`: create a new `ThreadList` from an existing `Vec<Thread>`
247struct ThreadList {
248    done: DoneThreads,
249    threads: Vec<Thread>,
250}
251
252impl ThreadList {
253    /// Create a new `ThreadList` with a given capacity.
254    fn new(size: usize) -> Self {
255        Self {
256            done: DoneThreads::with_raw_capacity(size),
257            threads: Vec::new(),
258        }
259    }
260
261    /// Insert a new `Thread` in the `ThreadList`. Doesn't do anything if it has already been added once.
262    fn add(&mut self, thread: Thread) {
263        let pos = thread.instruction();
264        if !self.done.contains(pos) {
265            self.done.insert(pos);
266            self.threads.push(thread);
267        }
268    }
269
270    /// Pop a `Thread` from the `ThreadList`. This will **not** make the `ThreadList` accept the same `Thread` again.
271    fn get(&mut self) -> Option<Thread> {
272        self.threads.pop()
273    }
274
275    /// Create a new `ThreadList` with given capacity from a `Vec<Thread>`.
276    fn from(threads: Vec<Thread>, size: usize) -> Self {
277        let mut thread_list = Self::new(size);
278        for thread in threads.into_iter() {
279            thread_list.add(thread);
280        }
281        thread_list
282    }
283}
284
285/// # Summary
286///
287/// `Thread` corresponds to a given execution of the program up to a certain point.
288/// During its lifetime, it only moves forward during the executing of the program,
289/// meaning it never backtracks (it might execute code already executed, but never
290/// without having gone forward in the text matching). If the thread has to make a choice,
291/// a copy of the thread is created, and then each copy makes a different choice.
292/// This is why a `Thread` never backtracks (in the worst case scenario, it just stops).
293///
294/// Fundamentally, a `Thread` is set a registers: the special *instruction pointer* `ip` register,
295/// and a bunch of registers dedicated to group matching.
296///
297/// # Methods
298///
299/// `new`: create a new `Thread`
300/// `instruction`: return the `ip` value
301/// `jump`: set the `ip` to a new value
302/// `save`: set the value for a register dedicated to group matching
303#[derive(Clone, Debug)]
304struct Thread {
305    instruction: InstructionPointer,
306    groups: Vec<Option<usize>>,
307}
308
309impl Thread {
310    /// Create a new `Thread`. This method is public only to be accessible from documentation,
311    /// but this doesn't mean it is meant to be accessed by the end-user.
312    pub fn new(instruction: InstructionPointer, size: usize) -> Self {
313        Self {
314            instruction,
315            groups: vec![None; 2 * size],
316        }
317    }
318
319    /// Return the value stored in the *instruction pointer* register, `ip`.
320    fn instruction(&self) -> InstructionPointer {
321        self.instruction
322    }
323
324    /// Set the *instruction pointer* register value to `pos`.
325    fn jump(&mut self, pos: InstructionPointer) {
326        self.instruction = pos;
327    }
328
329    /// Set the *group* register `idx` value to `pos`.
330    fn save(&mut self, idx: usize, bytes_pos: usize) {
331        self.groups[idx] = Some(bytes_pos);
332    }
333}
334
335/// Execute a single instruction for `thread`, in a given context.
336#[allow(clippy::too_many_arguments)]
337fn match_next(
338    chr: char,
339    bytes_pos: usize,
340    chars_pos: usize,
341    mut thread: Thread,
342    current: &mut ThreadList,
343    next: Option<&mut ThreadList>,
344    prog: &ProgramSlice,
345    best_match: &mut Option<Match>,
346    last: Option<char>,
347    allowed: &Allowed,
348) {
349    /// Return whether `chr` is a word char,
350    /// matched by /[a-zA-Z0-9_]/.
351    fn is_word_char(chr: char) -> bool {
352        chr.is_alphanumeric() || chr == '_'
353    }
354
355    /// Return whether `chr` is a digit,
356    /// matched by /[0-9]/.
357    fn is_digit(chr: char) -> bool {
358        chr.is_ascii_digit()
359    }
360
361    /// Return whether `chr` is a whitespace,
362    /// matched by /[ \t]/.
363    fn is_whitespace(chr: char) -> bool {
364        chr == ' ' || chr == '\t'
365    }
366
367    /// Advance `thread` to the next instruction and queue it
368    /// in the `thread_list`.
369    fn advance(mut thread: Thread, thread_list: Option<&mut ThreadList>) {
370        thread.jump(thread.instruction().incr());
371        if let Some(next) = thread_list {
372            next.add(thread);
373        }
374    }
375
376    match &prog[thread.instruction()] {
377        Instruction::Char(expected) => {
378            if *expected == chr {
379                advance(thread, next);
380            }
381        }
382        Instruction::Any => advance(thread, next),
383        Instruction::WordChar => {
384            if is_word_char(chr) {
385                advance(thread, next);
386            }
387        }
388        Instruction::Digit => {
389            if is_digit(chr) {
390                advance(thread, next);
391            }
392        }
393        Instruction::Whitespace => {
394            if is_whitespace(chr) {
395                advance(thread, next);
396            }
397        }
398        Instruction::Jump(pos) => {
399            thread.jump(*pos);
400            current.add(thread);
401        }
402        Instruction::Save(idx) => {
403            thread.save(*idx, bytes_pos);
404            advance(thread, Some(current));
405        }
406        Instruction::Switch(instructions) => {
407            instructions
408                .iter()
409                .rev()
410                .filter(|(id, _)| allowed.contains(*id))
411                .for_each(|(_, ip)| {
412                    let mut new = thread.clone();
413                    new.jump(*ip);
414                    current.add(new);
415                });
416        }
417        Instruction::Split(pos1, pos2) => {
418            let mut other = thread.clone();
419            other.jump(*pos2);
420            thread.jump(*pos1);
421            current.add(other);
422            current.add(thread);
423        }
424        Instruction::Match(id) => {
425            if let Some(Match {
426                char_pos: p,
427                id: prior,
428                ..
429            }) = best_match
430            {
431                if chars_pos > *p || *prior > *id {
432                    *best_match = Some(Match {
433                        char_pos: chars_pos,
434                        id: *id,
435                        groups: thread.groups,
436                    });
437                }
438            } else {
439                *best_match = Some(Match {
440                    char_pos: chars_pos,
441                    id: *id,
442                    groups: thread.groups,
443                });
444            }
445        }
446        Instruction::CharacterClass(class, negated) => {
447            if negated ^ class.contains_point(&chr) {
448                advance(thread, next);
449            }
450        }
451        Instruction::WordBoundary => {
452            if let Some(last) = last {
453                if is_word_char(last) ^ is_word_char(chr) {
454                    advance(thread, Some(current));
455                }
456            } else {
457                advance(thread, Some(current));
458            }
459        }
460        Instruction::EOF => {
461            if next.is_none() {
462                advance(thread, Some(current));
463            }
464        }
465    }
466}
467
468/// Simulate a VM with program `prog` on `input`. There should be `size` groups.
469pub fn find(prog: &ProgramSlice, input: &str, size: usize, allowed: &Allowed) -> Option<Match> {
470    let mut current =
471        ThreadList::from(vec![Thread::new(InstructionPointer(0), size)], prog.len());
472    let mut best_match = None;
473    let mut last = None;
474    let mut bytes_pos = 0;
475    for (chars_pos, chr) in input.chars().enumerate() {
476        let mut next = ThreadList::new(prog.len());
477        while let Some(thread) = current.get() {
478            match_next(
479                chr,
480                bytes_pos,
481                chars_pos,
482                thread,
483                &mut current,
484                Some(&mut next),
485                prog,
486                &mut best_match,
487                last,
488                allowed,
489            );
490        }
491        current = next;
492        last = Some(chr);
493        bytes_pos += chr.len_utf8();
494    }
495    let chars_pos = input.len();
496    while let Some(thread) = current.get() {
497        match_next(
498            '#',
499            bytes_pos,
500            chars_pos,
501            thread,
502            &mut current,
503            None,
504            prog,
505            &mut best_match,
506            last,
507            allowed,
508        );
509    }
510
511    best_match
512}