pie 0.1.1

Programmable Inference Engine (PIE)
Documentation
use base64::{Engine as _, engine::general_purpose};
use fancy_regex::Regex;
use std::collections::{HashMap, HashSet};
use std::fs;
pub type Rank = u32;

// The code below is copied from the tiktoken.
// https://github.com/openai/tiktoken/blob/main/src/lib.rs

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
    let mut parts = Vec::with_capacity(piece.len() + 1);

    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
    for i in 0..piece.len() - 1 {
        let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
        if rank < min_rank.0 {
            min_rank = (rank, i);
        }
        parts.push((i, rank));
    }
    parts.push((piece.len() - 1, Rank::MAX));
    parts.push((piece.len(), Rank::MAX));

    let get_rank = {
        |parts: &Vec<(usize, Rank)>, i: usize| {
            if (i + 3) < parts.len() {
                *ranks
                    .get(&piece[parts[i].0..parts[i + 3].0])
                    .unwrap_or(&Rank::MAX)
            } else {
                Rank::MAX
            }
        }
    };

    while min_rank.0 != Rank::MAX {
        let i = min_rank.1;

        if i > 0 {
            parts[i - 1].1 = get_rank(&parts, i - 1);
        }
        parts[i].1 = get_rank(&parts, i);
        parts.remove(i + 1);

        min_rank = (Rank::MAX, usize::MAX);
        for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
            if rank < min_rank.0 {
                min_rank = (rank, i);
            }
        }
    }
    parts
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
    if piece.len() == 1 {
        return vec![ranks[piece]];
    }
    _byte_pair_merge(ranks, piece)
        .windows(2)
        .map(|part| ranks[&piece[part[0].0..part[1].0]])
        .collect()
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
    assert!(piece.len() > 1);
    _byte_pair_merge(ranks, piece)
        .windows(2)
        .map(|part| &piece[part[0].0..part[1].0])
        .collect()
}

#[derive(Debug, Clone)]
pub struct DecodeKeyError {
    pub token: Rank,
}

impl std::fmt::Display for DecodeKeyError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "Invalid token for decoding: {}", self.token)
    }
}

impl std::error::Error for DecodeKeyError {}

#[derive(Debug, Clone)]
pub struct DecodeError {
    pub message: String,
}

impl std::fmt::Display for DecodeError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "Could not decode tokens: {}", self.message)
    }
}

impl std::error::Error for DecodeError {}

#[derive(Clone, Debug)]
pub struct BytePairEncoder {
    encoder: HashMap<Vec<u8>, Rank>,
    special_tokens_encoder: HashMap<String, Rank>,
    decoder: HashMap<Rank, Vec<u8>>,
    special_tokens_decoder: HashMap<Rank, Vec<u8>>,
    regex: Regex,
    special_regex: Regex,
}

impl BytePairEncoder {
    pub fn decode(&self, tokens: &[Rank]) -> Result<String, DecodeError> {
        // First, decode raw bytes from the tokens.
        let decoded_bytes = self.decode_bytes(tokens).map_err(|err| DecodeError {
            message: err.to_string(),
        })?;

        // Then, convert the bytes to a UTF-8 string.
        // Using `from_utf8_lossy` would silently replace invalid sequences with
        // the Unicode replacement character;
        let decoded_string = String::from_utf8_lossy(&*decoded_bytes).to_string();

        Ok(decoded_string)
    }
    fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
        let mut ret = Vec::with_capacity(tokens.len() * 2);
        for &token in tokens {
            let token_bytes = match self.decoder.get(&token) {
                Some(bytes) => bytes,
                None => self
                    .special_tokens_decoder
                    .get(&token)
                    .ok_or(DecodeKeyError { token })?,
            };
            ret.extend(token_bytes);
        }
        Ok(ret)
    }

    pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
        let mut ret = vec![];

        let mut start = 0;
        let mut last_piece_token_len = 0;
        loop {
            let mut next_special;
            let mut start_find = start;
            loop {
                next_special = self.special_regex.find_from_pos(text, start_find).unwrap();
                match next_special {
                    Some(m) => {
                        if allowed_special.contains(&text[m.start()..m.end()]) {
                            break;
                        }
                        start_find = m.start() + 1;
                    }
                    None => break,
                }
            }
            let end = next_special.map_or(text.len(), |m| m.start());

            for mat in self.regex.find_iter(&text[start..end]) {
                let piece = mat.unwrap().as_str().as_bytes();
                if let Some(token) = self.encoder.get(piece) {
                    last_piece_token_len = 1;
                    ret.push(*token);
                    continue;
                }
                let tokens = byte_pair_encode(piece, &self.encoder);
                last_piece_token_len = tokens.len();
                ret.extend(&tokens);
            }

            match next_special {
                // And here we push the special token
                Some(m) => {
                    let piece = m.as_str();
                    let token = self.special_tokens_encoder[piece];
                    ret.push(token);
                    start = m.end();
                    last_piece_token_len = 0;
                }
                None => break,
            }
        }

        ret
    }

    pub(crate) fn new(
        decoder: HashMap<Rank, Vec<u8>>,
        special_tokens_encoder: HashMap<String, Rank>,
        pattern: &str,
    ) -> Self {
        let regex = Regex::new(pattern).unwrap();

        let special_regex = {
            let parts = special_tokens_encoder
                .keys()
                .map(|s| fancy_regex::escape(s))
                .collect::<Vec<_>>();
            Regex::new(&parts.join("|")).unwrap()
        };

        let encoder: HashMap<Vec<u8>, Rank> =
            decoder.iter().map(|(k, v)| (v.clone(), *k)).collect();

        assert_eq!(
            encoder.len(),
            decoder.len(),
            "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
        );

        let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
            .iter()
            .map(|(k, v)| (*v, k.as_bytes().to_vec()))
            .collect();

        // Clone because I don't know how to tell Rust I'm not going to change the map

        Self {
            encoder,
            special_tokens_encoder,
            decoder,
            special_tokens_decoder,
            regex,
            special_regex,
        }
    }

    pub fn special_tokens(&self) -> HashSet<&str> {
        self.special_tokens_encoder
            .keys()
            .map(|s| s.as_str())
            .collect()
    }

    pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
        let allowed_special = self.special_tokens();
        self.encode(text, &allowed_special)
    }

    pub fn get_vocabs(&self) -> (Vec<Rank>, Vec<Vec<u8>>) {
        // return decoder ranks and bytes
        (
            self.decoder.keys().cloned().collect(),
            self.decoder.values().cloned().collect(),
        )
    }
}

pub fn load_merge_rules(path: &str) -> Result<HashMap<Rank, Vec<u8>>, Box<dyn std::error::Error>> {
    // Read the entire file as a UTF-8 string
    let contents = fs::read_to_string(path)?;

    let mut ret = HashMap::new();

    for (line_number, line) in contents.lines().enumerate() {
        let line = line.trim();
        // Skip empty lines
        if line.is_empty() {
            continue;
        }

        // Expect two parts: base64-encoded token and rank
        let parts: Vec<&str> = line.split_whitespace().collect();
        if parts.len() != 2 {
            return Err(format!(
                "Error parsing line {}: expected two parts, got {} (line: {:?})",
                line_number,
                parts.len(),
                line
            )
            .into());
        }

        let b64_token = parts[0];
        let rank_str = parts[1];

        // Decode base64 token
        let decoded_token = general_purpose::STANDARD
            .decode(b64_token)
            .map_err(|e| format!("Error decoding base64 at line {}: {}", line_number, e))?;

        // Parse rank into i32
        let rank = rank_str
            .parse::<Rank>()
            .map_err(|e| format!("Error parsing rank at line {}: {}", line_number, e))?;

        // Insert into the HashMap
        ret.insert(rank, decoded_token);
    }

    Ok(ret)
}
//
// pub fn empty_tokenizer() -> BytePairEncoder {
//     // Create an empty encoder with no merge rules and no special tokens
//     BytePairEncoder::new(
//         HashMap::new(),
//         HashMap::new(),
//         r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
//     )
// }
//
// // https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
// pub fn llama3_tokenizer(path: &str) -> Result<BytePairEncoder, Box<dyn std::error::Error>> {
//     // Example usage
//     let mergeable_ranks = load_merge_rules(path)?;
//     let special_tokens = vec![
//         "<|begin_of_text|>",
//         "<|end_of_text|>",
//         "<|reserved_special_token_0|>",
//         "<|reserved_special_token_1|>",
//         "<|reserved_special_token_2|>",
//         "<|reserved_special_token_3|>",
//         "<|start_header_id|>",
//         "<|end_header_id|>",
//         "<|reserved_special_token_4|>",
//         "<|eot_id|>",
//     ];
//     let pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
//
//     let num_base_tokens = mergeable_ranks.len() as Rank;
//
//     let special_tokens_encoder: HashMap<String, Rank> = special_tokens
//         .into_iter()
//         .enumerate()
//         .map(|(i, s)| (s.to_string(), num_base_tokens + i as Rank))
//         .collect();
//
//     let encoder = BytePairEncoder::new(mergeable_ranks, special_tokens_encoder, pattern);
//     // [9906, 11, 856, 5679, 374, 19369]
//     // encode text
//     //let text = "Hello, my dog is cute";
//     //let tokens = encoder.encode_with_special_tokens(text);
//     //println!("Encoded tokens: {:?}", tokens);
//
//     Ok(encoder)
// }