use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use tokenizers::{models::bpe::BPE, EncodeInput, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
#[derive(Serialize, Deserialize)]
pub struct CLIPStandardTokenizerWrapper {
#[serde(flatten)]
pub tokenizer: Tokenizer,
pub model_max_length: usize,
pub bos_token_id: u32,
pub eos_token_id: u32
}
pub struct CLIPStandardTokenizer {
pub tokenizer: Tokenizer,
model_max_length: usize,
bos_token_id: u32,
eos_token_id: u32
}
unsafe impl Send for CLIPStandardTokenizer {}
unsafe impl Sync for CLIPStandardTokenizer {}
impl CLIPStandardTokenizer {
pub fn new(path: impl Into<PathBuf>, reconfigure: bool, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
let path = path.into();
let bytes = std::fs::read(path)?;
Self::from_bytes(bytes, reconfigure, model_max_length, bos_token_id, eos_token_id)
}
pub fn from_bytes<B: AsRef<[u8]>>(bytes: B, reconfigure: bool, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
let mut tokenizer: Tokenizer = serde_json::from_slice(bytes.as_ref())?;
if reconfigure {
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(model_max_length),
pad_id: eos_token_id,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: model_max_length,
..Default::default()
}));
}
Ok(Self {
tokenizer,
model_max_length,
bos_token_id,
eos_token_id
})
}
#[allow(dead_code)]
pub fn model(&self) -> &BPE {
match self.tokenizer.get_model() {
tokenizers::ModelWrapper::BPE(ref bpe) => bpe,
_ => unreachable!()
}
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.model_max_length
}
#[allow(dead_code)]
pub fn eos(&self) -> u32 {
self.eos_token_id
}
#[allow(dead_code)]
pub fn bos(&self) -> u32 {
self.bos_token_id
}
pub fn encode<'s, 'e, E>(&self, enc: Vec<E>) -> anyhow::Result<Vec<Vec<u32>>>
where
E: Into<EncodeInput<'s>>
{
let enc_len = enc.len();
let encoded: Vec<Vec<u32>> = enc
.into_iter()
.map(|f| self.tokenizer.encode(f, true).map(|f| f.get_ids().to_vec()))
.scan((), |_, x| x.ok())
.collect();
assert_eq!(encoded.len(), enc_len);
Ok(encoded)
}
}