general_sam/utils/
tokenize.rs

1//! Greedy tokenizer.
2
3use std::ops::{AddAssign, Deref, SubAssign};
4
5use crate::{GeneralSam, GeneralSamState, TransitionTable, TrieNodeAlike};
6
7use super::suffixwise::SuffixInTrieData;
8
9/// Greedy tokenizer with a general suffix automaton of the vocabulary.
10///
11/// Assuming that the input length is $n$, the maximum word length is $l$,
12/// and querying transitions in the trie takes $\mathcal{O}\left(\log{\Sigma}\right)$ time,
13/// then the overall time complexity of this implementation is
14/// $\mathcal{O}\left( n \cdot \left( \log{l} + \log{\Sigma} \right) \right)$.
15///
16/// The main optimization is to store suffix-wise information with persistent ropes.
17/// For each suffix in a state of the suffix automaton,
18/// the longest word matching the prefix of the suffix is stored in the rope.
19/// And the information stored in a state
20/// will be further merged in the ropes of its successors.
21#[derive(Clone, Debug)]
22pub struct GreedyTokenizer<
23    TransTable: TransitionTable,
24    TokenIDType: Clone + Default + PartialEq,
25    SamRef: Deref<Target = GeneralSam<TransTable>>,
26> {
27    sam: SamRef,
28    suffix_data: Vec<SuffixInTrieData<TokenIDType>>,
29}
30
31#[derive(Clone, Debug)]
32pub struct OwnedGeneralSam<TransTable: TransitionTable> {
33    pub sam: GeneralSam<TransTable>,
34}
35
36impl<TransTable: TransitionTable> Deref for OwnedGeneralSam<TransTable> {
37    type Target = GeneralSam<TransTable>;
38
39    fn deref(&self) -> &Self::Target {
40        &self.sam
41    }
42}
43
44impl<TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq>
45    GreedyTokenizer<TransTable, TokenIDType, OwnedGeneralSam<TransTable>>
46{
47    pub fn build_from_sam<
48        TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
49        F: FnMut(&TN) -> TokenIDType,
50    >(
51        sam: GeneralSam<TransTable>,
52        trie_node: TN,
53        f: F,
54    ) -> Self {
55        Self {
56            suffix_data: SuffixInTrieData::build(&sam, trie_node, f),
57            sam: OwnedGeneralSam { sam },
58        }
59    }
60}
61
62impl<
63        TransTable: TransitionTable,
64        TokenIDType: Clone + Default + PartialEq,
65        SamRef: Deref<Target = GeneralSam<TransTable>>,
66    > GreedyTokenizer<TransTable, TokenIDType, SamRef>
67{
68    pub fn get_sam(&self) -> &SamRef {
69        &self.sam
70    }
71
72    pub fn get_sam_ref(&self) -> &GeneralSam<TransTable> {
73        &self.sam
74    }
75
76    pub fn get_suffix_data(&self) -> &Vec<SuffixInTrieData<TokenIDType>> {
77        &self.suffix_data
78    }
79
80    pub fn inner_as_ref(
81        &self,
82    ) -> GreedyTokenizer<TransTable, TokenIDType, &GeneralSam<TransTable>> {
83        GreedyTokenizer {
84            sam: &self.sam,
85            suffix_data: self.suffix_data.clone(),
86        }
87    }
88
89    pub fn build<
90        TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
91        F: FnMut(&TN) -> TokenIDType,
92    >(
93        sam: SamRef,
94        trie_node: TN,
95        f: F,
96    ) -> Self {
97        Self {
98            suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f),
99            sam,
100        }
101    }
102
103    pub fn tokenize<Iter: IntoIterator<Item = TransTable::KeyType>>(
104        &self,
105        iter: Iter,
106        unk_token_id: &TokenIDType,
107    ) -> Vec<(TokenIDType, usize)> {
108        let mut res = Vec::new();
109
110        let push = |res: &mut Vec<_>, token_id: TokenIDType, token_len: usize| {
111            if let Some((last_token_id, last_token_len)) = res.last_mut() {
112                if *last_token_id == *unk_token_id && token_id == *unk_token_id {
113                    *last_token_len += token_len;
114                    return;
115                }
116            }
117            res.push((token_id, token_len))
118        };
119
120        let pop_buffer = |cur_len: &mut usize,
121                          cur_state: &mut GeneralSamState<TransTable, &GeneralSam<TransTable>>,
122                          res: &mut Vec<_>| {
123            let inner_data = self.suffix_data[cur_state.node_id]
124                .get(*cur_len)
125                .expect("invalid state");
126
127            // TODO: Optimize for unknown tokens:
128            // Find the lower bound position where the suffix is prefixed with a token.
129            // But this does not improve the time complexity, pending...
130            let (token_id, token_len) = inner_data.as_ref().map_or_else(
131                || (unk_token_id, 1),
132                |token_info| (&token_info.digested_trie_node, token_info.seq_len),
133            );
134
135            cur_len.sub_assign(token_len);
136            push(res, token_id.clone(), token_len);
137        };
138
139        let mut cur_state = self.sam.get_root_state();
140        let mut cur_len = 0;
141
142        for key in iter {
143            debug_assert!(!cur_state.is_nil());
144            let mut nxt_state = cur_state.get_non_nil_trans(&key);
145            while cur_len > 0 && nxt_state.is_none() {
146                pop_buffer(&mut cur_len, &mut cur_state, &mut res);
147
148                if cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
149                    while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
150                        cur_state.goto_suffix_parent();
151                    }
152                    nxt_state = cur_state.get_non_nil_trans(&key);
153                }
154            }
155            if let Some(nxt) = nxt_state {
156                cur_state = nxt;
157                cur_len.add_assign(1);
158            } else {
159                debug_assert!(cur_state.is_root());
160                push(&mut res, unk_token_id.clone(), 1);
161            }
162        }
163
164        while cur_len > 0 {
165            pop_buffer(&mut cur_len, &mut cur_state, &mut res);
166
167            while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
168                cur_state.goto_suffix_parent();
169            }
170        }
171
172        res
173    }
174}
175
176#[cfg(feature = "trie")]
177#[cfg_attr(doc_cfg, doc(cfg(feature = "trie")))]
178pub mod trie {
179    use std::ops::Deref;
180
181    use crate::{GeneralSam, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState};
182
183    use super::OwnedGeneralSam;
184
185    impl<TransTable: TransitionTable, SamRef: Deref<Target = GeneralSam<TransTable>>>
186        super::GreedyTokenizer<TransTable, TrieNodeID, SamRef>
187    {
188        pub fn build_from_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
189            sam: SamRef,
190            trie_state: TrieState<TT, &Trie<TT>>,
191        ) -> Self {
192            Self::build(sam, trie_state, |tn| tn.node_id)
193        }
194    }
195
196    impl<TransTable: TransitionTable>
197        super::GreedyTokenizer<TransTable, TrieNodeID, OwnedGeneralSam<TransTable>>
198    {
199        pub fn build_from_sam_and_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
200            sam: GeneralSam<TransTable>,
201            trie_state: TrieState<TT, &Trie<TT>>,
202        ) -> Self {
203            Self::build_from_sam(sam, trie_state, |tn| tn.node_id)
204        }
205    }
206
207    /// Greedy tokenizer with a trie of the vocabulary.
208    ///
209    /// Assuming that the input length is $n$, the maximum word length is $l$,
210    /// and querying transitions in the trie takes $\mathcal{O}\left(\log{\Sigma}\right)$ time,
211    /// then the overall time complexity of this implementation is
212    /// $\mathcal{O}\left( n \cdot l \cdot \log{\Sigma} \right)$.
213    pub fn greedy_tokenize_with_trie<
214        TransTable: TransitionTable,
215        Iter: IntoIterator<Item = TransTable::KeyType>,
216    >(
217        trie: &Trie<TransTable>,
218        seq: Iter,
219    ) -> Vec<(usize, usize)> {
220        let unk_token_id = trie.num_of_nodes();
221
222        let mut res = Vec::new();
223
224        let push = |res: &mut Vec<_>, token_id: usize, token_len: usize| {
225            if let Some((last_token_id, last_token_len)) = res.last_mut() {
226                if *last_token_id == unk_token_id && token_id == unk_token_id {
227                    *last_token_len += token_len;
228                    return;
229                }
230            }
231            res.push((token_id, token_len))
232        };
233
234        let seq: Box<[_]> = seq.into_iter().collect();
235        let mut cur = 0;
236        while cur < seq.len() {
237            let mut best: Option<(usize, usize)> = None;
238            let mut cur_state = trie.get_root_state();
239            for i in cur..seq.len() {
240                if !cur_state.is_root() && cur_state.is_accepting() {
241                    best = Some((cur_state.node_id, i - cur));
242                }
243                let key = &seq[i];
244                cur_state.goto(key);
245                if cur_state.is_nil() {
246                    break;
247                }
248            }
249            if !cur_state.is_root() && !cur_state.is_nil() && cur_state.is_accepting() {
250                best = Some((cur_state.node_id, seq.len() - cur));
251            }
252            if let Some((best_token_id, best_token_len)) = best {
253                push(&mut res, best_token_id, best_token_len);
254                cur += best_token_len;
255            } else {
256                push(&mut res, unk_token_id, 1);
257                cur += 1;
258            }
259        }
260
261        res
262    }
263}