instant_clip_tokenizer/
lib.rs

1//! This crate provides a text tokenizer for [OpenAI's CLIP
2//! model](https://github.com/openai/CLIP).
3//!
4//! It is intended to be a fast replacement for the original Python-based
5//! tokenizer included in the CLIP repository, aiming for 100% compatibility
6//! with the original implementation. It can also be used with
7//! [OpenCLIP](https://github.com/mlfoundations/open_clip) and other
8//! implementations using the same tokenizer.
9//!
10//! # Examples
11//!
12//! Basic usage with the bundled vocabulary data suitable for OpenAI's CLIP
13//! model (requires the `openai-vocabulary-file` [crate
14//! feature](#crate-features)):
15//!
16//! ```
17//! # use instant_clip_tokenizer::{Token, Tokenizer};
18//! let tokenizer = Tokenizer::new();
19//! let mut tokens = vec![tokenizer.start_of_text()];
20//! tokenizer.encode("Hi there", &mut tokens);
21//! tokens.push(tokenizer.end_of_text());
22//! let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
23//! assert_eq!(tokens, [49406, 1883, 997, 49407]);
24//! ```
25//!
26//! Using a custom vocabulary file:
27//!
28//! ```
29//! # use std::fs::File;
30//! # use std::io::{self, BufReader};
31//! # use instant_clip_tokenizer::{Token, Tokenizer};
32//! # fn main() -> io::Result<()> {
33//! let f = BufReader::new(File::open("bpe_simple_vocab_16e6.txt")?);
34//! let tokenizer = Tokenizer::with_vocabulary(f, 50_000)?;
35//! let mut tokens = vec![tokenizer.start_of_text()];
36//! tokenizer.encode("Hi there", &mut tokens);
37//! tokens.push(tokenizer.end_of_text());
38//! let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
39//! assert_eq!(tokens, [49998, 1883, 997, 49999]);
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! # Crate features
45//!
46//! This crate provides two features:
47//!
48//! * **ndarray** - Enables the [`ndarray`](https://docs.rs/ndarray) dependency
49//!   and the `Tokenizer::tokenize_batch` method that can be used to tokenize
50//!   several input strings at once, returning a matrix suitable for directly
51//!   passing to the CLIP neural network.
52//! * **openai-vocabulary-file** - This feature bundles the default vocabulary
53//!   file used for OpenAI's CLIP model together with this crate and allows
54//!   users to construct a new tokenizer simply by calling [`Tokenizer::new`].
55//!   When disabled, you will need to supply your own vocabulary file and
56//!   construct the tokenizer using [`Tokenizer::with_vocabulary`].
57//!
58//! The **openai-vocabulary-file** feature is enabled by default. To disable it
59//! use `default-features = false` when specifying the dependency on this crate
60//! in your `Cargo.toml`.
61
62use std::io::{self, BufRead};
63
64use ahash::AHashMap;
65use regex::Regex;
66
67/// A text tokenizer for the CLIP neural network.
68///
69/// See the [module-level documentation](index.html) for more.
70pub struct Tokenizer {
71    byte_to_token: Box<[Token; 256]>,
72    merge_rules: AHashMap<(Token, Token), Token>,
73    start_of_text: Token,
74    end_of_text: Token,
75    decoder: AHashMap<Token, Vec<u8>>,
76    word_split: Regex,
77}
78
79impl Tokenizer {
80    /// Create a new `Tokenizer` using the vocabulary data bundled with this
81    /// crate.
82    ///
83    /// The resulting `Tokenizer` is suitable for use with the original CLIP
84    /// model.
85    ///
86    /// Note that creating a new `Tokenizer` is expensive, so it is recommended
87    /// to create the `Tokenizer` once and then reuse it.
88    #[cfg(any(test, feature = "openai-vocabulary-file"))]
89    pub fn new() -> Tokenizer {
90        static VOCABULARY_DATA: &str = include_str!("../bpe_simple_vocab_16e6.txt");
91        const MAX_VOCABULARY_SIZE: u16 = 49408;
92        Tokenizer::with_vocabulary(io::Cursor::new(VOCABULARY_DATA), MAX_VOCABULARY_SIZE)
93            .expect("bundled vocabulary data is valid")
94    }
95
96    /// Create a new `Tokenizer` by reading the vocabulary data from `reader`.
97    ///
98    /// The data must be in the format used by the original CLIP tokenizer
99    /// implementation from OpenAI.
100    ///
101    /// Note that creating a new `Tokenizer` is expensive, so it is recommended
102    /// to create the `Tokenizer` once and then reuse it.
103    ///
104    /// # Errors
105    ///
106    /// If the data format is incorrect or reading from `reader` fails, then an
107    /// error is returned.
108    pub fn with_vocabulary(
109        reader: impl BufRead,
110        max_vocabulary_size: u16,
111    ) -> io::Result<Tokenizer> {
112        let mut string_to_token = AHashMap::default();
113        let mut byte_to_token = Box::new([Token(u16::MAX); 256]);
114        let mut byte_decoder = AHashMap::default();
115        let r1 = b'!'..=b'~';
116        let r2 = b'\xA1'..=b'\xAC'; // "¡" to "¬"
117        let r3 = b'\xAE'..=b'\xFF'; // "®" to "ÿ"
118        let mut token_index = 0;
119        for byte in r1.chain(r2).chain(r3) {
120            let token = Token(token_index);
121            byte_to_token[usize::from(byte)] = token;
122            let ch = char::from(byte);
123            byte_decoder.insert(ch, byte);
124            // Add token and also its corresponding end-of-word token
125            string_to_token.insert(format!("{ch}"), token);
126            string_to_token.insert(format!("{ch}</w>"), Token(token.0 + 256));
127            token_index += 1;
128        }
129        for (idx, (byte, token)) in byte_to_token
130            .iter_mut()
131            .enumerate()
132            .filter(|(_, token)| **token == Token(u16::MAX))
133            .enumerate()
134        {
135            *token = Token(token_index);
136            let ch = char::from_u32(idx as u32 + 256).unwrap();
137            let byte = u8::try_from(byte).unwrap();
138            byte_decoder.insert(ch, byte);
139            string_to_token.insert(format!("{ch}"), *token);
140            string_to_token.insert(format!("{ch}</w>"), Token(token.0 + 256));
141            token_index += 1;
142        }
143
144        // For every increment of `token_index` above we actually also added the
145        // corresponding end-of-word token, so we have to double `token_index`
146        // now in order for it to be correct again.
147        token_index *= 2;
148
149        let mut merge_rules = AHashMap::default();
150        for line in reader
151            .lines()
152            .skip(1)
153            .take((max_vocabulary_size - 512 - 2).into())
154        {
155            let line = line?;
156            let mut parts = line.split_whitespace();
157            let first = parts.next().ok_or(io::Error::new(
158                io::ErrorKind::Other,
159                "lines must contain 2 tokens",
160            ))?;
161            let second = parts.next().ok_or(io::Error::new(
162                io::ErrorKind::Other,
163                "lines must contain 2 tokens",
164            ))?;
165            let first_token = *string_to_token
166                .get(first)
167                .ok_or(io::Error::new(io::ErrorKind::Other, "invalid merge rule"))?;
168            let second_token = *string_to_token
169                .get(second)
170                .ok_or(io::Error::new(io::ErrorKind::Other, "invalid merge rule"))?;
171
172            let result_token = Token(token_index);
173            merge_rules.insert((first_token, second_token), result_token);
174            string_to_token.insert(format!("{first}{second}"), result_token);
175            token_index += 1;
176        }
177
178        // Note that the values we store in `decoder` are not necessarily valid
179        // UTF-8, so we have to use `Vec<u8>` for them.
180        let decoder = string_to_token
181            .into_iter()
182            .map(|(string, token)| (token, string.chars().map(|ch| byte_decoder[&ch]).collect()))
183            .collect();
184
185        let word_split = Regex::new(
186            r"(?x)
187                # Special substrings - these each get encoded as a single marker token
188                <start_of_text>|<end_of_text>|
189                # Common english contractions
190                's|'t|'re|'ve|'m|'ll|'d|
191                # Consecutive letters, single numbers, or runs of special chars
192                [\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+",
193        )
194        .unwrap();
195
196        Ok(Tokenizer {
197            byte_to_token,
198            merge_rules,
199            start_of_text: Token(token_index),
200            end_of_text: Token(token_index + 1),
201            decoder,
202            word_split,
203        })
204    }
205
206    /// Tokenize a batch of multiple input strings.
207    ///
208    /// Each given input string is encoded using the [`encode`] method and the
209    /// numeric representation written to a row in the resulting two-dimensional
210    /// matrix of shape `(texts.len(), context_length)`, with the special
211    /// `<start_of_text>` token prepended, and `<end_of_text>` appended to each
212    /// text.
213    ///
214    /// The individual input strings are lowercased before being tokenized, but
215    /// otherwise no pre-processing is performed.
216    ///
217    /// `context_length` is the maximum number of tokens per each text and
218    /// should be `77` for all current CLIP models. If tokenization results in
219    /// less than `context_length` tokens the resulting row will be padded with
220    /// trailing zeros. If tokenizing an input text results in too many tokens,
221    /// the token sequence will be truncated to fit within the resulting row of
222    /// length `context_length`, always including the `<start_of_text>` and
223    /// `<end_of_text>` marker tokens.
224    ///
225    /// The resulting matrix can be passed directly to the CLIP neural network.
226    ///
227    /// [`encode`]: Tokenizer::encode
228    ///
229    /// # Panics
230    ///
231    /// Panics if `context_length < 3`.
232    ///
233    /// # Examples
234    ///
235    /// ```
236    /// # use ndarray::array;
237    /// # use instant_clip_tokenizer::{Token, Tokenizer};
238    /// let tokenizer = Tokenizer::new();
239    /// let encoded = tokenizer.tokenize_batch(["Hi", "How are you?"], 5);
240    /// assert_eq!(encoded, array![
241    ///     [49406, 1883, 49407, 0, 0],
242    ///     [49406, 829, 631, 592, 49407],
243    /// ]);
244    /// ```
245    #[cfg(feature = "ndarray")]
246    pub fn tokenize_batch<'a, I>(&self, texts: I, context_length: usize) -> ndarray::Array2<u16>
247    where
248        I: IntoIterator<Item = &'a str>,
249        I::IntoIter: std::iter::ExactSizeIterator,
250    {
251        if context_length < 3 {
252            panic!("context length must be at least 3");
253        }
254        let texts = texts.into_iter();
255        let mut result = ndarray::Array2::zeros((texts.len(), context_length));
256        let mut tokens = Vec::with_capacity(context_length);
257        for (text, mut result_row) in texts.zip(result.rows_mut()) {
258            tokens.clear();
259            tokens.push(self.start_of_text());
260            self.encode(text, &mut tokens);
261            tokens.truncate(context_length - 1);
262            tokens.push(self.end_of_text());
263            for (token, result_element) in tokens.iter().zip(&mut result_row) {
264                *result_element = token.to_u16();
265            }
266        }
267        result
268    }
269
270    /// Encode a `text` input as a sequence of tokens.
271    ///
272    /// The resulting tokens are appended to `out`. `text` is lowercased before
273    /// being tokenized, but otherwise no pre-processing is performed.
274    ///
275    /// The encoded token sequence does not include the special
276    /// `<start_of_text>` and `<end_of_text>` marker tokens. When these are
277    /// needed you can either use the `tokenize_batch` method instead, or add
278    /// them manually by using the [`start_of_text`] and [`end_of_text`]
279    /// methods, as in the example below.
280    ///
281    /// [`start_of_text`]: Tokenizer::start_of_text
282    /// [`end_of_text`]: Tokenizer::end_of_text
283    ///
284    /// # Examples
285    ///
286    /// ```
287    /// # use instant_clip_tokenizer::{Token, Tokenizer};
288    /// let tokenizer = Tokenizer::new();
289    /// let mut tokens = vec![tokenizer.start_of_text()];
290    /// tokenizer.encode("Hi there", &mut tokens);
291    /// tokens.push(tokenizer.end_of_text());
292    /// let tokens = tokens.into_iter().map(Token::to_u16).collect::<Vec<_>>();
293    /// assert_eq!(tokens, [49406, 1883, 997, 49407]);
294    /// ```
295    pub fn encode(&self, text: &str, out: &mut Vec<Token>) {
296        let text = text.to_lowercase();
297        out.reserve(text.as_bytes().len());
298        let words = self.word_split.find_iter(&text).map(|m| m.as_str());
299        for word in words {
300            if word == "<start_of_text>" {
301                out.push(self.start_of_text());
302                continue;
303            } else if word == "<end_of_text>" {
304                out.push(self.end_of_text());
305                continue;
306            }
307
308            let start_index = out.len();
309            out.extend(
310                word.as_bytes()
311                    .iter()
312                    .map(|b| self.byte_to_token[usize::from(*b)]),
313            );
314            if start_index < out.len() {
315                // If we added anything, mark last character as end-of-word
316                // token
317                out.last_mut().unwrap().0 += 256;
318            }
319            self.apply_merge_rules(start_index, out);
320        }
321    }
322
323    fn apply_merge_rules(&self, start_index: usize, tokens: &mut Vec<Token>) {
324        loop {
325            let Some(((first, second), result_token)) = tokens[start_index..]
326                .windows(2)
327                .map(|pair| (pair[0], pair[1]))
328                .filter_map(|pair| {
329                    self.merge_rules
330                        .get(&pair)
331                        .map(|result_token| (pair, *result_token))
332                })
333                .min_by_key(|&(_, result_token)| result_token)
334            else {
335                // No merge rules left to apply -> we're done
336                break;
337            };
338
339            // Reduce all occurences of this pair to `result_token`
340            let mut i = start_index;
341            while i < tokens.len() - 1 {
342                if tokens[i] == first && tokens[i + 1] == second {
343                    tokens[i] = result_token;
344                    tokens.remove(i + 1);
345                }
346                i += 1;
347            }
348        }
349    }
350
351    /// Convert a sequence of `tokens` back to a textual representation.
352    ///
353    /// Due to the way whitespace and lowercasing is handled a sequence of
354    /// tokens will not always be decoded back to the exact same text that
355    /// `encode` was called with, in other words, `decode(encode(text)) == text`
356    /// does not always hold true. Hence, this function is mostly useful for
357    /// debugging purposes.
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// # use instant_clip_tokenizer::Tokenizer;
363    /// let tokenizer = Tokenizer::new();
364    /// let mut tokens = Vec::new();
365    /// tokenizer.encode("Hello world!!!", &mut tokens);
366    /// let decoded = tokenizer.decode(tokens);
367    /// assert_eq!(decoded, "hello world !!! ");
368    /// ```
369    pub fn decode(&self, tokens: impl IntoIterator<Item = Token>) -> String {
370        let bytes = tokens
371            .into_iter()
372            .flat_map(|token| {
373                if token == self.start_of_text {
374                    "<start_of_text>".as_bytes()
375                } else if token == self.end_of_text {
376                    "<end_of_text>".as_bytes()
377                } else {
378                    &self.decoder[&token]
379                }
380            })
381            .copied()
382            .collect::<Vec<_>>();
383
384        String::from_utf8_lossy(&bytes).replace("</w>", " ")
385    }
386
387    /// Returns the special `<start_of_text>` marker token.
388    ///
389    /// See [`encode`] for an example about how to add this token to a token
390    /// sequence.
391    ///
392    /// [`encode`]: Tokenizer::encode
393    pub fn start_of_text(&self) -> Token {
394        self.start_of_text
395    }
396
397    /// Returns the special `<end_of_text>` marker token.
398    ///
399    /// See [`encode`] for an example about how to add this token to a token
400    /// sequence.
401    ///
402    /// [`encode`]: Tokenizer::encode
403    pub fn end_of_text(&self) -> Token {
404        self.end_of_text
405    }
406}
407
408#[cfg(any(test, feature = "openai-vocabulary-file"))]
409impl Default for Tokenizer {
410    fn default() -> Tokenizer {
411        Tokenizer::new()
412    }
413}
414
415/// Represents a single token.
416///
417/// Values of this type can only be produced by calls to methods on the
418/// [`Tokenizer`] type, mainly [`Tokenizer::encode`]. To input tokens into an
419/// actual neural network the [`to_u16`] method should be used.
420///
421/// [`to_u16`]: Token::to_u16
422#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
423pub struct Token(u16);
424
425impl Token {
426    /// Create `Token` from number, validating against the given `tokenizer`.
427    pub fn from_u16(token: u16, tokenizer: &Tokenizer) -> Option<Self> {
428        (token <= tokenizer.end_of_text().0).then_some(Self(token))
429    }
430
431    /// Returns the numerical representation of this `Token`.
432    ///
433    /// The resulting number is suitable for feeding into a neural network.
434    pub fn to_u16(self) -> u16 {
435        self.0
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[cfg(feature = "ndarray")]
444    #[test]
445    fn tokenize_batch() {
446        let tokenizer = Tokenizer::new();
447        let encoded = tokenizer.tokenize_batch(["Hi", "How are you?", "I'm fine, thanks!"], 6);
448        let expected = ndarray::array![
449            [49406, 1883, 49407, 0, 0, 0],
450            [49406, 829, 631, 592, 286, 49407],
451            [49406, 328, 880, 3797, 267, 49407],
452        ];
453        assert_eq!(encoded, expected);
454    }
455
456    #[test]
457    fn encode_special_chars() {
458        let tokens = encode("hello world!!!");
459        assert_eq!(tokens, [Token(3306), Token(1002), Token(995)]);
460    }
461
462    #[test]
463    fn decode_special_chars() {
464        let tokenizer = Tokenizer::new();
465        let decoded = tokenizer.decode([Token(3306), Token(1002), Token(995)]);
466        assert_eq!(decoded, "hello world !!! ");
467    }
468
469    #[test]
470    fn encode_apostrophe() {
471        let tokens = encode("i've seen it");
472        assert_eq!(tokens, [Token(328), Token(1200), Token(2041), Token(585)]);
473    }
474
475    #[test]
476    fn decode_apostrophe() {
477        let tokenizer = Tokenizer::new();
478        let decoded = tokenizer.decode([Token(328), Token(1200), Token(2041), Token(585)]);
479        assert_eq!(decoded, "i 've seen it ");
480    }
481
482    #[test]
483    fn encode_short() {
484        let tokens = encode("Hello Båstad");
485        assert_eq!(tokens, [Token(3306), Token(65), Token(23176), Token(16485)]);
486    }
487
488    #[test]
489    fn decode_short() {
490        let tokenizer = Tokenizer::new();
491        let decoded = tokenizer.decode([Token(3306), Token(65), Token(23176), Token(16485)]);
492        assert_eq!(decoded, "hello båstad ");
493    }
494
495    #[test]
496    fn encode_realistic() {
497        let tokens = encode("A person riding a motorcycle");
498        assert_eq!(tokens, [320, 2533, 6765, 320, 10297].map(Token));
499    }
500
501    #[test]
502    fn decode_realistic() {
503        let tokenizer = Tokenizer::new();
504        let decoded = tokenizer.decode([320, 2533, 6765, 320, 10297].map(Token));
505        assert_eq!(decoded, "a person riding a motorcycle ");
506    }
507
508    #[test]
509    fn encode_long_word() {
510        let tokens = encode("donaudampfschifffahrtsgesellschaftskapitänsmütze");
511        assert_eq!(
512            tokens,
513            [
514                1067, 627, 1880, 16680, 13731, 1021, 778, 4810, 2290, 619, 10279, 45588, 83, 909,
515                688, 529, 42787, 978, 6522, 83, 1298
516            ]
517            .map(Token)
518        );
519    }
520
521    #[test]
522    fn decode_long_word() {
523        let tokenizer = Tokenizer::new();
524        let decoded = tokenizer.decode(
525            [
526                1067, 627, 1880, 16680, 13731, 1021, 778, 4810, 2290, 619, 10279, 45588, 83, 909,
527                688, 529, 42787, 978, 6522, 83, 1298,
528            ]
529            .map(Token),
530        );
531        assert_eq!(decoded, "donaudampfschifffahrtsgesellschaftskapitänsmütze ");
532    }
533
534    #[test]
535    fn encode_start_and_end_of_text() {
536        let tokens = encode("<start_of_text>Hi<start_of_text>instant labs<end_of_text>");
537        assert_eq!(tokens, [49406, 1883, 49406, 10635, 12021, 49407].map(Token));
538    }
539
540    #[test]
541    fn encode_start_and_end_of_text_with_special_char() {
542        let tokens = encode("<start_of_text>Hi!<end_of_text>");
543        // Note how the "<end_of_text>" substring is not encoded as the special
544        // marker token (which would be 49407), because the word-splitting regex
545        // does not split it as a separate word due to the exclamation mark
546        // preceeding it. This behavior is somewhat strange, but we preserve it
547        // in order to stay compatible with the original Python implementation.
548        assert_eq!(
549            tokens,
550            [49406, 1883, 0, 283, 806, 318, 539, 318, 4160, 285].map(Token)
551        );
552    }
553
554    #[test]
555    fn decode_start_and_end_of_text() {
556        let tokenizer = Tokenizer::new();
557        let decoded = tokenizer.decode([49406, 1883, 49406, 10635, 12021, 49407].map(Token));
558        assert_eq!(
559            decoded,
560            "<start_of_text>hi <start_of_text>instant labs <end_of_text>"
561        );
562    }
563
564    fn encode(input: &str) -> Vec<Token> {
565        let tokenizer = Tokenizer::new();
566        let mut tokens = Vec::with_capacity(input.len());
567        tokenizer.encode(input, &mut tokens);
568        tokens
569    }
570}