1use std::ops::{AddAssign, Deref, SubAssign};
4
5use crate::{GeneralSam, GeneralSamState, TransitionTable, TrieNodeAlike};
6
7use super::suffixwise::SuffixInTrieData;
8
9#[derive(Clone, Debug)]
22pub struct GreedyTokenizer<
23 TransTable: TransitionTable,
24 TokenIDType: Clone + Default + PartialEq,
25 SamRef: Deref<Target = GeneralSam<TransTable>>,
26> {
27 sam: SamRef,
28 suffix_data: Vec<SuffixInTrieData<TokenIDType>>,
29}
30
31#[derive(Clone, Debug)]
32pub struct OwnedGeneralSam<TransTable: TransitionTable> {
33 pub sam: GeneralSam<TransTable>,
34}
35
36impl<TransTable: TransitionTable> Deref for OwnedGeneralSam<TransTable> {
37 type Target = GeneralSam<TransTable>;
38
39 fn deref(&self) -> &Self::Target {
40 &self.sam
41 }
42}
43
44impl<TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq>
45 GreedyTokenizer<TransTable, TokenIDType, OwnedGeneralSam<TransTable>>
46{
47 pub fn build_from_sam<
48 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
49 F: FnMut(&TN) -> TokenIDType,
50 >(
51 sam: GeneralSam<TransTable>,
52 trie_node: TN,
53 f: F,
54 ) -> Self {
55 Self {
56 suffix_data: SuffixInTrieData::build(&sam, trie_node, f),
57 sam: OwnedGeneralSam { sam },
58 }
59 }
60}
61
62impl<
63 TransTable: TransitionTable,
64 TokenIDType: Clone + Default + PartialEq,
65 SamRef: Deref<Target = GeneralSam<TransTable>>,
66> GreedyTokenizer<TransTable, TokenIDType, SamRef>
67{
68 pub fn get_sam(&self) -> &SamRef {
69 &self.sam
70 }
71
72 pub fn get_sam_ref(&self) -> &GeneralSam<TransTable> {
73 &self.sam
74 }
75
76 pub fn get_suffix_data(&self) -> &Vec<SuffixInTrieData<TokenIDType>> {
77 &self.suffix_data
78 }
79
80 pub fn inner_as_ref(
81 &self,
82 ) -> GreedyTokenizer<TransTable, TokenIDType, &GeneralSam<TransTable>> {
83 GreedyTokenizer {
84 sam: &self.sam,
85 suffix_data: self.suffix_data.clone(),
86 }
87 }
88
89 pub fn build<
90 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
91 F: FnMut(&TN) -> TokenIDType,
92 >(
93 sam: SamRef,
94 trie_node: TN,
95 f: F,
96 ) -> Self {
97 Self {
98 suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f),
99 sam,
100 }
101 }
102
103 pub fn tokenize<Iter: IntoIterator<Item = TransTable::KeyType>>(
104 &self,
105 iter: Iter,
106 unk_token_id: &TokenIDType,
107 ) -> Vec<(TokenIDType, usize)> {
108 let mut res = Vec::new();
109
110 let push = |res: &mut Vec<_>, token_id: TokenIDType, token_len: usize| {
111 if let Some((last_token_id, last_token_len)) = res.last_mut()
112 && *last_token_id == *unk_token_id
113 && token_id == *unk_token_id
114 {
115 *last_token_len += token_len;
116 return;
117 }
118 res.push((token_id, token_len))
119 };
120
121 let pop_buffer = |cur_len: &mut usize,
122 cur_state: &mut GeneralSamState<TransTable, &GeneralSam<TransTable>>,
123 res: &mut Vec<_>| {
124 let inner_data = self.suffix_data[cur_state.node_id]
125 .get(*cur_len)
126 .expect("invalid state");
127
128 let (token_id, token_len) = inner_data.as_ref().map_or_else(
132 || (unk_token_id, 1),
133 |token_info| (&token_info.digested_trie_node, token_info.seq_len),
134 );
135
136 cur_len.sub_assign(token_len);
137 push(res, token_id.clone(), token_len);
138 };
139
140 let mut cur_state = self.sam.get_root_state();
141 let mut cur_len = 0;
142
143 for key in iter {
144 debug_assert!(!cur_state.is_nil());
145 let mut nxt_state = cur_state.get_non_nil_trans(&key);
146 while cur_len > 0 && nxt_state.is_none() {
147 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
148
149 if cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
150 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
151 cur_state.goto_suffix_parent();
152 }
153 nxt_state = cur_state.get_non_nil_trans(&key);
154 }
155 }
156 if let Some(nxt) = nxt_state {
157 cur_state = nxt;
158 cur_len.add_assign(1);
159 } else {
160 debug_assert!(cur_state.is_root());
161 push(&mut res, unk_token_id.clone(), 1);
162 }
163 }
164
165 while cur_len > 0 {
166 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
167
168 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
169 cur_state.goto_suffix_parent();
170 }
171 }
172
173 res
174 }
175}
176
177#[cfg(feature = "trie")]
178pub mod trie {
179 use std::ops::Deref;
180
181 use crate::{GeneralSam, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState};
182
183 use super::OwnedGeneralSam;
184
185 impl<TransTable: TransitionTable, SamRef: Deref<Target = GeneralSam<TransTable>>>
186 super::GreedyTokenizer<TransTable, TrieNodeID, SamRef>
187 {
188 pub fn build_from_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
189 sam: SamRef,
190 trie_state: TrieState<TT, &Trie<TT>>,
191 ) -> Self {
192 Self::build(sam, trie_state, |tn| tn.node_id)
193 }
194 }
195
196 impl<TransTable: TransitionTable>
197 super::GreedyTokenizer<TransTable, TrieNodeID, OwnedGeneralSam<TransTable>>
198 {
199 pub fn build_from_sam_and_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
200 sam: GeneralSam<TransTable>,
201 trie_state: TrieState<TT, &Trie<TT>>,
202 ) -> Self {
203 Self::build_from_sam(sam, trie_state, |tn| tn.node_id)
204 }
205 }
206
207 pub fn greedy_tokenize_with_trie<
214 TransTable: TransitionTable,
215 Iter: IntoIterator<Item = TransTable::KeyType>,
216 >(
217 trie: &Trie<TransTable>,
218 seq: Iter,
219 ) -> Vec<(usize, usize)> {
220 let unk_token_id = trie.num_of_nodes();
221
222 let mut res = Vec::new();
223
224 let push = |res: &mut Vec<_>, token_id: usize, token_len: usize| {
225 if let Some((last_token_id, last_token_len)) = res.last_mut()
226 && *last_token_id == unk_token_id
227 && token_id == unk_token_id
228 {
229 *last_token_len += token_len;
230 return;
231 }
232 res.push((token_id, token_len))
233 };
234
235 let seq: Box<[_]> = seq.into_iter().collect();
236 let mut cur = 0;
237 while cur < seq.len() {
238 let mut best: Option<(usize, usize)> = None;
239 let mut cur_state = trie.get_root_state();
240 for i in cur..seq.len() {
241 if !cur_state.is_root() && cur_state.is_accepting() {
242 best = Some((cur_state.node_id, i - cur));
243 }
244 let key = &seq[i];
245 cur_state.goto(key);
246 if cur_state.is_nil() {
247 break;
248 }
249 }
250 if !cur_state.is_root() && !cur_state.is_nil() && cur_state.is_accepting() {
251 best = Some((cur_state.node_id, seq.len() - cur));
252 }
253 if let Some((best_token_id, best_token_len)) = best {
254 push(&mut res, best_token_id, best_token_len);
255 cur += best_token_len;
256 } else {
257 push(&mut res, unk_token_id, 1);
258 cur += 1;
259 }
260 }
261
262 res
263 }
264}