1use std::{hash::Hash, ops::Index};
2
3use bytes::BytesMut;
4use derive_more::{Deref, From, Into};
5use rapidhash::{HashMapExt, RapidHashMap};
6use thiserror::Error;
7use tinyvec::TinyVec;
8
9use crate::{
10 Token, TokenId, Vocab,
11 typed_vec::{TypedVec, typed_vec_index},
12};
13
14typed_vec_index!(pub RuleId, u32);
15
16pub(crate) type RuleIdVec = TinyVec<[RuleId; 6]>;
17const _: () = {
18 assert!(std::mem::size_of::<RuleIdVec>() == 32);
19};
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Into, From)]
22pub struct Rule {
23 pub merged: TokenId,
24 pub pre: TokenId,
25 pub suc: TokenId,
26}
27
28#[derive(Clone, Debug, Deref)]
29pub struct Dictionary {
30 #[deref]
31 vocab: Vocab,
32 pub(crate) rules: TypedVec<RuleId, Rule>,
33 pair_to_rule_id: RapidHashMap<(TokenId, TokenId), RuleId>,
34}
35
36#[derive(Clone, Debug, Error)]
37#[non_exhaustive]
38pub enum DictBuildError {
39 #[error("rule {rule_id} uses an unknown token")]
40 UnknownToken { rule_id: RuleId, token: Token },
41 #[error("rule {rule_id} uses token id {token_id} which exceeds vocab size")]
42 InvalidTokenId { rule_id: RuleId, token_id: TokenId },
43 #[error("rule {rule_id} uses an empty or special token with id {token_id}")]
44 EmptyToken { rule_id: RuleId, token_id: TokenId },
45}
46
47impl Dictionary {
48 fn from_rules(vocab: Vocab, rules: TypedVec<RuleId, Rule>) -> Self {
49 let mut pair_to_rule_id = RapidHashMap::with_capacity(rules.len().as_usize());
50 for (id, rule) in rules.enumerate() {
51 pair_to_rule_id.insert((rule.pre, rule.suc), id);
52 }
53 Self {
54 vocab,
55 rules,
56 pair_to_rule_id,
57 }
58 }
59
60 pub fn new_from_id_pair<T: Into<TokenId>, R: IntoIterator<Item = (T, T)>>(
61 vocab: Vocab,
62 rule_iter: R,
63 ) -> Result<Self, DictBuildError> {
64 let rule_iter = rule_iter.into_iter();
65 let mut rules = TypedVec::with_capacity(RuleId::from(rule_iter.size_hint().0));
66 let get_token = |rule_id, token_id| {
67 vocab
68 .get_token(token_id)
69 .ok_or(DictBuildError::InvalidTokenId { rule_id, token_id })
70 .and_then(|t| {
71 if t.is_empty() {
72 Err(DictBuildError::EmptyToken { rule_id, token_id })
73 } else {
74 Ok(t)
75 }
76 })
77 };
78 for (pos, (left, right)) in rule_iter.map(|(i, j)| (i.into(), j.into())).enumerate() {
79 let rule_id = RuleId::from(pos);
80 let token = {
81 let mut buf = BytesMut::from(get_token(rule_id, left)?.clone());
82 buf.extend_from_slice(get_token(rule_id, right)?);
83 buf.freeze()
84 };
85 let merged = vocab
86 .find_token_id(&token)
87 .ok_or(DictBuildError::UnknownToken { rule_id, token })?;
88 rules.push(Rule {
89 merged,
90 pre: left,
91 suc: right,
92 });
93 }
94 Ok(Self::from_rules(vocab, rules))
95 }
96
97 pub fn new_from_token_pair<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
98 vocab: Vocab,
99 rule_iter: R,
100 ) -> Result<Self, DictBuildError> {
101 let rule_iter = rule_iter.into_iter();
102 let mut rules = TypedVec::with_capacity(RuleId::from(rule_iter.size_hint().0));
103 let get_id = |pos, token: &[u8]| {
104 vocab
105 .find_token_id(token)
106 .ok_or(DictBuildError::UnknownToken {
107 rule_id: pos,
108 token: token.to_owned().into(),
109 })
110 };
111 for (pos, (left, right)) in rule_iter.enumerate() {
112 let (left, right) = (left.as_ref(), right.as_ref());
113 let pos = RuleId::from(pos);
114 let left_id = get_id(pos, left)?;
115 let right_id = get_id(pos, right)?;
116 let token = {
117 let mut buf = BytesMut::from(left);
118 buf.extend_from_slice(right);
119 buf.freeze()
120 };
121 let merged = get_id(pos, &token)?;
122 rules.push(Rule {
123 merged,
124 pre: left_id,
125 suc: right_id,
126 });
127 }
128 Ok(Self::from_rules(vocab, rules))
129 }
130
131 #[inline(always)]
132 pub fn rules(&self) -> &[Rule] {
133 self.rules.as_slice()
134 }
135
136 #[inline(always)]
137 pub fn get_rule(&self, rule_id: RuleId) -> Option<&Rule> {
138 self.rules.get(rule_id)
139 }
140
141 #[inline(always)]
142 pub fn num_of_rules(&self) -> RuleId {
143 self.rules.len()
144 }
145
146 #[inline(always)]
147 pub fn find_rule(&self, left: TokenId, right: TokenId) -> Option<RuleId> {
148 self.pair_to_rule_id.get(&(left, right)).copied()
149 }
150}
151
152impl Index<RuleId> for Dictionary {
153 type Output = Rule;
154
155 #[inline(always)]
156 fn index(&self, index: RuleId) -> &Self::Output {
157 self.rules.index(index)
158 }
159}
160
161impl Index<TokenId> for Dictionary {
162 type Output = Token;
163
164 #[inline(always)]
165 fn index(&self, index: TokenId) -> &Self::Output {
166 self.vocab.index(index)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use crate::{Dictionary, Vocab};
173
174 fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
175 vocab: &Vocab,
176 rules: R,
177 ) -> Dictionary {
178 Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
179 }
180
181 #[test]
182 fn test_dict() {
183 let vocab = Vocab::new([
184 b"a" as &[_],
185 b"b",
186 b"c",
187 b"d",
188 b"cd",
189 b"bcd",
190 b"abcd",
191 "你".as_bytes(),
192 "好".as_bytes(),
193 "呀".as_bytes(),
194 "你好".as_bytes(),
195 "你好呀".as_bytes(),
196 b"\xe4",
197 b"\xbd",
198 b"\xa0",
199 b"\xbd\xa0",
200 ])
201 .unwrap();
202
203 assert!(Dictionary::new_from_token_pair(vocab.clone(), [("c", "d")]).is_ok());
204 assert!(Dictionary::new_from_token_pair(vocab.clone(), [("a", "b")]).is_err());
205 assert!(Dictionary::new_from_id_pair(vocab.clone(), [(2usize, 3)]).is_ok());
206 assert!(Dictionary::new_from_id_pair(vocab.clone(), [(0usize, 1)]).is_err());
207
208 build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
209 build_dict(&vocab, [("b", "cd")]);
210 build_dict(
211 &vocab,
212 [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
213 );
214 build_dict(&vocab, [("你", "好")]);
215 build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
216 build_dict(&vocab, [("你好", "呀"), ("你", "好")]);
217 }
218}