1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
use std::{collections::HashMap, str::FromStr};
use crate::InferenceError;
/// The identifier of a token in a vocabulary.
pub type TokenId = i32;
pub(crate) type Token = Vec<u8>;
pub(crate) type TokenScore = f32;
/// The vocabulary used by a model.
#[derive(Debug, Clone, Default)]
pub struct Vocabulary {
/// Maps every integer (index) token ID to its corresponding token.
pub id_to_token: Vec<Token>,
/// Maps every integer (index) token ID to corresponding score.
pub id_to_token_score: Vec<TokenScore>,
// todo: use a radix tree
/// Maps a token to a token ID.
pub token_to_id: HashMap<Token, TokenId>,
/// The longest token in this vocabulary.
pub max_token_length: usize,
}
impl Vocabulary {
/// Add a token to the vocabulary.
///
/// The token added must have `id` directly after the last token in the vocabulary.
///
/// # Panics
/// - This function can panic if `id` does not correspond to the next token in the vocabulary.
/// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`.
pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
// These are loader invariants. If this is broken, then the loader is broken and this is a bug,
// not an issue with the model itself.
assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize {
let expected_id = self.id_to_token.len() as TokenId;
panic!("the id of token added should be {expected_id}; is {id}");
}
self.max_token_length = self.max_token_length.max(content.len());
self.id_to_token.push(content.clone());
self.id_to_token_score.push(score);
self.token_to_id.insert(content, id);
}
pub(crate) fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
}
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
/// Tokenize a `text` with this vocabulary.
///
/// `bos` controls whether a beginning-of-string token should be inserted.
pub fn tokenize<'a>(
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
let len = text.len();
let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];
for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);
if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;
if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}
// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}
if bos {
// TODO: replace with vocab.bos
res.push((&[], 1));
}
// Pieces are in reverse order so correct that
res.reverse();
Ok(res)
}
}
#[derive(Default, Clone, Debug, PartialEq)]
/// A list of tokens to bias during the process of inferencing.
///
/// When a biased token is encountered, the bias will be used
/// instead of the inferred logit during the sampling process.
///
/// This can be used to disable the generation of responses
/// with specific tokens by setting their corresponding bias
/// to -1.0.
pub struct TokenBias(Vec<(TokenId, f32)>);
impl TokenBias {
/// Create a [TokenBias] from an existing `Vec`.
pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
v.sort_by_cached_key(|(tid, _)| *tid);
v.dedup_by_key(|(tid, _)| *tid);
Self(v)
}
/// Retrieves the bias for a given token, if available.
pub fn get(&self, tid: TokenId) -> Option<f32> {
self.0
.binary_search_by_key(&tid, |(tid, _)| *tid)
.map(|idx| self.0[idx].1)
.ok()
}
}
impl FromStr for TokenBias {
type Err = String;
/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
fn from_str(s: &str) -> Result<Self, Self::Err> {
let x = s
.split(',')
.map(|kv| {
let (k, v) = kv
.trim()
.split_once('=')
.ok_or_else(|| "Missing '=' in bias item".to_owned())?;
let tid: TokenId = k
.trim()
.parse()
.map_err(|e: std::num::ParseIntError| e.to_string())?;
let bias: f32 = v
.trim()
.parse()
.map_err(|e: std::num::ParseFloatError| e.to_string())?;
Result::<_, String>::Ok((tid, bias))
})
.collect::<Result<_, _>>()?;
Ok(TokenBias::new(x))
}
}
impl std::fmt::Display for TokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}