use super::Regex;
use crate::regex::nfa::NFA;
use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
type StateId = usize;
type Symbol = char;
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
pub struct DFA {
states: Vec<DFAState>,
start: StateId,
accept_states: HashSet<StateId>,
}
#[derive(Debug, Clone)]
pub struct DFAState {
transitions: HashMap<Symbol, StateId>,
}
impl DFAState {
pub fn new() -> Self {
DFAState {
transitions: HashMap::new(),
}
}
pub fn add_transition(&mut self, symbol: Symbol, state: StateId) {
self.transitions.insert(symbol, state);
}
pub fn get(&self, symbol: char) -> Option<StateId> {
self.transitions.get(&symbol).copied()
}
pub fn nexts(&self) -> Vec<StateId> {
self.transitions.values().cloned().collect()
}
}
impl DFA {
pub fn new() -> Self {
DFA {
states: Vec::new(),
start: 0,
accept_states: HashSet::new(),
}
}
pub fn add_state(&mut self) -> StateId {
let id = self.states.len();
self.states.push(DFAState::new());
id
}
pub fn add_transition(&mut self, from: StateId, symbol: Symbol, to: StateId) {
self.states[from].add_transition(symbol, to);
}
pub fn from(nfa: NFA) -> Self {
let mut dfa = DFA::new();
let nfa_start_closure = nfa.epsilon_closure(vec![nfa.start]);
let closure_set: BTreeSet<_> = nfa_start_closure.into_iter().collect();
let dfa_start = dfa.add_state();
dfa.start = dfa_start;
if closure_set.contains(&nfa.accept) {
dfa.accept_states.insert(dfa_start);
}
let mut unmarked_states: VecDeque<BTreeSet<StateId>> = VecDeque::new();
unmarked_states.push_back(closure_set.clone());
let mut state_map: HashMap<BTreeSet<StateId>, StateId> = HashMap::new();
state_map.insert(closure_set.clone(), dfa_start);
while let Some(current_set) = unmarked_states.pop_front() {
let alphabet = extract_alphabet(&nfa);
for &symbol in &alphabet {
let mut move_set: BTreeSet<StateId> = BTreeSet::new();
for &state_id in ¤t_set {
for (trans_symbol, to_state) in nfa.states[state_id].transitions.clone() {
if trans_symbol == symbol {
move_set.insert(to_state);
}
}
}
let mut closure_set: BTreeSet<StateId> = BTreeSet::new();
for &state in &move_set {
let state_closure = nfa.epsilon_closure(vec![state]);
closure_set.extend(state_closure);
}
if !closure_set.is_empty() {
if let Some(&existing_state) = state_map.get(&closure_set) {
let current_dfa_state = state_map[¤t_set];
dfa.add_transition(current_dfa_state, symbol, existing_state);
} else {
let new_dfa_state = dfa.add_state();
state_map.insert(closure_set.clone(), new_dfa_state);
let current_dfa_state = state_map[¤t_set];
dfa.add_transition(current_dfa_state, symbol, new_dfa_state);
if closure_set.contains(&nfa.accept) {
dfa.accept_states.insert(new_dfa_state);
}
unmarked_states.push_back(closure_set);
}
}
}
}
dfa
}
pub fn from_regex(regex: Regex) -> Self {
let nfa = NFA::from(regex);
Self::from(nfa)
}
pub fn product<F>(&self, other: &DFA, accept_fn: F) -> DFA
where
F: Fn(bool, bool) -> bool,
{
let mut state_map = HashMap::new();
let mut queue = VecDeque::new();
let mut product_states = Vec::new();
let mut transitions = Vec::new();
let start_pair = (self.start, other.start);
state_map.insert(start_pair, 0);
queue.push_back(start_pair);
product_states.push(start_pair);
transitions.push(HashMap::new());
let alphabet = self
.extract_dfa_alphabet()
.union(&other.extract_dfa_alphabet())
.copied()
.collect::<HashSet<_>>();
while let Some((s1, s2)) = queue.pop_front() {
let current_id = state_map[&(s1, s2)];
for symbol in &alphabet {
if let (Some(&n1), Some(&n2)) = (
self.states[s1].transitions.get(symbol),
other.states[s2].transitions.get(symbol),
) {
let next_pair = (n1, n2);
let next_id = *state_map.entry(next_pair).or_insert_with(|| {
let id = product_states.len();
product_states.push(next_pair);
transitions.push(HashMap::new());
queue.push_back(next_pair);
id
});
transitions[current_id].insert(*symbol, next_id);
}
}
}
let accept_states = product_states
.iter()
.enumerate()
.filter(|(_, (s1, s2))| {
accept_fn(
self.accept_states.contains(s1),
other.accept_states.contains(s2),
)
})
.map(|(id, _)| id)
.collect();
DFA {
states: transitions
.into_iter()
.map(|t| DFAState { transitions: t })
.collect(),
start: 0,
accept_states,
}
}
pub fn complement(&self) -> DFA {
let accept_states = (0..self.states.len())
.filter(|id| !self.accept_states.contains(id))
.collect();
DFA {
states: self.states.clone(),
start: self.start,
accept_states,
}
}
fn extract_dfa_alphabet(&self) -> HashSet<Symbol> {
self.states
.iter()
.flat_map(|state| state.transitions.keys().copied())
.collect()
}
pub fn run(&self, input: &str) -> Option<StateId> {
let mut state = self.start;
for c in input.chars() {
if let Some(next_state) = self.states[state].get(c) {
state = next_state;
} else {
return None;
}
}
Some(state)
}
pub fn accepts(&self, input: &str) -> bool {
self.run(input)
.is_some_and(|state| self.accept_states.contains(&state))
}
pub fn derive(&self, input: &str) -> Option<DFA> {
self.run(input).map(|state| DFA {
states: self.states.clone(),
start: state,
accept_states: self.accept_states.clone(),
})
}
pub fn clean(&self) -> DFA {
let mut reachable = HashSet::new();
let mut stack = vec![self.start];
while let Some(state) = stack.pop() {
reachable.insert(state);
for &next_state in &self.states[state].nexts() {
if !reachable.contains(&next_state) {
stack.push(next_state);
}
}
}
let mut state_mapping = HashMap::new();
let mut new_states = Vec::new();
for (old_id, state) in self.states.iter().enumerate() {
if reachable.contains(&old_id) {
let new_id = new_states.len();
state_mapping.insert(old_id, new_id);
new_states.push(state.clone());
}
}
let mut cleaned_states = Vec::new();
for (old_id, state) in self.states.iter().enumerate() {
if reachable.contains(&old_id) {
let mut new_state = DFAState::new();
for (symbol, &old_target) in &state.transitions {
if let Some(&new_target) = state_mapping.get(&old_target) {
new_state.add_transition(*symbol, new_target);
}
}
cleaned_states.push(new_state);
}
}
let new_accept_states = self
.accept_states
.iter()
.filter_map(|&old_id| state_mapping.get(&old_id).copied())
.collect();
let new_start = state_mapping[&self.start];
DFA {
states: cleaned_states,
start: new_start,
accept_states: new_accept_states,
}
}
pub fn is_accepting(&self) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(self.start);
while let Some(state) = queue.pop_front() {
if self.accept_states.contains(&state) {
return true;
}
if visited.contains(&state) {
continue;
}
visited.insert(state);
if let Some(dfa_state) = self.states.get(state) {
for &next_state in dfa_state.transitions.values() {
if !visited.contains(&next_state) {
queue.push_back(next_state);
}
}
}
}
false
}
}
use std::ops::{BitAnd, BitOr, BitXor, Not, Sub};
impl BitAnd for &DFA {
type Output = DFA;
fn bitand(self, other: &DFA) -> DFA {
self.product(other, |a, b| a && b)
}
}
impl BitOr for &DFA {
type Output = DFA;
fn bitor(self, other: &DFA) -> DFA {
self.product(other, |a, b| a || b)
}
}
impl Sub for &DFA {
type Output = DFA;
fn sub(self, other: &DFA) -> DFA {
self.product(other, |a, b| a && !b)
}
}
impl BitXor for &DFA {
type Output = DFA;
fn bitxor(self, other: &DFA) -> DFA {
self.product(other, |a, b| a ^ b)
}
}
impl Not for &DFA {
type Output = DFA;
fn not(self) -> DFA {
self.complement()
}
}
fn extract_alphabet(nfa: &NFA) -> Vec<Symbol> {
let mut alphabet: HashSet<Symbol> = HashSet::new();
for state in &nfa.states {
for (symbol, _) in &state.transitions {
alphabet.insert(*symbol);
}
}
let mut sorted_alphabet: Vec<Symbol> = alphabet.into_iter().collect();
sorted_alphabet.sort_unstable();
sorted_alphabet
}
use std::fmt;
impl fmt::Display for DFA {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "DFA with {} states", self.states.len())?;
writeln!(
f,
"Start: {}, Accept States: {:?}",
self.start, self.accept_states
)?;
writeln!(f)?;
for (id, state) in self.states.iter().enumerate() {
write!(f, "State {}", id)?;
if id == self.start {
write!(f, " (START)")?;
}
if self.accept_states.contains(&id) {
write!(f, " (ACCEPT)")?;
}
writeln!(f)?;
if state.transitions.is_empty() {
writeln!(f, " (no transitions)")?;
} else {
for (symbol, to) in &state.transitions {
writeln!(f, " └─'{}'───> ({})", { *symbol }, to)?;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::regex::Regex;
#[test]
fn test_dfa_from_nfa() {
let regex = Regex::from_str("(a|b)bb*|y").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
println!("{}", dfa);
}
#[test]
fn test_dfa_derive() {
let regex = Regex::from_str("(a|b)bb*|y").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
let derived_dfa = dfa.derive("ab");
assert!(derived_dfa.is_some());
let cleaned_dfa = derived_dfa.unwrap().clean();
println!("{}", cleaned_dfa);
}
#[test]
fn test_dfa_clean_basic() {
let regex = Regex::from_str("a*b").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
let cleaned_dfa = dfa.clean();
assert_eq!(dfa.accepts("b"), cleaned_dfa.accepts("b"));
assert_eq!(dfa.accepts("ab"), cleaned_dfa.accepts("ab"));
assert_eq!(dfa.accepts("aab"), cleaned_dfa.accepts("aab"));
assert_eq!(dfa.accepts("aaab"), cleaned_dfa.accepts("aaab"));
assert_eq!(dfa.accepts("a"), cleaned_dfa.accepts("a"));
assert_eq!(dfa.accepts(""), cleaned_dfa.accepts(""));
}
#[test]
fn test_dfa_clean_with_unreachable_states() {
let mut dfa = DFA::new();
let s0 = dfa.add_state();
let s1 = dfa.add_state();
let s2 = dfa.add_state();
let _s3 = dfa.add_state();
dfa.start = s0;
dfa.accept_states.insert(s2);
dfa.add_transition(s0, 'a', s1);
dfa.add_transition(s1, 'b', s2);
let cleaned_dfa = dfa.clean();
assert!(cleaned_dfa.states.len() < dfa.states.len());
assert_eq!(dfa.accepts("ab"), cleaned_dfa.accepts("ab"));
assert_eq!(dfa.accepts("a"), cleaned_dfa.accepts("a"));
assert_eq!(dfa.accepts("b"), cleaned_dfa.accepts("b"));
assert_eq!(dfa.accepts(""), cleaned_dfa.accepts(""));
}
#[test]
fn test_dfa_clean_preserves_acceptance() {
let regex = Regex::from_str("(a|b)*abb").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
let cleaned_dfa = dfa.clean();
let test_cases = vec![
("abb", true),
("aabb", true),
("babb", true),
("ababb", true),
("ab", false),
("abbc", false),
("", false),
("a", false),
("b", false),
];
for (input, expected) in test_cases {
assert_eq!(dfa.accepts(input), expected, "Failed for input: {}", input);
assert_eq!(
cleaned_dfa.accepts(input),
expected,
"Failed for input: {}",
input
);
assert_eq!(
dfa.accepts(input),
cleaned_dfa.accepts(input),
"Cleaning changed acceptance for: {}",
input
);
}
}
#[test]
fn test_dfa_clean_idempotent() {
let regex = Regex::from_str("a(b|c)*d").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
let cleaned_once = dfa.clean();
let cleaned_twice = cleaned_once.clean();
assert_eq!(cleaned_once.states.len(), cleaned_twice.states.len());
assert_eq!(cleaned_once.start, cleaned_twice.start);
assert_eq!(cleaned_once.accept_states, cleaned_twice.accept_states);
let test_cases = vec!["ad", "abd", "acd", "abcd", "accbd", "abcbd"];
for input in test_cases {
assert_eq!(cleaned_once.accepts(input), cleaned_twice.accepts(input));
}
}
#[test]
fn test_dfa_operations_intersection() {
let regex1 = Regex::from_str("a*b*").unwrap();
let regex2 = Regex::from_str("ab*").unwrap();
let nfa1 = NFA::from(regex1);
let nfa2 = NFA::from(regex2);
let dfa1 = DFA::from(nfa1);
let dfa2 = DFA::from(nfa2);
let intersection_dfa = &dfa1 & &dfa2;
assert!(intersection_dfa.accepts("a"));
assert!(intersection_dfa.accepts("ab"));
assert!(intersection_dfa.accepts("abb"));
assert!(intersection_dfa.accepts("abbb"));
assert!(!intersection_dfa.accepts("b"));
assert!(!intersection_dfa.accepts("ba"));
assert!(!intersection_dfa.accepts(""));
}
#[test]
fn test_dfa_operations_union() {
let regex1 = Regex::from_str("a*").unwrap();
let regex2 = Regex::from_str("a*").unwrap();
let nfa1 = NFA::from(regex1);
let nfa2 = NFA::from(regex2);
let dfa1 = DFA::from(nfa1);
let dfa2 = DFA::from(nfa2);
let union_dfa = &dfa1 | &dfa2;
assert!(union_dfa.accepts(""));
assert!(union_dfa.accepts("a"));
assert!(union_dfa.accepts("aa"));
assert!(union_dfa.accepts("aaa"));
}
#[test]
fn test_dfa_operations_difference() {
let regex1 = Regex::from_str("a*").unwrap();
let regex2 = Regex::from_str("aa*").unwrap();
let nfa1 = NFA::from(regex1);
let nfa2 = NFA::from(regex2);
let dfa1 = DFA::from(nfa1);
let dfa2 = DFA::from(nfa2);
let difference_dfa = &dfa1 - &dfa2;
assert!(difference_dfa.accepts(""));
assert!(!difference_dfa.accepts("a"));
assert!(!difference_dfa.accepts("aa"));
assert!(!difference_dfa.accepts("aaa"));
}
#[test]
fn test_dfa_operations_symmetric_difference() {
let regex1 = Regex::from_str("a*").unwrap();
let regex2 = Regex::from_str("aa*").unwrap();
let nfa1 = NFA::from(regex1);
let nfa2 = NFA::from(regex2);
let dfa1 = DFA::from(nfa1);
let dfa2 = DFA::from(nfa2);
let xor_dfa = &dfa1 ^ &dfa2;
assert!(xor_dfa.accepts("")); assert!(!xor_dfa.accepts("a")); }
#[test]
fn test_dfa_operations_complement() {
let regex = Regex::from_str("a*").unwrap();
let nfa = NFA::from(regex);
let dfa = DFA::from(nfa);
let complement_dfa = !&dfa;
assert!(!complement_dfa.accepts("")); assert!(!complement_dfa.accepts("a")); assert!(!complement_dfa.accepts("aa")); }
#[test]
fn test_dfa_operations_complex() {
let regex1 = Regex::from_str("a*b*").unwrap();
let regex2 = Regex::from_str("ab*").unwrap();
let regex3 = Regex::from_str("b*").unwrap();
let nfa1 = NFA::from(regex1);
let nfa2 = NFA::from(regex2);
let nfa3 = NFA::from(regex3);
let dfa1 = DFA::from(nfa1);
let dfa2 = DFA::from(nfa2);
let _dfa3 = DFA::from(nfa3);
let intersection_dfa = &dfa1 & &dfa2;
assert!(!intersection_dfa.accepts("")); assert!(intersection_dfa.accepts("a")); assert!(!intersection_dfa.accepts("b")); assert!(intersection_dfa.accepts("ab")); assert!(!intersection_dfa.accepts("bb")); assert!(intersection_dfa.accepts("abb"));
assert!(!intersection_dfa.accepts("ba"));
assert!(!intersection_dfa.accepts("bab"));
}
#[test]
fn test_dfa_is_accepting() {
let mut dfa = DFA::new();
let s0 = dfa.add_state();
let s1 = dfa.add_state();
dfa.start = s0;
dfa.accept_states.insert(s1);
dfa.add_transition(s0, 'a', s1);
assert!(dfa.is_accepting());
let mut dfa_no_accept = DFA::new();
let s0 = dfa_no_accept.add_state();
dfa_no_accept.start = s0;
assert!(!dfa_no_accept.is_accepting());
}
}