aleph_alpha_tokenizer/
lib.rs

1#![doc(html_root_url = "https://docs.rs/aleph-alpha-tokenizer/0.3.0")]
2
3//! aleph-alpha-tokenizer is a fast word-piece-like tokenizer based on fst
4//!
5//! This can be used as a `Model` in huggingface's tokenizers, or standalone.
6//!
7//! By default, this library builds only the code to be used standalone. Add it
8//! to your `Cargo.toml` with the following `[dependencies]` entry:
9//!
10//! ```toml
11//! [dependencies]
12//! aleph-alpha-tokenizers = "0.3"
13//! ```
14//!
15//! If you want to use it together with `tokenizers`, you need to enable the
16//! `huggingface` feature, so the dependency entry becomes:
17//!
18//! ```toml
19//! [dependencies]
20//! aleph-alpha-tokenizers = { version = "0.3", features = ["huggingface"] }
21//! ```
22//!
23//! # Examples
24//!
25//! To use as a [`Model`](../tokenizers/tokenizer/trait.Model.html), you need
26//! to box it:
27//!
28//! ```
29//!# use std::error::Error;
30//!
31//!# #[cfg(feature = "huggingface")] {
32//! use tokenizers::{
33//!     tokenizer::{EncodeInput, Model, Tokenizer},
34//!     pre_tokenizers::bert::BertPreTokenizer,
35//! };
36//! use aleph_alpha_tokenizer::AlephAlphaTokenizer;
37//!
38//! let mut tokenizer = Tokenizer::new(
39//!     Box::new(AlephAlphaTokenizer::from_vocab("vocab.txt")?));
40//! tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
41//! let _result = tokenizer.encode(
42//!     EncodeInput::Single("Some Test".to_string()), true)?;
43//!# }
44//!# Ok::<_, Box<dyn Error + Send + Sync>>(())
45//! ```
46//!
47//! Remember this depends on the `huggingface` feature. Otherwise, you can use
48//! it directly:
49//!
50//! ```
51//!# use std::error::Error;
52//! use aleph_alpha_tokenizer::AlephAlphaTokenizer;
53//!
54//! let source_text = "Ein interessantes Beispiel";
55//! let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt")?;
56//! let mut ids: Vec<i64> = Vec::new();
57//! let mut ranges = Vec::new();
58//! tokenizer.tokens_into(source_text, &mut ids, &mut ranges, None);
59//! for (id, range) in ids.iter().zip(ranges.iter()) {
60//!      let _token_source = &source_text[range.clone()];
61//!      let _token_text = tokenizer.text_of(*id);
62//!      let _is_special = tokenizer.is_special(*id);
63//!      // etc.
64//! }
65//!# Ok::<_, Box<dyn Error + Send + Sync>>(())
66//! ```
67
68use fst::raw::{Fst, Output};
69use std::error::Error;
70use std::fs::File;
71use std::io::{BufRead, BufReader, BufWriter, Write};
72use std::mem::replace;
73use std::ops::Range;
74use std::path::PathBuf;
75
76#[cfg(feature = "huggingface")]
77use tokenizers::tokenizer::{Model, Token as HfToken};
78
79// TODO: this should be upstreamed into fst
80//
81// For now, we'll keep it here.
82#[inline]
83fn find_longest_prefix<D: AsRef<[u8]>>(fst: &Fst<D>, input: &[u8]) -> Option<(usize, u64)> {
84    let mut node = fst.root();
85    let mut out = Output::zero();
86    let mut last_match: Option<(usize, Output)> = None;
87    for (i, &b) in input.iter().enumerate() {
88        if let Some(trans_index) = node.find_input(b) {
89            let t = node.transition(trans_index);
90            node = fst.node(t.addr);
91            out = out.cat(t.out);
92            if node.is_final() {
93                last_match = Some((i + 1, out.cat(node.final_output())));
94            }
95        } else {
96            break;
97        }
98    }
99    last_match.map(|(i, o)| (i, o.value()))
100}
101
102// we use this to calculate offsets in characters instead of bytes
103fn char_offs(text: &str, last_known_char: usize, range: Range<usize>) -> usize {
104    text[range].chars().count() + last_known_char
105}
106
107/// A trait to be able to convert token IDs on the fly
108pub trait TokenID: PartialEq + Clone {
109    /// Get a zero value
110    fn zero() -> Self;
111
112    /// Convert a `u64` to `Self`
113    fn coerce(t: u64) -> Self;
114
115    /// Convert back into `u64`
116    fn restore(self) -> u64;
117}
118
119impl TokenID for u64 {
120    fn zero() -> Self {
121        0
122    }
123
124    #[inline(always)]
125    fn coerce(t: u64) -> Self {
126        t
127    }
128
129    #[inline(always)]
130    fn restore(self) -> u64 {
131        self
132    }
133}
134
135// This can be used in torch Tensors
136macro_rules! impl_token_id {
137    ($ty:ty, $zero:expr) => {
138        impl TokenID for $ty {
139            #[inline(always)]
140            fn zero() -> Self {
141                $zero
142            }
143            
144            #[inline(always)]
145            fn coerce(t: u64) -> Self {
146                t as $ty
147            }
148            
149            #[inline(always)]
150            fn restore(self) -> u64 {
151                self as u64
152            }
153        }
154    };
155}
156            
157impl_token_id!(i64, 0);
158impl_token_id!(i32, 0);
159impl_token_id!(f64, 0.0);
160impl_token_id!(f32, 0.0);
161
162/// The Tokenizer. Use [`AlephAlphaTokenizer::from_vocab`] to create an
163/// instance.
164pub struct AlephAlphaTokenizer {
165    tokens: Vec<String>,
166    starters: Fst<Vec<u8>>,
167    followers: Fst<Vec<u8>>,
168    //TODO: perhaps use a SmallVec here
169    special_tokens: Vec<u64>,
170    unk_id: u32,
171    prefix: Option<u32>,
172    suffix: Option<u32>,
173}
174
175impl AlephAlphaTokenizer {
176    /// Creates a tokenizer from the vocabulary.
177    ///
178    /// For now, we assume the following tokens / IDs:
179    ///
180    /// * `[CLS]` is classification (and if present is used as prefix)
181    /// * `[SEP]` is separator (and if present is used as suffix)
182    /// * `[PAD]` is padding and is in position `0`
183    /// * `[UNK]` is the *unknonw* token specifier
184    pub fn from_vocab(path: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
185        let vocab = File::open(path)?;
186        let tokens = BufReader::new(vocab)
187            .lines()
188            .collect::<Result<Vec<String>, std::io::Error>>()?;
189        let mut starter: Vec<(Vec<u8>, u64)> = Vec::new();
190        let mut follower: Vec<(Vec<u8>, u64)> = Vec::new();
191        let mut special_tokens = Vec::new();
192        let mut unk_id = None;
193        let mut prefix = None;
194        let mut suffix = None;
195        for (i, tok) in tokens.iter().enumerate() {
196            let token = tok.trim().as_bytes();
197            if token.starts_with(b"[") && token.ends_with(b"]") {
198                if token.starts_with(b"[unused") {
199                    continue;
200                }
201                if token == b"[UNK]" {
202                    unk_id = Some(i as u32);
203                } else if token == b"[CLS]" {
204                    prefix = Some(i as u32);
205                } else if token == b"[SEP]" {
206                    suffix = Some(i as u32);
207                }
208                special_tokens.push(i as u64);
209            }
210            if token.starts_with(b"##") {
211                follower.push((token[2..].to_vec(), i as u64));
212            } else {
213                starter.push((token.to_vec(), i as u64));
214            }
215        }
216        let unk_id = if let Some(u) = unk_id {
217            u
218        } else {
219            return Err(Box::new(std::env::VarError::NotPresent));
220        };
221        starter.sort_by(|(k, _), (j, _)| k.cmp(j));
222        follower.sort_by(|(k, _), (j, _)| k.cmp(j));
223        let starters = Fst::from_iter_map(starter)?;
224        let followers = Fst::from_iter_map(follower)?;
225        Ok(AlephAlphaTokenizer {
226            tokens,
227            starters,
228            followers,
229            special_tokens,
230            unk_id,
231            prefix,
232            suffix,
233        })
234    }
235
236    /// Wraps a UTF8 byte range iterator to produce a tuple of (byte-range, character-range).
237    ///
238    /// # Examples
239    ///
240    /// ```
241    ///# use aleph_alpha_tokenizer::AlephAlphaTokenizer;
242    /// let text = "äußerst";
243    /// let ranges = &[0usize..3, 3..7, 7..9];
244    /// assert_eq!(&[(0..3, 0..2), (3..7, 2..5), (7..9, 5..7)],
245    ///     &AlephAlphaTokenizer::char_ranges(text, ranges.iter()).collect::<Vec<_>>()[..]);
246    /// ```
247    pub fn char_ranges<'i>(
248        text: &'i str,
249        ranges: impl Iterator<Item = &'i Range<usize>> + 'i,
250    ) -> impl Iterator<Item = (Range<usize>, Range<usize>)> + 'i {
251        let (mut last_char, mut last_byte) = (0, 0);
252        ranges.map(move |r| {
253            let (s, e) = (r.start, r.end);
254            let cs = char_offs(text, last_char, last_byte..s);
255            last_char = char_offs(text, cs, s..e);
256            last_byte = e;
257            (r.clone(), cs..last_char)
258        })
259    }
260
261    #[inline]
262    fn add_prefix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
263        if let Some(id) = self.prefix {
264            token_ids.push(T::coerce(u64::from(id)));
265            token_ranges.push(0..0);
266        }
267    }
268
269    #[inline]
270    fn add_suffix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
271        if let Some(id) = self.suffix {
272            let pos = token_ranges.last().map_or(0, |range| range.end);
273            token_ids.push(T::coerce(u64::from(id)));
274            token_ranges.push(pos..pos);
275        }
276    }
277
278    fn tokenize_word<T: TokenID>(
279        &self,
280        text: &str,
281        range: Range<usize>,
282        token_ids: &mut Vec<T>,
283        token_ranges: &mut Vec<Range<usize>>,
284    ) {
285        let (start, end) = (range.start, range.end);
286        let word_index = token_ids.len();
287        let mut last_index = start;
288        if let Some((len, id)) = find_longest_prefix(&self.starters, text[start..end].as_bytes()) {
289            last_index = start + len;
290            token_ids.push(T::coerce(id));
291            token_ranges.push(start..last_index);
292            while last_index < end {
293                if let Some((len, id)) =
294                    find_longest_prefix(&self.followers, &text[last_index..end].as_bytes())
295                {
296                    let next_index = last_index + len;
297                    token_ids.push(T::coerce(id));
298                    token_ranges.push(last_index..replace(&mut last_index, next_index));
299                } else {
300                    break;
301                }
302            }
303        }
304        if last_index < end {
305            assert!(word_index <= token_ids.len());
306            token_ids.truncate(word_index);
307            token_ids.push(T::coerce(u64::from(self.unk_id)));
308            token_ranges.truncate(word_index);
309            token_ranges.push(range);
310        }
311    }
312
313    /// tokenize the given text into a `&mut Vec<u64>` for ids and
314    /// `&mut Vec<Range<usize>>` for source ranges respectively, optionally
315    /// filling a `words` `&mut Vec<Range>` with ranges into the tokens array
316    /// with the words' token indices.
317    ///
318    /// This works by first splitting by whitespace, then gathering the longest
319    /// prefix in our token tree (first the starters, then the followers) until
320    /// the word is complete, or inserting a `[UNK]` token if the word couldn't
321    /// fully be tokenized. This is what wordpiece does, too.
322    ///
323    /// Note: The output `Vec`s will be cleared before appending tokens.
324    ///
325    /// # Examples
326    ///
327    /// ```
328    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
329    ///
330    /// let source_text = "Ein interessantes Beispiel";
331    /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
332    /// let mut ids: Vec<i32> = Vec::new();
333    /// let mut ranges = Vec::new();
334    /// tokenizer.tokens_into(source_text, &mut ids, &mut ranges, None);
335    /// assert_eq!(&[3, 198, 23181, 26902, 2249, 4], &ids[..]);
336    /// ```
337    pub fn tokens_into<T: TokenID>(
338        &self,
339        text: &str,
340        token_ids: &mut Vec<T>,
341        token_ranges: &mut Vec<Range<usize>>,
342        words: Option<&mut Vec<Range<usize>>>,
343    ) {
344        token_ids.clear();
345        token_ranges.clear();
346        let text_len = text.len();
347        let mut words = words;
348        if let Some(w) = words.as_mut() {
349            w.clear();
350        }
351        let mut last_offs = 0;
352        self.add_prefix(token_ids, token_ranges);
353        let mut last_token = token_ids.len();
354        //TODO: there may be a faster version of this using SIMD
355        while let Some(next_ws) = text[last_offs..].find(char::is_whitespace) {
356            if next_ws != 0 {
357                self.tokenize_word(
358                    text,
359                    last_offs..last_offs + next_ws,
360                    token_ids,
361                    token_ranges,
362                );
363                if let Some(w) = words.as_mut() {
364                    w.push(last_token..replace(&mut last_token, token_ids.len()));
365                }
366            }
367            last_offs += next_ws;
368            last_offs += text[last_offs..].chars().next().unwrap_or('\0').len_utf8();
369            if let Some(non_ws) = text[last_offs..].find(|c: char| !c.is_whitespace()) {
370                last_offs += non_ws;
371            }
372        }
373        if last_offs < text_len {
374            self.tokenize_word(text, last_offs..text_len, token_ids, token_ranges);
375        }
376        self.add_suffix(token_ids, token_ranges);
377    }
378
379    /// Gets the text of this token.
380    ///
381    /// # Examples
382    ///
383    /// ```
384    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
385    /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
386    ///
387    /// assert_eq!("[PAD]", tokenizer.text_of(0));
388    /// ```
389    #[inline]
390    pub fn text_of<T: TokenID>(&self, token_id: T) -> &str {
391        &self.tokens[token_id.restore() as usize]
392    }
393
394    /// Gets the texts of the tokens.
395    ///
396    /// # Examples
397    ///
398    /// ```
399    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
400    /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
401    ///
402    /// assert_eq!(
403    ///     vec!["[CLS]", "Super", "[SEP]"],
404    ///     tokenizer.texts_of(&[3, 4285, 4])
405    /// );
406    /// ```
407    pub fn texts_of<'t, T: TokenID>(&'t self, token_ids: &[T]) -> Vec<&'t str> {
408        token_ids
409            .iter()
410            .cloned()
411            .map(|id| self.text_of(id))
412            .collect()
413    }
414
415    /// Determines whether this token is a special token.
416    ///
417    /// Special tokens are e.g. `[CLS]`, `[SEP]`, `[PAD]` or `[UNK]`.
418    ///
419    /// # Examples
420    ///
421    /// ```
422    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
423    /// let tokenizer = AlephAlphaTokenizer::from_vocab("vocab.txt").unwrap();
424    ///
425    /// assert!(tokenizer.is_special(0i32)); // [PAD]
426    /// assert!(tokenizer.is_special(3i32));  // [CLS]
427    /// assert!(tokenizer.is_special(4i32));  // [SEP]
428    /// assert!(!tokenizer.is_special(42i32));
429    /// ```
430    #[inline]
431    pub fn is_special<T: TokenID>(&self, token_id: T) -> bool {
432        self.special_tokens.contains(&token_id.restore())
433    }
434
435    /// Calculates the required attention for this token.
436    ///
437    /// # Examples
438    ///
439    /// ```
440    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
441    ///
442    /// let pad_attention: i64 = AlephAlphaTokenizer::attention(0u64);
443    /// let token_attention: f64 = AlephAlphaTokenizer::attention(99i32);
444    /// assert_eq!(pad_attention, 0);
445    /// assert_eq!(token_attention, 1.0f64);
446    /// ```
447    #[inline]
448    pub fn attention<T: TokenID, U: TokenID>(token_id: T) -> U {
449        if token_id == T::zero() {
450            U::zero()
451        } else {
452            U::coerce(1)
453        }
454    }
455
456    /// Given a slice of `[u64]`s, appends the attentions to the given `Vec`.
457    ///
458    /// # Examples
459    ///
460    /// ```
461    /// use aleph_alpha_tokenizer::AlephAlphaTokenizer;
462    ///
463    /// let mut attns: Vec<i32> = Vec::new();
464    /// AlephAlphaTokenizer::attentions_into(&[3, 4285, 4, 0, 0], &mut attns);
465    /// assert_eq!(&attns[..], &[1, 1, 1, 0, 0]);
466    /// ```
467    pub fn attentions_into<T: TokenID, U: TokenID>(token_ids: &[T], attns: &mut Vec<U>) {
468        attns.clear();
469        attns.extend(
470            token_ids
471                .iter()
472                .cloned()
473                .map(AlephAlphaTokenizer::attention),
474        );
475    }
476
477    /// Save the vocabulary back to a file
478    pub fn save_vocab(&self, vocab_path: PathBuf) -> Result<PathBuf, Box<dyn Error + Send + Sync>> {
479        let vocab = File::create(&vocab_path)?;
480        let mut vocab_writer = BufWriter::new(vocab);
481        for token in &self.tokens {
482            writeln!(vocab_writer, "{}", token)?;
483        }
484        //TODO: write out FSTs to reduce load time
485        Ok(vocab_path)
486    }
487}
488
489#[cfg(feature = "huggingface")]
490use std::{borrow::Cow, path::Path};
491
492/// This type implements the [`Model`] trait so you can use it within
493/// huggingface's tokenizers framework.
494#[cfg(feature = "huggingface")]
495impl Model for AlephAlphaTokenizer {
496    fn tokenize(
497        &self,
498        tokens: Vec<(String, (usize, usize))>,
499    ) -> Result<Vec<HfToken>, Box<dyn Error + Send + Sync>> {
500        // we expect at least one token per word.
501        let mut result = Vec::with_capacity(tokens.len());
502        for (index, (word_str, offsets)) in tokens.into_iter().enumerate() {
503            let word = index as u32;
504            let word_index = result.len();
505            let word_bytes = word_str.as_bytes();
506            let word_len = word_bytes.len();
507            let mut last_index = 0;
508            if let Some((start_index, id)) = find_longest_prefix(&self.starters, word_bytes) {
509                let value = word_str[..start_index].to_string();
510                let mut last_offset = offsets.0 + value.chars().count();
511                result.push(HfToken {
512                    id: id as u32,
513                    value,
514                    offsets: (offsets.0, last_offset),
515                    word,
516                });
517                last_index = start_index;
518                while last_index < word_len {
519                    if let Some((len, id)) =
520                        find_longest_prefix(&self.followers, &word_bytes[last_index..])
521                    {
522                        let value = &word_str[last_index..last_index + len];
523                        let start = last_offset;
524                        last_offset += value.chars().count();
525                        result.push(HfToken {
526                            id: id as u32,
527                            value: "##".to_string() + value,
528                            offsets: (start, last_offset),
529                            word,
530                        });
531                        last_index += len;
532                    } else {
533                        break;
534                    }
535                }
536            }
537            // in case we couldn't match the whole word, replace all we have so far with an [UNK] token
538            if last_index < word_len {
539                assert!(word_index <= result.len());
540                result.truncate(word_index);
541                result.push(HfToken {
542                    id: self.unk_id,
543                    value: "[UNK]".to_string(),
544                    offsets: (offsets.0, offsets.1),
545                    word,
546                });
547            }
548        }
549        Ok(result)
550    }
551
552    fn token_to_id(&self, token: &str) -> Option<u32> {
553        if token.starts_with("##") {
554            self.followers.get(&token[2..])
555        } else {
556            self.starters.get(token)
557        }
558        .map(|x| x.value() as u32)
559    }
560
561    fn id_to_token(&self, id: u32) -> Option<String> {
562        self.tokens.get(id as usize).cloned()
563    }
564
565    fn get_vocab_size(&self) -> usize {
566        self.tokens.len()
567    }
568
569    /// We won't implement this method because we don't store the tokens in
570    /// a `HashMap`, and doing so would increase our memory footprint
571    /// considerably.
572    fn get_vocab(&self) -> &std::collections::HashMap<String, u32> {
573        unimplemented!()
574    }
575
576    fn save(
577        &self,
578        folder: &Path,
579        name: Option<&str>,
580    ) -> Result<Vec<PathBuf>, Box<dyn Error + Send + Sync>> {
581        let vocab_name = name.map_or(Cow::Borrowed("vocab.txt"), |n| {
582            Cow::Borrowed(n) + "-vocab.txt"
583        });
584        let mut vocab_path = folder.to_path_buf();
585        vocab_path.push(&Path::new(vocab_name.as_ref()));
586        self.save_vocab(vocab_path).map(|p| vec![p])
587    }
588}