bpe_tokenizer/tokenizer.rs
1use std::{collections::HashMap, fs, iter};
2
3use unicode_segmentation::UnicodeSegmentation;
4
5use crate::{
6 constants::*,
7 default_vocabs::{new_default, DefaultVocab},
8 BytePairEncoderError,
9};
10
11/// # Represents a Byte Pair Encoding (BPE) vocabulary used for tokenization.
12///
13/// This struct holds the mapping of tokens to their respective scores and provides methods for
14/// tokenizing text using the BPE algorithm.
15///
16/// The vocabulary is typically loaded from a file or string where each line
17/// contains a token and its score, separated by a tab character.
18///
19/// ## Example
20///
21/// ```
22/// use bpe_tokenizer::BytePairEncoder;
23///
24/// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
25/// let tokenized = vocab.tokenize("Hello, world!");
26/// ```
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct BytePairEncoder {
29 /// # A mapping of tokens to their respective scores.
30 ///
31 /// In BPE, tokens with higher scores are typically more common and are preferred during the
32 /// tokenization process.
33 pub(crate) tokens: HashMap<String, isize>,
34}
35
36impl BytePairEncoder {
37 /// # Creates a new `BytePairEncoder` from a file containing token-score pairs.
38 ///
39 /// This function reads the contents of the file specified by `file_path` and constructs
40 /// a `BytePairEncoder` from it. The file should contain token-score pairs, with each pair
41 /// on a separate line and the token and score separated by a tab character (`\t`).
42 ///
43 /// ## Input Format
44 ///
45 /// The file is expected to follow this format:
46 ///
47 /// ```text
48 /// <token>\t<score>\n
49 /// ```
50 ///
51 /// Each line should consist of:
52 /// * A token (a string) followed by a tab character (`\t`)
53 /// * A score (an integer) as either a positive or negative value.
54 ///
55 /// Example lines from the file:
56 ///
57 /// ```text
58 /// <unk> 0
59 /// ▁t -0
60 /// ▁the -4
61 /// ```
62 ///
63 /// ## Arguments
64 ///
65 /// * `file_path` - A string slice that holds the path to the file containing token-score pairs.
66 ///
67 /// ## Returns
68 ///
69 /// * `Result<Self, BytePairEncoderError>` - A Result containing the created `BytePairEncoder` if successful,
70 /// or a `BytePairEncoderError` if there was an error reading the file or parsing its contents.
71 ///
72 /// ## Errors
73 ///
74 /// This function will return an error if:
75 /// * The file cannot be read (returns `BytePairEncoderError::InvalidFile`)
76 /// * The file contents are not in the expected format (returns `BytePairEncoderError::InvalidVocabularyInput`)
77 ///
78 /// ## Example
79 ///
80 /// ```
81 /// use bpe_tokenizer::BytePairEncoder;
82 ///
83 /// let vocab = BytePairEncoder::new_from_file("path/to/vocabulary/file.txt");
84 /// ```
85 pub fn new_from_file(file_path: &str) -> Result<Self, BytePairEncoderError> {
86 Self::new_from_str(
87 fs::read_to_string(file_path)
88 .map_err(|_| BytePairEncoderError::InvalidFile(file_path.to_string()))?
89 .as_ref(),
90 )
91 }
92
93 /// # Creates a new `BytePairEncoder` from a string containing token-score pairs.
94 ///
95 /// This function parses the input string to construct a `BytePairEncoder`. The input should
96 /// contain token-score pairs, with each pair on a separate line and the token and score
97 /// separated by a tab character (`\t`).
98 ///
99 /// ## Input Format
100 ///
101 /// The string must follow this format:
102 ///
103 /// ```text
104 /// <token>\t<score>\n
105 /// ```
106 ///
107 /// Each line in the string should consist of:
108 /// * A token (a string) followed by a tab character (`\t`)
109 /// * A score (an integer) as either a positive or negative value.
110 ///
111 /// For example:
112 ///
113 /// ```text
114 /// hello 1
115 /// world 2
116 /// ▁the -4
117 /// ```
118 ///
119 /// ## Arguments
120 ///
121 /// * `input` - A string slice that holds the token-score pairs.
122 ///
123 /// ## Returns
124 ///
125 /// * `Result<Self, BytePairEncoderError>` - A Result containing the created `BytePairEncoder` if successful,
126 /// or a `BytePairEncoderError` if there was an error parsing the input.
127 ///
128 /// ## Errors
129 ///
130 /// This function will return `BytePairEncoderError::InvalidVocabularyInput` if:
131 /// * A line doesn't contain a tab character to separate token and score.
132 /// * The score cannot be parsed as an `isize`.
133 ///
134 /// ## Example
135 ///
136 /// ```
137 /// use bpe_tokenizer::BytePairEncoder;
138 ///
139 /// let input = "hello\t1\nworld\t2";
140 /// let vocab = BytePairEncoder::new_from_str(input).unwrap();
141 /// ```
142 pub fn new_from_str(input: &str) -> Result<Self, BytePairEncoderError> {
143 let mut tokens = HashMap::new();
144
145 for line in input.lines() {
146 let (token, score_str) = match line.split_once('\t') {
147 Some(pair) => pair,
148 None => return Err(BytePairEncoderError::InvalidVocabularyInput),
149 };
150 let score = match score_str.parse::<isize>() {
151 Ok(score) => score,
152 Err(_) => return Err(BytePairEncoderError::InvalidVocabularyInput),
153 };
154 tokens.insert(token.to_string(), score);
155 }
156
157 Ok(BytePairEncoder { tokens })
158 }
159
160 /// # Creates a new `BytePairEncoder` with a default small vocabulary size (100,000 tokens).
161 ///
162 /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
163 /// that supports 275 languages. The vocabulary is sourced from the
164 /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The small-sized
165 /// vocabulary file consists of 100,000 tokens, allowing for highly compressed tokenization
166 /// suitable for tasks with limited memory constraints.
167 ///
168 /// ## Returns
169 ///
170 /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
171 /// vocabulary loading, or a corresponding error if initialization fails.
172 ///
173 /// ## Example
174 ///
175 /// ```
176 /// # #[cfg(feature = "default-small")] {
177 /// use bpe_tokenizer::BytePairEncoder;
178 ///
179 /// let encoder = BytePairEncoder::new_default_small().unwrap();
180 /// # }
181 /// ```
182 ///
183 /// ## Note
184 ///
185 /// This is only enabled when the `default-small` feature is enabled in Cargo.toml.
186 ///
187 /// ```toml
188 /// [dependencies]
189 /// bpe-tokenizer = { version = "<version", features = ["default-small"] }
190 /// ```
191 pub fn new_default_small() -> Result<Self, BytePairEncoderError> {
192 new_default(DefaultVocab::Small)
193 }
194
195 /// # Creates a new `BytePairEncoder` with a default medium vocabulary size (320,000 tokens).
196 ///
197 /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
198 /// that supports 275 languages. The vocabulary is sourced from the
199 /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The
200 /// medium-sized vocabulary file consists of 320,000 tokens, offering a balance between token
201 /// coverage and memory efficiency, making it suitable for a wide variety of NLP tasks.
202 ///
203 /// ## Returns
204 ///
205 /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
206 /// vocabulary loading, or a corresponding error if initialization fails.
207 ///
208 /// ## Example
209 ///
210 /// ```
211 /// # #[cfg(feature = "default-medium")] {
212 /// use bpe_tokenizer::BytePairEncoder;
213 ///
214 /// let encoder = BytePairEncoder::new_default_medium().unwrap();
215 /// # }
216 /// ```
217 ///
218 /// ## Note
219 ///
220 /// This is only enabled when the `default-medium` feature is enabled in Cargo.toml.
221 ///
222 /// ```toml
223 /// [dependencies]
224 /// bpe-tokenizer = { version = "<version", features = ["default-medium"] }
225 /// ```
226 pub fn new_default_medium() -> Result<Self, BytePairEncoderError> {
227 new_default(DefaultVocab::Medium)
228 }
229
230 /// # Creates a new `BytePairEncoder` with a default large vocabulary size (1,000,000 tokens).
231 ///
232 /// This function constructs a `BytePairEncoder` using a pre-trained multilingual vocabulary
233 /// that supports 275 languages. The vocabulary is sourced from the
234 /// [BPEmb](https://github.com/bheinzerling/bpemb) project, licensed under MIT. The large-sized
235 /// vocabulary consists of 1,000,000 tokens, providing maximum coverage for detailed language
236 /// representation, especially useful in applications requiring high granularity.
237 ///
238 /// ## Returns
239 ///
240 /// A `Result<Self, BytePairEncoderError>`, constructing the `BytePairEncoder` on successful
241 /// vocabulary loading, or a corresponding error if initialization fails.
242 ///
243 /// ## Example
244 ///
245 /// ```
246 /// # #[cfg(feature = "default-large")] {
247 /// use bpe_tokenizer::BytePairEncoder;
248 ///
249 /// let encoder = BytePairEncoder::new_default_large().unwrap();
250 /// # }
251 /// ```
252 ///
253 /// ## Note
254 ///
255 /// This is only enabled when the `default-large` feature is enabled in Cargo.toml.
256 ///
257 /// ```toml
258 /// [dependencies]
259 /// bpe-tokenizer = { version = "<version", features = ["default-large"] }
260 /// ```
261 pub fn new_default_large() -> Result<Self, BytePairEncoderError> {
262 new_default(DefaultVocab::Large)
263 }
264
265 /// # Tokenizes a text into sentences, then words, and finally into BPE tokens.
266 ///
267 /// This function takes a string of text and returns an iterator that yields
268 /// vectors of tokens, where each vector represents a tokenized sentence.
269 ///
270 /// ## Arguments
271 ///
272 /// * `text` - A string slice containing the text to be tokenized.
273 ///
274 /// ## Returns
275 ///
276 /// An iterator that yields `Vec<String>`, where each `Vec<String>` represents
277 /// a tokenized sentence.
278 ///
279 /// ## Example
280 ///
281 /// ```
282 /// use bpe_tokenizer::BytePairEncoder;
283 ///
284 /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
285 /// let text = "Hello, world! How are you?";
286 /// let tokenized: Vec<Vec<String>> = vocab
287 /// .tokenize_sentences_iter(text)
288 /// .map(|sentence_iter| sentence_iter.collect()) // Collect each inner iterator into a Vec<String>
289 /// .collect(); // Then collect everything into Vec<Vec<String>>
290 /// ```
291 ///
292 /// ## Notes
293 ///
294 /// - This function uses Unicode-aware sentence and word segmentation.
295 /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
296 /// - Words are prefixed with the word break character (`▁`).
297 /// - Unknown tokens are replaced with the `<unk>` token.
298 pub fn tokenize_sentences_iter<'a>(
299 &'a self,
300 text: &'a str,
301 ) -> impl Iterator<Item = impl Iterator<Item = String> + 'a> + 'a {
302 UnicodeSegmentation::unicode_sentences(text)
303 .map(move |sentence| self.tokenize_with_sentence_markers_iter(sentence))
304 }
305
306 /// # Tokenizes a text into a flat sequence of BPE tokens.
307 ///
308 /// This function takes a string of text and returns an iterator that yields
309 /// individual tokens. It first tokenizes the text into sentences, then words,
310 /// and finally into BPE tokens, flattening the result into a single sequence.
311 ///
312 /// ## Arguments
313 ///
314 /// * `text` - A string slice containing the text to be tokenized.
315 ///
316 /// ## Returns
317 ///
318 /// An iterator that yields `String`, where each `String` represents a token.
319 ///
320 /// ## Example
321 ///
322 /// ```
323 /// use bpe_tokenizer::BytePairEncoder;
324 ///
325 /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
326 /// let text = "Hello, world! How are you?";
327 /// let tokenized: Vec<String> = vocab.tokenize_iter(text).collect();
328 /// ```
329 ///
330 /// ## Notes
331 ///
332 /// - This function uses Unicode-aware sentence and word segmentation.
333 /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
334 /// - Words are prefixed with the word break character (`▁`).
335 /// - Unknown tokens are replaced with the `<unk>` token.
336 pub fn tokenize_iter<'a>(&'a self, text: &'a str) -> impl Iterator<Item = String> + 'a {
337 self.tokenize_sentences_iter(text).flatten()
338 }
339
340 /// # Tokenizes a text into sentences, then words, and finally into BPE tokens.
341 ///
342 /// This function takes a string of text and returns a vector of tokenized sentences,
343 /// where each sentence is represented as a vector of tokens.
344 ///
345 /// ## Arguments
346 ///
347 /// * `text` - A string slice containing the text to be tokenized.
348 ///
349 /// ## Returns
350 ///
351 /// A `Vec<Vec<String>>`, where each inner `Vec<String>` represents a tokenized sentence.
352 ///
353 /// ## Example
354 ///
355 /// ```
356 /// use bpe_tokenizer::BytePairEncoder;
357 ///
358 /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
359 /// let text = "Hello, world! How are you?";
360 /// let tokenized = vocab.tokenize_sentences(text);
361 /// ```
362 ///
363 /// ## Notes
364 ///
365 /// - This function uses Unicode-aware sentence and word segmentation.
366 /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
367 /// - Words are prefixed with the word break character (`▁`).
368 /// - Unknown tokens are replaced with the `<unk>` token.
369 pub fn tokenize_sentences(&self, text: &str) -> Vec<Vec<String>> {
370 self.tokenize_sentences_iter(text)
371 .map(|sentence_iter| sentence_iter.collect())
372 .collect()
373 }
374
375 /// # Tokenizes a text into a flat sequence of BPE tokens.
376 ///
377 /// This function takes a string of text and returns a vector of tokens.
378 /// It first tokenizes the text into sentences, then words, and finally into BPE tokens,
379 /// flattening the result into a single sequence.
380 ///
381 /// ## Arguments
382 ///
383 /// * `text` - A string slice containing the text to be tokenized.
384 ///
385 /// ## Returns
386 ///
387 /// A `Vec<String>`, where each `String` represents a token.
388 ///
389 /// ## Example
390 ///
391 /// ```
392 /// use bpe_tokenizer::BytePairEncoder;
393 ///
394 /// let vocab = BytePairEncoder::new_from_str("hello\t1\nworld\t2").unwrap();
395 /// let text = "Hello, world! How are you?";
396 /// let tokenized = vocab.tokenize(text);
397 /// ```
398 ///
399 /// ## Notes
400 ///
401 /// - This function uses Unicode-aware sentence and word segmentation.
402 /// - Each sentence is wrapped with sentence start (`<s>`) and end (`</s>`) tokens.
403 /// - Words are prefixed with the word break character (`▁`).
404 /// - Unknown tokens are replaced with the `<unk>` token.
405 pub fn tokenize(&self, text: &str) -> Vec<String> {
406 self.tokenize_iter(text).collect()
407 }
408
409 /// # Tokenizes a single sentence, adding sentence start and end markers.
410 ///
411 /// This function breaks down the tokenization process for a single sentence:
412 /// 1. Adds a sentence start token.
413 /// 2. Splits the sentence into words using Unicode-aware word segmentation.
414 /// 3. Prepends each word with the word break character.
415 /// 4. Tokenizes each word using the BPE vocabulary.
416 /// 5. Adds a sentence end token.
417 ///
418 /// ## Arguments
419 ///
420 /// * `sentence` - A string slice containing a single sentence to be tokenized.
421 ///
422 /// ## Returns
423 ///
424 /// An iterator that yields `String`s representing the tokenized sentence,
425 /// including start and end markers.
426 ///
427 /// ## Implementation Notes
428 ///
429 /// - Uses `unicode_words` for word segmentation to handle various Unicode scripts correctly.
430 /// - Converts words to lowercase before tokenization to match the vocabulary.
431 /// - Returns an iterator instead of a fully collected `Vec<String>` to allow for
432 /// more efficient tokenization and processing.
433 pub(crate) fn tokenize_with_sentence_markers_iter<'a>(
434 &'a self,
435 sentence: &'a str,
436 ) -> impl Iterator<Item = String> + 'a {
437 iter::once(SENTENCE_START_TOKEN.to_string())
438 .chain(sentence.unicode_words().flat_map(move |word| {
439 self.tokenize_word(&format!("{}{}", WORD_BREAK_CHAR, word.to_lowercase()))
440 }))
441 .chain(iter::once(SENTENCE_END_TOKEN.to_string()))
442 }
443
444 /// # Tokenizes a single word using the Byte Pair Encoding (BPE) algorithm.
445 ///
446 /// This function implements the core BPE tokenization logic:
447 /// 1. If the word is empty, return an empty vector.
448 /// 2. Convert the word to a vector of Unicode characters.
449 /// 3. Iterate through possible substrings of the word, from longest to shortest.
450 /// 4. For each substring length, find all matching tokens in the vocabulary.
451 /// 5. Choose the matching token with the highest score in the vocabulary.
452 /// 6. Split the word at the chosen token and recursively tokenize the parts before and after.
453 /// 7. If no match is found, return the unknown token.
454 ///
455 /// ## Arguments
456 ///
457 /// * `text` - A string slice containing a single word to be tokenized.
458 ///
459 /// ## Returns
460 ///
461 /// A `Vec<String>` containing the BPE tokens for the input word.
462 ///
463 /// ## Implementation Notes
464 ///
465 /// - The algorithm prioritizes longer matches over shorter ones.
466 /// - In case of multiple matches of the same length, it chooses the one with the highest score.
467 /// - The function is recursive, handling subwords created by splitting at a matched token.
468 /// - If no match is found in the vocabulary, it returns the unknown token.
469 pub(crate) fn tokenize_word(&self, text: &str) -> Vec<String> {
470 // Base case: If the input is empty, return an empty vector
471 if text.is_empty() {
472 return vec![];
473 }
474
475 // Convert the `text` to a Vec of `char`s to index by character rather than byte
476 let word: Vec<char> = text.chars().collect();
477
478 // Look for the longest matching token in the vocabulary
479 for len in (1..=word.len()).rev() {
480 let mut matches = vec![];
481 // Iterate over each possible start position for substrings of length `len`
482 for start in 0..=(word.len() - len) {
483 let end = start + len;
484
485 // Extract candidate substring (convert chars[start..end] back to a &str)
486 let candidate = &word[start..end].iter().collect::<String>();
487
488 // If we have an exact match, just store it for now
489 if self.tokens.contains_key(candidate) {
490 matches.push((candidate.to_string(), start, end));
491 }
492 }
493
494 // If we got matches, choose the one with the highest score
495 if !matches.is_empty() {
496 let (candidate, start, end) = matches
497 .into_iter()
498 .max_by_key(|(candidate, _, _)| {
499 self.tokens.get(candidate).copied().unwrap_or(isize::MIN)
500 })
501 .unwrap();
502
503 // Recursively process the left part (before the match)
504 let left: String = word[..start].iter().collect();
505 let left_tokens = self.tokenize_word(&left);
506
507 // The middle part is the matched token
508 let middle = vec![candidate];
509
510 // Recursively process the right part (after the match)
511 let right: String = word[end..].iter().collect();
512 let right_tokens = self.tokenize_word(&right);
513
514 // Concatenate the result of left, middle, and right
515 return [left_tokens, middle, right_tokens].concat();
516 }
517 }
518
519 // If no match is found, return <unk> for the whole text
520 vec![UNKNOWN_TOKEN.to_string()]
521 }
522}