vyre-std 0.1.0

Vyre standard library: GPU DFA assembly pipeline, Aho-Corasick construction, and compositional arithmetic helpers
Documentation
//! Transition-table compression for GPU dispatch.
//!
//! The pack formats trade memory footprint for scan speed. Dense is the
//! fastest but largest; EquivClass collapses redundant byte columns to
//! shrink the table when the effective alphabet is small.

use super::types::{Dfa, DfaPackFormat, PackedDfa, INVALID_STATE};

/// Pack a DFA into GPU-uploadable bytes using the selected format.
///
/// # Examples
///
/// ```
/// use vyre_std::pattern::{regex_to_nfa::regex_to_nfa, nfa_to_dfa::nfa_to_dfa, dfa_minimize::dfa_minimize, dfa_pack::dfa_pack, types::DfaPackFormat};
///
/// let nfa = regex_to_nfa("foo|bar").unwrap();
/// let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
/// let packed = dfa_pack(&dfa, DfaPackFormat::Dense);
/// assert!(!packed.bytes.is_empty());
/// ```
#[must_use]
#[inline]
pub fn dfa_pack(dfa: &Dfa, format: DfaPackFormat) -> PackedDfa {
    match format {
        DfaPackFormat::Dense => pack_dense(dfa),
        DfaPackFormat::EquivClass => pack_equiv_class(dfa),
    }
}

fn pack_dense(dfa: &Dfa) -> PackedDfa {
    // Layout: [format tag: u32][state_count: u32][start: u32][accept bitmap]
    //         [transitions: state_count × 256 × u32]
    let mut bytes = Vec::new();
    bytes.extend_from_slice(&0u32.to_le_bytes()); // format tag 0 = Dense
    bytes.extend_from_slice(&dfa.state_count.to_le_bytes());
    bytes.extend_from_slice(&dfa.start.to_le_bytes());
    write_accept(&mut bytes, &dfa.accept);
    for &t in &dfa.transitions {
        let word: u32 = t;
        bytes.extend_from_slice(&word.to_le_bytes());
    }
    PackedDfa {
        format: DfaPackFormat::Dense,
        state_count: dfa.state_count,
        start: dfa.start,
        bytes,
    }
}

fn pack_equiv_class(dfa: &Dfa) -> PackedDfa {
    // Build byte → class table by column-equivalence.
    let state_count = dfa.state_count as usize;
    let mut columns: Vec<Vec<u32>> = Vec::with_capacity(256);
    for byte in 0u8..=255 {
        let col: Vec<u32> = (0..state_count)
            .map(|s| dfa.transitions[s * 256 + byte as usize])
            .collect();
        columns.push(col);
    }
    let mut classes: Vec<u8> = Vec::with_capacity(256);
    let mut class_representatives: Vec<Vec<u32>> = Vec::new();
    for col in &columns {
        let mut found = None;
        for (idx, rep) in class_representatives.iter().enumerate() {
            if rep == col {
                found = Some(idx);
                break;
            }
        }
        match found {
            Some(idx) => classes.push(idx as u8),
            None => {
                classes.push(class_representatives.len() as u8);
                class_representatives.push(col.clone());
            }
        }
    }
    let num_classes = class_representatives.len() as u32;

    // Layout: [format tag: u32 = 1][state_count: u32][start: u32][num_classes: u32]
    //         [class table: 256 × u8 padded to u32]
    //         [accept bitmap]
    //         [transitions: state_count × num_classes × u32]
    let mut bytes = Vec::new();
    bytes.extend_from_slice(&1u32.to_le_bytes()); // format tag 1 = EquivClass
    bytes.extend_from_slice(&dfa.state_count.to_le_bytes());
    bytes.extend_from_slice(&dfa.start.to_le_bytes());
    bytes.extend_from_slice(&num_classes.to_le_bytes());
    for &c in &classes {
        bytes.push(c);
    }
    // Pad class table to 4-byte alignment.
    while bytes.len() % 4 != 0 {
        bytes.push(0);
    }
    write_accept(&mut bytes, &dfa.accept);
    // class_representatives is indexed [class][state]; the iteration
    // order (state outer, class inner) transposes it into the packed
    // layout. The needless_range_loop lint would prefer iterator
    // access but the transpose semantics are clearer as indexed.
    #[allow(clippy::needless_range_loop)]
    for state in 0..state_count {
        for class in 0..num_classes as usize {
            bytes.extend_from_slice(&class_representatives[class][state].to_le_bytes());
        }
    }
    PackedDfa {
        format: DfaPackFormat::EquivClass,
        state_count: dfa.state_count,
        start: dfa.start,
        bytes,
    }
}

fn write_accept(bytes: &mut Vec<u8>, accept: &[bool]) {
    let words = accept.len().div_ceil(32);
    bytes.extend_from_slice(&(words as u32).to_le_bytes());
    let mut word: u32 = 0;
    let mut bit = 0;
    for &a in accept {
        if a {
            word |= 1 << bit;
        }
        bit += 1;
        if bit == 32 {
            bytes.extend_from_slice(&word.to_le_bytes());
            word = 0;
            bit = 0;
        }
    }
    if bit != 0 {
        bytes.extend_from_slice(&word.to_le_bytes());
    }
}

/// Unpack a [`PackedDfa`] back into a [`Dfa`]. Used by tests and by
/// consumers that need to verify a packed buffer round-trips.
///
/// # Errors
///
/// Returns `None` when the buffer is malformed or the format tag does not
/// match any known encoding.
#[must_use]
#[inline]
pub fn dfa_unpack(packed: &PackedDfa) -> Option<Dfa> {
    let bytes = &packed.bytes;
    let tag = u32::from_le_bytes(bytes.get(0..4)?.try_into().ok()?);
    match (tag, packed.format) {
        (0, DfaPackFormat::Dense) => unpack_dense(bytes),
        (1, DfaPackFormat::EquivClass) => unpack_equiv_class(bytes),
        _ => None,
    }
}

fn unpack_dense(bytes: &[u8]) -> Option<Dfa> {
    let state_count = u32::from_le_bytes(bytes.get(4..8)?.try_into().ok()?);
    let start = u32::from_le_bytes(bytes.get(8..12)?.try_into().ok()?);
    let accept_words = u32::from_le_bytes(bytes.get(12..16)?.try_into().ok()?) as usize;
    let accept_start = 16;
    let accept_end = accept_start + accept_words * 4;
    let accept = read_accept(&bytes[accept_start..accept_end], state_count as usize);
    let trans_start = accept_end;
    let trans_end = trans_start + (state_count as usize) * 256 * 4;
    let transitions: Vec<u32> = bytes[trans_start..trans_end]
        .chunks_exact(4)
        .map(|c| u32::from_le_bytes(c.try_into().unwrap_or([0; 4])))
        .collect();
    Some(Dfa {
        state_count,
        transitions,
        start,
        accept,
    })
}

fn unpack_equiv_class(bytes: &[u8]) -> Option<Dfa> {
    let state_count = u32::from_le_bytes(bytes.get(4..8)?.try_into().ok()?);
    let start = u32::from_le_bytes(bytes.get(8..12)?.try_into().ok()?);
    let num_classes = u32::from_le_bytes(bytes.get(12..16)?.try_into().ok()?) as usize;
    let class_start = 16;
    let class_end = class_start + 256;
    let classes = &bytes[class_start..class_end];
    let aligned_end = (class_end + 3) & !3;
    let accept_words =
        u32::from_le_bytes(bytes.get(aligned_end..aligned_end + 4)?.try_into().ok()?) as usize;
    let accept_data_start = aligned_end + 4;
    let accept_data_end = accept_data_start + accept_words * 4;
    let accept = read_accept(
        &bytes[accept_data_start..accept_data_end],
        state_count as usize,
    );

    let trans_start = accept_data_end;
    let trans_count = (state_count as usize) * num_classes;
    let mut class_trans: Vec<u32> = Vec::with_capacity(trans_count);
    for i in 0..trans_count {
        let off = trans_start + i * 4;
        class_trans.push(u32::from_le_bytes(
            bytes.get(off..off + 4)?.try_into().ok()?,
        ));
    }

    let mut transitions = vec![INVALID_STATE; (state_count as usize) * 256];
    for state in 0..state_count as usize {
        for byte in 0u8..=255 {
            let class = classes[byte as usize] as usize;
            transitions[state * 256 + byte as usize] = class_trans[state * num_classes + class];
        }
    }
    Some(Dfa {
        state_count,
        transitions,
        start,
        accept,
    })
}

fn read_accept(bytes: &[u8], state_count: usize) -> Vec<bool> {
    let mut accept = Vec::with_capacity(state_count);
    let mut idx = 0;
    for chunk in bytes.chunks_exact(4) {
        let word = u32::from_le_bytes(chunk.try_into().unwrap_or([0; 4]));
        for bit in 0..32 {
            if idx >= state_count {
                break;
            }
            accept.push((word >> bit) & 1 == 1);
            idx += 1;
        }
    }
    accept.truncate(state_count);
    while accept.len() < state_count {
        accept.push(false);
    }
    accept
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::pattern::{
        dfa_minimize::dfa_minimize, nfa_to_dfa::nfa_to_dfa, regex_to_nfa::regex_to_nfa,
    };

    fn roundtrip(regex: &str, format: DfaPackFormat) {
        let nfa = regex_to_nfa(regex).unwrap();
        let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
        let packed = dfa_pack(&dfa, format);
        let unpacked = dfa_unpack(&packed).expect("unpack");
        assert_eq!(unpacked.state_count, dfa.state_count, "regex `{regex}`");
        assert_eq!(unpacked.start, dfa.start);
        assert_eq!(unpacked.accept, dfa.accept);
        assert_eq!(unpacked.transitions, dfa.transitions);
    }

    #[test]
    fn dense_roundtrip_literal() {
        roundtrip("hello", DfaPackFormat::Dense);
    }

    #[test]
    fn dense_roundtrip_alternation() {
        roundtrip("foo|bar|baz", DfaPackFormat::Dense);
    }

    #[test]
    fn dense_roundtrip_kleene() {
        roundtrip("a*b+c?", DfaPackFormat::Dense);
    }

    #[test]
    fn equiv_class_roundtrip_literal() {
        roundtrip("hello", DfaPackFormat::EquivClass);
    }

    #[test]
    fn equiv_class_roundtrip_char_class() {
        roundtrip("[a-z]+", DfaPackFormat::EquivClass);
    }

    #[test]
    fn equiv_class_fewer_bytes_than_dense_for_small_alphabet() {
        let nfa = regex_to_nfa("abc").unwrap();
        let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
        let dense = dfa_pack(&dfa, DfaPackFormat::Dense);
        let equiv = dfa_pack(&dfa, DfaPackFormat::EquivClass);
        assert!(
            equiv.bytes.len() < dense.bytes.len(),
            "equiv-class must be smaller for narrow alphabets: dense={} equiv={}",
            dense.bytes.len(),
            equiv.bytes.len()
        );
    }
}