1use std::{collections::HashMap, str::FromStr};
2
3use crate::InferenceError;
4
5pub type TokenId = i32;
7pub(crate) type Token = Vec<u8>;
8pub(crate) type TokenScore = f32;
9
10#[derive(Debug, Clone, Default)]
12pub struct Vocabulary {
13 pub id_to_token: Vec<Token>,
15
16 pub id_to_token_score: Vec<TokenScore>,
18
19 pub token_to_id: HashMap<Token, TokenId>,
22
23 pub max_token_length: usize,
25}
26
27impl Vocabulary {
28 pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
36 assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
39 if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize {
40 let expected_id = self.id_to_token.len() as TokenId;
41 panic!("the id of token added should be {expected_id}; is {id}");
42 }
43
44 self.max_token_length = self.max_token_length.max(content.len());
45 self.id_to_token.push(content.clone());
46 self.id_to_token_score.push(score);
47 self.token_to_id.insert(content, id);
48 }
49
50 pub fn token(&self, idx: usize) -> &[u8] {
52 &self.id_to_token[idx]
53 }
54
55 pub fn tokenize<'a>(
60 &'a self,
61 text: &str,
62 bos: bool,
63 ) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
64 let len = text.len();
65
66 let mut score = vec![0usize; len + 1];
67 let mut prev = vec![TokenId::default(); len + 1];
68
69 for i in 0..len {
70 let max_len = (len - i).min(self.max_token_length);
71 for sub_len in 1..=max_len {
72 let sub = &text.as_bytes()[i..i + sub_len];
73 let token = self.token_to_id.get(sub);
74
75 if let Some(token) = token {
76 let token_score = sub.len() * sub.len();
77 let local_score = score[i] + token_score;
78 let next = i + sub_len;
79
80 if score[next] < local_score {
81 score[next] = local_score;
82 prev[next] = *token;
83 }
84 }
85 }
86 }
87
88 let mut res = vec![];
90 let mut i = len;
91 while i > 0 {
92 let token_id = prev[i];
93 if token_id == 0 {
94 return Err(InferenceError::TokenizationFailed);
95 }
96 let token = self.id_to_token[token_id as usize].as_slice();
97 res.push((token, token_id));
98 i -= token.len();
99 }
100
101 if bos {
102 res.push((&[], 1));
104 }
105
106 res.reverse();
108
109 Ok(res)
110 }
111}
112
113#[derive(Default, Clone, Debug, PartialEq)]
114pub struct TokenBias(Vec<(TokenId, f32)>);
123
124impl TokenBias {
125 pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
127 v.sort_by_cached_key(|(tid, _)| *tid);
128 v.dedup_by_key(|(tid, _)| *tid);
129 Self(v)
130 }
131
132 pub fn get(&self, tid: TokenId) -> Option<f32> {
134 self.0
135 .binary_search_by_key(&tid, |(tid, _)| *tid)
136 .map(|idx| self.0[idx].1)
137 .ok()
138 }
139}
140
141impl FromStr for TokenBias {
142 type Err = String;
143
144 fn from_str(s: &str) -> Result<Self, Self::Err> {
151 let x = s
152 .split(',')
153 .map(|kv| {
154 let (k, v) = kv
155 .trim()
156 .split_once('=')
157 .ok_or_else(|| "Missing '=' in bias item".to_owned())?;
158 let tid: TokenId = k
159 .trim()
160 .parse()
161 .map_err(|e: std::num::ParseIntError| e.to_string())?;
162 let bias: f32 = v
163 .trim()
164 .parse()
165 .map_err(|e: std::num::ParseFloatError| e.to_string())?;
166 Result::<_, String>::Ok((tid, bias))
167 })
168 .collect::<Result<_, _>>()?;
169 Ok(TokenBias::new(x))
170 }
171}
172
173impl std::fmt::Display for TokenBias {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 write!(f, "{:?}", self.0)
176 }
177}