mtc-inc-bpe 0.9.0

Incremental BPE tokenization for all prefixes
Documentation
use std::collections::VecDeque;

use derive_more::Deref;

use crate::{
    SkipLen,
    aho_corasick::{AC_NODE_ROOT, ACAutomaton, ACNodeId},
    normalize::ATOMIC_TOKEN_PRIORITY,
    successor::{FOREST_VIRTUAL_ROOT, ForestNodeId, SucForest},
    typed_vec::TypedVec,
};

#[derive(Clone, Debug)]
pub(crate) struct SufSucNode<I = ForestNodeId> {
    pub repr_id: I,
    pub skip_len: SkipLen,
    pub suc_skip_len: SkipLen,
    pub valid_range: (ForestNodeId, ForestNodeId),
}

const _: () = {
    assert!(std::mem::size_of::<SufSucNode>() == 16);
    assert!(std::mem::size_of::<SufSucNode<u32>>() == 16);
};

impl<I> SufSucNode<I> {
    #[inline(always)]
    pub fn verify<F: FnOnce(usize) -> ForestNodeId>(&self, f: F) -> bool {
        self.verify_skipped(f(self.suc_skip_len as usize))
    }

    #[inline(always)]
    pub fn verify_skipped(&self, w: ForestNodeId) -> bool {
        let (u, v) = self.valid_range;
        u <= w && w < v
    }
}

#[derive(Debug, Deref)]
pub(crate) struct SufSucNodeSet {
    #[deref]
    nodes: TypedVec<ForestNodeId, SufSucNode>,
    pub suffix_parent: TypedVec<ForestNodeId, ForestNodeId>,
    pub longest_token_node: TypedVec<ACNodeId, ForestNodeId>,
}

impl SufSucNodeSet {
    pub fn new(forest: &SucForest, automaton: &ACAutomaton) -> Self {
        let automaton_size = automaton.num_of_nodes();
        let forest_size = forest.len();

        let mut longest_token_node = TypedVec::new_with(FOREST_VIRTUAL_ROOT, automaton_size);
        for (token_id, ac_node_id) in automaton.token_to_node.enumerate_copied() {
            let forest_node_id = forest.token_to_node_id[token_id];
            longest_token_node[ac_node_id] = forest_node_id;
        }

        let mut suffix_parent = TypedVec::new_with(FOREST_VIRTUAL_ROOT, forest_size);

        let mut queue = VecDeque::with_capacity(automaton_size.as_usize());
        queue.push_back(AC_NODE_ROOT);
        while let Some(node) = queue.pop_front() {
            let cur_longest = longest_token_node[node];
            for child in automaton.suffix.children(node) {
                if longest_token_node[child] == FOREST_VIRTUAL_ROOT {
                    longest_token_node[child] = cur_longest;
                } else {
                    suffix_parent[longest_token_node[child]] = cur_longest;
                }
                queue.push_back(child);
            }
        }

        let calc_valid_pre_node_id_range = |node_id: ForestNodeId| {
            let node = &forest[node_id];
            if node.skip_len <= 1 {
                debug_assert!(
                    node_id == FOREST_VIRTUAL_ROOT && node.skip_len == 0
                        || node.parent == FOREST_VIRTUAL_ROOT && node.skip_len == 1
                );
                (ForestNodeId::ZERO, ForestNodeId::MAX)
            } else {
                debug_assert!(
                    node.pre_id != FOREST_VIRTUAL_ROOT && node.priority < ATOMIC_TOKEN_PRIORITY
                );
                let pre = &forest[node.pre_id];
                let last = if pre
                    .children
                    .first()
                    .is_none_or(|&c| node.priority >= forest[c].priority)
                {
                    node.pre_id.next()
                } else if pre
                    .children
                    .last()
                    .is_some_and(|&c| node.priority < forest[c].priority)
                {
                    pre.subtree_last_node.next()
                } else {
                    match pre
                        .children
                        .binary_search_by_key(&!node.priority, |&i| !forest[i].priority)
                    {
                        Ok(idx) => pre.children[idx],
                        Err(idx) => pre.children[idx],
                    }
                };
                #[cfg(debug_assertions)]
                {
                    for &c in &pre.children {
                        debug_assert!((node.priority < forest[c].priority) ^ (c >= last));
                        debug_assert!(c > node.pre_id);
                    }
                }
                (node.pre_id, last)
            }
        };

        let nodes: TypedVec<ForestNodeId, _> = forest
            .enumerate()
            .map(|(i, node)| SufSucNode {
                repr_id: i,
                skip_len: node.skip_len,
                suc_skip_len: forest[node.parent].skip_len,
                valid_range: calc_valid_pre_node_id_range(i),
            })
            .collect();

        #[cfg(debug_assertions)]
        {
            for (i, node) in nodes.enumerate() {
                let parent = forest[i].parent;
                if parent == i {
                    debug_assert_eq!(i, FOREST_VIRTUAL_ROOT);
                }
                debug_assert_eq!(node.repr_id, i);
                debug_assert_eq!(node.skip_len, forest[i].skip_len);
                debug_assert_eq!(node.suc_skip_len, forest[parent].skip_len);
            }

            for node_id in nodes.keys() {
                if node_id == FOREST_VIRTUAL_ROOT {
                    continue;
                }
                let node = &forest[node_id];
                let mut ranges = Vec::with_capacity(node.children.len());
                for &child_id in &node.children {
                    let child = &nodes[child_id];
                    ranges.push(child.valid_range);
                }
                ranges.sort();
                for slice in ranges.windows(2) {
                    let (l, r) = slice[0];
                    let (u, v) = slice[1];
                    debug_assert!(l.max(u) >= r.min(v));
                }
            }
        }

        Self {
            nodes,
            suffix_parent,
            longest_token_node,
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        Dictionary, NormalizedDict, Vocab,
        aho_corasick::{AC_NODE_ROOT, ACAutomaton},
        successor::{FOREST_VIRTUAL_ROOT, SucForest},
        suf_suc::SufSucNodeSet,
    };

    fn node_set_case(vocab: &[&str], rules: &[(&str, &str)]) {
        let vocab = Vocab::new(vocab.iter().map(|s| s.as_bytes().to_vec())).unwrap();

        let dict = Dictionary::new_from_token_pair(vocab, rules.iter().copied()).unwrap();
        let dict = NormalizedDict::new_in_bytes(dict.clone()).unwrap();
        let automaton = ACAutomaton::new(dict.iter_canonical_or_empty_tokens());
        let forest = SucForest::new(&dict);

        let node_set = SufSucNodeSet::new(&forest, &automaton);

        for (node_id, node) in forest.enumerate() {
            let s = if node_id.inner() == 0 {
                "(epsilon)"
            } else {
                std::str::from_utf8(&dict[node.token_id]).unwrap()
            };
            let suf_suc_node = &node_set[node_id];
            println!("{s:12} {node_id:2}: {node:?} {suf_suc_node:?}");
        }

        let mut stack = vec![(AC_NODE_ROOT, automaton.trie.children(AC_NODE_ROOT))];
        let mut cur_string = Vec::with_capacity(dict.tokens.iter().map(|t| t.len()).max().unwrap());
        println!("{:?}", automaton.token_to_node);
        while let Some((ac_node_id, child_iter)) = stack.last_mut() {
            let ac_node_id = *ac_node_id;
            let Some((child, byte)) = child_iter.next() else {
                stack.pop();
                if ac_node_id != AC_NODE_ROOT {
                    cur_string.pop();
                }
                continue;
            };
            stack.push((child, automaton.trie.children(child)));
            cur_string.push(byte);
            let longest = dict
                .tokens
                .keys()
                .filter(|&i| dict.is_canonical(i) && cur_string.ends_with(&dict[i]))
                .max_by_key(|&i| dict[i].len())
                .map(|i| forest.token_to_node_id[i])
                .unwrap_or(FOREST_VIRTUAL_ROOT);
            println!("{child}: {:?} {longest}", str::from_utf8(&cur_string));
            assert_eq!(node_set.longest_token_node[child], longest);
        }

        for (token_id, ac_node_id) in automaton.token_to_node.enumerate_copied() {
            let node_id = forest.token_to_node_id[token_id];
            if node_id == FOREST_VIRTUAL_ROOT {
                continue;
            }
            assert_eq!(node_set.longest_token_node[ac_node_id], node_id);
            let token = dict.get_token(token_id).unwrap();
            let suf_parent_id = dict
                .tokens
                .keys()
                .filter(|&i| {
                    dict.is_canonical(i) && token.ends_with(&dict[i]) && dict[i].len() < token.len()
                })
                .max_by_key(|&i| dict[i].len())
                .map(|i| forest.token_to_node_id[i])
                .unwrap_or(FOREST_VIRTUAL_ROOT);
            println!("{node_id}: {}", node_set.suffix_parent[node_id]);
            assert_eq!(node_set.suffix_parent[node_id], suf_parent_id);
        }
    }

    #[test]
    fn test_node_set() {
        let vocab = [
            "", "a", "abc", "abcde", "abcdef", "b", "ba", "bc", "bcdef", "c", "cd", "cde", "cdefg",
            "d", "de", "def", "e", "ef", "efg", "f", "g",
        ];
        node_set_case(
            &vocab,
            &[
                ("b", "c"),
                ("e", "f"),
                ("d", "e"),
                ("c", "d"),
                ("d", "ef"),
                ("b", "a"),
                ("a", "bc"),
                ("abc", "de"),
                ("abc", "def"),
                ("bc", "def"),
                ("c", "de"),
                ("ef", "g"),
                ("cd", "efg"),
            ],
        );
        node_set_case(
            &vocab,
            &[
                ("b", "c"),
                ("e", "f"),
                ("d", "e"),
                ("c", "d"),
                ("d", "ef"),
                ("a", "bc"),
                ("b", "a"),
                ("abc", "de"),
                ("abc", "def"),
                ("bc", "def"),
                ("c", "de"),
                ("ef", "g"),
                ("cd", "efg"),
            ],
        );
    }

    #[test]
    fn test_repeated_suf_suc() {
        let vocab: Vec<String> = ["".to_owned()]
            .into_iter()
            .chain((1..=16).map(|i| std::iter::repeat_n('a', i).collect()))
            .collect();
        let vocab_ref: Vec<_> = vocab.iter().map(|s| s.as_ref()).collect();
        node_set_case(
            &vocab_ref,
            &[
                ("a", "a"),
                ("aa", "a"),
                ("aa", "aa"),
                ("aaaa", "aaaa"),
                ("aaaa", "aa"),
                ("aa", "aaa"),
                ("aaaa", "aaa"),
                ("aaaaaaaa", "aaaaaaaa"),
            ],
        );
    }
}