aleph_alpha_tokenizer/lib.rs
1#![doc(html_root_url = "https://docs.rs/aleph-alpha-tokenizer/0.3.0")]
2
3//! aleph-alpha-tokenizer is a fast word-piece-like tokenizer based on fst
4//!
5//! This can be used as a `Model` in huggingface's tokenizers, or standalone.
6//!
7//! By default, this library builds only the code to be used standalone. Add it
8//! to your `Cargo.toml` with the following `[dependencies]` entry:
9//!
10//! ```toml
11//! [dependencies]
12//! aleph-alpha-tokenizers = "0.3"
13//! ```
14//!
15//! If you want to use it together with `tokenizers`, you need to enable the
16//! `huggingface` feature, so the dependency entry becomes:
17//!
18//! ```toml
19//! [dependencies]
20//! aleph-alpha-tokenizers = { version = "0.3", features = ["huggingface"] }
21//! ```
22//!
23//! # Examples
24//!
25//! To use as a [`Model`](../tokenizers/tokenizer/trait.Model.html), you need
26//! to box it:
27//!
28//! ```
29//!# use std::error::Error;
30//!
31//!# #[cfg(feature = "huggingface")] {
32//! use tokenizers::{
33//! tokenizer::{EncodeInput, Model, Tokenizer},
34//! pre_tokenizers::bert::BertPreTokenizer,
35//! };
36//! use aleph_alpha_tokenizer::AlephAlphaTokenizer;
37//!
38//! let mut tokenizer = Tokenizer::new(
39//! Box::new(AlephAlphaTokenizer::from_vocab("vocab.txt")?));
40//! tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
41//! let _result = tokenizer.encode(
42//! EncodeInput::Single("Some Test".to_string()), true)?;
43//!# }
44//!# Ok::<_, Box<dyn Error + Send + Sync>>(())
45//! ```
46//!
47//! Remember this depends on the `huggingface` feature. Otherwise, you can use
48//! it directly:
49//!
50//! ```
51//!# use std::error::Error;
52//! use aleph_alpha_tokenizer::AlephAlphaTokenizer;
53//!
54//! let source_text = "Ein interessantes Beispiel";
55//! let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt")?;
56//! let mut ids: Vec<i64> = Vec::new();
57//! let mut ranges = Vec::new();
58//! tokenizer.tokens_into(source_text, &mut ids, &mut ranges, None);
59//! for (id, range) in ids.iter().zip(ranges.iter()) {
60//! let _token_source = &source_text[range.clone()];
61//! let _token_text = tokenizer.text_of(*id);
62//! let _is_special = tokenizer.is_special(*id);
63//! // etc.
64//! }
65//!# Ok::<_, Box<dyn Error + Send + Sync>>(())
66//! ```
67
68use fst::raw::{Fst, Output};
69use std::error::Error;
70use std::fs::File;
71use std::io::{BufRead, BufReader, BufWriter, Write};
72use std::mem::replace;
73use std::ops::Range;
74use std::path::PathBuf;
75
76#[cfg(feature = "huggingface")]
77use tokenizers::tokenizer::{Model, Token as HfToken};
78
79// TODO: this should be upstreamed into fst
80//
81// For now, we'll keep it here.
82#[inline]
83fn find_longest_prefix<D: AsRef<[u8]>>(fst: &Fst<D>, input: &[u8]) -> Option<(usize, u64)> {
84 let mut node = fst.root();
85 let mut out = Output::zero();
86 let mut last_match: Option<(usize, Output)> = None;
87 for (i, &b) in input.iter().enumerate() {
88 if let Some(trans_index) = node.find_input(b) {
89 let t = node.transition(trans_index);
90 node = fst.node(t.addr);
91 out = out.cat(t.out);
92 if node.is_final() {
93 last_match = Some((i + 1, out.cat(node.final_output())));
94 }
95 } else {
96 break;
97 }
98 }
99 last_match.map(|(i, o)| (i, o.value()))
100}
101
102// we use this to calculate offsets in characters instead of bytes
103fn char_offs(text: &str, last_known_char: usize, range: Range<usize>) -> usize {
104 text[range].chars().count() + last_known_char
105}
106
107/// A trait to be able to convert token IDs on the fly
108pub trait TokenID: PartialEq + Clone {
109 /// Get a zero value
110 fn zero() -> Self;
111
112 /// Convert a `u64` to `Self`
113 fn coerce(t: u64) -> Self;
114
115 /// Convert back into `u64`
116 fn restore(self) -> u64;
117}
118
119impl TokenID for u64 {
120 fn zero() -> Self {
121 0
122 }
123
124 #[inline(always)]
125 fn coerce(t: u64) -> Self {
126 t
127 }
128
129 #[inline(always)]
130 fn restore(self) -> u64 {
131 self
132 }
133}
134
135// This can be used in torch Tensors
136macro_rules! impl_token_id {
137 ($ty:ty, $zero:expr) => {
138 impl TokenID for $ty {
139 #[inline(always)]
140 fn zero() -> Self {
141 $zero
142 }
143
144 #[inline(always)]
145 fn coerce(t: u64) -> Self {
146 t as $ty
147 }
148
149 #[inline(always)]
150 fn restore(self) -> u64 {
151 self as u64
152 }
153 }
154 };
155}
156
157impl_token_id!(i64, 0);
158impl_token_id!(i32, 0);
159impl_token_id!(f64, 0.0);
160impl_token_id!(f32, 0.0);
161
162/// The Tokenizer. Use [`AlephAlphaTokenizer::from_vocab`] to create an
163/// instance.
164pub struct AlephAlphaTokenizer {
165 tokens: Vec<String>,
166 starters: Fst<Vec<u8>>,
167 followers: Fst<Vec<u8>>,
168 //TODO: perhaps use a SmallVec here
169 special_tokens: Vec<u64>,
170 unk_id: u32,
171 prefix: Option<u32>,
172 suffix: Option<u32>,
173}
174
175impl AlephAlphaTokenizer {
176 /// Creates a tokenizer from the vocabulary.
177 ///
178 /// For now, we assume the following tokens / IDs:
179 ///
180 /// * `[CLS]` is classification (and if present is used as prefix)
181 /// * `[SEP]` is separator (and if present is used as suffix)
182 /// * `[PAD]` is padding and is in position `0`
183 /// * `[UNK]` is the *unknonw* token specifier
184 pub fn from_vocab(path: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
185 let vocab = File::open(path)?;
186 let tokens = BufReader::new(vocab)
187 .lines()
188 .collect::<Result<Vec<String>, std::io::Error>>()?;
189 let mut starter: Vec<(Vec<u8>, u64)> = Vec::new();
190 let mut follower: Vec<(Vec<u8>, u64)> = Vec::new();
191 let mut special_tokens = Vec::new();
192 let mut unk_id = None;
193 let mut prefix = None;
194 let mut suffix = None;
195 for (i, tok) in tokens.iter().enumerate() {
196 let token = tok.trim().as_bytes();
197 if token.starts_with(b"[") && token.ends_with(b"]") {
198 if token.starts_with(b"[unused") {
199 continue;
200 }
201 if token == b"[UNK]" {
202 unk_id = Some(i as u32);
203 } else if token == b"[CLS]" {
204 prefix = Some(i as u32);
205 } else if token == b"[SEP]" {
206 suffix = Some(i as u32);
207 }
208 special_tokens.push(i as u64);
209 }
210 if token.starts_with(b"##") {
211 follower.push((token[2..].to_vec(), i as u64));
212 } else {
213 starter.push((token.to_vec(), i as u64));
214 }
215 }
216 let unk_id = if let Some(u) = unk_id {
217 u
218 } else {
219 return Err(Box::new(std::env::VarError::NotPresent));
220 };
221 starter.sort_by(|(k, _), (j, _)| k.cmp(j));
222 follower.sort_by(|(k, _), (j, _)| k.cmp(j));
223 let starters = Fst::from_iter_map(starter)?;
224 let followers = Fst::from_iter_map(follower)?;
225 Ok(AlephAlphaTokenizer {
226 tokens,
227 starters,
228 followers,
229 special_tokens,
230 unk_id,
231 prefix,
232 suffix,
233 })
234 }
235
236 /// Wraps a UTF8 byte range iterator to produce a tuple of (byte-range, character-range).
237 ///
238 /// # Examples
239 ///
240 /// ```
241 ///# use aleph_alpha_tokenizer::AlephAlphaTokenizer;
242 /// let text = "äußerst";
243 /// let ranges = &[0usize..3, 3..7, 7..9];
244 /// assert_eq!(&[(0..3, 0..2), (3..7, 2..5), (7..9, 5..7)],
245 /// &AlephAlphaTokenizer::char_ranges(text, ranges.iter()).collect::<Vec<_>>()[..]);
246 /// ```
247 pub fn char_ranges<'i>(
248 text: &'i str,
249 ranges: impl Iterator<Item = &'i Range<usize>> + 'i,
250 ) -> impl Iterator<Item = (Range<usize>, Range<usize>)> + 'i {
251 let (mut last_char, mut last_byte) = (0, 0);
252 ranges.map(move |r| {
253 let (s, e) = (r.start, r.end);
254 let cs = char_offs(text, last_char, last_byte..s);
255 last_char = char_offs(text, cs, s..e);
256 last_byte = e;
257 (r.clone(), cs..last_char)
258 })
259 }
260
261 #[inline]
262 fn add_prefix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
263 if let Some(id) = self.prefix {
264 token_ids.push(T::coerce(u64::from(id)));
265 token_ranges.push(0..0);
266 }
267 }
268
269 #[inline]
270 fn add_suffix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
271 if let Some(id) = self.suffix {
272 let pos = token_ranges.last().map_or(0, |range| range.end);
273 token_ids.push(T::coerce(u64::from(id)));
274 token_ranges.push(pos..pos);
275 }
276 }
277
278 fn tokenize_word<T: TokenID>(
279 &self,
280 text: &str,
281 range: Range<usize>,
282 token_ids: &mut Vec<T>,
283 token_ranges: &mut Vec<Range<usize>>,
284 ) {
285 let (start, end) = (range.start, range.end);
286 let word_index = token_ids.len();
287 let mut last_index = start;
288 if let Some((len, id)) = find_longest_prefix(&self.starters, text[start..end].as_bytes()) {
289 last_index = start + len;
290 token_ids.push(T::coerce(id));
291 token_ranges.push(start..last_index);
292 while last_index < end {
293 if let Some((len, id)) =
294 find_longest_prefix(&self.followers, &text[last_index..end].as_bytes())
295 {
296 let next_index = last_index + len;
297 token_ids.push(T::coerce(id));
298 token_ranges.push(last_index..replace(&mut last_index, next_index));
299 } else {
300 break;
301 }
302 }
303 }
304 if last_index < end {
305 assert!(word_index <= token_ids.len());
306 token_ids.truncate(word_index);
307 token_ids.push(T::coerce(u64::from(self.unk_id)));
308 token_ranges.truncate(word_index);
309 token_ranges.push(range);
310 }
311 }
312
313 /// tokenize the given text into a `&mut Vec<u64>` for ids and
314 /// `&mut Vec<Range<usize>>` for source ranges respectively, optionally
315 /// filling a `words` `&mut Vec<Range>` with ranges into the tokens array
316 /// with the words' token indices.
317 ///
318 /// This works by first splitting by whitespace, then gathering the longest
319 /// prefix in our token tree (first the starters, then the followers) until
320 /// the word is complete, or inserting a `[UNK]` token if the word couldn't
321 /// fully be tokenized. This is what wordpiece does, too.
322 ///
323 /// Note: The output `Vec`s will be cleared before appending tokens.
324 ///
325 /// # Examples
326 ///
327 /// ```
328 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
329 ///
330 /// let source_text = "Ein interessantes Beispiel";
331 /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
332 /// let mut ids: Vec<i32> = Vec::new();
333 /// let mut ranges = Vec::new();
334 /// tokenizer.tokens_into(source_text, &mut ids, &mut ranges, None);
335 /// assert_eq!(&[3, 198, 23181, 26902, 2249, 4], &ids[..]);
336 /// ```
337 pub fn tokens_into<T: TokenID>(
338 &self,
339 text: &str,
340 token_ids: &mut Vec<T>,
341 token_ranges: &mut Vec<Range<usize>>,
342 words: Option<&mut Vec<Range<usize>>>,
343 ) {
344 token_ids.clear();
345 token_ranges.clear();
346 let text_len = text.len();
347 let mut words = words;
348 if let Some(w) = words.as_mut() {
349 w.clear();
350 }
351 let mut last_offs = 0;
352 self.add_prefix(token_ids, token_ranges);
353 let mut last_token = token_ids.len();
354 //TODO: there may be a faster version of this using SIMD
355 while let Some(next_ws) = text[last_offs..].find(char::is_whitespace) {
356 if next_ws != 0 {
357 self.tokenize_word(
358 text,
359 last_offs..last_offs + next_ws,
360 token_ids,
361 token_ranges,
362 );
363 if let Some(w) = words.as_mut() {
364 w.push(last_token..replace(&mut last_token, token_ids.len()));
365 }
366 }
367 last_offs += next_ws;
368 last_offs += text[last_offs..].chars().next().unwrap_or('\0').len_utf8();
369 if let Some(non_ws) = text[last_offs..].find(|c: char| !c.is_whitespace()) {
370 last_offs += non_ws;
371 }
372 }
373 if last_offs < text_len {
374 self.tokenize_word(text, last_offs..text_len, token_ids, token_ranges);
375 }
376 self.add_suffix(token_ids, token_ranges);
377 }
378
379 /// Gets the text of this token.
380 ///
381 /// # Examples
382 ///
383 /// ```
384 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
385 /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
386 ///
387 /// assert_eq!("[PAD]", tokenizer.text_of(0));
388 /// ```
389 #[inline]
390 pub fn text_of<T: TokenID>(&self, token_id: T) -> &str {
391 &self.tokens[token_id.restore() as usize]
392 }
393
394 /// Gets the texts of the tokens.
395 ///
396 /// # Examples
397 ///
398 /// ```
399 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
400 /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
401 ///
402 /// assert_eq!(
403 /// vec!["[CLS]", "Super", "[SEP]"],
404 /// tokenizer.texts_of(&[3, 4285, 4])
405 /// );
406 /// ```
407 pub fn texts_of<'t, T: TokenID>(&'t self, token_ids: &[T]) -> Vec<&'t str> {
408 token_ids
409 .iter()
410 .cloned()
411 .map(|id| self.text_of(id))
412 .collect()
413 }
414
415 /// Determines whether this token is a special token.
416 ///
417 /// Special tokens are e.g. `[CLS]`, `[SEP]`, `[PAD]` or `[UNK]`.
418 ///
419 /// # Examples
420 ///
421 /// ```
422 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
423 /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
424 ///
425 /// assert!(tokenizer.is_special(0i32)); // [PAD]
426 /// assert!(tokenizer.is_special(3i32)); // [CLS]
427 /// assert!(tokenizer.is_special(4i32)); // [SEP]
428 /// assert!(!tokenizer.is_special(42i32));
429 /// ```
430 #[inline]
431 pub fn is_special<T: TokenID>(&self, token_id: T) -> bool {
432 self.special_tokens.contains(&token_id.restore())
433 }
434
435 /// Calculates the required attention for this token.
436 ///
437 /// # Examples
438 ///
439 /// ```
440 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
441 ///
442 /// let pad_attention: i64 = AlephAlphaTokenizer::attention(0u64);
443 /// let token_attention: f64 = AlephAlphaTokenizer::attention(99i32);
444 /// assert_eq!(pad_attention, 0);
445 /// assert_eq!(token_attention, 1.0f64);
446 /// ```
447 #[inline]
448 pub fn attention<T: TokenID, U: TokenID>(token_id: T) -> U {
449 if token_id == T::zero() {
450 U::zero()
451 } else {
452 U::coerce(1)
453 }
454 }
455
456 /// Given a slice of `[u64]`s, appends the attentions to the given `Vec`.
457 ///
458 /// # Examples
459 ///
460 /// ```
461 /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
462 ///
463 /// let mut attns: Vec<i32> = Vec::new();
464 /// AlephAlphaTokenizer::attentions_into(&[3, 4285, 4, 0, 0], &mut attns);
465 /// assert_eq!(&attns[..], &[1, 1, 1, 0, 0]);
466 /// ```
467 pub fn attentions_into<T: TokenID, U: TokenID>(token_ids: &[T], attns: &mut Vec<U>) {
468 attns.clear();
469 attns.extend(
470 token_ids
471 .iter()
472 .cloned()
473 .map(AlephAlphaTokenizer::attention),
474 );
475 }
476
477 /// Save the vocabulary back to a file
478 pub fn save_vocab(&self, vocab_path: PathBuf) -> Result<PathBuf, Box<dyn Error + Send + Sync>> {
479 let vocab = File::create(&vocab_path)?;
480 let mut vocab_writer = BufWriter::new(vocab);
481 for token in &self.tokens {
482 writeln!(vocab_writer, "{}", token)?;
483 }
484 //TODO: write out FSTs to reduce load time
485 Ok(vocab_path)
486 }
487}
488
489#[cfg(feature = "huggingface")]
490use std::{borrow::Cow, path::Path};
491
492/// This type implements the [`Model`] trait so you can use it within
493/// huggingface's tokenizers framework.
494#[cfg(feature = "huggingface")]
495impl Model for AlephAlphaTokenizer {
496 fn tokenize(
497 &self,
498 tokens: Vec<(String, (usize, usize))>,
499 ) -> Result<Vec<HfToken>, Box<dyn Error + Send + Sync>> {
500 // we expect at least one token per word.
501 let mut result = Vec::with_capacity(tokens.len());
502 for (index, (word_str, offsets)) in tokens.into_iter().enumerate() {
503 let word = index as u32;
504 let word_index = result.len();
505 let word_bytes = word_str.as_bytes();
506 let word_len = word_bytes.len();
507 let mut last_index = 0;
508 if let Some((start_index, id)) = find_longest_prefix(&self.starters, word_bytes) {
509 let value = word_str[..start_index].to_string();
510 let mut last_offset = offsets.0 + value.chars().count();
511 result.push(HfToken {
512 id: id as u32,
513 value,
514 offsets: (offsets.0, last_offset),
515 word,
516 });
517 last_index = start_index;
518 while last_index < word_len {
519 if let Some((len, id)) =
520 find_longest_prefix(&self.followers, &word_bytes[last_index..])
521 {
522 let value = &word_str[last_index..last_index + len];
523 let start = last_offset;
524 last_offset += value.chars().count();
525 result.push(HfToken {
526 id: id as u32,
527 value: "##".to_string() + value,
528 offsets: (start, last_offset),
529 word,
530 });
531 last_index += len;
532 } else {
533 break;
534 }
535 }
536 }
537 // in case we couldn't match the whole word, replace all we have so far with an [UNK] token
538 if last_index < word_len {
539 assert!(word_index <= result.len());
540 result.truncate(word_index);
541 result.push(HfToken {
542 id: self.unk_id,
543 value: "[UNK]".to_string(),
544 offsets: (offsets.0, offsets.1),
545 word,
546 });
547 }
548 }
549 Ok(result)
550 }
551
552 fn token_to_id(&self, token: &str) -> Option<u32> {
553 if token.starts_with("##") {
554 self.followers.get(&token[2..])
555 } else {
556 self.starters.get(token)
557 }
558 .map(|x| x.value() as u32)
559 }
560
561 fn id_to_token(&self, id: u32) -> Option<String> {
562 self.tokens.get(id as usize).cloned()
563 }
564
565 fn get_vocab_size(&self) -> usize {
566 self.tokens.len()
567 }
568
569 /// We won't implement this method because we don't store the tokens in
570 /// a `HashMap`, and doing so would increase our memory footprint
571 /// considerably.
572 fn get_vocab(&self) -> &std::collections::HashMap<String, u32> {
573 unimplemented!()
574 }
575
576 fn save(
577 &self,
578 folder: &Path,
579 name: Option<&str>,
580 ) -> Result<Vec<PathBuf>, Box<dyn Error + Send + Sync>> {
581 let vocab_name = name.map_or(Cow::Borrowed("vocab.txt"), |n| {
582 Cow::Borrowed(n) + "-vocab.txt"
583 });
584 let mut vocab_path = folder.to_path_buf();
585 vocab_path.push(&Path::new(vocab_name.as_ref()));
586 self.save_vocab(vocab_path).map(|p| vec![p])
587 }
588}