#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct BytePred([u64; 4]);
impl BytePred {
#[must_use]
pub const fn none() -> Self {
BytePred([0; 4])
}
#[must_use]
pub const fn any() -> Self {
BytePred([u64::MAX; 4])
}
#[must_use]
pub fn byte(b: u8) -> Self {
let mut p = BytePred::none();
p.insert(b);
p
}
#[must_use]
pub fn range(lo: u8, hi: u8) -> Self {
let mut p = BytePred::none();
for b in lo..=hi {
p.insert(b);
}
p
}
pub fn insert(&mut self, b: u8) {
self.0[(b >> 6) as usize] |= 1u64 << (b & 63);
}
#[must_use]
pub fn contains(&self, b: u8) -> bool {
self.0[(b >> 6) as usize] & (1u64 << (b & 63)) != 0
}
#[must_use]
pub fn or(self, o: Self) -> Self {
BytePred(std::array::from_fn(|i| self.0[i] | o.0[i]))
}
#[must_use]
pub fn and(self, o: Self) -> Self {
BytePred(std::array::from_fn(|i| self.0[i] & o.0[i]))
}
#[must_use]
pub fn minus(self, o: Self) -> Self {
self.and(!o)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0 == [0; 4]
}
#[must_use]
pub fn witness(&self) -> Option<u8> {
for (i, &w) in self.0.iter().enumerate() {
if w != 0 {
return Some((i as u8) * 64 + w.trailing_zeros() as u8);
}
}
None
}
#[must_use]
pub fn count(&self) -> u32 {
self.0.iter().map(|w| w.count_ones()).sum()
}
#[must_use]
pub fn to_hex(self) -> String {
let mut s = String::with_capacity(64);
for w in self.0 {
s.push_str(&format!("{w:016x}"));
}
s
}
#[must_use]
pub fn from_hex(h: &str) -> Option<Self> {
if h.len() != 64 {
return None;
}
let mut p = [0u64; 4];
for (i, word) in p.iter_mut().enumerate() {
*word = u64::from_str_radix(&h[i * 16..i * 16 + 16], 16).ok()?;
}
Some(BytePred(p))
}
}
impl std::ops::Not for BytePred {
type Output = BytePred;
fn not(self) -> BytePred {
BytePred(std::array::from_fn(|i| !self.0[i]))
}
}
pub type StateId = usize;
pub type DeltaTable = Vec<Vec<(BytePred, StateId)>>;
#[derive(Debug, Clone)]
pub struct Sfa {
start: StateId,
accept: Vec<bool>,
delta: Vec<Vec<(BytePred, StateId)>>,
}
impl Sfa {
#[must_use]
pub fn new(start: StateId, accept: Vec<bool>, delta: Vec<Vec<(BytePred, StateId)>>) -> Self {
assert_eq!(accept.len(), delta.len(), "accept/delta arity mismatch");
assert!(start < accept.len(), "start state out of range");
for (s, trans) in delta.iter().enumerate() {
let mut cover = BytePred::none();
for (g, t) in trans {
assert!(
*t < accept.len(),
"state {s}: transition target out of range"
);
assert!(
cover.and(*g).is_empty(),
"state {s}: overlapping guards (non-deterministic)"
);
cover = cover.or(*g);
}
assert_eq!(cover, BytePred::any(), "state {s}: guards are not total");
}
Sfa {
start,
accept,
delta,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.accept.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.accept.is_empty()
}
#[must_use]
pub fn export(&self) -> (StateId, Vec<bool>, DeltaTable) {
(self.start, self.accept.clone(), self.delta.clone())
}
#[must_use]
pub fn import(start: StateId, accept: Vec<bool>, delta: DeltaTable) -> Self {
Sfa::new(start, accept, delta)
}
#[must_use]
pub fn start_state(&self) -> StateId {
self.start
}
#[must_use]
pub fn is_accepting(&self, s: StateId) -> bool {
self.accept[s]
}
#[must_use]
pub fn transitions(&self, s: StateId) -> &[(BytePred, StateId)] {
&self.delta[s]
}
#[must_use]
pub fn step_byte(&self, s: StateId, b: u8) -> StateId {
self.step(s, b)
}
fn step(&self, s: StateId, b: u8) -> StateId {
for (g, t) in &self.delta[s] {
if g.contains(b) {
return *t;
}
}
panic!("SFA totality invariant broken: no transition for byte {b} in state {s}");
}
#[must_use]
pub fn accepts(&self, word: &[u8]) -> bool {
let mut s = self.start;
for &b in word {
s = self.step(s, b);
}
self.accept[s]
}
#[must_use]
pub fn complement(&self) -> Sfa {
Sfa {
start: self.start,
accept: self.accept.iter().map(|a| !a).collect(),
delta: self.delta.clone(),
}
}
fn product(&self, o: &Sfa, accept: impl Fn(bool, bool) -> bool) -> Sfa {
use std::collections::HashMap;
let mut id: HashMap<(StateId, StateId), StateId> = HashMap::new();
let mut work = vec![(self.start, o.start)];
id.insert((self.start, o.start), 0);
let mut acc = vec![accept(self.accept[self.start], o.accept[o.start])];
let mut delta: Vec<Vec<(BytePred, StateId)>> = vec![Vec::new()];
let mut i = 0;
while i < work.len() {
let (a, b) = work[i];
let src = i;
for (ga, ta) in &self.delta[a] {
for (gb, tb) in &o.delta[b] {
let g = ga.and(*gb);
if g.is_empty() {
continue;
}
let key = (*ta, *tb);
let next = *id.entry(key).or_insert_with(|| {
work.push(key);
acc.push(accept(self.accept[*ta], o.accept[*tb]));
delta.push(Vec::new());
work.len() - 1
});
delta[src].push((g, next));
}
}
i += 1;
}
Sfa::new(0, acc, delta)
}
#[must_use]
pub fn intersect(&self, o: &Sfa) -> Sfa {
self.product(o, |a, b| a && b)
}
#[must_use]
pub fn union(&self, o: &Sfa) -> Sfa {
self.product(o, |a, b| a || b)
}
#[must_use]
pub fn difference(&self, o: &Sfa) -> Sfa {
self.product(o, |a, b| a && !b)
}
#[must_use]
pub fn shortest_accepted(&self) -> Option<Vec<u8>> {
use std::collections::VecDeque;
let mut seen = vec![false; self.len()];
let mut q = VecDeque::new();
q.push_back((self.start, Vec::new()));
seen[self.start] = true;
while let Some((s, w)) = q.pop_front() {
if self.accept[s] {
return Some(w);
}
let mut edges: Vec<(u8, StateId)> = self.delta[s]
.iter()
.filter_map(|(g, t)| g.witness().map(|b| (b, *t)))
.collect();
edges.sort_unstable();
for (b, t) in edges {
if !seen[t] {
seen[t] = true;
let mut nw = w.clone();
nw.push(b);
q.push_back((t, nw));
}
}
}
None
}
#[must_use]
pub fn is_language_empty(&self) -> bool {
self.shortest_accepted().is_none()
}
#[must_use]
pub fn enumerate_accepted(&self, max_words: usize, max_len: usize) -> Vec<Vec<u8>> {
use std::collections::VecDeque;
let mut out = Vec::new();
if max_words == 0 {
return out;
}
let mut q = VecDeque::from([(self.start, Vec::<u8>::new())]);
while let Some((s, w)) = q.pop_front() {
if self.accept[s] {
out.push(w.clone());
if out.len() == max_words {
return out;
}
}
if w.len() >= max_len {
continue;
}
if q.len() >= Self::ENUMERATE_QUEUE_CAP {
continue;
}
let mut edges: Vec<(u8, StateId)> = self.delta[s]
.iter()
.filter_map(|(g, t)| g.witness().map(|b| (b, *t)))
.collect();
edges.sort_unstable();
for (b, t) in edges {
let mut nw = w.clone();
nw.push(b);
q.push_back((t, nw));
}
}
out
}
const ENUMERATE_QUEUE_CAP: usize = 1_000_000;
#[must_use]
pub fn distinguishing_word(&self, o: &Sfa) -> Option<Vec<u8>> {
let sym = self.difference(o).union(&o.difference(self));
sym.shortest_accepted()
}
#[must_use]
pub fn equivalent(&self, o: &Sfa) -> bool {
self.distinguishing_word(o).is_none()
}
#[must_use]
pub fn minimize(&self) -> Sfa {
let mut reach = vec![false; self.len()];
let mut stk = vec![self.start];
reach[self.start] = true;
while let Some(s) = stk.pop() {
for (_, t) in &self.delta[s] {
if !reach[*t] {
reach[*t] = true;
stk.push(*t);
}
}
}
let mut minterms = vec![BytePred::any()];
for trans in &self.delta {
for (g, _) in trans {
let mut next = Vec::with_capacity(minterms.len());
for m in &minterms {
let yes = m.and(*g);
let no = m.and(!*g);
if !yes.is_empty() {
next.push(yes);
}
if !no.is_empty() {
next.push(no);
}
}
minterms = next;
}
}
let n = self.len();
let step_mt: Vec<Vec<StateId>> = (0..n)
.map(|s| {
minterms
.iter()
.map(|m| {
let b = m.witness().expect("non-empty minterm");
self.step(s, b)
})
.collect()
})
.collect();
let mut class: Vec<usize> = (0..n).map(|s| usize::from(self.accept[s])).collect();
loop {
let mut sig: std::collections::HashMap<Vec<usize>, usize> =
std::collections::HashMap::new();
let mut next = vec![0usize; n];
for s in 0..n {
if !reach[s] {
continue;
}
let mut key = vec![class[s]];
key.extend(step_mt[s].iter().map(|&t| class[t]));
let id = sig.len();
next[s] = *sig.entry(key).or_insert(id);
}
if next == class {
break;
}
class = next;
}
let mut rep: std::collections::HashMap<usize, StateId> = std::collections::HashMap::new();
let start_c = class[self.start];
rep.insert(start_c, 0);
let mut order = vec![start_c];
for s in 0..n {
if reach[s] {
let c = class[s];
if let std::collections::hash_map::Entry::Vacant(e) = rep.entry(c) {
e.insert(order.len());
order.push(c);
}
}
}
let mut witness_state: std::collections::HashMap<usize, StateId> =
std::collections::HashMap::new();
for s in 0..n {
if reach[s] {
witness_state.entry(class[s]).or_insert(s);
}
}
let mut accept = vec![false; order.len()];
let mut delta: Vec<Vec<(BytePred, StateId)>> = vec![Vec::new(); order.len()];
for (&c, &idx) in &rep {
let ws = witness_state[&c];
accept[idx] = self.accept[ws];
let mut by_dst: std::collections::HashMap<StateId, BytePred> =
std::collections::HashMap::new();
for (mi, m) in minterms.iter().enumerate() {
let dst_class = class[step_mt[ws][mi]];
let e = by_dst.entry(rep[&dst_class]).or_insert(BytePred::none());
*e = e.or(*m);
}
delta[idx] = by_dst.into_iter().map(|(t, g)| (g, t)).collect();
}
Sfa::new(0, accept, delta)
}
}