mtc-inc-bpe 0.9.0

Incremental BPE tokenization for all prefixes
Documentation
use crate::{
    aho_corasick::{AC_NODE_ROOT, ACNodeId, ACTrie, relabeling::Relabeling},
    typed_vec::{TypedVec, vec_with_head},
};

pub(super) fn heavy_light_decomposition(trie: &ACTrie) -> Relabeling<ACNodeId> {
    let len = trie.len();

    let mut size = TypedVec::new_with(1u32, len);
    let mut largest = TypedVec::new_with(AC_NODE_ROOT, len);

    let mut stack = vec_with_head(
        (AC_NODE_ROOT, None::<ACNodeId>, trie.children(AC_NODE_ROOT)),
        len,
    );
    while let Some((node_id, parent_id, child_iter)) = stack.last_mut() {
        let node_id = *node_id;

        if let Some((child_id, _)) = child_iter.next() {
            stack.push((child_id, Some(node_id), trie.children(child_id)));
            continue;
        }

        if let Some(parent_id) = *parent_id {
            let cur_size = size[node_id];
            size[parent_id] += cur_size;
            if largest[parent_id] == AC_NODE_ROOT || size[largest[parent_id]] < cur_size {
                largest[parent_id] = node_id;
            }
        }

        stack.pop();
    }

    let mut stack = vec_with_head(AC_NODE_ROOT, len);
    let mut order = TypedVec::with_capacity(len);
    while let Some(mut head) = stack.pop() {
        loop {
            order.push(head);
            let next_head = largest[head];
            if next_head == AC_NODE_ROOT {
                break;
            }
            for (child_id, _) in trie.children(head).rev() {
                if child_id != next_head {
                    stack.push(child_id);
                }
            }
            head = next_head;
        }
    }
    debug_assert_eq!(order[AC_NODE_ROOT], AC_NODE_ROOT);
    debug_assert_eq!(order.len(), len);
    Relabeling::new(order)
}

#[cfg(test)]
mod tests {
    use crate::aho_corasick::{
        AC_NODE_ROOT, ACNodeId, ACTrie, heavy_light::heavy_light_decomposition,
    };

    #[test]
    fn test_ac_trie_heavy_light_decomposition() {
        let mut trie = ACTrie::default();
        for token in ["abcd", "aaa", "acd", "rose", "rad", "rand", "rb"] {
            let mut node = AC_NODE_ROOT;
            for &byte in token.as_bytes() {
                node = trie.get_or_add(node, byte);
            }
        }
        let relabeling = heavy_light_decomposition(&trie);
        let order = relabeling.apply_to_typed_vec(trie.keys().collect());
        let order: Vec<_> = order.into_iter().map(|i| i.inner()).collect();
        let expected = [0, 9, 13, 15, 16, 14, 17, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8];
        assert_eq!(order, expected);

        let trie = trie.apply_relabeling(&relabeling);

        for (node_id, expected) in [
            (0, &[(1u32, b'r'), (10, b'a')] as &[_]),
            (1, &[(2, b'a'), (6, b'b'), (7, b'o')]),
            (2, &[(3, b'n'), (5, b'd')]),
            (10, &[(11, b'b'), (14, b'a'), (16, b'c')]),
        ] {
            let mut children: Vec<_> = trie
                .children(ACNodeId::new(node_id))
                .map(|(u, v)| (u.inner(), v))
                .collect();
            children.sort();
            let expected = expected.to_vec();
            assert_eq!(children, expected);
        }
    }
}