#[cfg(feature = "rwkv-tokenizer")]
mod rwkv;
use crate::error::{MIError, Result};
use crate::util::positioning::EncodingWithOffsets;
#[non_exhaustive]
pub enum MITokenizer {
HuggingFace(Box<tokenizers::Tokenizer>),
#[cfg(feature = "rwkv-tokenizer")]
Rwkv(rwkv::RwkvTokenizer),
}
impl MITokenizer {
pub fn from_hf_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
let tok = tokenizers::Tokenizer::from_file(path.as_ref()).map_err(|e| {
MIError::Tokenizer(format!(
"failed to load HF tokenizer from {}: {e}",
path.as_ref().display()
))
})?;
Ok(Self::HuggingFace(Box::new(tok)))
}
#[must_use]
pub fn from_hf(tokenizer: tokenizers::Tokenizer) -> Self {
Self::HuggingFace(Box::new(tokenizer))
}
#[cfg(feature = "rwkv-tokenizer")]
pub fn from_rwkv_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
let tok = rwkv::RwkvTokenizer::from_file(path.as_ref())?;
Ok(Self::Rwkv(tok))
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
match self {
Self::HuggingFace(tok) => {
let encoding = tok
.encode(text, true)
.map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
Ok(encoding.get_ids().to_vec())
}
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(tok) => tok.encode(text),
}
}
pub fn encode_raw(&self, text: &str) -> Result<Vec<u32>> {
match self {
Self::HuggingFace(tok) => {
let encoding = tok
.encode(text, false)
.map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
Ok(encoding.get_ids().to_vec())
}
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(tok) => tok.encode(text),
}
}
pub fn encode_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
self.encode_with_offsets_inner(text, true)
}
pub fn encode_raw_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
self.encode_with_offsets_inner(text, false)
}
fn encode_with_offsets_inner(
&self,
text: &str,
add_special_tokens: bool,
) -> Result<EncodingWithOffsets> {
match self {
Self::HuggingFace(tok) => {
let encoding = tok
.encode(text, add_special_tokens)
.map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
let ids = encoding.get_ids().to_vec();
let tokens: Vec<String> = encoding
.get_tokens()
.iter()
.map(ToString::to_string)
.collect();
let offsets = encoding.get_offsets().to_vec();
Ok(EncodingWithOffsets::new(ids, tokens, offsets))
}
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(_) => Err(MIError::Tokenizer(
"RWKV tokenizer does not support offset mapping".into(),
)),
}
}
pub fn decode(&self, ids: &[u32]) -> Result<String> {
match self {
Self::HuggingFace(tok) => tok
.decode(ids, false)
.map_err(|e| MIError::Tokenizer(format!("HF decode failed: {e}"))),
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(tok) => tok.decode(ids),
}
}
#[must_use]
pub fn vocab_size(&self) -> usize {
match self {
Self::HuggingFace(tok) => tok.get_vocab_size(true),
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(tok) => tok.vocab_size(),
}
}
pub fn find_token_id(&self, word: &str) -> Result<u32> {
let with_space = format!(" {word}");
let ids = self.encode(&with_space)?;
if ids.len() == 2 {
return ids
.get(1)
.copied()
.ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \" {word}\"")));
}
let bare_ids = self.encode(word)?;
if bare_ids.len() == 2 {
return bare_ids
.get(1)
.copied()
.ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \"{word}\"")));
}
ids.last().copied().ok_or_else(|| {
MIError::Tokenizer(format!("could not find single token ID for \"{word}\""))
})
}
pub fn decode_token(&self, token_id: u32) -> Result<String> {
self.decode(&[token_id])
}
}
impl std::fmt::Debug for MITokenizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HuggingFace(_) => f.debug_tuple("HuggingFace").field(&"...").finish(),
#[cfg(feature = "rwkv-tokenizer")]
Self::Rwkv(tok) => f.debug_tuple("Rwkv").field(tok).finish(),
}
}
}