toktkn 0.1.0

a minimal byte-pair encoding tokenizer implementation
Documentation
use rayon::prelude::*;
use rustc_hash::{FxBuildHasher, FxHashMap};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use std::sync::RwLock;
use std::str;

use crate::config::TokenizerConfig;
use crate::util::{inject_special_tokens, ngram_replace, replace_special_tokens};


pub type Token = u32; // 2^32 - 1 max new tokens

// map aliases
pub type FwdMap = FxHashMap<(Token, Token), Token>;
pub type BkwdMap = FxHashMap<Token, (Token, Token)>;
pub type VocabMap = FxHashMap<String, Token>;

pub trait Tokenizer {
    fn encode(&self, text: &str) -> Vec<Token>;
    fn decode(&self, input_ids: &[Token]) -> String;
}

#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct BPETokenizer {
    #[serde_as(as = "Vec<((DisplayFromStr, DisplayFromStr), DisplayFromStr)>")]
    pub encoder: FwdMap,
    #[serde(skip)]
    pub decoder: RwLock<Option<BkwdMap>>, // thread-safe & nullable
    pub config: TokenizerConfig,
}

impl Tokenizer for BPETokenizer {
    fn encode(&self, text: &str) -> Vec<Token> {
        // parallel
        text.as_bytes()
            .par_chunks(1024)
            .flat_map_iter(|c| {
                let chunk_as_tokens: Vec<Token> = c.iter().map(|&i| i as Token).collect();
                self._encode_chunk(&chunk_as_tokens)
            })
            .collect()
    }

    fn decode(&self, input_ids: &[Token]) -> String {
        // first pass
        let raw_tokens: Vec<Token> = input_ids
            .par_chunks(1024)
            .flat_map_iter(|t| self._decode_chunk(t))
            .collect();

        let bytes: Vec<u8> = raw_tokens.iter().map(|&t| t as u8).collect();

        let s = str::from_utf8(&bytes)
            .unwrap_or_else(|_| panic!("failed to decode into valid utf-8: {:?}", bytes));
        s.into()
    }
}

impl BPETokenizer {
    pub fn new(config: TokenizerConfig) -> Self {
        Self {
            encoder: FwdMap::default(),
            decoder: RwLock::new(None),
            config,
        }
    }

    pub fn len(&self) -> usize {
        match self.config.special_tokens_map.as_ref() {
            Some(map) => map.len() + self.encoder.len(),
            None => self.encoder.len(),
        }
    }

    fn _sync_decoder(&self) {
        let mut inner = self.decoder.write().unwrap();
        inner.replace(
            self.encoder
                .iter()
                .map(|(&k, &v)| (v, k))
                .collect::<BkwdMap>(),
        );
    }

    pub fn add_special_tokens<S: Into<String>>(&mut self, tokens: Vec<S>) {
        let token_id = self.len() + 128;
        let token_map: VocabMap = tokens
            .into_iter()
            .enumerate()
            .map(|(e, s)| (s.into(), (token_id + e) as Token))
            .collect();

        self.config.special_tokens_map =
            self.config
                .special_tokens_map
                .take()
                .map_or(Some(token_map.clone()), |mut m| {
                    m.extend(token_map);
                    Some(m)
                });
    }

    pub fn preprocess(&self, text: &mut String) {
        let preproc = self.config.preproc.into_strategy();
        preproc.normalize(text);
    }

    fn _encode_chunk(&self, chunk: &[Token]) -> Vec<Token> {
        let mut tokens = chunk.to_vec().clone();

        if let Some(map) = self.config.special_tokens_map.as_ref() {
            replace_special_tokens::<Token, FxBuildHasher>(&mut tokens, map);
        }

        loop {
            let mut merges = Vec::new();

            for i in 0..tokens.len() - 1 {
                if let Some(&new_token) = self.encoder.get(&(tokens[i], tokens[i + 1])) {
                    merges.push((i, new_token));
                }
            }
            // early stopping: no more token pairs in merge rules
            if merges.is_empty() {
                break;
            }

            // apply merges and swap in tokens in reverse
            let mut i = merges.len() - 1;

            while i > 0 {
                let x = &mut merges[i - 1..=i];
                let mut l = x[0];
                let r = x[1];

                if r.0 - l.0 > 1 && r.1 != Token::MAX {
                    tokens[r.0] = r.1;
                    tokens.remove(r.0 + 1);
                } else if r.1 < l.1 {
                    tokens[r.0] = r.1;
                    tokens.remove(r.0 + 1);

                    l.1 = Token::MAX;
                    i -= 1;
                }

                //avoid overflow on usize 0-1
                if i == 0 {
                    break;
                }
                i -= 1;
            }

            // edge case
            if merges.len() == 1 || merges[0].1 < merges[1].1 {
                tokens[merges[0].0] = merges[0].1;
                tokens.remove(merges[0].0 + 1);
            }
        }
        tokens
    }

    fn _decode_chunk(&self, tokens: &[Token]) -> Vec<Token> {
        let mut tokens: Vec<Token> = Vec::from(tokens);

        // lazy init
        self._sync_decoder();
        let lock = self.decoder.read().expect("could not acquire lock");

        let decoder = lock.as_ref().unwrap();

        loop {
            let mut demerges = Vec::new();
            for i in 0..tokens.len() {
                let rank = tokens[i];
                if let Some(&tup) = decoder.get(&rank) {
                    demerges.push((i, tup));
                }
            }
            if demerges.is_empty() {
                break;
            }

            for op in demerges.iter().rev() {
                let i = op.0;
                let tup = op.1;
                tokens[i] = tup.0;
                tokens.insert(i + 1, tup.1);
            }
        }

        // special tokens
        if let Some(map) = self.config.special_tokens_map.as_ref() {
            inject_special_tokens::<Token, FxBuildHasher>(&mut tokens, map);
        }

        tokens
    }

    pub fn train(&mut self, text: &str) -> Vec<Token> {
        let mut pieces: Vec<Token>;

        if !self.encoder.is_empty() {
            println!("pretrained tokenizer detected!");
            pieces = self.encode(text);
        } else {
            let text = text.as_bytes();
            pieces = text.iter().map(|&i| i as Token).collect();
        }

        match self.config.vocab_size.checked_sub(self.len()) {
            Some(size) => {
                for _ in tqdm::tqdm(0..size) {
                    let mut counts = FwdMap::default();
                    for i in 0..pieces.len() - 1 {
                        *counts.entry((pieces[i], pieces[i + 1])).or_insert(0) += 1;
                    }

                    let (&p, _) = counts.iter().max_by_key(|&(_, &c)| c).unwrap();
                    let token_id = (self.len() + 127 + 1) as Token;

                    self.encoder.insert(p, token_id);
                    ngram_replace(&mut pieces, &[p.0, p.1], &[token_id]);
                }
            }
            None => println!(
                "requested vocab_size: {} already reached.",
                self.config.vocab_size
            ),
        };

        self._sync_decoder();
        pieces
    }
}