use ahash::AHashMap;
use fixedbitset_stack::FixedBitSet;
use jaggedarray::jagged_array::JaggedArray;
#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
use crate::utils;
use crate::utils::ByteSet;
const BYTES_NUM: usize = 257;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
#[repr(transparent)]
#[cfg_attr(feature = "wasm", wasm_bindgen(getter_with_clone))]
#[cfg_attr(feature = "python", pyclass)]
pub struct Token(pub Box<[u8]>);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FirstBytes([u32; BYTES_NUM]);
impl tinyvec::Array for FirstBytes {
type Item = u32;
const CAPACITY: usize = BYTES_NUM;
fn as_slice(&self) -> &[Self::Item] {
&self.0
}
fn as_slice_mut(&mut self) -> &mut [Self::Item] {
&mut self.0
}
fn default() -> Self {
Self([0; BYTES_NUM])
}
}
#[derive(Clone)]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[cfg_attr(feature = "python", pyclass)]
pub struct Vocabulary {
pub(crate) token_to_id: AHashMap<Token, u32>,
pub(crate) id_to_token: AHashMap<u32, Token>,
pub(crate) id_to_token_contiguous: JaggedArray<u8, Vec<u32>, 2>,
pub(crate) byte_to_token_ids: [FixedBitSet; 256],
pub(crate) id_to_token_string: AHashMap<u32, String>,
}
impl Debug for Vocabulary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Vocabulary")
.field("token_to_id", &self.token_to_id)
.field("id_to_token", &self.id_to_token)
.field("id_to_token_string", &self.id_to_token_string)
.finish()
}
}
#[derive(Debug, thiserror::Error)]
pub enum CreateVocabularyError {
#[error("The vocabulary size is {0}, while the maximum supported is {1}.")]
VocabularyTooLarge(usize, usize),
#[error("The token's length is {0}, while the maximum supported is {1}.")]
TokenTooLong(usize, usize),
}
impl Vocabulary {
pub fn new(
id_to_token: AHashMap<u32, Token>,
id_to_token_string: AHashMap<u32, String>,
) -> Result<Vocabulary, CreateVocabularyError> {
let mut token_to_id = AHashMap::with_capacity(id_to_token.len());
let mut conflicting_token_ids: Vec<(u32, u32)> = Vec::new();
for (&token_id, token) in id_to_token.iter() {
match token_to_id.entry(token.clone()) {
Entry::Occupied(entry) => {
conflicting_token_ids.push((token_id, *entry.get()));
}
Entry::Vacant(entry) => {
entry.insert(token_id);
}
}
}
if !conflicting_token_ids.is_empty() {
let conflicting_pairs: Vec<String> = conflicting_token_ids
.iter()
.map(|(new_id, existing_id)| format!("({}, {})", existing_id, new_id))
.collect();
log::warn!(
"Multiple token ids correspond to the same token. Matching \
tokens to token ids is only used for debugging purposes. The second \
token id in each pair will be ignored when matching tokens to \
ids: {}.",
conflicting_pairs.join(", ")
);
}
const VEC: Vec<usize> = Vec::new();
let mut byte_to_token_ids = [VEC; 256];
Self::check_vocabulary_utf8_support(&token_to_id);
let mut sorted_tokens = id_to_token.iter().collect::<Vec<_>>();
sorted_tokens.sort_by_key(|x| x.0);
let mut id_to_token_contiguous = JaggedArray::new();
let mut next_slot = 0;
for (&token_id, token) in sorted_tokens.into_iter() {
while next_slot <= token_id {
id_to_token_contiguous.new_row::<0>();
next_slot += 1;
}
id_to_token_contiguous.extend_last_row(token.0.iter().copied());
if let Some(first_byte) = token.0.first().copied() {
byte_to_token_ids[first_byte as usize].push(token_id as usize);
} else {
log::warn!("The token {} is empty.", token_id);
}
}
let byte_to_token_ids_iter = byte_to_token_ids
.into_iter()
.map(FixedBitSet::from_iter);
const SET: FixedBitSet = FixedBitSet::new();
let mut byte_to_token_ids = [SET; 256];
for (i, set) in byte_to_token_ids_iter.enumerate() {
byte_to_token_ids[i] = set;
}
Ok(Self {
token_to_id,
id_to_token,
id_to_token_contiguous,
id_to_token_string,
byte_to_token_ids,
})
}
fn check_vocabulary_utf8_support(token_to_id: &AHashMap<Token, u32>) {
let mut not_existing_bytes = ByteSet::with_capacity(256);
fn check_non_existing_byte_in_range(
token_to_id: &AHashMap<Token, u32>,
not_existing_bytes: &mut ByteSet,
start: u8,
end: u8,
) {
for byte in start..=end {
let mut found = false;
for token in token_to_id.keys() {
if token.0.contains(&byte) {
found = true;
break;
}
}
if !found {
not_existing_bytes.insert(byte as usize);
}
}
}
check_non_existing_byte_in_range(token_to_id, &mut not_existing_bytes, 0, 247);
if !not_existing_bytes.is_clear() {
log::warn!(
"\
The following bytes are not present in any token: {:?}. \
This likely indicates that the vocabulary loading code is wrong, the tokenizer is doing some creepy processing \
or the tokenizer is not UTF-8 compatible. \
Check the vocabulary loading code and the tokenizer code to fix any bug and/or consider \
processing the vocab like the tokenizer.",
utils::get_display_form_from_bitset_on_stack(¬_existing_bytes)
);
}
}
pub fn token(&self, token_id: u32) -> Option<&Token> {
self.id_to_token.get(&token_id)
}
pub fn token_string(&self, token_id: u32) -> Option<&str> {
self.id_to_token_string.get(&token_id).map(|x| x.as_str())
}
}
impl Vocabulary {
pub fn token_id(&self, token: &Token) -> Option<u32> {
self.token_to_id.get(token).copied()
}
pub fn vocab_size(&self) -> usize {
self.id_to_token
.keys()
.copied()
.max()
.map(|x| x + 1)
.unwrap_or(0) as usize
}
}