use super::AprV2Model;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SimpleTokenizer {
pub id_to_token: Vec<String>,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl SimpleTokenizer {
#[must_use]
pub fn new(vocab: Vec<String>, bos_id: Option<u32>, eos_id: Option<u32>) -> Self {
Self {
id_to_token: vocab,
bos_token_id: bos_id,
eos_token_id: eos_id,
}
}
#[must_use]
pub fn decode(&self, token_ids: &[u32]) -> String {
AprV2Model::decode_tokens(&self.id_to_token, token_ids)
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.id_to_token.len()
}
#[must_use]
pub fn is_eos(&self, token_id: u32) -> bool {
self.eos_token_id.is_some_and(|eos| token_id == eos)
}
#[must_use]
pub fn is_bos(&self, token_id: u32) -> bool {
self.bos_token_id.is_some_and(|bos| token_id == bos)
}
}
#[derive(Debug, Clone)]
pub struct BpeTokenizer {
pub token_to_id: HashMap<String, u32>,
pub id_to_token: Vec<String>,
pub merge_rules: Vec<(String, String)>,
pub bos_id: Option<u32>,
pub eos_id: Option<u32>,
pub special_tokens: HashMap<String, u32>,
}
impl BpeTokenizer {
pub fn encode(&self, text: &str) -> Vec<u32> {
bpe_encode(
text,
&self.token_to_id,
&self.merge_rules,
&self.special_tokens,
)
}
pub fn decode(&self, token_ids: &[u32]) -> String {
AprV2Model::decode_tokens(&self.id_to_token, token_ids)
}
}
pub(crate) fn bpe_encode(
text: &str,
vocab: &HashMap<String, u32>,
merges: &[(String, String)],
special_tokens: &HashMap<String, u32>,
) -> Vec<u32> {
let segments = split_by_special_tokens(text, special_tokens);
let mut result = Vec::new();
for segment in segments {
match segment {
TextSegment::Special(id) => {
result.push(id);
},
TextSegment::Regular(s) => {
result.extend(bpe_encode_segment(&s, vocab, merges));
},
}
}
result
}
enum TextSegment {
Special(u32),
Regular(String),
}
fn try_match_special_at_start<'a>(
remaining: &str,
sorted_tokens: &[(&'a String, &'a u32)],
) -> Option<(u32, usize)> {
for (token_str, &token_id) in sorted_tokens {
if remaining.starts_with(token_str.as_str()) {
return Some((token_id, token_str.len()));
}
}
None
}
fn find_earliest_special_pos(remaining: &str, sorted_tokens: &[(&String, &u32)]) -> usize {
let mut earliest = remaining.len();
for (token_str, _) in sorted_tokens {
if let Some(pos) = remaining.find(token_str.as_str()) {
earliest = earliest.min(pos);
}
}
earliest
}
fn split_by_special_tokens(text: &str, special_tokens: &HashMap<String, u32>) -> Vec<TextSegment> {
if special_tokens.is_empty() {
return vec![TextSegment::Regular(text.to_string())];
}
let mut sorted_tokens: Vec<(&String, &u32)> = special_tokens.iter().collect();
sorted_tokens.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
let mut segments = Vec::new();
let mut remaining = text;
while !remaining.is_empty() {
if let Some((token_id, consumed)) = try_match_special_at_start(remaining, &sorted_tokens) {
segments.push(TextSegment::Special(token_id));
remaining = &remaining[consumed..];
} else {
let next_pos = find_earliest_special_pos(remaining, &sorted_tokens);
if next_pos > 0 {
segments.push(TextSegment::Regular(remaining[..next_pos].to_string()));
remaining = &remaining[next_pos..];
}
}
}
segments
}
fn char_to_bpe_token(c: char) -> String {
match c {
' ' => "Ġ".to_string(),
'\n' => "Ċ".to_string(),
'\t' => "ĉ".to_string(),
c if c.is_ascii() => c.to_string(),
c => {
let mut buf = [0u8; 4];
let s = c.encode_utf8(&mut buf);
s.chars()
.map(|byte_char| byte_to_bpe_char(byte_char as u8))
.collect()
},
}
}
fn apply_bpe_merge(tokens: &mut Vec<String>, first: &str, second: &str, merged: &str) -> bool {
let mut found = false;
let mut i = 0;
while i + 1 < tokens.len() {
if tokens[i] == first && tokens[i + 1] == second {
tokens[i] = merged.to_string();
tokens.remove(i + 1);
found = true;
}
i += 1;
}
found
}
fn bpe_encode_segment(
text: &str,
vocab: &HashMap<String, u32>,
merges: &[(String, String)],
) -> Vec<u32> {
let mut tokens: Vec<String> = text.chars().map(char_to_bpe_token).collect();
for (first, second) in merges {
let merged = format!("{}{}", first, second);
while apply_bpe_merge(&mut tokens, first, second, &merged) {}
}
tokens
.iter()
.filter_map(|t| vocab.get(t).copied())
.collect()
}
pub fn byte_to_bpe_char(b: u8) -> String {
match b {
b' ' => "Ġ".to_string(),
b'\n' => "Ċ".to_string(),
b'\t' => "ĉ".to_string(),
_ if b.is_ascii_graphic() || b.is_ascii_alphanumeric() => (b as char).to_string(),
_ => format!("<0x{:02X}>", b),
}
}
include!("tokenizer_tests.rs");