use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
use anyhow::Result;
use crate::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
pub type TokenIdType = u32;
pub trait Encoder: Send + Sync {
fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
}
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
fn decode_step(
&self,
token_id: TokenIdType,
ids: &mut Vec<TokenIdType>,
prefix: &mut String,
prefix_index: &mut usize,
skip_special_tokens: bool,
) -> Result<Option<String>> {
if prefix.is_empty() && !ids.is_empty() {
let new_prefix = self.decode(ids, skip_special_tokens)?;
if !new_prefix.ends_with('�') {
*prefix = new_prefix;
*prefix_index = ids.len();
}
}
ids.push(token_id);
let string = self.decode(ids, skip_special_tokens)?;
if string.len() > prefix.len() && !string.ends_with('�') {
let mut split_at = prefix.len();
while !string.is_char_boundary(split_at) && split_at > 0 {
split_at -= 1;
}
let new_text = string[split_at..].to_string();
let new_prefix_len = ids.len() - *prefix_index;
ids.drain(..*prefix_index);
*prefix_index = new_prefix_len;
*prefix = self.decode(ids, skip_special_tokens)?;
Ok(Some(new_text))
} else {
Ok(None)
}
}
}
pub trait Tokenizer: Encoder + Decoder {
fn vocab_size(&self) -> usize;
fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
fn as_any(&self) -> &dyn std::any::Any;
fn apply_chat_template(
&self,
_messages: &[serde_json::Value],
_params: ChatTemplateParams,
) -> Result<String> {
Err(anyhow::anyhow!(
"Chat template not supported by this tokenizer"
))
}
fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
ChatTemplateContentFormat::default()
}
fn set_chat_template(&mut self, _template: String) -> Result<()> {
Err(anyhow::anyhow!(
"set_chat_template is not supported by this tokenizer"
))
}
}
#[derive(Debug, Clone)]
pub enum Encoding {
Hf(Box<tokenizers::tokenizer::Encoding>),
Plain(Vec<TokenIdType>),
Tiktoken(Vec<TokenIdType>),
}
impl Encoding {
#[inline]
pub fn token_ids(&self) -> &[TokenIdType] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Plain(inner) => inner,
Encoding::Tiktoken(inner) => inner,
}
}
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Encoding::Hf(inner) => inner.get_ids().hash(state),
Encoding::Plain(inner) => inner.hash(state),
Encoding::Tiktoken(inner) => inner.hash(state),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub sep_token: Option<String>,
pub pad_token: Option<String>,
pub cls_token: Option<String>,
pub mask_token: Option<String>,
pub additional_special_tokens: Vec<String>,
}