vyre-std 0.1.0

Vyre standard library: GPU DFA assembly pipeline, Aho-Corasick construction, and compositional arithmetic helpers
Documentation
//! Hopcroft's DFA minimization algorithm.
//!
//! Partition refinement yields the unique minimal DFA accepting the
//! same language. The output is canonical up to state-id relabeling,
//! so `minimize(minimize(d)) == minimize(d)` holds bytewise after both
//! runs relabel starting from the start state.
//!
//! Per perf-audit L.2.2 (`audits/perf-kimi-AUDIT.md` finding 1): the
//! prior implementation was the textbook Hopcroft shape, but computed
//! the predecessor set `x` for every (splitter, byte) pair by scanning
//! **all** states — O(|Σ|·N²) per refinement step. For 10k literal
//! patterns that's ~30 billion state inspections and 632 s assembly
//! time; `regex-automata` does the same work in <50 ms.
//!
//! This implementation:
//!
//! - Builds an inverse transition table
//!   `pred: Vec<Vec<(byte, state)>>` once so the predecessor set for
//!   `(splitter, byte)` is O(|preds|), not O(N).
//! - Keeps partitions as **dense arrays** indexed by class id rather
//!   than `BTreeSet<u32>` clones. Each state has a `class_of[state]`
//!   entry; each class stores the backing `Vec<u32>` of its members
//!   and a `len` snapshot used by `refine_in_place`.
//! - Uses a `VecDeque<u32>` worklist of class ids, not a
//!   `VecDeque<BTreeSet<u32>>` of copied sets.
//!
//! The asymptotic bound is now O(|Σ|·N·log N), which matches the
//! published Hopcroft cost model.

use std::collections::VecDeque;

use super::types::{Dfa, INVALID_STATE};

/// Minimize a DFA via partition refinement (Hopcroft's algorithm).
///
/// The returned DFA is the unique minimal equivalent; state ids are
/// renumbered so that the start state is `0`.
///
/// # Examples
///
/// ```
/// use vyre_std::pattern::{regex_to_nfa::regex_to_nfa, nfa_to_dfa::nfa_to_dfa, dfa_minimize::dfa_minimize};
///
/// let nfa = regex_to_nfa("a|aa").unwrap();
/// let dfa = nfa_to_dfa(&nfa).unwrap();
/// let minimized = dfa_minimize(&dfa);
/// assert!(minimized.state_count <= dfa.state_count);
/// ```
#[must_use]
#[inline]
pub fn dfa_minimize(dfa: &Dfa) -> Dfa {
    let state_count = dfa.state_count as usize;
    if state_count == 0 {
        return dfa.clone();
    }

    // Implicit dead state at index `state_count` so every byte from
    // every state has a defined target. The dead class is pruned at
    // the end.
    let dead = state_count as u32;
    let total_states = state_count + 1;

    // Forward transition table with the dead-state row appended so we
    // can walk every (state, byte) without a conditional. Redirects
    // INVALID_STATE transitions into the dead state.
    let mut trans = vec![dead; total_states * 256];
    for s in 0..state_count {
        for b in 0..256usize {
            let t = dfa.transitions[s * 256 + b];
            trans[s * 256 + b] = if t == INVALID_STATE { dead } else { t };
        }
    }
    // Dead state self-loops on every byte by construction (already `dead`).

    // Inverse transition table: for each byte `b`, pred[b] is a flat
    // Vec<u32> concatenating the predecessors of every state, with
    // `pred_slice[b][state]` giving the (start, len) window. The
    // "predecessors of s under byte b" are the states `p` with
    // `trans[p * 256 + b] == s`.
    // Flat pred_head[byte * total_states + state] = (start, len)
    // layout, avoiding the Vec<Vec<..>> 2D allocation that a naive
    // inverse transition table would require.
    let mut pred_head: Vec<(u32, u32)> = vec![(0, 0); 256 * total_states];
    let mut pred_flat: Vec<u32>;
    {
        let mut count = vec![0u32; 256 * total_states];
        for p in 0..total_states as u32 {
            for b in 0..256usize {
                let t = trans[p as usize * 256 + b];
                count[b * total_states + t as usize] += 1;
            }
        }
        let mut cursor = 0u32;
        for i in 0..(256 * total_states) {
            pred_head[i] = (cursor, count[i]);
            cursor += count[i];
        }
        pred_flat = vec![0u32; cursor as usize];
        let mut fill = vec![0u32; 256 * total_states];
        for p in 0..total_states as u32 {
            for b in 0..256usize {
                let t = trans[p as usize * 256 + b];
                let slot = b * total_states + t as usize;
                let off = pred_head[slot].0 + fill[slot];
                pred_flat[off as usize] = p;
                fill[slot] += 1;
            }
        }
    }
    // Partition membership. `class_of[s]` is the current class id of
    // state `s`. Each class owns a `Vec<u32>` of its members.
    let mut accept_class: Vec<u32> = Vec::with_capacity(state_count);
    let mut non_accept_class: Vec<u32> = Vec::with_capacity(total_states);
    for s in 0..state_count {
        if dfa.accept[s] {
            accept_class.push(s as u32);
        } else {
            non_accept_class.push(s as u32);
        }
    }
    non_accept_class.push(dead);

    let mut classes: Vec<Vec<u32>> = Vec::with_capacity(total_states);
    let mut class_of: Vec<u32> = vec![0; total_states];
    if !accept_class.is_empty() {
        let id = classes.len() as u32;
        for &s in &accept_class {
            class_of[s as usize] = id;
        }
        classes.push(accept_class);
    }
    if !non_accept_class.is_empty() {
        let id = classes.len() as u32;
        for &s in &non_accept_class {
            class_of[s as usize] = id;
        }
        classes.push(non_accept_class);
    }

    let mut worklist: VecDeque<u32> = VecDeque::new();
    let mut on_worklist = vec![false; classes.len()];
    // Smaller initial class in the worklist.
    if classes.len() == 2 {
        let smaller = if classes[0].len() <= classes[1].len() {
            0
        } else {
            1
        };
        worklist.push_back(smaller as u32);
        on_worklist[smaller] = true;
    } else if classes.len() == 1 {
        worklist.push_back(0);
        on_worklist[0] = true;
    }

    // Scratch buffers reused across splitter iterations to avoid
    // per-(splitter, byte) allocations.
    let mut hit = vec![false; classes.len()];
    let mut x_set = vec![false; total_states];
    let mut intersect_buf: Vec<u32> = Vec::with_capacity(total_states);
    let mut x_list: Vec<u32> = Vec::with_capacity(total_states);

    while let Some(splitter_id) = worklist.pop_front() {
        if (splitter_id as usize) >= on_worklist.len() {
            continue;
        }
        on_worklist[splitter_id as usize] = false;
        for byte in 0u16..256 {
            let b_idx = byte as usize;
            // Compute x = { p | trans[p, byte] ∈ splitter } via
            // inverse-transition lookup instead of scanning all states.
            for &s in &classes[splitter_id as usize] {
                let (off, len) = pred_head[b_idx * total_states + s as usize];
                for p in &pred_flat[off as usize..(off + len) as usize] {
                    if !x_set[*p as usize] {
                        x_set[*p as usize] = true;
                    }
                }
            }
            x_list.clear();
            x_list.extend((0..total_states as u32).filter(|p| x_set[*p as usize]));

            // For every class that has a non-empty, non-full
            // intersection with x, split it in place.
            hit.resize(classes.len(), false);
            hit.iter_mut().for_each(|h| *h = false);
            for &p in &x_list {
                let c = class_of[p as usize] as usize;
                if !hit[c] {
                    hit[c] = true;
                }
            }

            for c in 0..classes.len() {
                if !hit[c] {
                    continue;
                }
                intersect_buf.clear();
                classes[c].retain(|&s| {
                    if x_set[s as usize] {
                        intersect_buf.push(s);
                        false
                    } else {
                        true
                    }
                });
                if intersect_buf.is_empty() || classes[c].is_empty() {
                    // Either no split or the class moved wholesale
                    // into intersect_buf; restore and continue.
                    if classes[c].is_empty() {
                        classes[c] = std::mem::take(&mut intersect_buf);
                    } else {
                        classes[c].append(&mut intersect_buf);
                    }
                    continue;
                }

                let new_id = classes.len() as u32;
                let intersect = std::mem::take(&mut intersect_buf);
                for &s in &intersect {
                    class_of[s as usize] = new_id;
                }
                classes.push(intersect);
                on_worklist.push(false);
                hit.push(false);

                // Worklist update: if the original class was already
                // queued, add the new class too (both halves must
                // eventually act as splitters). Otherwise enqueue the
                // smaller half only — the classic Hopcroft
                // optimization that gives the N·log N bound.
                if on_worklist[c] {
                    worklist.push_back(new_id);
                    on_worklist[new_id as usize] = true;
                } else {
                    let enqueue = if classes[c].len() <= classes[new_id as usize].len() {
                        c as u32
                    } else {
                        new_id
                    };
                    worklist.push_back(enqueue);
                    on_worklist[enqueue as usize] = true;
                }
            }

            // Reset x_set for the next byte. Scanning x_list is cheap
            // compared to a full total_states clear.
            for &p in &x_list {
                x_set[p as usize] = false;
            }
        }
    }

    // Compute canonical class ids: 0 for the start class, then BFS
    // order. The dead class is pruned.
    let dead_class = class_of[dead as usize];
    let start_class = class_of[dfa.start as usize];

    let mut canonical: Vec<i64> = vec![-1; classes.len()];
    canonical[start_class as usize] = 0;
    let mut queue: VecDeque<u32> = VecDeque::new();
    queue.push_back(start_class);
    let mut next_id: u32 = 1;
    while let Some(class_id) = queue.pop_front() {
        let representative = match classes[class_id as usize].iter().find(|s| **s != dead) {
            Some(&s) => s,
            None => continue,
        };
        for byte in 0u8..=255 {
            let target = trans[representative as usize * 256 + byte as usize];
            if target == dead {
                continue;
            }
            let target_class = class_of[target as usize];
            if canonical[target_class as usize] < 0 {
                canonical[target_class as usize] = i64::from(next_id);
                queue.push_back(target_class);
                next_id += 1;
            }
        }
    }

    let final_state_count = next_id;
    let mut transitions = vec![INVALID_STATE; (final_state_count as usize) * 256];
    let mut accept = vec![false; final_state_count as usize];

    for (class_id, &canon) in canonical.iter().enumerate() {
        if canon < 0 {
            continue;
        }
        let canon = canon as u32;
        let representative = match classes[class_id].iter().find(|s| **s != dead) {
            Some(&s) => s,
            None => continue,
        };
        if dfa.accept[representative as usize] {
            accept[canon as usize] = true;
        }
        for byte in 0u8..=255 {
            let target = trans[representative as usize * 256 + byte as usize];
            if target == dead {
                continue;
            }
            let target_class = class_of[target as usize];
            if target_class == dead_class {
                continue;
            }
            transitions[(canon as usize) * 256 + byte as usize] =
                canonical[target_class as usize] as u32;
        }
    }

    Dfa {
        state_count: final_state_count,
        transitions,
        start: 0,
        accept,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::pattern::{nfa_to_dfa::nfa_to_dfa, regex_to_nfa::regex_to_nfa};

    fn run(dfa: &Dfa, input: &[u8]) -> bool {
        let mut state = dfa.start;
        for &b in input {
            let next = dfa.go(state, b);
            if next == INVALID_STATE {
                return false;
            }
            state = next;
        }
        dfa.accept[state as usize]
    }

    #[test]
    fn minimize_preserves_language_literal() {
        let nfa = regex_to_nfa("hello").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let min = dfa_minimize(&dfa);
        assert!(run(&min, b"hello"));
        assert!(!run(&min, b"world"));
        assert!(min.state_count <= dfa.state_count);
    }

    #[test]
    fn minimize_is_idempotent() {
        let nfa = regex_to_nfa("a|aa|aaa").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let once = dfa_minimize(&dfa);
        let twice = dfa_minimize(&once);
        assert_eq!(once, twice, "Hopcroft output must be canonical");
    }

    #[test]
    fn minimize_alternation() {
        let nfa = regex_to_nfa("foo|bar|baz").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let min = dfa_minimize(&dfa);
        assert!(run(&min, b"foo"));
        assert!(run(&min, b"bar"));
        assert!(run(&min, b"baz"));
        assert!(!run(&min, b"qux"));
    }

    #[test]
    fn minimize_collapses_equivalent_states() {
        let nfa = regex_to_nfa("(a|b)(c|d)").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let min = dfa_minimize(&dfa);
        assert!(run(&min, b"ac"));
        assert!(run(&min, b"bd"));
        assert!(!run(&min, b"ab"));
        assert!(min.state_count <= dfa.state_count);
    }

    #[test]
    fn minimize_empty_language() {
        let nfa = regex_to_nfa("").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let min = dfa_minimize(&dfa);
        assert!(run(&min, b""));
        assert!(!run(&min, b"x"));
    }

    #[test]
    fn minimize_kleene_star() {
        let nfa = regex_to_nfa("a*").unwrap();
        let dfa = nfa_to_dfa(&nfa).unwrap();
        let min = dfa_minimize(&dfa);
        assert!(run(&min, b""));
        assert!(run(&min, b"aaaaa"));
        assert!(!run(&min, b"b"));
    }
}