pub mod hf;
pub mod tiktoken;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::{ops::Deref, path::Path};
use crate::protocols::TokenIdType;
pub use anyhow::{Error, Result};
pub use hf::HuggingFaceTokenizer;
pub use tiktoken::TikTokenTokenizer;
#[derive(Debug)]
pub enum TokenizerType {
HuggingFace(String),
TikToken(String),
}
pub type Offsets = (usize, usize);
#[derive(Debug, Clone)]
pub enum Encoding {
Hf(Box<tokenizers::tokenizer::Encoding>),
Sp(Vec<TokenIdType>),
}
impl Encoding {
pub fn token_ids(&self) -> &[u32] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
}
}
}
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
self.token_ids().hash(state);
}
}
pub mod traits {
use super::*;
pub trait Encoder: Send + Sync {
fn encode(&self, input: &str) -> Result<Encoding>;
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
}
pub trait Tokenizer: Encoder + Decoder {
}
}
impl Encoding {
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
impl Tokenizer {
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
}
pub fn decode_stream(
&self,
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> DecodeStream {
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
}
}
impl Deref for Tokenizer {
type Target = Arc<dyn traits::Tokenizer>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
}
impl<T> From<Arc<T>> for Tokenizer
where
T: traits::Tokenizer + 'static, {
fn from(tokenizer: Arc<T>) -> Self {
Tokenizer(tokenizer)
}
}
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let path = Path::new(file_path);
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
match extension {
"json" => {
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
Ok(Arc::new(tokenizer))
}
"model" | "tiktoken" => {
let tokenizer = TikTokenTokenizer::from_file_auto(file_path)?;
Ok(Arc::new(tokenizer))
}
_ => Err(Error::msg(format!(
"Unsupported tokenizer file type: .{extension}"
))),
}
}
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
pub struct DecodeStream {
tokenizer: Arc<dyn traits::Tokenizer>,
skip_special_tokens: bool,
all_token_ids: Vec<u32>,
prefix_offset: usize,
read_offset: usize,
}
impl DecodeStream {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> Self {
let num_input_tokens = prompt_token_ids.len();
let prompt_token_ids = prompt_token_ids.to_vec();
Self {
tokenizer,
skip_special_tokens,
all_token_ids: prompt_token_ids,
prefix_offset: num_input_tokens
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
read_offset: num_input_tokens,
}
}
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.all_token_ids.push(id);
let prefix_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?;
let new_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..],
self.skip_special_tokens,
)?;
if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
let new_text = new_text[prefix_text.len()..].to_string();
self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len();
Ok(Some(new_text))
} else {
Ok(None)
}
}
}
pub struct Sequence {
tokenizer: Tokenizer,
token_ids: Vec<TokenIdType>,
prefix_offset: usize,
read_offset: usize,
}
impl std::fmt::Debug for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequence")
.field("tokenizer", &"Arc<dyn Tokenizer>")
.field(
"token_ids",
&format_args!("{}", {
let token_ids = self.token_ids();
if token_ids.len() <= 20 {
format!("{:?}", token_ids)
} else {
let first_ten = &token_ids[..10];
let last_ten = &token_ids[token_ids.len() - 10..];
format!("{:?} ... {:?}", first_ten, last_ten)
}
}),
)
.field("prefix_offset", &self.prefix_offset)
.field("read_offset", &self.read_offset)
.field("token count", &self.token_ids.len())
.finish()
}
}
impl Sequence {
pub fn new(tokenizer: Tokenizer) -> Self {
Self {
tokenizer,
token_ids: Vec::new(),
prefix_offset: 0,
read_offset: 0,
}
}
pub fn is_empty(&self) -> bool {
self.token_ids.is_empty()
}
pub fn len(&self) -> usize {
self.token_ids.len()
}
pub fn clear(&mut self) {
self.token_ids.clear();
self.prefix_offset = 0;
self.read_offset = 0;
}
pub fn append_text(&mut self, input: &str) -> Result<()> {
let encoding = self.tokenizer.encode(input)?;
self.token_ids.extend(encoding.token_ids());
Ok(())
}
pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
self.token_ids.push(token_id);
let prefix_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
let new_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..], false)?;
let mut prefix_text_len = prefix_text.len();
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
prefix_text_len -= 1;
}
let prefix_text_len = prefix_text_len;
if new_text.len() > prefix_text.len() {
if new_text.ends_with("�") {
return Ok("".to_string());
} else {
let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
self.prefix_offset = self.read_offset;
self.read_offset = self.token_ids.len();
return Ok(new_text);
}
}
Ok("".to_string())
}
pub fn tokenizer(&self) -> Tokenizer {
self.tokenizer.clone()
}
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
pub fn text(&self) -> Result<String> {
self.tokenizer.decode(&self.token_ids, false)
}
}
pub enum SequenceDecoderOutput {
Text(String),
Held,
Stopped,
StoppedWithText(String),
}
#[derive(Debug)]
pub struct StopSequenceDecoder {
sequence: Sequence,
stop_token_ids_visible: Vec<TokenIdType>,
stop_token_ids_hidden: Vec<TokenIdType>,
#[allow(dead_code)]
stop_sequences_visible: Vec<String>,
stop_sequences_hidden: Vec<String>,
stopped: bool,
state: String,
}
impl StopSequenceDecoder {
pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
StopSequenceDecoderBuilder::new(tokenizer)
}
pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Err(Error::msg("Decoder is stopped"));
}
let text = self.sequence.append_token_id(token_id)?;
self.state.push_str(text.as_str());
let mut stop: bool = false;
let mut visible: bool = false;
if self.stop_token_ids_visible.contains(&token_id) {
stop = true;
visible = true;
}
if self.stop_token_ids_hidden.contains(&token_id) {
stop = true;
visible = false;
}
if stop {
self.stopped = true;
let state = std::mem::take(&mut self.state);
if visible {
return Ok(SequenceDecoderOutput::StoppedWithText(state));
}
return Ok(SequenceDecoderOutput::Stopped);
}
for stop_sequence in self.stop_sequences_hidden.iter() {
if stop_sequence.starts_with(&self.state) {
if stop_sequence == &self.state {
self.stopped = true;
return Ok(SequenceDecoderOutput::Stopped);
} else {
return Ok(SequenceDecoderOutput::Held);
}
}
}
let state = std::mem::take(&mut self.state);
Ok(SequenceDecoderOutput::Text(state))
}
pub fn is_empty(&self) -> bool {
self.sequence.token_ids.is_empty()
}
pub fn len(&self) -> usize {
self.sequence.token_ids.len()
}
pub fn is_complete(&self) -> bool {
self.stopped
}
pub fn close(&mut self) {
self.stopped = true;
}
}
pub struct StopSequenceDecoderBuilder {
tokenizer: Tokenizer,
stop_token_ids_visible: Vec<TokenIdType>,
stop_token_ids_hidden: Vec<TokenIdType>,
stop_sequences_visible: Vec<String>,
stop_sequences_hidden: Vec<String>,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Tokenizer) -> Self {
Self {
tokenizer,
stop_token_ids_visible: Vec::new(),
stop_token_ids_hidden: Vec::new(),
stop_sequences_visible: Vec::new(),
stop_sequences_hidden: Vec::new(),
}
}
pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
self.stop_token_ids_visible.push(token_id);
self
}
pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
self.stop_token_ids_visible.extend(token_ids);
self
}
pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
self.stop_token_ids_hidden.push(token_id);
self
}
pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
self.stop_token_ids_hidden.extend(token_ids);
self
}
pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
self.stop_sequences_visible.push(text.to_string());
self
}
pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
self.stop_sequences_visible
.extend(strings.iter().map(|text| text.to_string()));
self
}
pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
self.stop_sequences_hidden.push(text.to_string());
self
}
pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
self.stop_sequences_hidden
.extend(strings.iter().map(|text| text.to_string()));
self
}
pub fn build(self) -> Result<StopSequenceDecoder> {
Ok(StopSequenceDecoder {
sequence: Sequence::new(self.tokenizer.clone()),
stop_token_ids_visible: self.stop_token_ids_visible,
stop_token_ids_hidden: self.stop_token_ids_hidden,
stop_sequences_visible: self.stop_sequences_visible,
stop_sequences_hidden: self.stop_sequences_hidden,
stopped: false,
state: String::new(),
})
}
}