1use std::ops::{AddAssign, Deref, SubAssign};
4
5use crate::{GeneralSam, GeneralSamState, TransitionTable, TrieNodeAlike};
6
7use super::suffixwise::SuffixInTrieData;
8
9#[derive(Clone, Debug)]
22pub struct GreedyTokenizer<
23 TransTable: TransitionTable,
24 TokenIDType: Clone + Default + PartialEq,
25 SamRef: Deref<Target = GeneralSam<TransTable>>,
26> {
27 sam: SamRef,
28 suffix_data: Vec<SuffixInTrieData<TokenIDType>>,
29}
30
31#[derive(Clone, Debug)]
32pub struct OwnedGeneralSam<TransTable: TransitionTable> {
33 pub sam: GeneralSam<TransTable>,
34}
35
36impl<TransTable: TransitionTable> Deref for OwnedGeneralSam<TransTable> {
37 type Target = GeneralSam<TransTable>;
38
39 fn deref(&self) -> &Self::Target {
40 &self.sam
41 }
42}
43
44impl<TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq>
45 GreedyTokenizer<TransTable, TokenIDType, OwnedGeneralSam<TransTable>>
46{
47 pub fn build_from_sam<
48 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
49 F: FnMut(&TN) -> TokenIDType,
50 >(
51 sam: GeneralSam<TransTable>,
52 trie_node: TN,
53 f: F,
54 ) -> Self {
55 Self {
56 suffix_data: SuffixInTrieData::build(&sam, trie_node, f),
57 sam: OwnedGeneralSam { sam },
58 }
59 }
60}
61
62impl<
63 TransTable: TransitionTable,
64 TokenIDType: Clone + Default + PartialEq,
65 SamRef: Deref<Target = GeneralSam<TransTable>>,
66> GreedyTokenizer<TransTable, TokenIDType, SamRef>
67{
68 pub fn get_sam(&self) -> &SamRef {
69 &self.sam
70 }
71
72 pub fn get_sam_ref(&self) -> &GeneralSam<TransTable> {
73 &self.sam
74 }
75
76 pub fn get_suffix_data(&self) -> &Vec<SuffixInTrieData<TokenIDType>> {
77 &self.suffix_data
78 }
79
80 pub fn inner_as_ref(
81 &self,
82 ) -> GreedyTokenizer<TransTable, TokenIDType, &GeneralSam<TransTable>> {
83 GreedyTokenizer {
84 sam: &self.sam,
85 suffix_data: self.suffix_data.clone(),
86 }
87 }
88
89 pub fn build<
90 TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
91 F: FnMut(&TN) -> TokenIDType,
92 >(
93 sam: SamRef,
94 trie_node: TN,
95 f: F,
96 ) -> Self {
97 Self {
98 suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f),
99 sam,
100 }
101 }
102
103 pub fn tokenize<Iter: IntoIterator<Item = TransTable::KeyType>>(
104 &self,
105 iter: Iter,
106 unk_token_id: &TokenIDType,
107 ) -> Vec<(TokenIDType, usize)> {
108 let mut res = Vec::new();
109
110 let push = |res: &mut Vec<_>, token_id: TokenIDType, token_len: usize| {
111 if let Some((last_token_id, last_token_len)) = res.last_mut() {
112 if *last_token_id == *unk_token_id && token_id == *unk_token_id {
113 *last_token_len += token_len;
114 return;
115 }
116 }
117 res.push((token_id, token_len))
118 };
119
120 let pop_buffer = |cur_len: &mut usize,
121 cur_state: &mut GeneralSamState<TransTable, &GeneralSam<TransTable>>,
122 res: &mut Vec<_>| {
123 let inner_data = self.suffix_data[cur_state.node_id]
124 .get(*cur_len)
125 .expect("invalid state");
126
127 let (token_id, token_len) = inner_data.as_ref().map_or_else(
131 || (unk_token_id, 1),
132 |token_info| (&token_info.digested_trie_node, token_info.seq_len),
133 );
134
135 cur_len.sub_assign(token_len);
136 push(res, token_id.clone(), token_len);
137 };
138
139 let mut cur_state = self.sam.get_root_state();
140 let mut cur_len = 0;
141
142 for key in iter {
143 debug_assert!(!cur_state.is_nil());
144 let mut nxt_state = cur_state.get_non_nil_trans(&key);
145 while cur_len > 0 && nxt_state.is_none() {
146 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
147
148 if cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
149 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
150 cur_state.goto_suffix_parent();
151 }
152 nxt_state = cur_state.get_non_nil_trans(&key);
153 }
154 }
155 if let Some(nxt) = nxt_state {
156 cur_state = nxt;
157 cur_len.add_assign(1);
158 } else {
159 debug_assert!(cur_state.is_root());
160 push(&mut res, unk_token_id.clone(), 1);
161 }
162 }
163
164 while cur_len > 0 {
165 pop_buffer(&mut cur_len, &mut cur_state, &mut res);
166
167 while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
168 cur_state.goto_suffix_parent();
169 }
170 }
171
172 res
173 }
174}
175
176#[cfg(feature = "trie")]
177pub mod trie {
178 use std::ops::Deref;
179
180 use crate::{GeneralSam, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState};
181
182 use super::OwnedGeneralSam;
183
184 impl<TransTable: TransitionTable, SamRef: Deref<Target = GeneralSam<TransTable>>>
185 super::GreedyTokenizer<TransTable, TrieNodeID, SamRef>
186 {
187 pub fn build_from_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
188 sam: SamRef,
189 trie_state: TrieState<TT, &Trie<TT>>,
190 ) -> Self {
191 Self::build(sam, trie_state, |tn| tn.node_id)
192 }
193 }
194
195 impl<TransTable: TransitionTable>
196 super::GreedyTokenizer<TransTable, TrieNodeID, OwnedGeneralSam<TransTable>>
197 {
198 pub fn build_from_sam_and_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
199 sam: GeneralSam<TransTable>,
200 trie_state: TrieState<TT, &Trie<TT>>,
201 ) -> Self {
202 Self::build_from_sam(sam, trie_state, |tn| tn.node_id)
203 }
204 }
205
206 pub fn greedy_tokenize_with_trie<
213 TransTable: TransitionTable,
214 Iter: IntoIterator<Item = TransTable::KeyType>,
215 >(
216 trie: &Trie<TransTable>,
217 seq: Iter,
218 ) -> Vec<(usize, usize)> {
219 let unk_token_id = trie.num_of_nodes();
220
221 let mut res = Vec::new();
222
223 let push = |res: &mut Vec<_>, token_id: usize, token_len: usize| {
224 if let Some((last_token_id, last_token_len)) = res.last_mut() {
225 if *last_token_id == unk_token_id && token_id == unk_token_id {
226 *last_token_len += token_len;
227 return;
228 }
229 }
230 res.push((token_id, token_len))
231 };
232
233 let seq: Box<[_]> = seq.into_iter().collect();
234 let mut cur = 0;
235 while cur < seq.len() {
236 let mut best: Option<(usize, usize)> = None;
237 let mut cur_state = trie.get_root_state();
238 for i in cur..seq.len() {
239 if !cur_state.is_root() && cur_state.is_accepting() {
240 best = Some((cur_state.node_id, i - cur));
241 }
242 let key = &seq[i];
243 cur_state.goto(key);
244 if cur_state.is_nil() {
245 break;
246 }
247 }
248 if !cur_state.is_root() && !cur_state.is_nil() && cur_state.is_accepting() {
249 best = Some((cur_state.node_id, seq.len() - cur));
250 }
251 if let Some((best_token_id, best_token_len)) = best {
252 push(&mut res, best_token_id, best_token_len);
253 cur += best_token_len;
254 } else {
255 push(&mut res, unk_token_id, 1);
256 cur += 1;
257 }
258 }
259
260 res
261 }
262}