instant_clip_tokenizer/lib.rs
1//! This crate provides a text tokenizer for [OpenAI's CLIP
2//! model](https://github.com/openai/CLIP).
3//!
4//! It is intended to be a fast replacement for the original Python-based
5//! tokenizer included in the CLIP repository, aiming for 100% compatibility
6//! with the original implementation. It can also be used with
7//! [OpenCLIP](https://github.com/mlfoundations/open_clip) and other
8//! implementations using the same tokenizer.
9//!
10//! # Examples
11//!
12//! Basic usage with the bundled vocabulary data suitable for OpenAI's CLIP
13//! model (requires the `openai-vocabulary-file` [crate
14//! feature](#crate-features)):
15//!
16//! ```
17//! # use instant_clip_tokenizer::{Token, Tokenizer};
18//! let tokenizer = Tokenizer::new();
19//! let mut tokens = vec![tokenizer.start_of_text()];
20//! tokenizer.encode("Hi there", &mut tokens);
21//! tokens.push(tokenizer.end_of_text());
22//! let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
23//! assert_eq!(tokens, [49406, 1883, 997, 49407]);
24//! ```
25//!
26//! Using a custom vocabulary file:
27//!
28//! ```
29//! # use std::fs::File;
30//! # use std::io::{self, BufReader};
31//! # use instant_clip_tokenizer::{Token, Tokenizer};
32//! # fn main() -> io::Result<()> {
33//! let f = BufReader::new(File::open("bpe_simple_vocab_16e6.txt")?);
34//! let tokenizer = Tokenizer::with_vocabulary(f, 50_000)?;
35//! let mut tokens = vec![tokenizer.start_of_text()];
36//! tokenizer.encode("Hi there", &mut tokens);
37//! tokens.push(tokenizer.end_of_text());
38//! let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
39//! assert_eq!(tokens, [49998, 1883, 997, 49999]);
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! # Crate features
45//!
46//! This crate provides two features:
47//!
48//! * **ndarray** - Enables the [`ndarray`](https://docs.rs/ndarray) dependency
49//! and the `Tokenizer::tokenize_batch` method that can be used to tokenize
50//! several input strings at once, returning a matrix suitable for directly
51//! passing to the CLIP neural network.
52//! * **openai-vocabulary-file** - This feature bundles the default vocabulary
53//! file used for OpenAI's CLIP model together with this crate and allows
54//! users to construct a new tokenizer simply by calling [`Tokenizer::new`].
55//! When disabled, you will need to supply your own vocabulary file and
56//! construct the tokenizer using [`Tokenizer::with_vocabulary`].
57//!
58//! The **openai-vocabulary-file** feature is enabled by default. To disable it
59//! use `default-features = false` when specifying the dependency on this crate
60//! in your `Cargo.toml`.
61
62use std::io::{self, BufRead};
63
64use ahash::AHashMap;
65use regex::Regex;
66
67/// A text tokenizer for the CLIP neural network.
68///
69/// See the [module-level documentation](index.html) for more.
70pub struct Tokenizer {
71 byte_to_token: Box<[Token; 256]>,
72 merge_rules: AHashMap<(Token, Token), Token>,
73 start_of_text: Token,
74 end_of_text: Token,
75 decoder: AHashMap<Token, Vec<u8>>,
76 word_split: Regex,
77}
78
79impl Tokenizer {
80 /// Create a new `Tokenizer` using the vocabulary data bundled with this
81 /// crate.
82 ///
83 /// The resulting `Tokenizer` is suitable for use with the original CLIP
84 /// model.
85 ///
86 /// Note that creating a new `Tokenizer` is expensive, so it is recommended
87 /// to create the `Tokenizer` once and then reuse it.
88 #[cfg(any(test, feature = "openai-vocabulary-file"))]
89 pub fn new() -> Tokenizer {
90 static VOCABULARY_DATA: &str = include_str!("../bpe_simple_vocab_16e6.txt");
91 const MAX_VOCABULARY_SIZE: u16 = 49408;
92 Tokenizer::with_vocabulary(io::Cursor::new(VOCABULARY_DATA), MAX_VOCABULARY_SIZE)
93 .expect("bundled vocabulary data is valid")
94 }
95
96 /// Create a new `Tokenizer` by reading the vocabulary data from `reader`.
97 ///
98 /// The data must be in the format used by the original CLIP tokenizer
99 /// implementation from OpenAI.
100 ///
101 /// Note that creating a new `Tokenizer` is expensive, so it is recommended
102 /// to create the `Tokenizer` once and then reuse it.
103 ///
104 /// # Errors
105 ///
106 /// If the data format is incorrect or reading from `reader` fails, then an
107 /// error is returned.
108 pub fn with_vocabulary(
109 reader: impl BufRead,
110 max_vocabulary_size: u16,
111 ) -> io::Result<Tokenizer> {
112 let mut string_to_token = AHashMap::default();
113 let mut byte_to_token = Box::new([Token(u16::MAX); 256]);
114 let mut byte_decoder = AHashMap::default();
115 let r1 = b'!'..=b'~';
116 let r2 = b'\xA1'..=b'\xAC'; // "¡" to "¬"
117 let r3 = b'\xAE'..=b'\xFF'; // "®" to "ÿ"
118 let mut token_index = 0;
119 for byte in r1.chain(r2).chain(r3) {
120 let token = Token(token_index);
121 byte_to_token[usize::from(byte)] = token;
122 let ch = char::from(byte);
123 byte_decoder.insert(ch, byte);
124 // Add token and also its corresponding end-of-word token
125 string_to_token.insert(format!("{ch}"), token);
126 string_to_token.insert(format!("{ch}</w>"), Token(token.0 + 256));
127 token_index += 1;
128 }
129 for (idx, (byte, token)) in byte_to_token
130 .iter_mut()
131 .enumerate()
132 .filter(|(_, token)| **token == Token(u16::MAX))
133 .enumerate()
134 {
135 *token = Token(token_index);
136 let ch = char::from_u32(idx as u32 + 256).unwrap();
137 let byte = u8::try_from(byte).unwrap();
138 byte_decoder.insert(ch, byte);
139 string_to_token.insert(format!("{ch}"), *token);
140 string_to_token.insert(format!("{ch}</w>"), Token(token.0 + 256));
141 token_index += 1;
142 }
143
144 // For every increment of `token_index` above we actually also added the
145 // corresponding end-of-word token, so we have to double `token_index`
146 // now in order for it to be correct again.
147 token_index *= 2;
148
149 let mut merge_rules = AHashMap::default();
150 for line in reader
151 .lines()
152 .skip(1)
153 .take((max_vocabulary_size - 512 - 2).into())
154 {
155 let line = line?;
156 let mut parts = line.split_whitespace();
157 let first = parts.next().ok_or(io::Error::new(
158 io::ErrorKind::Other,
159 "lines must contain 2 tokens",
160 ))?;
161 let second = parts.next().ok_or(io::Error::new(
162 io::ErrorKind::Other,
163 "lines must contain 2 tokens",
164 ))?;
165 let first_token = *string_to_token
166 .get(first)
167 .ok_or(io::Error::new(io::ErrorKind::Other, "invalid merge rule"))?;
168 let second_token = *string_to_token
169 .get(second)
170 .ok_or(io::Error::new(io::ErrorKind::Other, "invalid merge rule"))?;
171
172 let result_token = Token(token_index);
173 merge_rules.insert((first_token, second_token), result_token);
174 string_to_token.insert(format!("{first}{second}"), result_token);
175 token_index += 1;
176 }
177
178 // Note that the values we store in `decoder` are not necessarily valid
179 // UTF-8, so we have to use `Vec<u8>` for them.
180 let decoder = string_to_token
181 .into_iter()
182 .map(|(string, token)| (token, string.chars().map(|ch| byte_decoder[&ch]).collect()))
183 .collect();
184
185 let word_split = Regex::new(
186 r"(?x)
187 # Special substrings - these each get encoded as a single marker token
188 <start_of_text>|<end_of_text>|
189 # Common english contractions
190 's|'t|'re|'ve|'m|'ll|'d|
191 # Consecutive letters, single numbers, or runs of special chars
192 [\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+",
193 )
194 .unwrap();
195
196 Ok(Tokenizer {
197 byte_to_token,
198 merge_rules,
199 start_of_text: Token(token_index),
200 end_of_text: Token(token_index + 1),
201 decoder,
202 word_split,
203 })
204 }
205
206 /// Tokenize a batch of multiple input strings.
207 ///
208 /// Each given input string is encoded using the [`encode`] method and the
209 /// numeric representation written to a row in the resulting two-dimensional
210 /// matrix of shape `(texts.len(), context_length)`, with the special
211 /// `<start_of_text>` token prepended, and `<end_of_text>` appended to each
212 /// text.
213 ///
214 /// The individual input strings are lowercased before being tokenized, but
215 /// otherwise no pre-processing is performed.
216 ///
217 /// `context_length` is the maximum number of tokens per each text and
218 /// should be `77` for all current CLIP models. If tokenization results in
219 /// less than `context_length` tokens the resulting row will be padded with
220 /// trailing zeros. If tokenizing an input text results in too many tokens,
221 /// the token sequence will be truncated to fit within the resulting row of
222 /// length `context_length`, always including the `<start_of_text>` and
223 /// `<end_of_text>` marker tokens.
224 ///
225 /// The resulting matrix can be passed directly to the CLIP neural network.
226 ///
227 /// [`encode`]: Tokenizer::encode
228 ///
229 /// # Panics
230 ///
231 /// Panics if `context_length < 3`.
232 ///
233 /// # Examples
234 ///
235 /// ```
236 /// # use ndarray::array;
237 /// # use instant_clip_tokenizer::{Token, Tokenizer};
238 /// let tokenizer = Tokenizer::new();
239 /// let encoded = tokenizer.tokenize_batch(["Hi", "How are you?"], 5);
240 /// assert_eq!(encoded, array![
241 /// [49406, 1883, 49407, 0, 0],
242 /// [49406, 829, 631, 592, 49407],
243 /// ]);
244 /// ```
245 #[cfg(feature = "ndarray")]
246 pub fn tokenize_batch<'a, I>(&self, texts: I, context_length: usize) -> ndarray::Array2<u16>
247 where
248 I: IntoIterator<Item = &'a str>,
249 I::IntoIter: std::iter::ExactSizeIterator,
250 {
251 if context_length < 3 {
252 panic!("context length must be at least 3");
253 }
254 let texts = texts.into_iter();
255 let mut result = ndarray::Array2::zeros((texts.len(), context_length));
256 let mut tokens = Vec::with_capacity(context_length);
257 for (text, mut result_row) in texts.zip(result.rows_mut()) {
258 tokens.clear();
259 tokens.push(self.start_of_text());
260 self.encode(text, &mut tokens);
261 tokens.truncate(context_length - 1);
262 tokens.push(self.end_of_text());
263 for (token, result_element) in tokens.iter().zip(&mut result_row) {
264 *result_element = token.to_u16();
265 }
266 }
267 result
268 }
269
270 /// Encode a `text` input as a sequence of tokens.
271 ///
272 /// The resulting tokens are appended to `out`. `text` is lowercased before
273 /// being tokenized, but otherwise no pre-processing is performed.
274 ///
275 /// The encoded token sequence does not include the special
276 /// `<start_of_text>` and `<end_of_text>` marker tokens. When these are
277 /// needed you can either use the `tokenize_batch` method instead, or add
278 /// them manually by using the [`start_of_text`] and [`end_of_text`]
279 /// methods, as in the example below.
280 ///
281 /// [`start_of_text`]: Tokenizer::start_of_text
282 /// [`end_of_text`]: Tokenizer::end_of_text
283 ///
284 /// # Examples
285 ///
286 /// ```
287 /// # use instant_clip_tokenizer::{Token, Tokenizer};
288 /// let tokenizer = Tokenizer::new();
289 /// let mut tokens = vec![tokenizer.start_of_text()];
290 /// tokenizer.encode("Hi there", &mut tokens);
291 /// tokens.push(tokenizer.end_of_text());
292 /// let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
293 /// assert_eq!(tokens, [49406, 1883, 997, 49407]);
294 /// ```
295 pub fn encode(&self, text: &str, out: &mut Vec<Token>) {
296 let text = text.to_lowercase();
297 out.reserve(text.as_bytes().len());
298 let words = self.word_split.find_iter(&text).map(|m| m.as_str());
299 for word in words {
300 if word == "<start_of_text>" {
301 out.push(self.start_of_text());
302 continue;
303 } else if word == "<end_of_text>" {
304 out.push(self.end_of_text());
305 continue;
306 }
307
308 let start_index = out.len();
309 out.extend(
310 word.as_bytes()
311 .iter()
312 .map(|b| self.byte_to_token[usize::from(*b)]),
313 );
314 if start_index < out.len() {
315 // If we added anything, mark last character as end-of-word
316 // token
317 out.last_mut().unwrap().0 += 256;
318 }
319 self.apply_merge_rules(start_index, out);
320 }
321 }
322
323 fn apply_merge_rules(&self, start_index: usize, tokens: &mut Vec<Token>) {
324 loop {
325 let Some(((first, second), result_token)) = tokens[start_index..]
326 .windows(2)
327 .map(|pair| (pair[0], pair[1]))
328 .filter_map(|pair| {
329 self.merge_rules
330 .get(&pair)
331 .map(|result_token| (pair, *result_token))
332 })
333 .min_by_key(|&(_, result_token)| result_token)
334 else {
335 // No merge rules left to apply -> we're done
336 break;
337 };
338
339 // Reduce all occurences of this pair to `result_token`
340 let mut i = start_index;
341 while i < tokens.len() - 1 {
342 if tokens[i] == first && tokens[i + 1] == second {
343 tokens[i] = result_token;
344 tokens.remove(i + 1);
345 }
346 i += 1;
347 }
348 }
349 }
350
351 /// Convert a sequence of `tokens` back to a textual representation.
352 ///
353 /// Due to the way whitespace and lowercasing is handled a sequence of
354 /// tokens will not always be decoded back to the exact same text that
355 /// `encode` was called with, in other words, `decode(encode(text)) == text`
356 /// does not always hold true. Hence, this function is mostly useful for
357 /// debugging purposes.
358 ///
359 /// # Examples
360 ///
361 /// ```
362 /// # use instant_clip_tokenizer::Tokenizer;
363 /// let tokenizer = Tokenizer::new();
364 /// let mut tokens = Vec::new();
365 /// tokenizer.encode("Hello world!!!", &mut tokens);
366 /// let decoded = tokenizer.decode(tokens);
367 /// assert_eq!(decoded, "hello world !!! ");
368 /// ```
369 pub fn decode(&self, tokens: impl IntoIterator<Item = Token>) -> String {
370 let bytes = tokens
371 .into_iter()
372 .flat_map(|token| {
373 if token == self.start_of_text {
374 "<start_of_text>".as_bytes()
375 } else if token == self.end_of_text {
376 "<end_of_text>".as_bytes()
377 } else {
378 &self.decoder[&token]
379 }
380 })
381 .copied()
382 .collect::<Vec<_>>();
383
384 String::from_utf8_lossy(&bytes).replace("</w>", " ")
385 }
386
387 /// Returns the special `<start_of_text>` marker token.
388 ///
389 /// See [`encode`] for an example about how to add this token to a token
390 /// sequence.
391 ///
392 /// [`encode`]: Tokenizer::encode
393 pub fn start_of_text(&self) -> Token {
394 self.start_of_text
395 }
396
397 /// Returns the special `<end_of_text>` marker token.
398 ///
399 /// See [`encode`] for an example about how to add this token to a token
400 /// sequence.
401 ///
402 /// [`encode`]: Tokenizer::encode
403 pub fn end_of_text(&self) -> Token {
404 self.end_of_text
405 }
406}
407
408#[cfg(any(test, feature = "openai-vocabulary-file"))]
409impl Default for Tokenizer {
410 fn default() -> Tokenizer {
411 Tokenizer::new()
412 }
413}
414
415/// Represents a single token.
416///
417/// Values of this type can only be produced by calls to methods on the
418/// [`Tokenizer`] type, mainly [`Tokenizer::encode`]. To input tokens into an
419/// actual neural network the [`to_u16`] method should be used.
420///
421/// [`to_u16`]: Token::to_u16
422#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
423pub struct Token(u16);
424
425impl Token {
426 /// Create `Token` from number, validating against the given `tokenizer`.
427 pub fn from_u16(token: u16, tokenizer: &Tokenizer) -> Option<Self> {
428 (token <= tokenizer.end_of_text().0).then_some(Self(token))
429 }
430
431 /// Returns the numerical representation of this `Token`.
432 ///
433 /// The resulting number is suitable for feeding into a neural network.
434 pub fn to_u16(self) -> u16 {
435 self.0
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[cfg(feature = "ndarray")]
444 #[test]
445 fn tokenize_batch() {
446 let tokenizer = Tokenizer::new();
447 let encoded = tokenizer.tokenize_batch(["Hi", "How are you?", "I'm fine, thanks!"], 6);
448 let expected = ndarray::array![
449 [49406, 1883, 49407, 0, 0, 0],
450 [49406, 829, 631, 592, 286, 49407],
451 [49406, 328, 880, 3797, 267, 49407],
452 ];
453 assert_eq!(encoded, expected);
454 }
455
456 #[test]
457 fn encode_special_chars() {
458 let tokens = encode("hello world!!!");
459 assert_eq!(tokens, [Token(3306), Token(1002), Token(995)]);
460 }
461
462 #[test]
463 fn decode_special_chars() {
464 let tokenizer = Tokenizer::new();
465 let decoded = tokenizer.decode([Token(3306), Token(1002), Token(995)]);
466 assert_eq!(decoded, "hello world !!! ");
467 }
468
469 #[test]
470 fn encode_apostrophe() {
471 let tokens = encode("i've seen it");
472 assert_eq!(tokens, [Token(328), Token(1200), Token(2041), Token(585)]);
473 }
474
475 #[test]
476 fn decode_apostrophe() {
477 let tokenizer = Tokenizer::new();
478 let decoded = tokenizer.decode([Token(328), Token(1200), Token(2041), Token(585)]);
479 assert_eq!(decoded, "i 've seen it ");
480 }
481
482 #[test]
483 fn encode_short() {
484 let tokens = encode("Hello Båstad");
485 assert_eq!(tokens, [Token(3306), Token(65), Token(23176), Token(16485)]);
486 }
487
488 #[test]
489 fn decode_short() {
490 let tokenizer = Tokenizer::new();
491 let decoded = tokenizer.decode([Token(3306), Token(65), Token(23176), Token(16485)]);
492 assert_eq!(decoded, "hello båstad ");
493 }
494
495 #[test]
496 fn encode_realistic() {
497 let tokens = encode("A person riding a motorcycle");
498 assert_eq!(tokens, [320, 2533, 6765, 320, 10297].map(Token));
499 }
500
501 #[test]
502 fn decode_realistic() {
503 let tokenizer = Tokenizer::new();
504 let decoded = tokenizer.decode([320, 2533, 6765, 320, 10297].map(Token));
505 assert_eq!(decoded, "a person riding a motorcycle ");
506 }
507
508 #[test]
509 fn encode_long_word() {
510 let tokens = encode("donaudampfschifffahrtsgesellschaftskapitänsmütze");
511 assert_eq!(
512 tokens,
513 [
514 1067, 627, 1880, 16680, 13731, 1021, 778, 4810, 2290, 619, 10279, 45588, 83, 909,
515 688, 529, 42787, 978, 6522, 83, 1298
516 ]
517 .map(Token)
518 );
519 }
520
521 #[test]
522 fn decode_long_word() {
523 let tokenizer = Tokenizer::new();
524 let decoded = tokenizer.decode(
525 [
526 1067, 627, 1880, 16680, 13731, 1021, 778, 4810, 2290, 619, 10279, 45588, 83, 909,
527 688, 529, 42787, 978, 6522, 83, 1298,
528 ]
529 .map(Token),
530 );
531 assert_eq!(decoded, "donaudampfschifffahrtsgesellschaftskapitänsmütze ");
532 }
533
534 #[test]
535 fn encode_start_and_end_of_text() {
536 let tokens = encode("<start_of_text>Hi<start_of_text>instant labs<end_of_text>");
537 assert_eq!(tokens, [49406, 1883, 49406, 10635, 12021, 49407].map(Token));
538 }
539
540 #[test]
541 fn encode_start_and_end_of_text_with_special_char() {
542 let tokens = encode("<start_of_text>Hi!<end_of_text>");
543 // Note how the "<end_of_text>" substring is not encoded as the special
544 // marker token (which would be 49407), because the word-splitting regex
545 // does not split it as a separate word due to the exclamation mark
546 // preceeding it. This behavior is somewhat strange, but we preserve it
547 // in order to stay compatible with the original Python implementation.
548 assert_eq!(
549 tokens,
550 [49406, 1883, 0, 283, 806, 318, 539, 318, 4160, 285].map(Token)
551 );
552 }
553
554 #[test]
555 fn decode_start_and_end_of_text() {
556 let tokenizer = Tokenizer::new();
557 let decoded = tokenizer.decode([49406, 1883, 49406, 10635, 12021, 49407].map(Token));
558 assert_eq!(
559 decoded,
560 "<start_of_text>hi <start_of_text>instant labs <end_of_text>"
561 );
562 }
563
564 fn encode(input: &str) -> Vec<Token> {
565 let tokenizer = Tokenizer::new();
566 let mut tokens = Vec::with_capacity(input.len());
567 tokenizer.encode(input, &mut tokens);
568 tokens
569 }
570}