bytepiece 0.1.0

Rust version of bytepiece tokenizer
Documentation
use super::utils::normalize;
use crate::Result;
use aho_corasick::{AhoCorasick, MatchKind};
use ouroboros::self_referencing;
use std::collections::HashMap;
use std::path::Path;

pub type Pieces = HashMap<Vec<u8>, (usize, String, usize)>;

pub fn parse_pieces_from_slice(buf: &[u8]) -> Result<Pieces> {
    let dict: HashMap<&str, (usize, String, usize)> = serde_json::from_slice(buf)?;
    let base64 = base64_simd::STANDARD;
    dict.into_iter()
        .map(|(key, value)| {
            let new_key = base64.decode_to_vec(key)?;
            Ok((new_key, value))
        })
        .collect::<Result<HashMap<_, _>>>()
}

pub trait Tokenize {
    fn tokenize<'s>(&self, text: &'s str, alpha: f64) -> Vec<&'s str>;
    fn piece_to_id(&self, p: &str) -> usize;
    fn id_to_piece(&self, i: usize) -> &str;
    fn vocab_size(&self) -> usize;

    fn pieces_to_ids(&self, pieces: &[&str]) -> Vec<usize> {
        pieces.iter().map(|p| self.piece_to_id(p)).collect()
    }

    fn ids_to_pieces(&self, ids: &[usize]) -> Vec<&str> {
        ids.iter().map(|i| self.id_to_piece(*i)).collect()
    }

    fn encode(&self, text: &str, add_bos: bool, add_eos: bool, alpha: f64) -> Vec<usize> {
        let mut pieces = if add_bos {
            let mut pieces = vec![1];
            for p in self.tokenize(text, alpha) {
                pieces.push(self.piece_to_id(p));
            }
            pieces
        } else {
            self.tokenize(text, alpha)
                .into_iter()
                .map(|p| self.piece_to_id(p))
                .collect()
        };
        if add_eos {
            pieces.push(2);
        }
        pieces
    }

    fn decode(&self, ids: &[usize]) -> String {
        let pieces: Vec<&str> = ids
            .iter()
            .filter(|i| **i > 2)
            .map(|i| self.id_to_piece(*i))
            .collect();
        pieces.join("")
    }
}

pub struct Tokenizer<'a> {
    piece_to_id: HashMap<&'a [u8], usize>,
    id_to_piece: HashMap<usize, &'a [u8]>,
    vocab_size: usize,
    values: Vec<f64>,
    ac: AhoCorasick,
}

impl<'a> Tokenize for Tokenizer<'a> {
    fn id_to_piece(&self, i: usize) -> &str {
        std::str::from_utf8(self.id_to_piece[&i]).unwrap()
    }

    fn piece_to_id(&self, p: &str) -> usize {
        self.piece_to_id[p.as_bytes()]
    }

    fn vocab_size(&self) -> usize {
        self.vocab_size
    }

    fn tokenize<'s>(&self, text: &'s str, alpha: f64) -> Vec<&'s str> {
        normalize(text, 0)
            .into_iter()
            .flat_map(|s| self._tokenize(s, alpha))
            .collect()
    }
}

impl<'a> Tokenizer<'a> {
    pub fn from_pieces(pieces: &'a Pieces) -> Result<Self> {
        let piece_to_id: HashMap<&[u8], usize> = pieces
            .iter()
            .map(|(key, value)| (key.as_slice(), value.0))
            .collect();

        let id_to_piece: HashMap<usize, &[u8]> =
            piece_to_id.iter().map(|(k, v)| (*v, *k)).collect();

        let vocab_size = pieces.len() + 3;

        let total: usize = pieces.values().map(|vs| vs.2).sum();
        let log_total = (total as f64).log2();

        let values: Vec<f64> = pieces
            .values()
            .map(|vs| (vs.2 as f64).log2() - log_total)
            .collect();

        let ac = AhoCorasick::builder()
            .match_kind(MatchKind::Standard)
            .build(pieces.keys())?;

        Ok(Self {
            id_to_piece,
            piece_to_id,
            vocab_size,
            values,
            ac,
        })
    }

    fn _tokenize<'s>(&self, text: &'s str, alpha: f64) -> impl Iterator<Item = &'s str> {
        // Every ending's max score.
        let mut scores = vec![-std::f64::INFINITY; text.len() + 1];
        scores[0] = 0.0;
        // Every ending's start. 
        let mut routes: Vec<usize> = (0..text.len() + 1).collect();

        for mat in self.ac.find_overlapping_iter(text.as_bytes()) {
            let start = mat.start();
            let end = mat.end();

            let index = mat.pattern().as_usize();
            let value = self.values[index];

            let score = scores[start] + value;

            if (alpha <= 0.0 && score > scores[end])
                || (alpha > 0.0 && rand::random::<f64>() < sigmoid((score - scores[end]) * alpha))
            {
                scores[end] = score;
                routes[end] = start;
            }
        }

        let mut text_slice = text;
        let mut end = routes.len() - 1;
        let mut tokens = Vec::new();

        while !text_slice.is_empty() {
            let start = routes[end];
            // dbg!(text_slice, start, end);
            tokens.push(&text_slice[start..end]);
            text_slice = &text[..start];
            end = start;
        }
        tokens.into_iter().rev()
    }
}

#[inline]
fn sigmoid(x: f64) -> f64 {
    if x >= 0. {
        1. / (1. + (-x).exp())
    } else {
        1. - 1. / (1. + x.exp())
    }
}

#[self_referencing]
pub struct OwnedTokenizer {
    pieces: Pieces,
    #[borrows(pieces)]
    #[covariant]
    tokenizer: Tokenizer<'this>,
}

impl Tokenize for OwnedTokenizer {
    fn tokenize<'s>(&self, text: &'s str, alpha: f64) -> Vec<&'s str> {
        self.borrow_tokenizer().tokenize(text, alpha)
    }

    fn vocab_size(&self) -> usize {
        self.borrow_tokenizer().vocab_size()
    }

    fn id_to_piece(&self, i: usize) -> &str {
        self.borrow_tokenizer().id_to_piece(i)
    }

    fn piece_to_id(&self, p: &str) -> usize {
        self.borrow_tokenizer().piece_to_id(p)
    }
}

pub fn make_owned_tokenizer(pieces: Pieces) -> Result<OwnedTokenizer> {
    OwnedTokenizerTryBuilder {
        pieces,
        tokenizer_builder: |pieces: &Pieces| Tokenizer::from_pieces(pieces),
    }
    .try_build()
}

impl OwnedTokenizer {
    pub fn from_path(path: impl AsRef<Path>) -> Result<Self> {
        let buf = std::fs::read(path)?;
        let pieces = parse_pieces_from_slice(&buf)?;
        make_owned_tokenizer(pieces)
    }
}

#[cfg(test)]
mod tests {}