use std::ascii::escape_default;
use std::collections::HashSet;
use std::fmt;
use std::{
collections::{hash_map::Entry, HashMap},
ops::RangeInclusive,
};
use dfa_util::{get_states, iter_matches, OwnedDFA};
use regex_automata::{
dfa::{dense::DFA, Automaton, StartKind},
nfa::thompson::NFA,
util::primitives::StateID,
Anchored, MatchKind,
};
use crate::leaf::{Leaf, LeafId};
mod dfa_util;
mod export;
#[derive(Debug)]
pub struct Config {
pub utf8_mode: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum GraphError {
NoUniversalStart,
EmptyMatch(LeafId),
Disambiguation(Vec<LeafId>),
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct StateType {
pub accept: Option<LeafId>,
pub early: Option<LeafId>,
}
impl StateType {
fn early_or_accept(&self) -> Option<LeafId> {
self.early.or(self.accept)
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct State(usize);
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "state{}", self.0)
}
}
impl State {
pub fn pascal_case(&self) -> String {
format!("State{}", self.0)
}
pub fn snake_case(&self) -> String {
format!("{self}")
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct StateData {
pub state_type: StateType,
pub normal: Vec<(ByteClass, State)>,
pub eoi: Option<State>,
pub backward: Vec<State>,
}
impl StateData {
fn new() -> Self {
Default::default()
}
fn iter_children<'a>(&'a self) -> impl Iterator<Item = State> + 'a {
self.normal
.iter()
.map(|(_bc, s)| *s)
.chain(self.eoi.iter().cloned())
}
fn add_back_edge(&mut self, state: State) {
if let Err(index) = self.backward.binary_search(&state) {
self.backward.insert(index, state);
}
}
fn set_normal_edges(&mut self, edges: HashMap<State, ByteClass>) {
self.normal = edges.into_iter().map(|(s, bc)| (bc, s)).collect();
self.normal.sort_unstable_by_key(|(_bc, s)| *s);
}
fn can_error(&self) -> bool {
let mut covered_ranges = self
.normal
.iter()
.flat_map(|(bc, _s)| bc.ranges.iter().cloned())
.collect::<Vec<_>>();
covered_ranges.sort_unstable_by_key(|r| *r.start());
if !covered_ranges
.first()
.map(|bc| *bc.start() == 0)
.unwrap_or(false)
{
return true;
}
if !covered_ranges
.last()
.map(|bc| *bc.end() == 255)
.unwrap_or(false)
{
return true;
}
for pair in covered_ranges.windows(2) {
let first = &pair[0];
let second = &pair[1];
if *first.end() + 1 < *second.start() {
return true;
}
}
false
}
}
impl fmt::Display for StateData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "StateData(")?;
if let Some(leaf_id) = self.state_type.accept {
write!(f, "accept({}) ", leaf_id.0)?
}
if let Some(leaf_id) = self.state_type.early {
write!(f, "early({}) ", leaf_id.0)?
}
write!(f, ")")?;
if f.alternate() {
writeln!(f, " {{")?;
for (bc, state) in &self.normal {
writeln!(f, " {} => {}", &bc.to_string(), state)?;
}
if let Some(eoi_state) = &self.eoi {
writeln!(f, " EOI => {eoi_state}")?;
}
write!(f, "}}")?;
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ByteClass {
pub ranges: Vec<RangeInclusive<u8>>,
}
impl ByteClass {
fn new() -> Self {
ByteClass { ranges: Vec::new() }
}
fn add_byte(&mut self, byte: u8) {
if let Some(last) = self.ranges.last_mut() {
if last.end() + 1 == byte {
*last = *last.start()..=byte;
return;
}
}
self.ranges.push(byte..=byte);
}
pub fn to_table(&self) -> [bool; 256] {
let mut table_bits = [false; 256];
for range in self.ranges.iter() {
for byte in range.clone() {
table_bits[byte as usize] = true;
}
}
table_bits
}
pub fn merge(&mut self, other: &ByteClass) {
let my_table = self.to_table();
let other_table = other.to_table();
self.ranges.clear();
for (byte, (mine, theirs)) in my_table.into_iter().zip(other_table).enumerate() {
if mine || theirs {
self.add_byte(byte as u8);
}
}
}
pub fn impl_with_cmp(&self) -> Vec<Comparisons> {
let mut ranges: Vec<Comparisons> = Vec::new();
for next_range in &self.ranges {
if let Some(Comparisons { range, except }) = ranges.last_mut() {
if *next_range.start() == *range.end() + 2 {
*range = *range.start()..=*next_range.end();
except.push(*next_range.start() - 1);
continue;
}
}
ranges.push(Comparisons::new(next_range.clone()));
}
ranges
}
}
impl fmt::Display for ByteClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (idx, range) in self.ranges.iter().enumerate() {
if range.start() == range.end() {
write!(f, "{}", escape_default(*range.start()))?;
} else {
write!(
f,
"{}..={}",
escape_default(*range.start()),
escape_default(*range.end())
)?;
}
if idx < self.ranges.len() - 1 {
if f.alternate() {
writeln!(f)?;
} else {
write!(f, "|")?;
}
}
}
Ok(())
}
}
pub struct Comparisons {
pub range: RangeInclusive<u8>,
pub except: Vec<u8>,
}
impl Comparisons {
pub fn new(range: RangeInclusive<u8>) -> Self {
Comparisons {
range,
except: Vec::new(),
}
}
pub fn count_ops(&self) -> usize {
(if *self.range.start() == *self.range.end() {
1
} else {
let mut edges = 0;
if *self.range.start() > u8::MIN {
edges += 1
}
if *self.range.end() < u8::MAX {
edges += 1
}
edges
}) + self.except.len() }
}
#[derive(Debug)]
pub struct Graph {
leaves: Vec<Leaf>,
dfa: OwnedDFA,
states: Vec<StateData>,
root: State,
errors: Vec<GraphError>,
}
impl Graph {
pub fn root(&self) -> State {
self.root
}
pub fn iter_states(&self) -> impl Iterator<Item = State> {
(0..self.states.len()).map(State)
}
pub fn get_state(&self, state: State) -> &StateData {
&self.states[state.0]
}
pub fn leaves(&self) -> &Vec<Leaf> {
&self.leaves
}
pub fn dfa(&self) -> &OwnedDFA {
&self.dfa
}
pub fn errors<'b>(&'b self) -> impl Iterator<Item = &'b GraphError> + 'b {
self.errors.iter()
}
pub fn new(leaves: Vec<Leaf>, config: Config) -> Result<Self, String> {
let hirs = leaves
.iter()
.map(|leaf| leaf.pattern.hir())
.collect::<Vec<_>>();
let nfa_config = NFA::config().shrink(true).utf8(config.utf8_mode);
let nfa = NFA::compiler()
.configure(nfa_config)
.build_many_from_hir(&hirs)
.map_err(|err| {
format!("Logos encountered an error compiling the NFA for this regex: {err}")
})?;
let dfa_config = DFA::config()
.accelerate(false)
.byte_classes(!cfg!(feature = "debug"))
.minimize(false)
.match_kind(MatchKind::All)
.start_kind(StartKind::Anchored);
let dfa = DFA::builder()
.configure(dfa_config)
.build_from_nfa(&nfa)
.map_err(|err| {
format!("Logos encountered an error compiling the DFA for this regex: {err}")
})?;
let mut graph = Graph {
leaves,
dfa,
states: Vec::new(),
root: State(0),
errors: Vec::new(),
};
let Some(start_id) = graph.dfa.universal_start_state(Anchored::Yes) else {
graph.errors.push(GraphError::NoUniversalStart);
return Ok(graph);
};
if graph.dfa.has_empty() {
for (leaf_id, leaf) in graph.leaves.iter().enumerate() {
if leaf.pattern.hir().properties().minimum_len() == Some(0) {
graph.errors.push(GraphError::EmptyMatch(LeafId(leaf_id)));
}
}
return Ok(graph);
}
let dfa_lookup = get_states(&graph.dfa, start_id)
.enumerate()
.map(|(idx, dfa_id)| (dfa_id, State(idx)))
.collect::<HashMap<StateID, State>>();
graph.root = dfa_lookup[&start_id];
graph.states = vec![StateData::new(); dfa_lookup.len()];
for (dfa_id, state_id) in dfa_lookup.iter() {
let dfa_id = *dfa_id;
let state_data = &mut graph.states[state_id.0];
match Self::get_state_type(dfa_id, &graph.leaves, &graph.dfa) {
Ok(state_type) => state_data.state_type = state_type,
Err(ambiguous_leaves) => graph
.errors
.push(GraphError::Disambiguation(ambiguous_leaves)),
}
let mut result: HashMap<State, ByteClass> = HashMap::new();
for input_byte in u8::MIN..=u8::MAX {
let next_id = graph.dfa.next_state(dfa_id, input_byte);
if next_id.as_usize() == 0 {
continue;
}
let next_state = dfa_lookup[&next_id];
result
.entry(next_state)
.or_insert(ByteClass::new())
.add_byte(input_byte);
}
state_data.set_normal_edges(result);
let eoi_id = graph.dfa.next_eoi_state(dfa_id);
state_data.eoi = if eoi_id.as_usize() == 0 {
None
} else {
Some(dfa_lookup[&eoi_id])
};
for child in state_data.iter_children().collect::<Vec<_>>() {
graph.states[child.0].add_back_edge(*state_id);
}
}
graph.errors.sort_unstable();
for state in graph.iter_states() {
let state_data = graph.get_state(state);
if !state_data.can_error() {
let child_state_types = state_data
.iter_children()
.map(|child_state| {
let child_state_data = graph.get_state(child_state);
child_state_data.state_type.accept
})
.collect::<HashSet<_>>();
let child_state_types_vec = child_state_types.into_iter().collect::<Vec<_>>();
if let &[Some(leaf_id)] = &*child_state_types_vec {
graph.states[state.0].state_type.early = Some(leaf_id);
}
}
}
for state in graph.iter_states() {
let state_data = graph.get_state(state);
if let Some(leaf_id) = state_data.state_type.accept {
if state_data.backward.iter().all(|&back_state| {
graph.get_state(back_state).state_type.early == Some(leaf_id)
}) {
graph.states[state.0].state_type.accept = None;
}
}
}
let mut visit_stack = graph
.iter_states()
.filter(|state| {
graph
.get_state(*state)
.state_type
.early_or_accept()
.is_some()
})
.collect::<Vec<_>>();
visit_stack.push(graph.root);
let mut reach_accept = visit_stack.iter().cloned().collect::<HashSet<_>>();
while let Some(state) = visit_stack.pop() {
for parent in &graph.get_state(state).backward {
if reach_accept.insert(*parent) {
visit_stack.push(*parent);
}
}
}
for state in graph.iter_states() {
let state_data = &mut graph.states[state.0];
state_data
.normal
.retain(|(_bc, next_state)| reach_accept.contains(next_state));
state_data.eoi = state_data.eoi.filter(|state| reach_accept.contains(state));
state_data.backward.clear();
}
graph.retain_states(&reach_accept, true);
loop {
let graph_size = graph.states.len();
let mut state_indexes = HashMap::new();
let mut state_lookup = HashMap::new();
for state in graph.iter_states() {
let state_data = &graph.states[state.0];
if let Entry::Vacant(e) = state_indexes.entry(state_data) {
e.insert(state);
} else {
state_lookup.insert(state, state_indexes[&state_data]);
}
}
graph.rewrite_states(&state_lookup);
graph.retain_states(&state_lookup.keys().cloned().collect(), false);
if graph.states.len() == graph_size {
break;
}
}
Ok(graph)
}
fn get_state_type(
state_id: StateID,
leaves: &[Leaf],
dfa: &OwnedDFA,
) -> Result<StateType, Vec<LeafId>> {
let matching_leaves = iter_matches(state_id, dfa)
.map(|leaf_id| (leaf_id, leaves[leaf_id.0].priority))
.collect::<Vec<_>>();
if let Some(&(highest_leaf_id, highest_priority)) = matching_leaves
.iter()
.max_by_key(|(_leaf_id, priority)| priority)
{
let matching_prio_leaves: Vec<LeafId> = matching_leaves
.into_iter()
.filter(|(_leaf_id, priority)| *priority == highest_priority)
.map(|(leaf_id, _priority)| leaf_id)
.collect();
if matching_prio_leaves.len() > 1 {
return Err(matching_prio_leaves);
}
Ok(StateType {
accept: Some(highest_leaf_id),
early: None,
})
} else {
Ok(StateType::default())
}
}
fn retain_states(&mut self, states: &HashSet<State>, keep: bool) {
let rewrite_map: HashMap<State, State> = self
.iter_states()
.filter(|state| states.contains(state) == keep)
.enumerate()
.map(|(new_idx, old_state)| (old_state, State(new_idx)))
.collect();
let mut index = 0;
self.states.retain(|_state_data| {
let retain = states.contains(&State(index)) == keep;
index += 1;
retain
});
self.rewrite_states(&rewrite_map);
}
fn rewrite_states(&mut self, rewrites: &HashMap<State, State>) {
for state in self.iter_states() {
let state_data = &mut self.states[state.0];
let mut edge_dedup = HashMap::<State, ByteClass>::new();
for (bc, next_state) in std::mem::take(&mut state_data.normal) {
let next_state = *rewrites.get(&next_state).unwrap_or(&next_state);
match edge_dedup.entry(next_state) {
Entry::Occupied(mut entry) => {
entry.get_mut().merge(&bc);
}
Entry::Vacant(entry) => {
entry.insert(bc);
}
}
}
state_data.set_normal_edges(edge_dedup);
if let Some(eoi_state) = &mut state_data.eoi {
if let Some(new_eoi_state) = rewrites.get(eoi_state) {
*eoi_state = *new_eoi_state;
}
}
}
if let Some(new_root) = rewrites.get(&self.root) {
self.root = *new_root;
}
}
}
impl fmt::Display for Graph {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let graph_rendered = self
.iter_states()
.map(|state| {
let transitions = format!("{:#}", self.get_state(state));
let indented = transitions
.lines()
.enumerate()
.map(|(idx, line)| format!("{}{line}", if idx > 0 { " " } else { "" }))
.collect::<Vec<_>>()
.join("\n");
format!(" {state} => {indented}")
})
.collect::<Vec<_>>()
.join("\n");
f.write_str(&graph_rendered)
}
}