bpe_tokenizer/
tokenizer.rs

1use std::{collections::HashMap, fs, iter};
2
3use unicode_segmentation::UnicodeSegmentation;
4
5use crate::{
6    constants::*,
7    default_vocabs::{new_default, DefaultVocab},
8    BytePairEncoderError,
9};
10
11/// # Represents a Byte Pair Encoding (BPE) vocabulary used for tokenization.
12///
13/// This struct holds the mapping of tokens to their respective scores and provides methods for
14/// tokenizing text using the BPE algorithm.
15///
16/// The vocabulary is typically loaded from a file or string where each line
17/// contains a token and its score, separated by a tab character.
18///
19/// ## Example
20///
21/// ```
22/// use bpe_tokenizer::BytePairEncoder;
23///
24/// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
25/// let tokenized = vocab.tokenize("Hello, world!");
26/// ```
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct BytePairEncoder {
29    /// # A mapping of tokens to their respective scores.
30    ///
31    /// In BPE, tokens with higher scores are typically more common and are preferred during the
32    /// tokenization process.
33    pub(crate) tokens: HashMap<String, isize>,
34}
35
36impl BytePairEncoder {
37    /// # Creates a new `BytePairEncoder` from a file containing token-score pairs.
38    ///
39    /// This function reads the contents of the file specified by `file_path` and constructs
40    /// a `BytePairEncoder` from it. The file should contain token-score pairs, with each pair
41    /// on a separate line and the token and score separated by a tab character (`\t`).
42    ///
43    /// ## Input Format
44    ///
45    /// The file is expected to follow this format:
46    ///
47    /// ```text
48    /// <token>\t<score>\n
49    /// ```
50    ///
51    /// Each line should consist of:
52    /// * A token (a string) followed by a tab character (`\t`)
53    /// * A score (an integer) as either a positive or negative value.
54    ///
55    /// Example lines from the file:
56    ///
57    /// ```text
58    /// <unk>    0
59    /// ▁t       -0
60    /// ▁the     -4
61    /// ```
62    ///
63    /// ## Arguments
64    ///
65    /// * `file_path` - A string slice that holds the path to the file containing token-score pairs.
66    ///
67    /// ## Returns
68    ///
69    /// * `Result<Self, BytePairEncoderError>` - A Result containing the created `BytePairEncoder` if successful,
70    ///   or a `BytePairEncoderError` if there was an error reading the file or parsing its contents.
71    ///
72    /// ## Errors
73    ///
74    /// This function will return an error if:
75    /// * The file cannot be read (returns `BytePairEncoderError::InvalidFile`)
76    /// * The file contents are not in the expected format (returns `BytePairEncoderError::InvalidVocabularyInput`)
77    ///
78    /// ## Example
79    ///
80    /// ```
81    /// use bpe_tokenizer::BytePairEncoder;
82    ///
83    /// let vocab = BytePairEncoder::new_from_file("path/to/vocabulary/file.txt");
84    /// ```
85    pub fn new_from_file(file_path: &str) -> Result<Self, BytePairEncoderError> {
86        Self::new_from_str(
87            fs::read_to_string(file_path)
88                .map_err(|_| BytePairEncoderError::InvalidFile(file_path.to_string()))?
89                .as_ref(),
90        )
91    }
92
93    /// # Creates a new `BytePairEncoder` from a string containing token-score pairs.
94    ///
95    /// This function parses the input string to construct a `BytePairEncoder`. The input should
96    /// contain token-score pairs, with each pair on a separate line and the token and score
97    /// separated by a tab character (`\t`).
98    ///
99    /// ## Input Format
100    ///
101    /// The string must follow this format:
102    ///
103    /// ```text
104    /// <token>\t<score>\n
105    /// ```
106    ///
107    /// Each line in the string should consist of:
108    /// * A token (a string) followed by a tab character (`\t`)
109    /// * A score (an integer) as either a positive or negative value.
110    ///
111    /// For example:
112    ///
113    /// ```text
114    /// hello   1
115    /// world   2
116    /// ▁the    -4
117    /// ```
118    ///
119    /// ## Arguments
120    ///
121    /// * `input` - A string slice that holds the token-score pairs.
122    ///
123    /// ## Returns
124    ///
125    /// * `Result<Self, BytePairEncoderError>` - A Result containing the created `BytePairEncoder` if successful,
126    ///   or a `BytePairEncoderError` if there was an error parsing the input.
127    ///
128    /// ## Errors
129    ///
130    /// This function will return `BytePairEncoderError::InvalidVocabularyInput` if:
131    /// * A line doesn't contain a tab character to separate token and score.
132    /// * The score cannot be parsed as an `isize`.
133    ///
134    /// ## Example
135    ///
136    /// ```
137    /// use bpe_tokenizer::BytePairEncoder;
138    ///
139    /// let input = "hello\t1\nworld\t2";
140    /// let vocab = BytePairEncoder::new_from_str(input).unwrap();
141    /// ```
142    pub fn new_from_str(input: &str) -> Result<Self, BytePairEncoderError> {
143        let mut tokens = HashMap::new();
144
145        for line in input.lines() {
146            let (token, score_str) = match line.split_once('\t') {
147                Some(pair) => pair,
148                None => return Err(BytePairEncoderError::InvalidVocabularyInput),
149            };
150            let score = match score_str.parse::<isize>() {
151                Ok(score) => score,
152                Err(_) => return Err(BytePairEncoderError::InvalidVocabularyInput),
153            };
154            tokens.insert(token.to_string(), score);
155        }
156
157        Ok(BytePairEncoder { tokens })
158    }
159
160    /// # Creates a new `BytePairEncoder` with a default small vocabulary size (100,000 tokens).
161    ///
162    /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
163    /// that supports 275 languages. The vocabulary is sourced from the
164    /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The small-sized
165    /// vocabulary file consists of 100,000 tokens, allowing for highly compressed tokenization
166    /// suitable for tasks with limited memory constraints.
167    ///
168    /// ## Returns
169    ///
170    /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
171    /// vocabulary loading, or a corresponding error if initialization fails.
172    ///
173    /// ## Example
174    ///
175    /// ```
176    /// # #[cfg(feature = "default-small")] {
177    /// use bpe_tokenizer::BytePairEncoder;
178    ///
179    /// let encoder = BytePairEncoder::new_default_small().unwrap();
180    /// # }
181    /// ```
182    ///
183    /// ## Note
184    ///
185    /// This is only enabled when the `default-small` feature is enabled in Cargo.toml.
186    ///
187    ///   ```toml
188    ///   [dependencies]
189    ///   bpe-tokenizer = { version = "<version", features = ["default-small"] }
190    ///   ```
191    pub fn new_default_small() -> Result<Self, BytePairEncoderError> {
192        new_default(DefaultVocab::Small)
193    }
194
195    /// # Creates a new `BytePairEncoder` with a default medium vocabulary size (320,000 tokens).
196    ///
197    /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
198    /// that supports 275 languages. The vocabulary is sourced from the
199    /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The
200    /// medium-sized vocabulary file consists of 320,000 tokens, offering a balance between token
201    /// coverage and memory efficiency, making it suitable for a wide variety of NLP tasks.
202    ///
203    /// ## Returns
204    ///
205    /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
206    /// vocabulary loading, or a corresponding error if initialization fails.
207    ///
208    /// ## Example
209    ///
210    /// ```
211    /// # #[cfg(feature = "default-medium")] {
212    /// use bpe_tokenizer::BytePairEncoder;
213    ///
214    /// let encoder = BytePairEncoder::new_default_medium().unwrap();
215    /// # }
216    /// ```
217    ///
218    /// ## Note
219    ///
220    /// This is only enabled when the `default-medium` feature is enabled in Cargo.toml.
221    ///
222    ///   ```toml
223    ///   [dependencies]
224    ///   bpe-tokenizer = { version = "<version", features = ["default-medium"] }
225    ///   ```
226    pub fn new_default_medium() -> Result<Self, BytePairEncoderError> {
227        new_default(DefaultVocab::Medium)
228    }
229
230    /// # Creates a new `BytePairEncoder` with a default large vocabulary size (1,000,000 tokens).
231    ///
232    /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
233    /// that supports 275 languages. The vocabulary is sourced from the
234    /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The large-sized
235    /// vocabulary consists of 1,000,000 tokens, providing maximum coverage for detailed language
236    /// representation, especially useful in applications requiring high granularity.
237    ///
238    /// ## Returns
239    ///
240    /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
241    /// vocabulary loading, or a corresponding error if initialization fails.
242    ///
243    /// ## Example
244    ///
245    /// ```
246    /// # #[cfg(feature = "default-large")] {
247    /// use bpe_tokenizer::BytePairEncoder;
248    ///
249    /// let encoder = BytePairEncoder::new_default_large().unwrap();
250    /// # }
251    /// ```
252    ///
253    /// ## Note
254    ///
255    /// This is only enabled when the `default-large` feature is enabled in Cargo.toml.
256    ///
257    ///   ```toml
258    ///   [dependencies]
259    ///   bpe-tokenizer = { version = "<version", features = ["default-large"] }
260    ///   ```
261    pub fn new_default_large() -> Result<Self, BytePairEncoderError> {
262        new_default(DefaultVocab::Large)
263    }
264
265    /// # Tokenizes a text into sentences, then words, and finally into BPE tokens.
266    ///
267    /// This function takes a string of text and returns an iterator that yields
268    /// vectors of tokens, where each vector represents a tokenized sentence.
269    ///
270    /// ## Arguments
271    ///
272    /// * `text` - A string slice containing the text to be tokenized.
273    ///
274    /// ## Returns
275    ///
276    /// An iterator that yields `Vec<String>`, where each `Vec<String>` represents
277    /// a tokenized sentence.
278    ///
279    /// ## Example
280    ///
281    /// ```
282    /// use bpe_tokenizer::BytePairEncoder;
283    ///
284    /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
285    /// let text = "Hello, world! How are you?";
286    /// let tokenized: Vec<Vec<String>> = vocab
287    ///     .tokenize_sentences_iter(text)
288    ///     .map(|sentence_iter| sentence_iter.collect())  // Collect each inner iterator into a Vec<String>
289    ///     .collect();  // Then collect everything into Vec<Vec<String>>
290    /// ```
291    ///
292    /// ## Notes
293    ///
294    /// - This function uses Unicode-aware sentence and word segmentation.
295    /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
296    /// - Words are prefixed with the word break character (`▁`).
297    /// - Unknown tokens are replaced with the `<unk>` token.
298    pub fn tokenize_sentences_iter<'a>(
299        &'a self,
300        text: &'a str,
301    ) -> impl Iterator<Item = impl Iterator<Item = String> + 'a> + 'a {
302        UnicodeSegmentation::unicode_sentences(text)
303            .map(move |sentence| self.tokenize_with_sentence_markers_iter(sentence))
304    }
305
306    /// # Tokenizes a text into a flat sequence of BPE tokens.
307    ///
308    /// This function takes a string of text and returns an iterator that yields
309    /// individual tokens. It first tokenizes the text into sentences, then words,
310    /// and finally into BPE tokens, flattening the result into a single sequence.
311    ///
312    /// ## Arguments
313    ///
314    /// * `text` - A string slice containing the text to be tokenized.
315    ///
316    /// ## Returns
317    ///
318    /// An iterator that yields `String`, where each `String` represents a token.
319    ///
320    /// ## Example
321    ///
322    /// ```
323    /// use bpe_tokenizer::BytePairEncoder;
324    ///
325    /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
326    /// let text = "Hello, world! How are you?";
327    /// let tokenized: Vec<String> = vocab.tokenize_iter(text).collect();
328    /// ```
329    ///
330    /// ## Notes
331    ///
332    /// - This function uses Unicode-aware sentence and word segmentation.
333    /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
334    /// - Words are prefixed with the word break character (`▁`).
335    /// - Unknown tokens are replaced with the `<unk>` token.
336    pub fn tokenize_iter<'a>(&'a self, text: &'a str) -> impl Iterator<Item = String> + 'a {
337        self.tokenize_sentences_iter(text).flatten()
338    }
339
340    /// # Tokenizes a text into sentences, then words, and finally into BPE tokens.
341    ///
342    /// This function takes a string of text and returns a vector of tokenized sentences,
343    /// where each sentence is represented as a vector of tokens.
344    ///
345    /// ## Arguments
346    ///
347    /// * `text` - A string slice containing the text to be tokenized.
348    ///
349    /// ## Returns
350    ///
351    /// A `Vec<Vec<String>>`, where each inner `Vec<String>` represents a tokenized sentence.
352    ///
353    /// ## Example
354    ///
355    /// ```
356    /// use bpe_tokenizer::BytePairEncoder;
357    ///
358    /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
359    /// let text = "Hello, world! How are you?";
360    /// let tokenized = vocab.tokenize_sentences(text);
361    /// ```
362    ///
363    /// ## Notes
364    ///
365    /// - This function uses Unicode-aware sentence and word segmentation.
366    /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
367    /// - Words are prefixed with the word break character (`▁`).
368    /// - Unknown tokens are replaced with the `<unk>` token.
369    pub fn tokenize_sentences(&self, text: &str) -> Vec<Vec<String>> {
370        self.tokenize_sentences_iter(text)
371            .map(|sentence_iter| sentence_iter.collect())
372            .collect()
373    }
374
375    /// # Tokenizes a text into a flat sequence of BPE tokens.
376    ///
377    /// This function takes a string of text and returns a vector of tokens.
378    /// It first tokenizes the text into sentences, then words, and finally into BPE tokens,
379    /// flattening the result into a single sequence.
380    ///
381    /// ## Arguments
382    ///
383    /// * `text` - A string slice containing the text to be tokenized.
384    ///
385    /// ## Returns
386    ///
387    /// A `Vec<String>`, where each `String` represents a token.
388    ///
389    /// ## Example
390    ///
391    /// ```
392    /// use bpe_tokenizer::BytePairEncoder;
393    ///
394    /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
395    /// let text = "Hello, world! How are you?";
396    /// let tokenized = vocab.tokenize(text);
397    /// ```
398    ///
399    /// ## Notes
400    ///
401    /// - This function uses Unicode-aware sentence and word segmentation.
402    /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
403    /// - Words are prefixed with the word break character (`▁`).
404    /// - Unknown tokens are replaced with the `<unk>` token.
405    pub fn tokenize(&self, text: &str) -> Vec<String> {
406        self.tokenize_iter(text).collect()
407    }
408
409    /// # Tokenizes a single sentence, adding sentence start and end markers.
410    ///
411    /// This function breaks down the tokenization process for a single sentence:
412    /// 1. Adds a sentence start token.
413    /// 2. Splits the sentence into words using Unicode-aware word segmentation.
414    /// 3. Prepends each word with the word break character.
415    /// 4. Tokenizes each word using the BPE vocabulary.
416    /// 5. Adds a sentence end token.
417    ///
418    /// ## Arguments
419    ///
420    /// * `sentence` - A string slice containing a single sentence to be tokenized.
421    ///
422    /// ## Returns
423    ///
424    /// An iterator that yields `String`s representing the tokenized sentence,
425    /// including start and end markers.
426    ///
427    /// ## Implementation Notes
428    ///
429    /// - Uses `unicode_words` for word segmentation to handle various Unicode scripts correctly.
430    /// - Converts words to lowercase before tokenization to match the vocabulary.
431    /// - Returns an iterator instead of a fully collected `Vec<String>` to allow for
432    ///   more efficient tokenization and processing.
433    pub(crate) fn tokenize_with_sentence_markers_iter<'a>(
434        &'a self,
435        sentence: &'a str,
436    ) -> impl Iterator<Item = String> + 'a {
437        iter::once(SENTENCE_START_TOKEN.to_string())
438            .chain(sentence.unicode_words().flat_map(move |word| {
439                self.tokenize_word(&format!("{}{}", WORD_BREAK_CHAR, word.to_lowercase()))
440            }))
441            .chain(iter::once(SENTENCE_END_TOKEN.to_string()))
442    }
443
444    /// # Tokenizes a single word using the Byte Pair Encoding (BPE) algorithm.
445    ///
446    /// This function implements the core BPE tokenization logic:
447    /// 1. If the word is empty, return an empty vector.
448    /// 2. Convert the word to a vector of Unicode characters.
449    /// 3. Iterate through possible substrings of the word, from longest to shortest.
450    /// 4. For each substring length, find all matching tokens in the vocabulary.
451    /// 5. Choose the matching token with the highest score in the vocabulary.
452    /// 6. Split the word at the chosen token and recursively tokenize the parts before and after.
453    /// 7. If no match is found, return the unknown token.
454    ///
455    /// ## Arguments
456    ///
457    /// * `text` - A string slice containing a single word to be tokenized.
458    ///
459    /// ## Returns
460    ///
461    /// A `Vec<String>` containing the BPE tokens for the input word.
462    ///
463    /// ## Implementation Notes
464    ///
465    /// - The algorithm prioritizes longer matches over shorter ones.
466    /// - In case of multiple matches of the same length, it chooses the one with the highest score.
467    /// - The function is recursive, handling subwords created by splitting at a matched token.
468    /// - If no match is found in the vocabulary, it returns the unknown token.
469    pub(crate) fn tokenize_word(&self, text: &str) -> Vec<String> {
470        // Base case: If the input is empty, return an empty vector
471        if text.is_empty() {
472            return vec![];
473        }
474
475        // Convert the `text` to a Vec of `char`s to index by character rather than byte
476        let word: Vec<char> = text.chars().collect();
477
478        // Look for the longest matching token in the vocabulary
479        for len in (1..=word.len()).rev() {
480            let mut matches = vec![];
481            // Iterate over each possible start position for substrings of length `len`
482            for start in 0..=(word.len() - len) {
483                let end = start + len;
484
485                // Extract candidate substring (convert chars[start..end] back to a &str)
486                let candidate = &word[start..end].iter().collect::<String>();
487
488                // If we have an exact match, just store it for now
489                if self.tokens.contains_key(candidate) {
490                    matches.push((candidate.to_string(), start, end));
491                }
492            }
493
494            // If we got matches, choose the one with the highest score
495            if !matches.is_empty() {
496                let (candidate, start, end) = matches
497                    .into_iter()
498                    .max_by_key(|(candidate, _, _)| {
499                        self.tokens.get(candidate).copied().unwrap_or(isize::MIN)
500                    })
501                    .unwrap();
502
503                // Recursively process the left part (before the match)
504                let left: String = word[..start].iter().collect();
505                let left_tokens = self.tokenize_word(&left);
506
507                // The middle part is the matched token
508                let middle = vec![candidate];
509
510                // Recursively process the right part (after the match)
511                let right: String = word[end..].iter().collect();
512                let right_tokens = self.tokenize_word(&right);
513
514                // Concatenate the result of left, middle, and right
515                return [left_tokens, middle, right_tokens].concat();
516            }
517        }
518
519        // If no match is found, return <unk> for the whole text
520        vec![UNKNOWN_TOKEN.to_string()]
521    }
522}