#[allow(clippy::wildcard_imports)]
use super::*;
use crate::error::{AprenderError, Result};
use serde::Deserialize;
use std::collections::HashMap;
fn deserialize_merges<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, SeqAccess, Visitor};
use std::fmt;
struct MergesVisitor;
impl<'de> Visitor<'de> for MergesVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a list of merge rules (strings or 2-element arrays)")
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Vec<String>, A::Error>
where
A: SeqAccess<'de>,
{
let mut merges = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(element) = seq.next_element::<MergeEntry>()? {
match element {
MergeEntry::String(s) => merges.push(s),
MergeEntry::Array(pair) => {
if pair.len() == 2 {
merges.push(format!("{} {}", pair[0], pair[1]));
} else {
return Err(de::Error::custom(format!(
"merge array must have exactly 2 elements, got {}",
pair.len()
)));
}
}
}
}
Ok(merges)
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum MergeEntry {
String(String),
Array(Vec<String>),
}
deserializer.deserialize_seq(MergesVisitor)
}
impl Qwen2BpeTokenizer {
pub const IM_START_ID: u32 = crate::demo::SpecialTokens::qwen2().im_start_id;
pub const IM_END_ID: u32 = crate::demo::SpecialTokens::qwen2().im_end_id;
pub const ENDOFTEXT_ID: u32 = crate::demo::SpecialTokens::qwen2().bos_id;
#[must_use]
pub fn new() -> Self {
let config = BpeConfig::qwen2();
let mut base = BpeTokenizer::new(config);
base.add_special_token("<|endoftext|>", Self::ENDOFTEXT_ID);
base.add_special_token("<|im_start|>", Self::IM_START_ID);
base.add_special_token("<|im_end|>", Self::IM_END_ID);
for i in 0..=255u8 {
if let Some(&c) = base.byte_encoder.get(&i) {
let token = c.to_string();
let id = u32::from(i);
base.vocab.insert(token.clone(), id);
base.id_to_token.insert(id, token);
}
}
Self {
base,
im_start_id: Self::IM_START_ID,
im_end_id: Self::IM_END_ID,
endoftext_id: Self::ENDOFTEXT_ID,
}
}
#[must_use]
pub fn is_eos(&self, token_id: u32) -> bool {
token_id == self.im_end_id || token_id == self.endoftext_id
}
#[must_use]
pub fn is_bos(&self, token_id: u32) -> bool {
token_id == self.im_start_id
}
#[must_use]
pub fn vocab_size(&self) -> usize {
crate::demo::Qwen2Config::VOCAB_SIZE
}
#[must_use]
pub fn encode(&self, text: &str) -> Vec<u32> {
contract_pre_encode!();
contract_pre_roundtrip_encoding!();
contract_pre_tokenizer_consistency!();
let result = self.base.encode(text);
contract_post_tokenizer_consistency!(&result);
result
}
#[must_use]
pub fn decode(&self, ids: &[u32]) -> String {
self.base.decode(ids)
}
#[must_use]
pub fn format_chat(&self, role: &str, content: &str) -> String {
format!("<|im_start|>{role}\n{content}<|im_end|>\n")
}
#[must_use]
pub fn format_conversation(&self, messages: &[(&str, &str)]) -> String {
let mut result = String::new();
for (role, content) in messages {
result.push_str(&self.format_chat(role, content));
}
result.push_str("<|im_start|>assistant\n");
result
}
#[must_use]
pub fn im_start_id(&self) -> u32 {
self.im_start_id
}
#[must_use]
pub fn im_end_id(&self) -> u32 {
self.im_end_id
}
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let json =
std::fs::read_to_string(path.as_ref()).map_err(|e| AprenderError::FormatError {
message: format!("Failed to read tokenizer file: {e}"),
})?;
Self::from_json(&json)
}
pub fn from_json(json: &str) -> Result<Self> {
let base = load_from_json(json)?;
let im_start_id = base
.vocab
.get("<|im_start|>")
.copied()
.unwrap_or(Self::IM_START_ID);
let im_end_id = base
.vocab
.get("<|im_end|>")
.copied()
.unwrap_or(Self::IM_END_ID);
let endoftext_id = base
.vocab
.get("<|endoftext|>")
.copied()
.unwrap_or(Self::ENDOFTEXT_ID);
Ok(Self {
base,
im_start_id,
im_end_id,
endoftext_id,
})
}
}
impl Default for Qwen2BpeTokenizer {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn bytes_to_unicode() -> (HashMap<u8, char>, HashMap<char, u8>) {
let mut encoder = HashMap::new();
let mut decoder = HashMap::new();
let mut n = 0u32;
for b in 0..=255u8 {
let c = if (b'!'..=b'~').contains(&b)
|| (b'\xa1'..=b'\xac').contains(&b)
|| (b'\xae'..=b'\xff').contains(&b)
{
char::from(b)
} else {
let c = char::from_u32(256 + n).unwrap_or('?');
n += 1;
c
};
encoder.insert(b, c);
decoder.insert(c, b);
}
(encoder, decoder)
}
#[derive(Debug, Deserialize)]
struct HfTokenizerJson {
model: HfModel,
#[serde(default)]
added_tokens: Vec<HfAddedToken>,
}
#[derive(Debug, Deserialize)]
struct HfModel {
vocab: HashMap<String, u32>,
#[serde(deserialize_with = "deserialize_merges")]
merges: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct HfAddedToken {
id: u32,
content: String,
#[serde(default)]
special: bool,
}
pub fn load_from_json(json: &str) -> Result<BpeTokenizer> {
if json.is_empty() {
return Err(AprenderError::FormatError {
message: "Empty tokenizer JSON".to_string(),
});
}
let hf_tokenizer: HfTokenizerJson =
serde_json::from_str(json).map_err(|e| AprenderError::FormatError {
message: format!("Failed to parse tokenizer JSON: {e}"),
})?;
let vocab_size = hf_tokenizer.model.vocab.len();
let merge_count = hf_tokenizer.model.merges.len();
let config = config_from_vocab_size(vocab_size);
let mut tokenizer = BpeTokenizer::with_capacity(config, vocab_size, merge_count);
load_vocab_owned(&mut tokenizer, hf_tokenizer.model.vocab);
load_merges_fast(&mut tokenizer, hf_tokenizer.model.merges);
load_added_tokens(&mut tokenizer, &hf_tokenizer.added_tokens);
Ok(tokenizer)
}
fn config_from_vocab_size(vocab_size: usize) -> BpeConfig {
if vocab_size > 150_000 {
BpeConfig::qwen2()
} else if vocab_size > 50_000 {
BpeConfig::whisper()
} else if vocab_size > 40_000 {
BpeConfig::gpt2()
} else {
BpeConfig::llama()
}
}
fn load_vocab_into(tokenizer: &mut BpeTokenizer, vocab: &HashMap<String, u32>) {
for (token, id) in vocab {
tokenizer.vocab.insert(token.clone(), *id);
tokenizer.id_to_token.insert(*id, token.clone());
}
}
fn load_vocab_owned(tokenizer: &mut BpeTokenizer, vocab: HashMap<String, u32>) {
for (token, id) in vocab {
tokenizer.id_to_token.insert(id, token.clone());
tokenizer.vocab.insert(token, id); }
}
fn load_merges_fast(tokenizer: &mut BpeTokenizer, merges: Vec<String>) {
for merge_str in merges {
if let Some((first, second)) = merge_str.split_once(' ') {
tokenizer.add_merge_owned(first.to_string(), second.to_string());
}
}
}
fn load_added_tokens(tokenizer: &mut BpeTokenizer, added_tokens: &[HfAddedToken]) {
for added in added_tokens {
if added.special {
tokenizer.add_special_token(&added.content, added.id);
} else {
tokenizer.vocab.insert(added.content.clone(), added.id);
tokenizer
.id_to_token
.insert(added.id, added.content.clone());
}
}
}
pub fn load_from_files(vocab_json: &str, merges_txt: &str) -> Result<BpeTokenizer> {
if vocab_json.is_empty() {
return Err(AprenderError::FormatError {
message: "Empty vocabulary JSON".to_string(),
});
}
let vocab: HashMap<String, u32> =
serde_json::from_str(vocab_json).map_err(|e| AprenderError::FormatError {
message: format!("Failed to parse vocabulary JSON: {e}"),
})?;
let config = config_from_vocab_size(vocab.len());
let mut tokenizer = BpeTokenizer::new(config);
load_vocab_into(&mut tokenizer, &vocab);
load_merges_from_text(&mut tokenizer, merges_txt);
Ok(tokenizer)
}
fn load_merges_from_text(tokenizer: &mut BpeTokenizer, merges_txt: &str) {
for line in merges_txt.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
tokenizer.add_merge(parts[0], parts[1]);
}
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;