1use crate::lexer::TerminalId;
2use newty::newty;
3use serde::{Deserialize, Serialize};
4use unbounded_interval_tree::interval_tree::IntervalTree;
5
6#[cfg(test)]
7mod tests {
8 use super::super::parsing::tests::compile;
9 use super::*;
10 #[test]
11 fn groups() {
12 let (program, nb_groups) = compile("(a+)(b+)", TerminalId(0)).unwrap();
14 let Match {
15 char_pos: end,
16 id: idx,
17 groups: results,
18 } = find(&program, "aabbb", nb_groups, &Allowed::All).unwrap();
19 assert_eq!(idx, TerminalId(0));
20 assert_eq!(end, 5);
21 assert_eq!(results, vec![Some(0), Some(2), Some(2), Some(5)]);
22 }
23
24 #[test]
25 fn chars() {
26 let (program, nb_groups) = compile("ab", TerminalId(0)).unwrap();
27 let Match {
28 char_pos: end,
29 id: idx,
30 groups: results,
31 } = find(&program, "abb", nb_groups, &Allowed::All).unwrap();
32 assert_eq!(idx, TerminalId(0));
33 assert_eq!(end, 2);
34 assert_eq!(results, vec![]);
35 }
36
37 #[test]
38 fn multiline_comments() {
39 let (program, nb_groups) = compile(r"/\*([^*]|\*[^/])*\*/", TerminalId(0)).unwrap();
40 let text1 = "/* hello, world */#and other stuff";
41 let text2 = "/* hello,\nworld */#and other stuff";
42 let text3 = "/* unicode éèàç */#and other stuff";
43 let Match {
44 char_pos: end, id, ..
45 } = find(&program, text1, nb_groups, &Allowed::All).unwrap();
46 assert_eq!(id, TerminalId(0));
47 assert_eq!(end, 18);
48 assert_eq!(text1.chars().nth(end).unwrap(), '#');
49 let Match {
50 char_pos: end, id, ..
51 } = find(&program, text2, nb_groups, &Allowed::All).unwrap();
52 assert_eq!(id, TerminalId(0));
53 assert_eq!(end, 18);
54 assert_eq!(text2.chars().nth(end).unwrap(), '#');
55 let Match {
56 char_pos: end, id, ..
57 } = find(&program, text3, nb_groups, &Allowed::All).unwrap();
58 assert_eq!(id, TerminalId(0));
59 assert_eq!(end, 18);
60 assert_eq!(text2.chars().nth(end).unwrap(), '#');
61 }
62
63 #[test]
64 fn escaped() {
65 let escaped = vec![
66 (
67 r"\w",
68 vec![
69 ("a", true),
70 ("A", true),
71 ("0", true),
72 ("_", true),
73 ("%", false),
74 ("'", false),
75 ],
76 ),
77 (r"a\b", vec![("a", true), ("ab", false)]),
78 (
79 r".\b.",
80 vec![("a ", true), (" a", true), (" ", false), ("aa", false)],
81 ),
82 ];
83 for (regex, tests) in escaped {
84 let (program, _) = compile(regex, TerminalId(0)).unwrap();
85 for (string, result) in tests {
86 assert_eq!(find(&program, string, 0, &Allowed::All).is_some(), result);
87 }
88 }
89 }
90
91 #[test]
92 fn greedy() {
93 let (program, nb_groups) = compile("(a+)(a+)", TerminalId(0)).unwrap();
94 let Match {
95 char_pos: end,
96 id: idx,
97 groups: results,
98 } = find(&program, "aaaa", nb_groups, &Allowed::All).unwrap();
99 assert_eq!(end, 4);
100 assert_eq!(idx, TerminalId(0));
101 assert_eq!(results, vec![Some(0), Some(3), Some(3), Some(4)]);
102 }
103
104 #[test]
105 fn partial() {
106 let (program, nb_groups) = compile("a+", TerminalId(0)).unwrap();
107 let Match {
108 char_pos: end,
109 id: idx,
110 groups: results,
111 } = find(&program, "aaabcd", nb_groups, &Allowed::All).unwrap();
112 assert_eq!(end, 3);
113 assert_eq!(idx, TerminalId(0));
114 assert_eq!(results, Vec::new());
115 }
116}
117
118newty! {
119 pub id InstructionPointer
120 impl {
121 pub fn incr(&self) -> Self {
122 Self(self.0+1)
123 }
124 }
125}
126
127#[cfg_attr(test, derive(PartialEq))]
154#[derive(Debug, Serialize, Deserialize)]
155pub enum Instruction {
156 Switch(Vec<(TerminalId, InstructionPointer)>),
157 Save(usize),
158 Split(InstructionPointer, InstructionPointer),
159 Char(char),
160 Jump(InstructionPointer),
161 Match(TerminalId),
162 WordChar,
163 Digit,
164 WordBoundary,
165 Whitespace,
166 CharacterClass(IntervalTree<char>, bool),
167 EOF,
168 Any,
169}
170
171#[derive(Debug)]
175pub enum Allowed {
176 All,
177 Some(AllowedTerminals),
178}
179
180impl Allowed {
181 pub fn contains(&self, i: TerminalId) -> bool {
182 match self {
183 Allowed::All => true,
184 Allowed::Some(allowed) => allowed.contains(i),
185 }
186 }
187}
188
189pub struct Match {
201 pub char_pos: usize,
202 pub id: TerminalId,
203 pub groups: Vec<Option<usize>>,
204}
205
206newty! {
212 #[derive(Serialize, Deserialize)]
213 #[cfg_attr(test, derive(PartialEq))]
214 pub vec Program (Instruction) [InstructionPointer]
215 impl {
216 pub fn len_ip(&self) -> InstructionPointer {
217 InstructionPointer(self.len())
218 }
219 }
220}
221
222newty! {
223 pub slice ProgramSlice (Instruction) [InstructionPointer]
224 of Program
225}
226
227newty! {
228 set DoneThreads [InstructionPointer]
229}
230
231newty! {
232 pub set AllowedTerminals [TerminalId]
233}
234
235struct ThreadList {
248 done: DoneThreads,
249 threads: Vec<Thread>,
250}
251
252impl ThreadList {
253 fn new(size: usize) -> Self {
255 Self {
256 done: DoneThreads::with_raw_capacity(size),
257 threads: Vec::new(),
258 }
259 }
260
261 fn add(&mut self, thread: Thread) {
263 let pos = thread.instruction();
264 if !self.done.contains(pos) {
265 self.done.insert(pos);
266 self.threads.push(thread);
267 }
268 }
269
270 fn get(&mut self) -> Option<Thread> {
272 self.threads.pop()
273 }
274
275 fn from(threads: Vec<Thread>, size: usize) -> Self {
277 let mut thread_list = Self::new(size);
278 for thread in threads.into_iter() {
279 thread_list.add(thread);
280 }
281 thread_list
282 }
283}
284
285#[derive(Clone, Debug)]
304struct Thread {
305 instruction: InstructionPointer,
306 groups: Vec<Option<usize>>,
307}
308
309impl Thread {
310 pub fn new(instruction: InstructionPointer, size: usize) -> Self {
313 Self {
314 instruction,
315 groups: vec![None; 2 * size],
316 }
317 }
318
319 fn instruction(&self) -> InstructionPointer {
321 self.instruction
322 }
323
324 fn jump(&mut self, pos: InstructionPointer) {
326 self.instruction = pos;
327 }
328
329 fn save(&mut self, idx: usize, bytes_pos: usize) {
331 self.groups[idx] = Some(bytes_pos);
332 }
333}
334
335#[allow(clippy::too_many_arguments)]
337fn match_next(
338 chr: char,
339 bytes_pos: usize,
340 chars_pos: usize,
341 mut thread: Thread,
342 current: &mut ThreadList,
343 next: Option<&mut ThreadList>,
344 prog: &ProgramSlice,
345 best_match: &mut Option<Match>,
346 last: Option<char>,
347 allowed: &Allowed,
348) {
349 fn is_word_char(chr: char) -> bool {
352 chr.is_alphanumeric() || chr == '_'
353 }
354
355 fn is_digit(chr: char) -> bool {
358 chr.is_ascii_digit()
359 }
360
361 fn is_whitespace(chr: char) -> bool {
364 chr == ' ' || chr == '\t'
365 }
366
367 fn advance(mut thread: Thread, thread_list: Option<&mut ThreadList>) {
370 thread.jump(thread.instruction().incr());
371 if let Some(next) = thread_list {
372 next.add(thread);
373 }
374 }
375
376 match &prog[thread.instruction()] {
377 Instruction::Char(expected) => {
378 if *expected == chr {
379 advance(thread, next);
380 }
381 }
382 Instruction::Any => advance(thread, next),
383 Instruction::WordChar => {
384 if is_word_char(chr) {
385 advance(thread, next);
386 }
387 }
388 Instruction::Digit => {
389 if is_digit(chr) {
390 advance(thread, next);
391 }
392 }
393 Instruction::Whitespace => {
394 if is_whitespace(chr) {
395 advance(thread, next);
396 }
397 }
398 Instruction::Jump(pos) => {
399 thread.jump(*pos);
400 current.add(thread);
401 }
402 Instruction::Save(idx) => {
403 thread.save(*idx, bytes_pos);
404 advance(thread, Some(current));
405 }
406 Instruction::Switch(instructions) => {
407 instructions
408 .iter()
409 .rev()
410 .filter(|(id, _)| allowed.contains(*id))
411 .for_each(|(_, ip)| {
412 let mut new = thread.clone();
413 new.jump(*ip);
414 current.add(new);
415 });
416 }
417 Instruction::Split(pos1, pos2) => {
418 let mut other = thread.clone();
419 other.jump(*pos2);
420 thread.jump(*pos1);
421 current.add(other);
422 current.add(thread);
423 }
424 Instruction::Match(id) => {
425 if let Some(Match {
426 char_pos: p,
427 id: prior,
428 ..
429 }) = best_match
430 {
431 if chars_pos > *p || *prior > *id {
432 *best_match = Some(Match {
433 char_pos: chars_pos,
434 id: *id,
435 groups: thread.groups,
436 });
437 }
438 } else {
439 *best_match = Some(Match {
440 char_pos: chars_pos,
441 id: *id,
442 groups: thread.groups,
443 });
444 }
445 }
446 Instruction::CharacterClass(class, negated) => {
447 if negated ^ class.contains_point(&chr) {
448 advance(thread, next);
449 }
450 }
451 Instruction::WordBoundary => {
452 if let Some(last) = last {
453 if is_word_char(last) ^ is_word_char(chr) {
454 advance(thread, Some(current));
455 }
456 } else {
457 advance(thread, Some(current));
458 }
459 }
460 Instruction::EOF => {
461 if next.is_none() {
462 advance(thread, Some(current));
463 }
464 }
465 }
466}
467
468pub fn find(prog: &ProgramSlice, input: &str, size: usize, allowed: &Allowed) -> Option<Match> {
470 let mut current =
471 ThreadList::from(vec![Thread::new(InstructionPointer(0), size)], prog.len());
472 let mut best_match = None;
473 let mut last = None;
474 let mut bytes_pos = 0;
475 for (chars_pos, chr) in input.chars().enumerate() {
476 let mut next = ThreadList::new(prog.len());
477 while let Some(thread) = current.get() {
478 match_next(
479 chr,
480 bytes_pos,
481 chars_pos,
482 thread,
483 &mut current,
484 Some(&mut next),
485 prog,
486 &mut best_match,
487 last,
488 allowed,
489 );
490 }
491 current = next;
492 last = Some(chr);
493 bytes_pos += chr.len_utf8();
494 }
495 let chars_pos = input.len();
496 while let Some(thread) = current.get() {
497 match_next(
498 '#',
499 bytes_pos,
500 chars_pos,
501 thread,
502 &mut current,
503 None,
504 prog,
505 &mut best_match,
506 last,
507 allowed,
508 );
509 }
510
511 best_match
512}