mtc_token_healing/
token.rs

1use std::convert::Infallible;
2
3use general_sam::{
4    BTreeTransTable, BoxBisectTable, GeneralSam, SAM_ROOT_NODE_ID, TransitionTable, TravelEvent,
5    Trie, TrieNodeAlike,
6};
7use tinyvec::TinyVec;
8
9pub type TokenId = u32;
10pub type SortedTokenId = u32;
11
12pub type SmallToken = TinyVec<[u8; 28]>;
13
14const _: () = [(); 1][(core::mem::size_of::<SmallToken>() == 32) as usize ^ 1];
15
16#[derive(Clone, Debug, Default, PartialEq, Eq)]
17#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, set_all))]
18pub struct SortedTokenRange {
19    pub lower: SortedTokenId,
20    pub upper: SortedTokenId,
21}
22
23#[cfg(feature = "pyo3")]
24mod _pyo3 {
25    use pyo3::pymethods;
26
27    use super::{SortedTokenId, SortedTokenRange};
28
29    impl SortedTokenRange {
30        pub(crate) fn repr_py(&self) -> String {
31            let Self { lower, upper } = self;
32            format!("SortedTokenRange(lower={lower}, upper={upper})")
33        }
34    }
35
36    #[pymethods]
37    impl SortedTokenRange {
38        #[new]
39        #[pyo3(signature=(lower=0, upper=0))]
40        fn py_new(lower: SortedTokenId, upper: SortedTokenId) -> Self {
41            Self { lower, upper }
42        }
43
44        fn __repr__(&self) -> String {
45            self.repr_py()
46        }
47    }
48}
49
50pub(crate) fn build_sam_of_reversed_tokens<
51    I: Ord + Clone,
52    T: AsRef<[I]>,
53    V: IntoIterator<Item = T>,
54>(
55    vocab: V,
56) -> GeneralSam<BoxBisectTable<I>> {
57    let trie_of_rev_tokens = {
58        let mut trie = Trie::<BTreeTransTable<_>>::default();
59        vocab.into_iter().for_each(|token| {
60            trie.insert(token.as_ref().iter().cloned().rev());
61        });
62        trie
63    };
64    GeneralSam::<BTreeTransTable<_>>::from_trie(trie_of_rev_tokens.get_root_state())
65        .alter_trans_table_into()
66}
67
68#[derive(Debug)]
69pub(crate) struct SortResult {
70    pub rank_ranges: Vec<SortedTokenRange>,
71    pub order: Vec<TokenId>,
72    pub rank: Vec<SortedTokenId>,
73}
74
75pub(crate) fn sort_vocab_with_trie<I: Ord + Clone, T: AsRef<[I]>, V: IntoIterator<Item = T>>(
76    vocab: V,
77) -> SortResult {
78    let (trie, trie_node_ids) = {
79        let mut trie = Trie::<BTreeTransTable<_>>::default();
80        let trie_node_ids: Vec<_> = vocab
81            .into_iter()
82            .map(|token| trie.insert(token.as_ref().iter().cloned()))
83            .collect();
84        (trie, trie_node_ids)
85    };
86
87    let vocab_size = trie_node_ids.len();
88
89    let mut rank_range_in_trie = vec![SortedTokenRange::default(); trie.num_of_nodes()];
90    let mut cnt_tokens_in_trie = vec![0 as SortedTokenId; trie.num_of_nodes()];
91    trie_node_ids
92        .iter()
93        .for_each(|&i| cnt_tokens_in_trie[i] += 1);
94
95    let mut tot_cnt: SortedTokenId = 0;
96
97    let res = trie.get_root_state().dfs_travel(|event| {
98        match event {
99            TravelEvent::PushRoot(state) | TravelEvent::Push(state, _, _) => {
100                let id = state.node_id;
101                let rank_range = &mut rank_range_in_trie[id];
102                rank_range.lower = tot_cnt;
103                tot_cnt += cnt_tokens_in_trie[id];
104            }
105            TravelEvent::Pop(state, _) => {
106                let id = state.node_id;
107                let rank_range = &mut rank_range_in_trie[id];
108                rank_range.upper = tot_cnt;
109            }
110        }
111        Ok::<_, Infallible>(())
112    });
113    match res {
114        Ok(()) => {}
115        Err(e) => match e {},
116    }
117
118    let rank_ranges: Vec<_> = (0..vocab_size)
119        .map(|i| rank_range_in_trie[trie_node_ids[i]].clone())
120        .collect();
121
122    let order = {
123        let mut order: Vec<_> = (0..vocab_size as TokenId).collect();
124        order.sort_by_key(|&i| rank_ranges[i as usize].lower);
125        order
126    };
127
128    let rank = {
129        let mut rank = vec![0; vocab_size];
130        order
131            .iter()
132            .enumerate()
133            .for_each(|(k, &i)| rank[i as usize] = k as SortedTokenId);
134        rank
135    };
136
137    debug_assert_eq!(order.len(), vocab_size);
138    debug_assert_eq!(rank.len(), vocab_size);
139    debug_assert_eq!(rank_ranges.len(), vocab_size);
140
141    SortResult {
142        rank_ranges,
143        order,
144        rank,
145    }
146}
147
148pub(crate) fn label_rank_range_on_sam_of_rev_tokens<
149    K: Ord + Clone,
150    T: AsRef<[K]>,
151    V: IntoIterator<Item = (T, SortedTokenRange)>,
152    TransTable: TransitionTable<KeyType = K>,
153>(
154    sam_of_rev_tokens: &GeneralSam<TransTable>,
155    vocab_and_rank_ranges: V,
156) -> Vec<Option<SortedTokenRange>> {
157    let mut rank_ranges = vec![None; sam_of_rev_tokens.num_of_nodes()];
158
159    for (token, rank_range) in vocab_and_rank_ranges {
160        let mut state = sam_of_rev_tokens.get_root_state();
161        state.feed_ref(token.as_ref().iter().rev());
162        rank_ranges[state.node_id] = Some(rank_range);
163    }
164
165    for &id in sam_of_rev_tokens
166        .get_topo_and_suf_len_sorted_node_ids()
167        .iter()
168        .rev()
169    {
170        if id == SAM_ROOT_NODE_ID {
171            continue;
172        }
173        let Some(node) = sam_of_rev_tokens.get_node(id) else {
174            continue;
175        };
176        let Some(rank_range) = rank_ranges[id].clone() else {
177            continue;
178        };
179        let link_rank_range =
180            rank_ranges[node.get_suffix_parent_id()].get_or_insert_with(|| rank_range.clone());
181        link_rank_range.lower = link_rank_range.lower.min(rank_range.lower);
182        link_rank_range.upper = link_rank_range.upper.max(rank_range.upper);
183    }
184
185    #[cfg(debug_assertions)]
186    for (id, rank_range) in rank_ranges.iter().enumerate() {
187        if id == SAM_ROOT_NODE_ID {
188            continue;
189        }
190        let Some(rank_range) = rank_range else {
191            continue;
192        };
193        let Some(node) = sam_of_rev_tokens.get_node(id) else {
194            continue;
195        };
196
197        let link_rank_range = rank_ranges[node.get_suffix_parent_id()].as_ref();
198
199        debug_assert!(link_rank_range.is_some_and(|link_rank_range| {
200            link_rank_range.lower <= rank_range.lower && link_rank_range.upper >= rank_range.upper
201        }));
202    }
203
204    rank_ranges
205}