Skip to main content

mtc_inc_bpe/
dict.rs

1use std::{hash::Hash, ops::Index};
2
3use bytes::BytesMut;
4use derive_more::{Deref, From, Into};
5use rapidhash::{HashMapExt, RapidHashMap};
6use thiserror::Error;
7use tinyvec::TinyVec;
8
9use crate::{
10    Token, TokenId, Vocab,
11    typed_vec::{TypedVec, typed_vec_index},
12};
13
14typed_vec_index!(pub RuleId, u32);
15
16pub(crate) type RuleIdVec = TinyVec<[RuleId; 6]>;
17const _: () = {
18    assert!(std::mem::size_of::<RuleIdVec>() == 32);
19};
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Into, From)]
22pub struct Rule {
23    pub merged: TokenId,
24    pub pre: TokenId,
25    pub suc: TokenId,
26}
27
28#[derive(Clone, Debug, Deref)]
29pub struct Dictionary {
30    #[deref]
31    vocab: Vocab,
32    pub(crate) rules: TypedVec<RuleId, Rule>,
33    pair_to_rule_id: RapidHashMap<(TokenId, TokenId), RuleId>,
34}
35
36#[derive(Clone, Debug, Error)]
37#[non_exhaustive]
38pub enum DictBuildError {
39    #[error("rule {rule_id} uses an unknown token")]
40    UnknownToken { rule_id: RuleId, token: Token },
41    #[error("rule {rule_id} uses token id {token_id} which exceeds vocab size")]
42    InvalidTokenId { rule_id: RuleId, token_id: TokenId },
43    #[error("rule {rule_id} uses an empty or special token with id {token_id}")]
44    EmptyToken { rule_id: RuleId, token_id: TokenId },
45}
46
47impl Dictionary {
48    fn from_rules(vocab: Vocab, rules: TypedVec<RuleId, Rule>) -> Self {
49        let mut pair_to_rule_id = RapidHashMap::with_capacity(rules.len().as_usize());
50        for (id, rule) in rules.enumerate() {
51            pair_to_rule_id.insert((rule.pre, rule.suc), id);
52        }
53        Self {
54            vocab,
55            rules,
56            pair_to_rule_id,
57        }
58    }
59
60    pub fn new_from_id_pair<T: Into<TokenId>, R: IntoIterator<Item = (T, T)>>(
61        vocab: Vocab,
62        rule_iter: R,
63    ) -> Result<Self, DictBuildError> {
64        let rule_iter = rule_iter.into_iter();
65        let mut rules = TypedVec::with_capacity(RuleId::from(rule_iter.size_hint().0));
66        let get_token = |rule_id, token_id| {
67            vocab
68                .get_token(token_id)
69                .ok_or(DictBuildError::InvalidTokenId { rule_id, token_id })
70                .and_then(|t| {
71                    if t.is_empty() {
72                        Err(DictBuildError::EmptyToken { rule_id, token_id })
73                    } else {
74                        Ok(t)
75                    }
76                })
77        };
78        for (pos, (left, right)) in rule_iter.map(|(i, j)| (i.into(), j.into())).enumerate() {
79            let rule_id = RuleId::from(pos);
80            let token = {
81                let mut buf = BytesMut::from(get_token(rule_id, left)?.clone());
82                buf.extend_from_slice(get_token(rule_id, right)?);
83                buf.freeze()
84            };
85            let merged = vocab
86                .find_token_id(&token)
87                .ok_or(DictBuildError::UnknownToken { rule_id, token })?;
88            rules.push(Rule {
89                merged,
90                pre: left,
91                suc: right,
92            });
93        }
94        Ok(Self::from_rules(vocab, rules))
95    }
96
97    pub fn new_from_token_pair<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
98        vocab: Vocab,
99        rule_iter: R,
100    ) -> Result<Self, DictBuildError> {
101        let rule_iter = rule_iter.into_iter();
102        let mut rules = TypedVec::with_capacity(RuleId::from(rule_iter.size_hint().0));
103        let get_id = |pos, token: &[u8]| {
104            vocab
105                .find_token_id(token)
106                .ok_or(DictBuildError::UnknownToken {
107                    rule_id: pos,
108                    token: token.to_owned().into(),
109                })
110        };
111        for (pos, (left, right)) in rule_iter.enumerate() {
112            let (left, right) = (left.as_ref(), right.as_ref());
113            let pos = RuleId::from(pos);
114            let left_id = get_id(pos, left)?;
115            let right_id = get_id(pos, right)?;
116            let token = {
117                let mut buf = BytesMut::from(left);
118                buf.extend_from_slice(right);
119                buf.freeze()
120            };
121            let merged = get_id(pos, &token)?;
122            rules.push(Rule {
123                merged,
124                pre: left_id,
125                suc: right_id,
126            });
127        }
128        Ok(Self::from_rules(vocab, rules))
129    }
130
131    #[inline(always)]
132    pub fn rules(&self) -> &[Rule] {
133        self.rules.as_slice()
134    }
135
136    #[inline(always)]
137    pub fn get_rule(&self, rule_id: RuleId) -> Option<&Rule> {
138        self.rules.get(rule_id)
139    }
140
141    #[inline(always)]
142    pub fn num_of_rules(&self) -> RuleId {
143        self.rules.len()
144    }
145
146    #[inline(always)]
147    pub fn find_rule(&self, left: TokenId, right: TokenId) -> Option<RuleId> {
148        self.pair_to_rule_id.get(&(left, right)).copied()
149    }
150}
151
152impl Index<RuleId> for Dictionary {
153    type Output = Rule;
154
155    #[inline(always)]
156    fn index(&self, index: RuleId) -> &Self::Output {
157        self.rules.index(index)
158    }
159}
160
161impl Index<TokenId> for Dictionary {
162    type Output = Token;
163
164    #[inline(always)]
165    fn index(&self, index: TokenId) -> &Self::Output {
166        self.vocab.index(index)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use crate::{Dictionary, Vocab};
173
174    fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
175        vocab: &Vocab,
176        rules: R,
177    ) -> Dictionary {
178        Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
179    }
180
181    #[test]
182    fn test_dict() {
183        let vocab = Vocab::new([
184            b"a" as &[_],
185            b"b",
186            b"c",
187            b"d",
188            b"cd",
189            b"bcd",
190            b"abcd",
191            "你".as_bytes(),
192            "好".as_bytes(),
193            "呀".as_bytes(),
194            "你好".as_bytes(),
195            "你好呀".as_bytes(),
196            b"\xe4",
197            b"\xbd",
198            b"\xa0",
199            b"\xbd\xa0",
200        ])
201        .unwrap();
202
203        assert!(Dictionary::new_from_token_pair(vocab.clone(), [("c", "d")]).is_ok());
204        assert!(Dictionary::new_from_token_pair(vocab.clone(), [("a", "b")]).is_err());
205        assert!(Dictionary::new_from_id_pair(vocab.clone(), [(2usize, 3)]).is_ok());
206        assert!(Dictionary::new_from_id_pair(vocab.clone(), [(0usize, 1)]).is_err());
207
208        build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
209        build_dict(&vocab, [("b", "cd")]);
210        build_dict(
211            &vocab,
212            [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
213        );
214        build_dict(&vocab, [("你", "好")]);
215        build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
216        build_dict(&vocab, [("你好", "呀"), ("你", "好")]);
217    }
218}