Skip to main content

mtc_inc_bpe/
vocab.rs

1use std::{hash::Hash, iter::FusedIterator, ops::Index};
2
3use bytes::Bytes;
4use rapidhash::RapidHashMap;
5use thiserror::Error;
6use tinyvec::TinyVec;
7
8use crate::typed_vec::{TypedVec, typed_vec_index};
9
10typed_vec_index!(pub TokenId, u32);
11
12pub(crate) type TokenIdVec = TinyVec<[TokenId; 6]>;
13const _: () = {
14    assert!(std::mem::size_of::<TokenIdVec>() == 32);
15};
16
17pub type Token = Bytes;
18
19pub const MAX_TOKEN_LENGTH: usize = (1 << 14) - 1;
20
21const _: () = {
22    assert!(MAX_TOKEN_LENGTH < u16::MAX as usize);
23};
24
25#[derive(Clone, Debug)]
26pub struct Vocab {
27    pub(crate) tokens: TypedVec<TokenId, Token>,
28    token_to_id: RapidHashMap<Token, TokenId>,
29    u8_to_id: Box<[TokenId; 1 << 8]>,
30    char_to_id: RapidHashMap<char, TokenId>,
31}
32
33#[derive(Clone, Debug, Error)]
34#[non_exhaustive]
35pub enum VocabBuildError {
36    #[error("duplicated tokens with ids {a} and {b}")]
37    Duplicated { a: TokenId, b: TokenId },
38    /// Token length is limited by [`MAX_TOKEN_LENGTH`].
39    #[error("token {token_id} exceeds length limit {MAX_TOKEN_LENGTH}")]
40    TokenTooLong { token_id: TokenId },
41}
42
43#[inline(always)]
44fn utf8_char_token(token: &[u8]) -> Option<char> {
45    if token.is_empty() || token.len() > 4 {
46        return None;
47    }
48    let Ok(s) = str::from_utf8(token) else {
49        return None;
50    };
51    debug_assert!(!s.is_empty());
52    let mut iter = s.chars();
53    let res = iter.next().unwrap();
54    if iter.next().is_none() {
55        Some(res)
56    } else {
57        None
58    }
59}
60
61impl Vocab {
62    pub fn new<T: Into<Token>, I: IntoIterator<Item = T>>(
63        iter: I,
64    ) -> Result<Self, VocabBuildError> {
65        let mut token_to_id = RapidHashMap::default();
66        let mut u8_to_id = Box::new([TokenId::MAX; _]);
67        let mut char_to_id = RapidHashMap::default();
68
69        let convert_token = |(k, token): (usize, T)| {
70            let token = token.into();
71            let token_id = TokenId::from(k);
72            if token.len() == 1 {
73                u8_to_id[token.as_ref()[0] as usize] = token_id;
74            }
75            if let Some(c) = utf8_char_token(&token) {
76                char_to_id.insert(c, token_id);
77            }
78            if token.len() > MAX_TOKEN_LENGTH {
79                Err(VocabBuildError::TokenTooLong { token_id })
80            } else if !token.is_empty()
81                && let Some(other) = token_to_id.insert(token.clone(), token_id)
82            {
83                Err(VocabBuildError::Duplicated {
84                    a: other,
85                    b: token_id,
86                })
87            } else {
88                Ok(token)
89            }
90        };
91
92        let tokens: TypedVec<_, _> = iter
93            .into_iter()
94            .enumerate()
95            .map(convert_token)
96            .collect::<Result<_, _>>()?;
97        debug_assert!(tokens.as_slice().len() >= token_to_id.len());
98
99        Ok(Self {
100            tokens,
101            token_to_id,
102            u8_to_id,
103            char_to_id,
104        })
105    }
106
107    #[inline(always)]
108    pub fn find_token_id<T: AsRef<[u8]>>(&self, token: T) -> Option<TokenId> {
109        self.token_to_id.get(token.as_ref()).copied()
110    }
111
112    #[inline(always)]
113    pub fn get_token<T: Into<TokenId>>(&self, token_id: T) -> Option<&Token> {
114        self.tokens.get(token_id.into())
115    }
116
117    #[inline(always)]
118    pub fn num_of_tokens(&self) -> TokenId {
119        self.tokens.len()
120    }
121
122    #[inline(always)]
123    pub fn tokens(&self) -> &[Token] {
124        self.tokens.as_slice()
125    }
126
127    #[inline(always)]
128    pub fn token_to_id_map(&self) -> &RapidHashMap<Token, TokenId> {
129        &self.token_to_id
130    }
131
132    #[inline(always)]
133    pub fn find_by_byte_unchecked(&self, b: u8) -> TokenId {
134        self.u8_to_id[b as usize]
135    }
136
137    #[inline(always)]
138    pub fn find_by_byte(&self, b: u8) -> Option<TokenId> {
139        Some(self.find_by_byte_unchecked(b)).filter(|&i| i != TokenId::MAX)
140    }
141
142    #[inline(always)]
143    pub fn find_by_char(&self, c: char) -> Option<TokenId> {
144        self.char_to_id.get(&c).copied()
145    }
146
147    #[inline(always)]
148    pub fn split_bytes_to_tokens_unchecked(
149        &self,
150        seq: &[u8],
151    ) -> impl DoubleEndedIterator<Item = TokenId> + ExactSizeIterator + FusedIterator {
152        seq.iter().map(|&b| self.find_by_byte_unchecked(b))
153    }
154
155    #[inline(always)]
156    pub fn split_bytes_to_tokens(
157        &self,
158        seq: &[u8],
159    ) -> impl DoubleEndedIterator<Item = Option<TokenId>> + ExactSizeIterator + FusedIterator {
160        seq.iter().map(|&b| self.find_by_byte(b))
161    }
162
163    #[inline(always)]
164    pub fn split_utf8_to_tokens(
165        &self,
166        seq: &str,
167    ) -> impl DoubleEndedIterator<Item = Option<TokenId>> + FusedIterator {
168        seq.chars().map(|c| self.find_by_char(c))
169    }
170}
171
172impl Index<TokenId> for Vocab {
173    type Output = Token;
174
175    #[inline(always)]
176    fn index(&self, index: TokenId) -> &Self::Output {
177        self.tokens.index(index)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use crate::{
184        TokenId, Vocab,
185        test_utils::{bytes_into_tokens, utf8_into_tokens},
186    };
187
188    #[test]
189    fn test_vocab() {
190        assert!(Vocab::new([b"abc" as &[_], b"abcd"]).is_ok());
191        assert!(Vocab::new([b"" as &[_], b"abc", b""]).is_ok());
192
193        let vocab = Vocab::new([b"a" as &[_], b"b", b"c", b"d", b"cd", b"bcd", b"abcd"]).unwrap();
194
195        assert_eq!(vocab.num_of_tokens().0, 7);
196
197        assert_eq!(vocab.find_token_id(b"a"), Some(TokenId::new(0)));
198        assert_eq!(vocab.find_token_id(b"b"), Some(TokenId::new(1)));
199        assert_eq!(vocab.find_token_id(b"cd"), Some(TokenId::new(4)));
200        assert_eq!(vocab.find_token_id(b"abcd"), Some(TokenId::new(6)));
201        assert_eq!(vocab.find_token_id(b""), None);
202        assert_eq!(vocab.find_token_id(b"e"), None);
203        assert_eq!(vocab.find_token_id(b"random"), None);
204
205        let check_token = |id: u32, e: &str| {
206            let token = vocab.get_token(id).map(|b| b.as_ref());
207            assert_eq!(token, Some(e.as_bytes()));
208        };
209        check_token(0, "a");
210        check_token(3, "d");
211        check_token(6, "abcd");
212        assert!(vocab.get_token(7u32).is_none());
213    }
214
215    #[test]
216    fn test_pre_tokenize() {
217        let vocab = Vocab::new([
218            b"a" as &[_],
219            b"b",
220            b"c",
221            b"d",
222            b"cd",
223            b"bcd",
224            b"abcd",
225            "你".as_bytes(),
226            "好".as_bytes(),
227            "呀".as_bytes(),
228            "你好".as_bytes(),
229            "你好呀".as_bytes(),
230            b"\xe4",
231            b"\xbd",
232            b"\xa0",
233            b"\xbd\xa0",
234        ])
235        .unwrap();
236
237        let expected = [12, 13, 14, u32::MAX, u32::MAX, 13];
238        let output = bytes_into_tokens(&vocab, "你好", u32::MAX);
239        assert_eq!(output, expected.map(TokenId::new));
240
241        let output = utf8_into_tokens(&vocab, "你好", u32::MAX);
242        assert_eq!(output, [7, 8].map(TokenId::new));
243    }
244}