use std::{collections::HashMap, str::FromStr};
use crate::InferenceError;
pub type TokenId = i32;
pub(crate) type Token = Vec<u8>;
pub(crate) type TokenScore = f32;
#[derive(Debug, Clone, Default)]
pub struct Vocabulary {
pub id_to_token: Vec<Token>,
pub id_to_token_score: Vec<TokenScore>,
pub token_to_id: HashMap<Token, TokenId>,
pub max_token_length: usize,
}
impl Vocabulary {
pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize {
let expected_id = self.id_to_token.len() as TokenId;
panic!("the id of token added should be {expected_id}; is {id}");
}
self.max_token_length = self.max_token_length.max(content.len());
self.id_to_token.push(content.clone());
self.id_to_token_score.push(score);
self.token_to_id.insert(content, id);
}
pub fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
}
pub fn tokenize<'a>(
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
let len = text.len();
let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];
for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);
if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;
if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}
if bos {
res.push((&[], 1));
}
res.reverse();
Ok(res)
}
}
#[derive(Default, Clone, Debug, PartialEq)]
pub struct TokenBias(Vec<(TokenId, f32)>);
impl TokenBias {
pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
v.sort_by_cached_key(|(tid, _)| *tid);
v.dedup_by_key(|(tid, _)| *tid);
Self(v)
}
pub fn get(&self, tid: TokenId) -> Option<f32> {
self.0
.binary_search_by_key(&tid, |(tid, _)| *tid)
.map(|idx| self.0[idx].1)
.ok()
}
}
impl FromStr for TokenBias {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let x = s
.split(',')
.map(|kv| {
let (k, v) = kv
.trim()
.split_once('=')
.ok_or_else(|| "Missing '=' in bias item".to_owned())?;
let tid: TokenId = k
.trim()
.parse()
.map_err(|e: std::num::ParseIntError| e.to_string())?;
let bias: f32 = v
.trim()
.parse()
.map_err(|e: std::num::ParseFloatError| e.to_string())?;
Result::<_, String>::Ok((tid, bias))
})
.collect::<Result<_, _>>()?;
Ok(TokenBias::new(x))
}
}
impl std::fmt::Display for TokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}