1use std::ops::{AddAssign, Deref, SubAssign};
4
5use crate::{GeneralSam, GeneralSamState, TransitionTable, TrieNodeAlike};
6
7use super::suffixwise::SuffixInTrieData;
8
9#[derive(Clone, Debug)]
22pub struct GreedyTokenizer<
23 TransTable: TransitionTable,
24 TokenIDType: Clone + Default + PartialEq,
25 SamRef: Deref<Target = GeneralSam<TransTable>>,
26> {
27 sam: SamRef,
28 suffix_data: Vec<SuffixInTrieData<TokenIDType>>,
29}
30
31#[derive(Clone, Debug)]
32pub struct OwnedGeneralSam<TransTable: TransitionTable> {
33 pub sam: GeneralSam<TransTable>,
34}
35
36impl<TransTable: TransitionTable> Deref for OwnedGeneralSam<TransTable> {
37 type Target = GeneralSam<TransTable>;
38
39 fn deref(&self) -> &Self::Target {
40 &self.sam
41 }
42}
43
44impl<TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq>
45 GreedyTokenizer<TransTable, TokenIDType, OwnedGeneralSam<TransTable>>
46{
47 pub fn build_from_sam<
48 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
49 F: FnMut(&TN) -> TokenIDType,
50 >(
51 sam: GeneralSam<TransTable>,
52 trie_node: TN,
53 f: F,
54 ) -> Self {
55 Self {
56 suffix_data: SuffixInTrieData::build(&sam, trie_node, f),
57 sam: OwnedGeneralSam { sam },
58 }
59 }
60}
61
62impl<
63 TransTable: TransitionTable,
64 TokenIDType: Clone + Default + PartialEq,
65 SamRef: Deref<Target = GeneralSam<TransTable>>,
66 > GreedyTokenizer<TransTable, TokenIDType, SamRef>
67{
68 pub fn get_sam(&self) -> &SamRef {
69 &self.sam
70 }
71
72 pub fn get_sam_ref(&self) -> &GeneralSam<TransTable> {
73 &self.sam
74 }
75
76 pub fn get_suffix_data(&self) -> &Vec<SuffixInTrieData<TokenIDType>> {
77 &self.suffix_data
78 }
79
80 pub fn inner_as_ref(
81 &self,
82 ) -> GreedyTokenizer<TransTable, TokenIDType, &GeneralSam<TransTable>> {
83 GreedyTokenizer {
84 sam: &self.sam,
85 suffix_data: self.suffix_data.clone(),
86 }
87 }
88
89 pub fn build<
90 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
91 F: FnMut(&TN) -> TokenIDType,
92 >(
93 sam: SamRef,
94 trie_node: TN,
95 f: F,
96 ) -> Self {
97 Self {
98 suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f),
99 sam,
100 }
101 }
102
103 pub fn tokenize<Iter: IntoIterator<Item = TransTable::KeyType>>(
104 &self,
105 iter: Iter,
106 unk_token_id: &TokenIDType,
107 ) -> Vec<(TokenIDType, usize)> {
108 let mut res = Vec::new();
109
110 let push = |res: &mut Vec<_>, token_id: TokenIDType, token_len: usize| {
111 if let Some((last_token_id, last_token_len)) = res.last_mut() {
112 if *last_token_id == *unk_token_id && token_id == *unk_token_id {
113 *last_token_len += token_len;
114 return;
115 }
116 }
117 res.push((token_id, token_len))
118 };
119
120 let pop_buffer = |cur_len: &mut usize,
121 cur_state: &mut GeneralSamState<TransTable, &GeneralSam<TransTable>>,
122 res: &mut Vec<_>| {
123 let inner_data = self.suffix_data[cur_state.node_id]
124 .get(*cur_len)
125 .expect("invalid state");
126
127 let (token_id, token_len) = inner_data.as_ref().map_or_else(
131 || (unk_token_id, 1),
132 |token_info| (&token_info.digested_trie_node, token_info.seq_len),
133 );
134
135 cur_len.sub_assign(token_len);
136 push(res, token_id.clone(), token_len);
137 };
138
139 let mut cur_state = self.sam.get_root_state();
140 let mut cur_len = 0;
141
142 for key in iter {
143 debug_assert!(!cur_state.is_nil());
144 let mut nxt_state = cur_state.get_non_nil_trans(&key);
145 while cur_len > 0 && nxt_state.is_none() {
146 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
147
148 if cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
149 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
150 cur_state.goto_suffix_parent();
151 }
152 nxt_state = cur_state.get_non_nil_trans(&key);
153 }
154 }
155 if let Some(nxt) = nxt_state {
156 cur_state = nxt;
157 cur_len.add_assign(1);
158 } else {
159 debug_assert!(cur_state.is_root());
160 push(&mut res, unk_token_id.clone(), 1);
161 }
162 }
163
164 while cur_len > 0 {
165 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
166
167 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
168 cur_state.goto_suffix_parent();
169 }
170 }
171
172 res
173 }
174}
175
176#[cfg(feature = "trie")]
177#[cfg_attr(doc_cfg, doc(cfg(feature = "trie")))]
178pub mod trie {
179 use std::ops::Deref;
180
181 use crate::{GeneralSam, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState};
182
183 use super::OwnedGeneralSam;
184
185 impl<TransTable: TransitionTable, SamRef: Deref<Target = GeneralSam<TransTable>>>
186 super::GreedyTokenizer<TransTable, TrieNodeID, SamRef>
187 {
188 pub fn build_from_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
189 sam: SamRef,
190 trie_state: TrieState<TT, &Trie<TT>>,
191 ) -> Self {
192 Self::build(sam, trie_state, |tn| tn.node_id)
193 }
194 }
195
196 impl<TransTable: TransitionTable>
197 super::GreedyTokenizer<TransTable, TrieNodeID, OwnedGeneralSam<TransTable>>
198 {
199 pub fn build_from_sam_and_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
200 sam: GeneralSam<TransTable>,
201 trie_state: TrieState<TT, &Trie<TT>>,
202 ) -> Self {
203 Self::build_from_sam(sam, trie_state, |tn| tn.node_id)
204 }
205 }
206
207 pub fn greedy_tokenize_with_trie<
214 TransTable: TransitionTable,
215 Iter: IntoIterator<Item = TransTable::KeyType>,
216 >(
217 trie: &Trie<TransTable>,
218 seq: Iter,
219 ) -> Vec<(usize, usize)> {
220 let unk_token_id = trie.num_of_nodes();
221
222 let mut res = Vec::new();
223
224 let push = |res: &mut Vec<_>, token_id: usize, token_len: usize| {
225 if let Some((last_token_id, last_token_len)) = res.last_mut() {
226 if *last_token_id == unk_token_id && token_id == unk_token_id {
227 *last_token_len += token_len;
228 return;
229 }
230 }
231 res.push((token_id, token_len))
232 };
233
234 let seq: Box<[_]> = seq.into_iter().collect();
235 let mut cur = 0;
236 while cur < seq.len() {
237 let mut best: Option<(usize, usize)> = None;
238 let mut cur_state = trie.get_root_state();
239 for i in cur..seq.len() {
240 if !cur_state.is_root() && cur_state.is_accepting() {
241 best = Some((cur_state.node_id, i - cur));
242 }
243 let key = &seq[i];
244 cur_state.goto(key);
245 if cur_state.is_nil() {
246 break;
247 }
248 }
249 if !cur_state.is_root() && !cur_state.is_nil() && cur_state.is_accepting() {
250 best = Some((cur_state.node_id, seq.len() - cur));
251 }
252 if let Some((best_token_id, best_token_len)) = best {
253 push(&mut res, best_token_id, best_token_len);
254 cur += best_token_len;
255 } else {
256 push(&mut res, unk_token_id, 1);
257 cur += 1;
258 }
259 }
260
261 res
262 }
263}