mtc_token_healing/
token.rs1use std::convert::Infallible;
2
3use general_sam::{
4 BTreeTransTable, BoxBisectTable, GeneralSam, SAM_ROOT_NODE_ID, TransitionTable, TravelEvent,
5 Trie, TrieNodeAlike,
6};
7use tinyvec::TinyVec;
8
9pub type TokenId = u32;
10pub type SortedTokenId = u32;
11
12pub type SmallToken = TinyVec<[u8; 28]>;
13
14const _: () = [(); 1][(core::mem::size_of::<SmallToken>() == 32) as usize ^ 1];
15
16#[derive(Clone, Debug, Default, PartialEq, Eq)]
17#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, set_all))]
18pub struct SortedTokenRange {
19 pub lower: SortedTokenId,
20 pub upper: SortedTokenId,
21}
22
23#[cfg(feature = "pyo3")]
24mod _pyo3 {
25 use pyo3::pymethods;
26
27 use super::{SortedTokenId, SortedTokenRange};
28
29 impl SortedTokenRange {
30 pub(crate) fn repr_py(&self) -> String {
31 let Self { lower, upper } = self;
32 format!("SortedTokenRange(lower={lower}, upper={upper})")
33 }
34 }
35
36 #[pymethods]
37 impl SortedTokenRange {
38 #[new]
39 #[pyo3(signature=(lower=0, upper=0))]
40 fn py_new(lower: SortedTokenId, upper: SortedTokenId) -> Self {
41 Self { lower, upper }
42 }
43
44 fn __repr__(&self) -> String {
45 self.repr_py()
46 }
47 }
48}
49
50pub(crate) fn build_sam_of_reversed_tokens<
51 I: Ord + Clone,
52 T: AsRef<[I]>,
53 V: IntoIterator<Item = T>,
54>(
55 vocab: V,
56) -> GeneralSam<BoxBisectTable<I>> {
57 let trie_of_rev_tokens = {
58 let mut trie = Trie::<BTreeTransTable<_>>::default();
59 vocab.into_iter().for_each(|token| {
60 trie.insert(token.as_ref().iter().cloned().rev());
61 });
62 trie
63 };
64 GeneralSam::<BTreeTransTable<_>>::from_trie(trie_of_rev_tokens.get_root_state())
65 .alter_trans_table_into()
66}
67
68#[derive(Debug)]
69pub(crate) struct SortResult {
70 pub rank_ranges: Vec<SortedTokenRange>,
71 pub order: Vec<TokenId>,
72 pub rank: Vec<SortedTokenId>,
73}
74
75pub(crate) fn sort_vocab_with_trie<I: Ord + Clone, T: AsRef<[I]>, V: IntoIterator<Item = T>>(
76 vocab: V,
77) -> SortResult {
78 let (trie, trie_node_ids) = {
79 let mut trie = Trie::<BTreeTransTable<_>>::default();
80 let trie_node_ids: Vec<_> = vocab
81 .into_iter()
82 .map(|token| trie.insert(token.as_ref().iter().cloned()))
83 .collect();
84 (trie, trie_node_ids)
85 };
86
87 let vocab_size = trie_node_ids.len();
88
89 let mut rank_range_in_trie = vec![SortedTokenRange::default(); trie.num_of_nodes()];
90 let mut cnt_tokens_in_trie = vec![0 as SortedTokenId; trie.num_of_nodes()];
91 trie_node_ids
92 .iter()
93 .for_each(|&i| cnt_tokens_in_trie[i] += 1);
94
95 let mut tot_cnt: SortedTokenId = 0;
96
97 let res = trie.get_root_state().dfs_travel(|event| {
98 match event {
99 TravelEvent::PushRoot(state) | TravelEvent::Push(state, _, _) => {
100 let id = state.node_id;
101 let rank_range = &mut rank_range_in_trie[id];
102 rank_range.lower = tot_cnt;
103 tot_cnt += cnt_tokens_in_trie[id];
104 }
105 TravelEvent::Pop(state, _) => {
106 let id = state.node_id;
107 let rank_range = &mut rank_range_in_trie[id];
108 rank_range.upper = tot_cnt;
109 }
110 }
111 Ok::<_, Infallible>(())
112 });
113 match res {
114 Ok(()) => {}
115 Err(e) => match e {},
116 }
117
118 let rank_ranges: Vec<_> = (0..vocab_size)
119 .map(|i| rank_range_in_trie[trie_node_ids[i]].clone())
120 .collect();
121
122 let order = {
123 let mut order: Vec<_> = (0..vocab_size as TokenId).collect();
124 order.sort_by_key(|&i| rank_ranges[i as usize].lower);
125 order
126 };
127
128 let rank = {
129 let mut rank = vec![0; vocab_size];
130 order
131 .iter()
132 .enumerate()
133 .for_each(|(k, &i)| rank[i as usize] = k as SortedTokenId);
134 rank
135 };
136
137 debug_assert_eq!(order.len(), vocab_size);
138 debug_assert_eq!(rank.len(), vocab_size);
139 debug_assert_eq!(rank_ranges.len(), vocab_size);
140
141 SortResult {
142 rank_ranges,
143 order,
144 rank,
145 }
146}
147
148pub(crate) fn label_rank_range_on_sam_of_rev_tokens<
149 K: Ord + Clone,
150 T: AsRef<[K]>,
151 V: IntoIterator<Item = (T, SortedTokenRange)>,
152 TransTable: TransitionTable<KeyType = K>,
153>(
154 sam_of_rev_tokens: &GeneralSam<TransTable>,
155 vocab_and_rank_ranges: V,
156) -> Vec<Option<SortedTokenRange>> {
157 let mut rank_ranges = vec![None; sam_of_rev_tokens.num_of_nodes()];
158
159 for (token, rank_range) in vocab_and_rank_ranges {
160 let mut state = sam_of_rev_tokens.get_root_state();
161 state.feed_ref(token.as_ref().iter().rev());
162 rank_ranges[state.node_id] = Some(rank_range);
163 }
164
165 for &id in sam_of_rev_tokens
166 .get_topo_and_suf_len_sorted_node_ids()
167 .iter()
168 .rev()
169 {
170 if id == SAM_ROOT_NODE_ID {
171 continue;
172 }
173 let Some(node) = sam_of_rev_tokens.get_node(id) else {
174 continue;
175 };
176 let Some(rank_range) = rank_ranges[id].clone() else {
177 continue;
178 };
179 let link_rank_range =
180 rank_ranges[node.get_suffix_parent_id()].get_or_insert_with(|| rank_range.clone());
181 link_rank_range.lower = link_rank_range.lower.min(rank_range.lower);
182 link_rank_range.upper = link_rank_range.upper.max(rank_range.upper);
183 }
184
185 #[cfg(debug_assertions)]
186 for (id, rank_range) in rank_ranges.iter().enumerate() {
187 if id == SAM_ROOT_NODE_ID {
188 continue;
189 }
190 let Some(rank_range) = rank_range else {
191 continue;
192 };
193 let Some(node) = sam_of_rev_tokens.get_node(id) else {
194 continue;
195 };
196
197 let link_rank_range = rank_ranges[node.get_suffix_parent_id()].as_ref();
198
199 debug_assert!(link_rank_range.is_some_and(|link_rank_range| {
200 link_rank_range.lower <= rank_range.lower && link_rank_range.upper >= rank_range.upper
201 }));
202 }
203
204 rank_ranges
205}