mtc-inc-bpe 0.9.0

Incremental BPE tokenization for all prefixes
Documentation
use derive_more::{Debug, Deref};

use crate::{
    TokenId,
    aho_corasick::{
        AC_NODE_ROOT, ACNodeId, ACSuffixLinkTree, ACTransTable, ACTrie,
        heavy_light::heavy_light_decomposition,
    },
    typed_vec::TypedVec,
};

#[derive(Debug, Deref)]
pub(crate) struct ACAutomaton {
    #[deref]
    pub trans_table: ACTransTable,
    #[cfg(test)]
    pub trie: ACTrie,
    pub suffix: ACSuffixLinkTree,
    pub token_to_node: TypedVec<TokenId, ACNodeId>,
    pub depths: TypedVec<ACNodeId, u16>,
}

impl ACAutomaton {
    pub fn new<T: AsRef<[u8]>, V: IntoIterator<Item = T>>(vocab: V) -> Self {
        let mut trie = ACTrie::default();

        let vocab = vocab.into_iter();
        let mut token_to_node = TypedVec::with_capacity(TokenId::from(vocab.size_hint().0));

        for token in vocab {
            let mut node = AC_NODE_ROOT;
            for &byte in token.as_ref() {
                node = trie.get_or_add(node, byte);
            }
            token_to_node.push(node);
        }

        let mut suffix = TypedVec::new_with(AC_NODE_ROOT, trie.len());
        for node in trie.bfs() {
            if node == AC_NODE_ROOT {
                continue;
            }
            for (child, byte) in trie.children(node) {
                let mut cursor = suffix[node];
                while cursor != AC_NODE_ROOT && trie.get(cursor, byte).is_none() {
                    cursor = suffix[cursor];
                }
                suffix[child] = trie.get(cursor, byte).unwrap_or(AC_NODE_ROOT);
            }
        }

        let relabeling = heavy_light_decomposition(&trie);

        let trie = trie.apply_relabeling(&relabeling);
        relabeling.apply_to_iter_mut(&mut token_to_node);

        relabeling.apply_to_iter_mut(&mut suffix);
        let suffix = ACSuffixLinkTree::new(relabeling.apply_to_typed_vec(suffix));

        let trans_table = ACTransTable::new(&trie, &suffix);

        let mut depths = TypedVec::new_with(0u16, trie.len());
        trie.bfs().for_each(|node| {
            let val = depths[node] + 1;
            for (child, _) in trie.children(node) {
                depths[child] = val;
            }
        });

        Self {
            trans_table,
            #[cfg(test)]
            trie,
            suffix,
            token_to_node,
            depths,
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        Vocab,
        aho_corasick::{AC_NODE_ROOT, ACAutomaton},
    };

    #[test]
    fn test_ac_automaton() {
        let vocab = Vocab::new([
            b"a" as &[u8],
            b"abc",
            b"abcde",
            b"abcdef",
            b"b",
            b"ba",
            b"bc",
            b"bcdef",
            b"c",
            b"cd",
            b"cde",
            b"cdefg",
            b"d",
            b"de",
            b"def",
            b"e",
            b"ef",
            b"efg",
            b"f",
            b"g",
        ])
        .unwrap();

        let automaton = ACAutomaton::new(vocab.tokens());

        for node in automaton.trie.keys() {
            let suffix = automaton.suffix[node];
            let children: Vec<_> = automaton.trie.children(node).collect();
            println!("{node:2} {suffix:2}: {children:?}");
        }
        for (id, token) in vocab.tokens.enumerate() {
            let node = automaton.token_to_node[id];
            let suffix = automaton.suffix[node];
            println!("{node:2} {suffix:2}: {}", str::from_utf8(token).unwrap());
        }

        let search = |s: &str| {
            let mut node = AC_NODE_ROOT;
            for &b in s.as_bytes() {
                if let Some(next) = automaton.trie.get(node, b) {
                    node = next;
                } else {
                    return None;
                }
            }
            Some(node)
        };

        let id_b = search("b").unwrap();
        let id_ba = search("ba").unwrap();
        assert!(search("babcd").is_none());
        let id_abcd = search("abcd").unwrap();
        let id_abcdef = search("abcdef").unwrap();
        assert!(search("abcdefg").is_none());
        assert!(search("bcdefg").is_none());
        let id_cdefg = search("cdefg").unwrap();

        let feed = |sequences: &[&str]| -> Vec<_> {
            let mut node = AC_NODE_ROOT;
            sequences
                .iter()
                .map(|&s| {
                    node = automaton.feed(node, s);
                    node
                })
                .collect()
        };

        let output = feed(&["b", "a", "bcd", "ef", "g"]);
        assert_eq!(output, vec![id_b, id_ba, id_abcd, id_abcdef, id_cdefg]);
    }
}