1use std::{hash::Hash, iter::FusedIterator, ops::Index};
2
3use bytes::Bytes;
4use rapidhash::RapidHashMap;
5use thiserror::Error;
6use tinyvec::TinyVec;
7
8use crate::typed_vec::{TypedVec, typed_vec_index};
9
10typed_vec_index!(pub TokenId, u32);
11
12pub(crate) type TokenIdVec = TinyVec<[TokenId; 6]>;
13const _: () = {
14 assert!(std::mem::size_of::<TokenIdVec>() == 32);
15};
16
17pub type Token = Bytes;
18
19pub const MAX_TOKEN_LENGTH: usize = (1 << 14) - 1;
20
21const _: () = {
22 assert!(MAX_TOKEN_LENGTH < u16::MAX as usize);
23};
24
25#[derive(Clone, Debug)]
26pub struct Vocab {
27 pub(crate) tokens: TypedVec<TokenId, Token>,
28 token_to_id: RapidHashMap<Token, TokenId>,
29 u8_to_id: Box<[TokenId; 1 << 8]>,
30 char_to_id: RapidHashMap<char, TokenId>,
31}
32
33#[derive(Clone, Debug, Error)]
34#[non_exhaustive]
35pub enum VocabBuildError {
36 #[error("duplicated tokens with ids {a} and {b}")]
37 Duplicated { a: TokenId, b: TokenId },
38 #[error("token {token_id} exceeds length limit {MAX_TOKEN_LENGTH}")]
40 TokenTooLong { token_id: TokenId },
41}
42
43#[inline(always)]
44fn utf8_char_token(token: &[u8]) -> Option<char> {
45 if token.is_empty() || token.len() > 4 {
46 return None;
47 }
48 let Ok(s) = str::from_utf8(token) else {
49 return None;
50 };
51 debug_assert!(!s.is_empty());
52 let mut iter = s.chars();
53 let res = iter.next().unwrap();
54 if iter.next().is_none() {
55 Some(res)
56 } else {
57 None
58 }
59}
60
61impl Vocab {
62 pub fn new<T: Into<Token>, I: IntoIterator<Item = T>>(
63 iter: I,
64 ) -> Result<Self, VocabBuildError> {
65 let mut token_to_id = RapidHashMap::default();
66 let mut u8_to_id = Box::new([TokenId::MAX; _]);
67 let mut char_to_id = RapidHashMap::default();
68
69 let convert_token = |(k, token): (usize, T)| {
70 let token = token.into();
71 let token_id = TokenId::from(k);
72 if token.len() == 1 {
73 u8_to_id[token.as_ref()[0] as usize] = token_id;
74 }
75 if let Some(c) = utf8_char_token(&token) {
76 char_to_id.insert(c, token_id);
77 }
78 if token.len() > MAX_TOKEN_LENGTH {
79 Err(VocabBuildError::TokenTooLong { token_id })
80 } else if !token.is_empty()
81 && let Some(other) = token_to_id.insert(token.clone(), token_id)
82 {
83 Err(VocabBuildError::Duplicated {
84 a: other,
85 b: token_id,
86 })
87 } else {
88 Ok(token)
89 }
90 };
91
92 let tokens: TypedVec<_, _> = iter
93 .into_iter()
94 .enumerate()
95 .map(convert_token)
96 .collect::<Result<_, _>>()?;
97 debug_assert!(tokens.as_slice().len() >= token_to_id.len());
98
99 Ok(Self {
100 tokens,
101 token_to_id,
102 u8_to_id,
103 char_to_id,
104 })
105 }
106
107 #[inline(always)]
108 pub fn find_token_id<T: AsRef<[u8]>>(&self, token: T) -> Option<TokenId> {
109 self.token_to_id.get(token.as_ref()).copied()
110 }
111
112 #[inline(always)]
113 pub fn get_token<T: Into<TokenId>>(&self, token_id: T) -> Option<&Token> {
114 self.tokens.get(token_id.into())
115 }
116
117 #[inline(always)]
118 pub fn num_of_tokens(&self) -> TokenId {
119 self.tokens.len()
120 }
121
122 #[inline(always)]
123 pub fn tokens(&self) -> &[Token] {
124 self.tokens.as_slice()
125 }
126
127 #[inline(always)]
128 pub fn token_to_id_map(&self) -> &RapidHashMap<Token, TokenId> {
129 &self.token_to_id
130 }
131
132 #[inline(always)]
133 pub fn find_by_byte_unchecked(&self, b: u8) -> TokenId {
134 self.u8_to_id[b as usize]
135 }
136
137 #[inline(always)]
138 pub fn find_by_byte(&self, b: u8) -> Option<TokenId> {
139 Some(self.find_by_byte_unchecked(b)).filter(|&i| i != TokenId::MAX)
140 }
141
142 #[inline(always)]
143 pub fn find_by_char(&self, c: char) -> Option<TokenId> {
144 self.char_to_id.get(&c).copied()
145 }
146
147 #[inline(always)]
148 pub fn split_bytes_to_tokens_unchecked(
149 &self,
150 seq: &[u8],
151 ) -> impl DoubleEndedIterator<Item = TokenId> + ExactSizeIterator + FusedIterator {
152 seq.iter().map(|&b| self.find_by_byte_unchecked(b))
153 }
154
155 #[inline(always)]
156 pub fn split_bytes_to_tokens(
157 &self,
158 seq: &[u8],
159 ) -> impl DoubleEndedIterator<Item = Option<TokenId>> + ExactSizeIterator + FusedIterator {
160 seq.iter().map(|&b| self.find_by_byte(b))
161 }
162
163 #[inline(always)]
164 pub fn split_utf8_to_tokens(
165 &self,
166 seq: &str,
167 ) -> impl DoubleEndedIterator<Item = Option<TokenId>> + FusedIterator {
168 seq.chars().map(|c| self.find_by_char(c))
169 }
170}
171
172impl Index<TokenId> for Vocab {
173 type Output = Token;
174
175 #[inline(always)]
176 fn index(&self, index: TokenId) -> &Self::Output {
177 self.tokens.index(index)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::{
184 TokenId, Vocab,
185 test_utils::{bytes_into_tokens, utf8_into_tokens},
186 };
187
188 #[test]
189 fn test_vocab() {
190 assert!(Vocab::new([b"abc" as &[_], b"abcd"]).is_ok());
191 assert!(Vocab::new([b"" as &[_], b"abc", b""]).is_ok());
192
193 let vocab = Vocab::new([b"a" as &[_], b"b", b"c", b"d", b"cd", b"bcd", b"abcd"]).unwrap();
194
195 assert_eq!(vocab.num_of_tokens().0, 7);
196
197 assert_eq!(vocab.find_token_id(b"a"), Some(TokenId::new(0)));
198 assert_eq!(vocab.find_token_id(b"b"), Some(TokenId::new(1)));
199 assert_eq!(vocab.find_token_id(b"cd"), Some(TokenId::new(4)));
200 assert_eq!(vocab.find_token_id(b"abcd"), Some(TokenId::new(6)));
201 assert_eq!(vocab.find_token_id(b""), None);
202 assert_eq!(vocab.find_token_id(b"e"), None);
203 assert_eq!(vocab.find_token_id(b"random"), None);
204
205 let check_token = |id: u32, e: &str| {
206 let token = vocab.get_token(id).map(|b| b.as_ref());
207 assert_eq!(token, Some(e.as_bytes()));
208 };
209 check_token(0, "a");
210 check_token(3, "d");
211 check_token(6, "abcd");
212 assert!(vocab.get_token(7u32).is_none());
213 }
214
215 #[test]
216 fn test_pre_tokenize() {
217 let vocab = Vocab::new([
218 b"a" as &[_],
219 b"b",
220 b"c",
221 b"d",
222 b"cd",
223 b"bcd",
224 b"abcd",
225 "你".as_bytes(),
226 "好".as_bytes(),
227 "呀".as_bytes(),
228 "你好".as_bytes(),
229 "你好呀".as_bytes(),
230 b"\xe4",
231 b"\xbd",
232 b"\xa0",
233 b"\xbd\xa0",
234 ])
235 .unwrap();
236
237 let expected = [12, 13, 14, u32::MAX, u32::MAX, 13];
238 let output = bytes_into_tokens(&vocab, "你好", u32::MAX);
239 assert_eq!(output, expected.map(TokenId::new));
240
241 let output = utf8_into_tokens(&vocab, "你好", u32::MAX);
242 assert_eq!(output, [7, 8].map(TokenId::new));
243 }
244}