general_sam/utils/
tokenize.rs

1//! Greedy tokenizer.
2
3use std::ops::{AddAssign, Deref, SubAssign};
4
5use crate::{GeneralSam, GeneralSamState, TransitionTable, TrieNodeAlike};
6
7use super::suffixwise::SuffixInTrieData;
8
9/// Greedy tokenizer with a general suffix automaton of the vocabulary.
10///
11/// Assuming that the input length is $n$, the maximum word length is $l$,
12/// and querying transitions in the trie takes $\mathcal{O}\left(\log{\Sigma}\right)$ time,
13/// then the overall time complexity of this implementation is
14/// $\mathcal{O}\left( n \cdot \left( \log{l} + \log{\Sigma} \right) \right)$.
15///
16/// The main optimization is to store suffix-wise information with persistent ropes.
17/// For each suffix in a state of the suffix automaton,
18/// the longest word matching the prefix of the suffix is stored in the rope.
19/// And the information stored in a state
20/// will be further merged in the ropes of its successors.
21#[derive(Clone, Debug)]
22pub struct GreedyTokenizer<
23    TransTable: TransitionTable,
24    TokenIDType: Clone + Default + PartialEq,
25    SamRef: Deref<Target = GeneralSam<TransTable>>,
26> {
27    sam: SamRef,
28    suffix_data: Vec<SuffixInTrieData<TokenIDType>>,
29}
30
31#[derive(Clone, Debug)]
32pub struct OwnedGeneralSam<TransTable: TransitionTable> {
33    pub sam: GeneralSam<TransTable>,
34}
35
36impl<TransTable: TransitionTable> Deref for OwnedGeneralSam<TransTable> {
37    type Target = GeneralSam<TransTable>;
38
39    fn deref(&self) -> &Self::Target {
40        &self.sam
41    }
42}
43
44impl<TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq>
45    GreedyTokenizer<TransTable, TokenIDType, OwnedGeneralSam<TransTable>>
46{
47    pub fn build_from_sam<
48        TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
49        F: FnMut(&TN) -> TokenIDType,
50    >(
51        sam: GeneralSam<TransTable>,
52        trie_node: TN,
53        f: F,
54    ) -> Self {
55        Self {
56            suffix_data: SuffixInTrieData::build(&sam, trie_node, f),
57            sam: OwnedGeneralSam { sam },
58        }
59    }
60}
61
62impl<
63    TransTable: TransitionTable,
64    TokenIDType: Clone + Default + PartialEq,
65    SamRef: Deref<Target = GeneralSam<TransTable>>,
66> GreedyTokenizer<TransTable, TokenIDType, SamRef>
67{
68    pub fn get_sam(&self) -> &SamRef {
69        &self.sam
70    }
71
72    pub fn get_sam_ref(&self) -> &GeneralSam<TransTable> {
73        &self.sam
74    }
75
76    pub fn get_suffix_data(&self) -> &Vec<SuffixInTrieData<TokenIDType>> {
77        &self.suffix_data
78    }
79
80    pub fn inner_as_ref(
81        &self,
82    ) -> GreedyTokenizer<TransTable, TokenIDType, &GeneralSam<TransTable>> {
83        GreedyTokenizer {
84            sam: &self.sam,
85            suffix_data: self.suffix_data.clone(),
86        }
87    }
88
89    pub fn build<
90        TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
91        F: FnMut(&TN) -> TokenIDType,
92    >(
93        sam: SamRef,
94        trie_node: TN,
95        f: F,
96    ) -> Self {
97        Self {
98            suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f),
99            sam,
100        }
101    }
102
103    pub fn tokenize<Iter: IntoIterator<Item = TransTable::KeyType>>(
104        &self,
105        iter: Iter,
106        unk_token_id: &TokenIDType,
107    ) -> Vec<(TokenIDType, usize)> {
108        let mut res = Vec::new();
109
110        let push = |res: &mut Vec<_>, token_id: TokenIDType, token_len: usize| {
111            if let Some((last_token_id, last_token_len)) = res.last_mut()
112                && *last_token_id == *unk_token_id
113                && token_id == *unk_token_id
114            {
115                *last_token_len += token_len;
116                return;
117            }
118            res.push((token_id, token_len))
119        };
120
121        let pop_buffer = |cur_len: &mut usize,
122                          cur_state: &mut GeneralSamState<TransTable, &GeneralSam<TransTable>>,
123                          res: &mut Vec<_>| {
124            let inner_data = self.suffix_data[cur_state.node_id]
125                .get(*cur_len)
126                .expect("invalid state");
127
128            // TODO: Optimize for unknown tokens:
129            // Find the lower bound position where the suffix is prefixed with a token.
130            // But this does not improve the time complexity, pending...
131            let (token_id, token_len) = inner_data.as_ref().map_or_else(
132                || (unk_token_id, 1),
133                |token_info| (&token_info.digested_trie_node, token_info.seq_len),
134            );
135
136            cur_len.sub_assign(token_len);
137            push(res, token_id.clone(), token_len);
138        };
139
140        let mut cur_state = self.sam.get_root_state();
141        let mut cur_len = 0;
142
143        for key in iter {
144            debug_assert!(!cur_state.is_nil());
145            let mut nxt_state = cur_state.get_non_nil_trans(&key);
146            while cur_len > 0 && nxt_state.is_none() {
147                pop_buffer(&mut cur_len, &mut cur_state, &mut res);
148
149                if cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
150                    while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
151                        cur_state.goto_suffix_parent();
152                    }
153                    nxt_state = cur_state.get_non_nil_trans(&key);
154                }
155            }
156            if let Some(nxt) = nxt_state {
157                cur_state = nxt;
158                cur_len.add_assign(1);
159            } else {
160                debug_assert!(cur_state.is_root());
161                push(&mut res, unk_token_id.clone(), 1);
162            }
163        }
164
165        while cur_len > 0 {
166            pop_buffer(&mut cur_len, &mut cur_state, &mut res);
167
168            while cur_len < self.suffix_data[cur_state.node_id].get_min_suf_len() {
169                cur_state.goto_suffix_parent();
170            }
171        }
172
173        res
174    }
175}
176
177#[cfg(feature = "trie")]
178pub mod trie {
179    use std::ops::Deref;
180
181    use crate::{GeneralSam, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState};
182
183    use super::OwnedGeneralSam;
184
185    impl<TransTable: TransitionTable, SamRef: Deref<Target = GeneralSam<TransTable>>>
186        super::GreedyTokenizer<TransTable, TrieNodeID, SamRef>
187    {
188        pub fn build_from_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
189            sam: SamRef,
190            trie_state: TrieState<TT, &Trie<TT>>,
191        ) -> Self {
192            Self::build(sam, trie_state, |tn| tn.node_id)
193        }
194    }
195
196    impl<TransTable: TransitionTable>
197        super::GreedyTokenizer<TransTable, TrieNodeID, OwnedGeneralSam<TransTable>>
198    {
199        pub fn build_from_sam_and_trie<TT: TransitionTable<KeyType = TransTable::KeyType>>(
200            sam: GeneralSam<TransTable>,
201            trie_state: TrieState<TT, &Trie<TT>>,
202        ) -> Self {
203            Self::build_from_sam(sam, trie_state, |tn| tn.node_id)
204        }
205    }
206
207    /// Greedy tokenizer with a trie of the vocabulary.
208    ///
209    /// Assuming that the input length is $n$, the maximum word length is $l$,
210    /// and querying transitions in the trie takes $\mathcal{O}\left(\log{\Sigma}\right)$ time,
211    /// then the overall time complexity of this implementation is
212    /// $\mathcal{O}\left( n \cdot l \cdot \log{\Sigma} \right)$.
213    pub fn greedy_tokenize_with_trie<
214        TransTable: TransitionTable,
215        Iter: IntoIterator<Item = TransTable::KeyType>,
216    >(
217        trie: &Trie<TransTable>,
218        seq: Iter,
219    ) -> Vec<(usize, usize)> {
220        let unk_token_id = trie.num_of_nodes();
221
222        let mut res = Vec::new();
223
224        let push = |res: &mut Vec<_>, token_id: usize, token_len: usize| {
225            if let Some((last_token_id, last_token_len)) = res.last_mut()
226                && *last_token_id == unk_token_id
227                && token_id == unk_token_id
228            {
229                *last_token_len += token_len;
230                return;
231            }
232            res.push((token_id, token_len))
233        };
234
235        let seq: Box<[_]> = seq.into_iter().collect();
236        let mut cur = 0;
237        while cur < seq.len() {
238            let mut best: Option<(usize, usize)> = None;
239            let mut cur_state = trie.get_root_state();
240            for i in cur..seq.len() {
241                if !cur_state.is_root() && cur_state.is_accepting() {
242                    best = Some((cur_state.node_id, i - cur));
243                }
244                let key = &seq[i];
245                cur_state.goto(key);
246                if cur_state.is_nil() {
247                    break;
248                }
249            }
250            if !cur_state.is_root() && !cur_state.is_nil() && cur_state.is_accepting() {
251                best = Some((cur_state.node_id, seq.len() - cur));
252            }
253            if let Some((best_token_id, best_token_len)) = best {
254                push(&mut res, best_token_id, best_token_len);
255                cur += best_token_len;
256            } else {
257                push(&mut res, unk_token_id, 1);
258                cur += 1;
259            }
260        }
261
262        res
263    }
264}