use crate::lexer::TerminalId;
use newty::newty;
use serde::{Deserialize, Serialize};
use unbounded_interval_tree::interval_tree::IntervalTree;
#[cfg(test)]
mod tests {
use super::super::parsing::tests::compile;
use super::*;
#[test]
fn groups() {
let (program, nb_groups) = compile("(a+)(b+)", TerminalId(0)).unwrap();
let Match {
char_pos: end,
id: idx,
groups: results,
} = find(&program, "aabbb", nb_groups, &Allowed::All).unwrap();
assert_eq!(idx, TerminalId(0));
assert_eq!(end, 5);
assert_eq!(results, vec![Some(0), Some(2), Some(2), Some(5)]);
}
#[test]
fn chars() {
let (program, nb_groups) = compile("ab", TerminalId(0)).unwrap();
let Match {
char_pos: end,
id: idx,
groups: results,
} = find(&program, "abb", nb_groups, &Allowed::All).unwrap();
assert_eq!(idx, TerminalId(0));
assert_eq!(end, 2);
assert_eq!(results, vec![]);
}
#[test]
fn multiline_comments() {
let (program, nb_groups) = compile(r"/\*([^*]|\*[^/])*\*/", TerminalId(0)).unwrap();
let text1 = "/* hello, world */#and other stuff";
let text2 = "/* hello,\nworld */#and other stuff";
let text3 = "/* unicode éèàç */#and other stuff";
let Match {
char_pos: end, id, ..
} = find(&program, text1, nb_groups, &Allowed::All).unwrap();
assert_eq!(id, TerminalId(0));
assert_eq!(end, 18);
assert_eq!(text1.chars().nth(end).unwrap(), '#');
let Match {
char_pos: end, id, ..
} = find(&program, text2, nb_groups, &Allowed::All).unwrap();
assert_eq!(id, TerminalId(0));
assert_eq!(end, 18);
assert_eq!(text2.chars().nth(end).unwrap(), '#');
let Match {
char_pos: end, id, ..
} = find(&program, text3, nb_groups, &Allowed::All).unwrap();
assert_eq!(id, TerminalId(0));
assert_eq!(end, 18);
assert_eq!(text2.chars().nth(end).unwrap(), '#');
}
#[test]
fn escaped() {
let escaped = vec![
(
r"\w",
vec![
("a", true),
("A", true),
("0", true),
("_", true),
("%", false),
("'", false),
],
),
(r"a\b", vec![("a", true), ("ab", false)]),
(
r".\b.",
vec![("a ", true), (" a", true), (" ", false), ("aa", false)],
),
];
for (regex, tests) in escaped {
let (program, _) = compile(regex, TerminalId(0)).unwrap();
for (string, result) in tests {
assert_eq!(find(&program, string, 0, &Allowed::All).is_some(), result);
}
}
}
#[test]
fn greedy() {
let (program, nb_groups) = compile("(a+)(a+)", TerminalId(0)).unwrap();
let Match {
char_pos: end,
id: idx,
groups: results,
} = find(&program, "aaaa", nb_groups, &Allowed::All).unwrap();
assert_eq!(end, 4);
assert_eq!(idx, TerminalId(0));
assert_eq!(results, vec![Some(0), Some(3), Some(3), Some(4)]);
}
#[test]
fn partial() {
let (program, nb_groups) = compile("a+", TerminalId(0)).unwrap();
let Match {
char_pos: end,
id: idx,
groups: results,
} = find(&program, "aaabcd", nb_groups, &Allowed::All).unwrap();
assert_eq!(end, 3);
assert_eq!(idx, TerminalId(0));
assert_eq!(results, Vec::new());
}
}
newty! {
pub id InstructionPointer
impl {
pub fn incr(&self) -> Self {
Self(self.0+1)
}
}
}
#[cfg_attr(test, derive(PartialEq))]
#[derive(Debug, Serialize, Deserialize)]
pub enum Instruction {
Switch(Vec<(TerminalId, InstructionPointer)>),
Save(usize),
Split(InstructionPointer, InstructionPointer),
Char(char),
Jump(InstructionPointer),
Match(TerminalId),
WordChar,
Digit,
WordBoundary,
Whitespace,
CharacterClass(IntervalTree<char>, bool),
EOF,
Any,
}
#[derive(Debug)]
pub enum Allowed {
All,
Some(AllowedTerminals),
}
impl Allowed {
pub fn contains(&self, i: TerminalId) -> bool {
match self {
Allowed::All => true,
Allowed::Some(allowed) => allowed.contains(i),
}
}
}
pub struct Match {
pub char_pos: usize,
pub id: TerminalId,
pub groups: Vec<Option<usize>>,
}
newty! {
#[derive(Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq))]
pub vec Program (Instruction) [InstructionPointer]
impl {
pub fn len_ip(&self) -> InstructionPointer {
InstructionPointer(self.len())
}
}
}
newty! {
pub slice ProgramSlice (Instruction) [InstructionPointer]
of Program
}
newty! {
set DoneThreads [InstructionPointer]
}
newty! {
pub set AllowedTerminals [TerminalId]
}
struct ThreadList {
done: DoneThreads,
threads: Vec<Thread>,
}
impl ThreadList {
fn new(size: usize) -> Self {
Self {
done: DoneThreads::with_raw_capacity(size),
threads: Vec::new(),
}
}
fn add(&mut self, thread: Thread) {
let pos = thread.instruction();
if !self.done.contains(pos) {
self.done.insert(pos);
self.threads.push(thread);
}
}
fn get(&mut self) -> Option<Thread> {
self.threads.pop()
}
fn from(threads: Vec<Thread>, size: usize) -> Self {
let mut thread_list = Self::new(size);
for thread in threads.into_iter() {
thread_list.add(thread);
}
thread_list
}
}
#[derive(Clone, Debug)]
struct Thread {
instruction: InstructionPointer,
groups: Vec<Option<usize>>,
}
impl Thread {
pub fn new(instruction: InstructionPointer, size: usize) -> Self {
Self {
instruction,
groups: vec![None; 2 * size],
}
}
fn instruction(&self) -> InstructionPointer {
self.instruction
}
fn jump(&mut self, pos: InstructionPointer) {
self.instruction = pos;
}
fn save(&mut self, idx: usize, bytes_pos: usize) {
self.groups[idx] = Some(bytes_pos);
}
}
#[allow(clippy::too_many_arguments)]
fn match_next(
chr: char,
bytes_pos: usize,
chars_pos: usize,
mut thread: Thread,
current: &mut ThreadList,
next: Option<&mut ThreadList>,
prog: &ProgramSlice,
best_match: &mut Option<Match>,
last: Option<char>,
allowed: &Allowed,
) {
fn is_word_char(chr: char) -> bool {
chr.is_alphanumeric() || chr == '_'
}
fn is_digit(chr: char) -> bool {
chr.is_ascii_digit()
}
fn is_whitespace(chr: char) -> bool {
chr == ' ' || chr == '\t'
}
fn advance(mut thread: Thread, thread_list: Option<&mut ThreadList>) {
thread.jump(thread.instruction().incr());
if let Some(next) = thread_list {
next.add(thread);
}
}
match &prog[thread.instruction()] {
Instruction::Char(expected) => {
if *expected == chr {
advance(thread, next);
}
}
Instruction::Any => advance(thread, next),
Instruction::WordChar => {
if is_word_char(chr) {
advance(thread, next);
}
}
Instruction::Digit => {
if is_digit(chr) {
advance(thread, next);
}
}
Instruction::Whitespace => {
if is_whitespace(chr) {
advance(thread, next);
}
}
Instruction::Jump(pos) => {
thread.jump(*pos);
current.add(thread);
}
Instruction::Save(idx) => {
thread.save(*idx, bytes_pos);
advance(thread, Some(current));
}
Instruction::Switch(instructions) => {
instructions
.iter()
.rev()
.filter(|(id, _)| allowed.contains(*id))
.for_each(|(_, ip)| {
let mut new = thread.clone();
new.jump(*ip);
current.add(new);
});
}
Instruction::Split(pos1, pos2) => {
let mut other = thread.clone();
other.jump(*pos2);
thread.jump(*pos1);
current.add(other);
current.add(thread);
}
Instruction::Match(id) => {
if let Some(Match {
char_pos: p,
id: prior,
..
}) = best_match
{
if chars_pos > *p || *prior > *id {
*best_match = Some(Match {
char_pos: chars_pos,
id: *id,
groups: thread.groups,
});
}
} else {
*best_match = Some(Match {
char_pos: chars_pos,
id: *id,
groups: thread.groups,
});
}
}
Instruction::CharacterClass(class, negated) => {
if negated ^ class.contains_point(&chr) {
advance(thread, next);
}
}
Instruction::WordBoundary => {
if let Some(last) = last {
if is_word_char(last) ^ is_word_char(chr) {
advance(thread, Some(current));
}
} else {
advance(thread, Some(current));
}
}
Instruction::EOF => {
if next.is_none() {
advance(thread, Some(current));
}
}
}
}
pub fn find(prog: &ProgramSlice, input: &str, size: usize, allowed: &Allowed) -> Option<Match> {
let mut current =
ThreadList::from(vec![Thread::new(InstructionPointer(0), size)], prog.len());
let mut best_match = None;
let mut last = None;
let mut bytes_pos = 0;
for (chars_pos, chr) in input.chars().enumerate() {
let mut next = ThreadList::new(prog.len());
while let Some(thread) = current.get() {
match_next(
chr,
bytes_pos,
chars_pos,
thread,
&mut current,
Some(&mut next),
prog,
&mut best_match,
last,
allowed,
);
}
current = next;
last = Some(chr);
bytes_pos += chr.len_utf8();
}
let chars_pos = input.len();
while let Some(thread) = current.get() {
match_next(
'#',
bytes_pos,
chars_pos,
thread,
&mut current,
None,
prog,
&mut best_match,
last,
allowed,
);
}
best_match
}