alith_interface/requests/
logit_bias.rs1use alith_models::tokenizer::Tokenizer;
2
3use crate::requests::req_components::RequestConfigTrait;
4
5use std::{collections::HashMap, sync::Arc};
6
7#[derive(Clone, Default)]
8pub struct LogitBias {
9 pub base_logit_bias: Option<HashMap<u32, f32>>,
10 pub built_openai_bias: OpenAILogitBias,
11 from_token_ids: FromTokenIds,
12 from_chars: FromChars,
13 from_words: FromWords,
14 from_texts: FromTexts,
15}
16
17impl LogitBias {
18 pub fn new() -> Self {
19 Self::default()
20 }
21
22 pub fn add_token_id(&mut self, token_id: u32, bias: f32) -> &mut Self {
23 self.from_token_ids.add_token_id(token_id, bias);
24 self.clear_built();
25 self
26 }
27
28 pub fn add_token_ids(&mut self, logit_bias: HashMap<u32, f32>) -> &mut Self {
29 self.from_token_ids.add_token_ids(logit_bias);
30 self.clear_built();
31 self
32 }
33
34 pub fn add_from_char(&mut self, char: char, bias: f32) -> &mut Self {
35 self.from_chars.add_char(char, bias);
36 self.clear_built();
37 self
38 }
39
40 pub fn add_from_word(&mut self, word: &str, bias: f32) -> &mut Self {
41 self.from_words.add_word(word, bias);
42 self.clear_built();
43 self
44 }
45
46 pub fn add_from_text(&mut self, text: &str, bias: f32) -> &mut Self {
47 self.from_texts.add_text(text, bias);
48 self.clear_built();
49 self
50 }
51
52 pub fn clear_logit_bias(&mut self) -> &mut Self {
53 self.from_token_ids.clear();
54 self.from_chars.clear();
55 self.from_words.clear();
56 self.from_texts.clear();
57 self.clear_built();
58 self
59 }
60
61 pub(crate) fn build_openai(&mut self, tokenizer: &Arc<Tokenizer>) -> crate::Result<()> {
62 if !self.built_openai_bias.is_none() {
63 return Ok(());
64 }
65 if self.base_logit_bias.is_none() {
66 self.build_base(tokenizer)?;
67 }
68 if let Some(base_logit_bias) = &self.base_logit_bias {
69 self.built_openai_bias.build(base_logit_bias);
70 }
71 Ok(())
72 }
73
74 pub(crate) fn get_openai(&self) -> Option<HashMap<String, serde_json::Value>> {
75 self.built_openai_bias.get()
76 }
77
78 fn build_base(&mut self, tokenizer: &Arc<Tokenizer>) -> crate::Result<()> {
79 if self.from_token_ids.is_none()
80 && self.from_chars.is_none()
81 && self.from_words.is_none()
82 && self.from_texts.is_none()
83 {
84 return Ok(());
85 }
86 let validated_logit_bias = self.from_token_ids.get(tokenizer)?;
87 self.from_token_ids.clear();
88
89 let validated_logit_bias = Self::merge_logit_biases(vec![
90 &validated_logit_bias,
91 &self.from_chars.get(tokenizer)?,
92 ]);
93 self.from_chars.clear();
94
95 let validated_logit_bias = Self::merge_logit_biases(vec![
96 &validated_logit_bias,
97 &self.from_words.get(tokenizer)?,
98 ]);
99 self.from_words.clear();
100
101 let validated_logit_bias = Self::merge_logit_biases(vec![
102 &validated_logit_bias,
103 &self.from_texts.get(tokenizer)?,
104 ]);
105 self.from_texts.clear();
106
107 if !validated_logit_bias.is_empty() {
108 Self::validate_logit_bias_values(&validated_logit_bias)?;
109 self.base_logit_bias = Some(validated_logit_bias);
110 }
111 Ok(())
112 }
113
114 fn clear_built(&mut self) -> &mut Self {
115 self.base_logit_bias = None;
116 self.built_openai_bias.clear();
117 self
118 }
119
120 fn validate_logit_bias_values(logit_bias: &HashMap<u32, f32>) -> crate::Result<()> {
130 for value in logit_bias.values() {
131 if *value > 100.0 || *value < -100.0 {
132 return Err(crate::anyhow!(
133 "logit_bias value must be between -100.0 and 100.0. Given value: {}",
134 value
135 ));
136 }
137 }
138 Ok(())
139 }
140
141 fn merge_logit_biases(logit_biases: Vec<&HashMap<u32, f32>>) -> HashMap<u32, f32> {
151 let mut merged_logit_bias: HashMap<u32, f32> = HashMap::new();
152 for logit_bias in logit_biases {
153 for (token_id, bias) in logit_bias {
154 merged_logit_bias.insert(*token_id, *bias);
155 }
156 }
157 merged_logit_bias
158 }
159}
160
161#[derive(Clone, Default)]
162struct FromTokenIds {
163 pub token_ids: Option<HashMap<u32, f32>>,
164}
165
166impl FromTokenIds {
167 fn is_none(&self) -> bool {
168 self.token_ids.is_none()
169 }
170
171 fn clear(&mut self) {
172 self.token_ids = None;
173 }
174
175 fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
176 if let Some(token_ids) = &self.token_ids {
177 for token_id in token_ids.keys() {
178 tokenizer.try_from_single_token_id(*token_id)?;
179 }
180 Ok(token_ids.clone())
181 } else {
182 Ok(HashMap::new())
183 }
184 }
185
186 fn add_token_id(&mut self, token_id: u32, bias: f32) {
187 self.token_ids
188 .get_or_insert_with(HashMap::new)
189 .entry(token_id)
190 .or_insert(bias);
191 }
192
193 fn add_token_ids(&mut self, logit_bias: HashMap<u32, f32>) {
194 self.token_ids
195 .get_or_insert_with(HashMap::new)
196 .extend(logit_bias);
197 }
198}
199
200#[derive(Clone, Default)]
201struct FromChars {
202 pub chars: Option<HashMap<char, f32>>,
203}
204
205impl FromChars {
206 fn is_none(&self) -> bool {
207 self.chars.is_none()
208 }
209
210 fn clear(&mut self) {
211 self.chars = None;
212 }
213
214 fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
215 if let Some(chars) = &self.chars {
216 let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
217 for (char, bias) in chars {
218 let token_id = tokenizer.try_into_single_token(&char.to_string())?;
219 token_logit_bias.insert(token_id, *bias);
220 }
221 Ok(token_logit_bias)
222 } else {
223 Ok(HashMap::new())
224 }
225 }
226
227 fn add_char(&mut self, char: char, bias: f32) {
228 self.chars
229 .get_or_insert_with(HashMap::new)
230 .entry(char)
231 .or_insert(bias);
232 }
233}
234
235#[derive(Clone, Default)]
236struct FromWords {
237 pub words: Option<HashMap<String, f32>>,
238}
239
240impl FromWords {
241 fn is_none(&self) -> bool {
242 self.words.is_none()
243 }
244
245 fn clear(&mut self) {
246 self.words = None;
247 }
248
249 fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
250 if let Some(words) = &self.words {
251 let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
252 for (word_maybe, bias) in words {
253 let mut words_maybe: Vec<String> = word_maybe
254 .trim()
255 .split_ascii_whitespace()
256 .map(|s| s.trim().to_string())
257 .collect();
258 let word = if words_maybe.is_empty() {
259 return Err(crate::anyhow!(
260 "logit_bias contains an empty word. Given word: {}",
261 word_maybe
262 ));
263 } else if words_maybe.len() > 1 {
264 return Err(crate::anyhow!(
265 "logit_bias contains a word seperated by whitespace. Given word: {}",
266 word_maybe
267 ));
268 } else {
269 words_maybe.remove(0)
270 };
271 let token_ids = tokenizer.tokenize(&word);
272 for id in token_ids {
273 if id == tokenizer.white_space_token_id {
274 panic!(
275 "logit_bias contains a whitespace token. Given word: {}",
276 word
277 )
278 }
279 token_logit_bias.insert(id, *bias);
280 }
281 }
282 Ok(token_logit_bias)
283 } else {
284 Ok(HashMap::new())
285 }
286 }
287
288 fn add_word(&mut self, word: &str, bias: f32) {
289 self.words
290 .get_or_insert_with(HashMap::new)
291 .entry(word.to_owned())
292 .or_insert(bias);
293 }
294}
295
296#[derive(Clone, Default)]
297struct FromTexts {
298 pub texts: Option<HashMap<String, f32>>,
299}
300
301impl FromTexts {
302 fn is_none(&self) -> bool {
303 self.texts.is_none()
304 }
305
306 fn clear(&mut self) {
307 self.texts = None;
308 }
309
310 fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
311 if let Some(texts) = &self.texts {
312 let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
313 for (text, bias) in texts {
314 let token_ids = tokenizer.tokenize(text);
315 for id in token_ids {
316 if id == tokenizer.white_space_token_id {
317 continue;
318 }
319 token_logit_bias.insert(id, *bias);
320 }
321 }
322 Ok(token_logit_bias)
323 } else {
324 Ok(HashMap::new())
325 }
326 }
327
328 fn add_text(&mut self, text: &str, bias: f32) {
329 self.texts
330 .get_or_insert_with(HashMap::new)
331 .entry(text.to_owned())
332 .or_insert(bias);
333 }
334}
335
336#[derive(Clone, Default)]
337pub struct OpenAILogitBias {
338 pub built_logit_bias: Option<HashMap<String, serde_json::Value>>,
339}
340
341impl OpenAILogitBias {
342 fn is_none(&self) -> bool {
343 self.built_logit_bias.is_none()
344 }
345
346 fn clear(&mut self) {
347 self.built_logit_bias = None;
348 }
349
350 fn build(&mut self, logit_bias: &HashMap<u32, f32>) {
351 let mut openai_logit_bias: HashMap<String, serde_json::Value> = HashMap::new();
352 for (token_id, value) in logit_bias {
353 openai_logit_bias.insert(
354 token_id.to_string(),
355 serde_json::Value::Number(serde_json::Number::from(value.ceil() as i32)),
356 );
357 }
358 }
359
360 fn get(&self) -> Option<HashMap<String, serde_json::Value>> {
361 self.built_logit_bias.clone()
362 }
363}
364
365pub trait LogitBiasTrait: RequestConfigTrait {
366 fn lb_mut(&mut self) -> &mut Option<LogitBias>;
367
368 fn logit_bias(&mut self) -> &mut LogitBias {
369 if self.lb_mut().is_none() {
370 *self.lb_mut() = Some(LogitBias::default());
371 }
372 self.lb_mut().as_mut().unwrap()
373 }
374
375 fn add_logit_bias_token_id(&mut self, token_id: u32, bias: f32) -> &mut Self {
382 self.logit_bias().add_token_id(token_id, bias);
383 self
384 }
385
386 fn add_logit_bias_token_ids(&mut self, logit_bias: HashMap<u32, f32>) -> &mut Self {
392 self.logit_bias().add_token_ids(logit_bias);
393 self
394 }
395
396 fn add_logit_bias_from_char(&mut self, char: char, bias: f32) -> &mut Self {
404 self.logit_bias().add_from_char(char, bias);
405 self
406 }
407
408 fn add_logit_bias_from_word(&mut self, word: &str, bias: f32) -> &mut Self {
416 self.logit_bias().add_from_word(word, bias);
417 self
418 }
419
420 fn add_logit_bias_from_text(&mut self, text: &str, bias: f32) -> &mut Self {
427 self.logit_bias().add_from_text(text, bias);
428 self
429 }
430
431 fn clear_logit_bias(&mut self) -> &mut Self {
433 self.logit_bias().clear_logit_bias();
434 self
435 }
436}