alith_interface/requests/
logit_bias.rs

1use alith_models::tokenizer::Tokenizer;
2
3use crate::requests::req_components::RequestConfigTrait;
4
5use std::{collections::HashMap, sync::Arc};
6
7#[derive(Clone, Default)]
8pub struct LogitBias {
9    pub base_logit_bias: Option<HashMap<u32, f32>>,
10    pub built_openai_bias: OpenAILogitBias,
11    from_token_ids: FromTokenIds,
12    from_chars: FromChars,
13    from_words: FromWords,
14    from_texts: FromTexts,
15}
16
17impl LogitBias {
18    pub fn new() -> Self {
19        Self::default()
20    }
21
22    pub fn add_token_id(&mut self, token_id: u32, bias: f32) -> &mut Self {
23        self.from_token_ids.add_token_id(token_id, bias);
24        self.clear_built();
25        self
26    }
27
28    pub fn add_token_ids(&mut self, logit_bias: HashMap<u32, f32>) -> &mut Self {
29        self.from_token_ids.add_token_ids(logit_bias);
30        self.clear_built();
31        self
32    }
33
34    pub fn add_from_char(&mut self, char: char, bias: f32) -> &mut Self {
35        self.from_chars.add_char(char, bias);
36        self.clear_built();
37        self
38    }
39
40    pub fn add_from_word(&mut self, word: &str, bias: f32) -> &mut Self {
41        self.from_words.add_word(word, bias);
42        self.clear_built();
43        self
44    }
45
46    pub fn add_from_text(&mut self, text: &str, bias: f32) -> &mut Self {
47        self.from_texts.add_text(text, bias);
48        self.clear_built();
49        self
50    }
51
52    pub fn clear_logit_bias(&mut self) -> &mut Self {
53        self.from_token_ids.clear();
54        self.from_chars.clear();
55        self.from_words.clear();
56        self.from_texts.clear();
57        self.clear_built();
58        self
59    }
60
61    pub(crate) fn build_openai(&mut self, tokenizer: &Arc<Tokenizer>) -> crate::Result<()> {
62        if !self.built_openai_bias.is_none() {
63            return Ok(());
64        }
65        if self.base_logit_bias.is_none() {
66            self.build_base(tokenizer)?;
67        }
68        if let Some(base_logit_bias) = &self.base_logit_bias {
69            self.built_openai_bias.build(base_logit_bias);
70        }
71        Ok(())
72    }
73
74    pub(crate) fn get_openai(&self) -> Option<HashMap<String, serde_json::Value>> {
75        self.built_openai_bias.get()
76    }
77
78    fn build_base(&mut self, tokenizer: &Arc<Tokenizer>) -> crate::Result<()> {
79        if self.from_token_ids.is_none()
80            && self.from_chars.is_none()
81            && self.from_words.is_none()
82            && self.from_texts.is_none()
83        {
84            return Ok(());
85        }
86        let validated_logit_bias = self.from_token_ids.get(tokenizer)?;
87        self.from_token_ids.clear();
88
89        let validated_logit_bias = Self::merge_logit_biases(vec![
90            &validated_logit_bias,
91            &self.from_chars.get(tokenizer)?,
92        ]);
93        self.from_chars.clear();
94
95        let validated_logit_bias = Self::merge_logit_biases(vec![
96            &validated_logit_bias,
97            &self.from_words.get(tokenizer)?,
98        ]);
99        self.from_words.clear();
100
101        let validated_logit_bias = Self::merge_logit_biases(vec![
102            &validated_logit_bias,
103            &self.from_texts.get(tokenizer)?,
104        ]);
105        self.from_texts.clear();
106
107        if !validated_logit_bias.is_empty() {
108            Self::validate_logit_bias_values(&validated_logit_bias)?;
109            self.base_logit_bias = Some(validated_logit_bias);
110        }
111        Ok(())
112    }
113
114    fn clear_built(&mut self) -> &mut Self {
115        self.base_logit_bias = None;
116        self.built_openai_bias.clear();
117        self
118    }
119
120    /// Validates the logit bias values by checking if they are within the range of -100.0 to 100.0.
121    ///
122    /// # Arguments
123    ///
124    /// * `logit_bias` - A reference to the `HashMap` containing the logit biases with token IDs as keys and bias values as values.
125    ///
126    /// # Returns
127    ///
128    /// Returns `Result<(), anyhow::Error>` indicating success or an error if any of the bias values are out of range.
129    fn validate_logit_bias_values(logit_bias: &HashMap<u32, f32>) -> crate::Result<()> {
130        for value in logit_bias.values() {
131            if *value > 100.0 || *value < -100.0 {
132                return Err(crate::anyhow!(
133                    "logit_bias value must be between -100.0 and 100.0. Given value: {}",
134                    value
135                ));
136            }
137        }
138        Ok(())
139    }
140
141    /// Merges multiple logit biases into a single `HashMap` of token IDs and bias values.
142    ///
143    /// # Arguments
144    ///
145    /// * `logit_biases` - A vector of references to `HashMap`s containing logit biases with token IDs as keys and bias values as values.
146    ///
147    /// # Returns
148    ///
149    /// Returns a `HashMap<u32, f32>` containing the merged logit biases.
150    fn merge_logit_biases(logit_biases: Vec<&HashMap<u32, f32>>) -> HashMap<u32, f32> {
151        let mut merged_logit_bias: HashMap<u32, f32> = HashMap::new();
152        for logit_bias in logit_biases {
153            for (token_id, bias) in logit_bias {
154                merged_logit_bias.insert(*token_id, *bias);
155            }
156        }
157        merged_logit_bias
158    }
159}
160
161#[derive(Clone, Default)]
162struct FromTokenIds {
163    pub token_ids: Option<HashMap<u32, f32>>,
164}
165
166impl FromTokenIds {
167    fn is_none(&self) -> bool {
168        self.token_ids.is_none()
169    }
170
171    fn clear(&mut self) {
172        self.token_ids = None;
173    }
174
175    fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
176        if let Some(token_ids) = &self.token_ids {
177            for token_id in token_ids.keys() {
178                tokenizer.try_from_single_token_id(*token_id)?;
179            }
180            Ok(token_ids.clone())
181        } else {
182            Ok(HashMap::new())
183        }
184    }
185
186    fn add_token_id(&mut self, token_id: u32, bias: f32) {
187        self.token_ids
188            .get_or_insert_with(HashMap::new)
189            .entry(token_id)
190            .or_insert(bias);
191    }
192
193    fn add_token_ids(&mut self, logit_bias: HashMap<u32, f32>) {
194        self.token_ids
195            .get_or_insert_with(HashMap::new)
196            .extend(logit_bias);
197    }
198}
199
200#[derive(Clone, Default)]
201struct FromChars {
202    pub chars: Option<HashMap<char, f32>>,
203}
204
205impl FromChars {
206    fn is_none(&self) -> bool {
207        self.chars.is_none()
208    }
209
210    fn clear(&mut self) {
211        self.chars = None;
212    }
213
214    fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
215        if let Some(chars) = &self.chars {
216            let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
217            for (char, bias) in chars {
218                let token_id = tokenizer.try_into_single_token(&char.to_string())?;
219                token_logit_bias.insert(token_id, *bias);
220            }
221            Ok(token_logit_bias)
222        } else {
223            Ok(HashMap::new())
224        }
225    }
226
227    fn add_char(&mut self, char: char, bias: f32) {
228        self.chars
229            .get_or_insert_with(HashMap::new)
230            .entry(char)
231            .or_insert(bias);
232    }
233}
234
235#[derive(Clone, Default)]
236struct FromWords {
237    pub words: Option<HashMap<String, f32>>,
238}
239
240impl FromWords {
241    fn is_none(&self) -> bool {
242        self.words.is_none()
243    }
244
245    fn clear(&mut self) {
246        self.words = None;
247    }
248
249    fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
250        if let Some(words) = &self.words {
251            let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
252            for (word_maybe, bias) in words {
253                let mut words_maybe: Vec<String> = word_maybe
254                    .trim()
255                    .split_ascii_whitespace()
256                    .map(|s| s.trim().to_string())
257                    .collect();
258                let word = if words_maybe.is_empty() {
259                    return Err(crate::anyhow!(
260                        "logit_bias contains an empty word. Given word: {}",
261                        word_maybe
262                    ));
263                } else if words_maybe.len() > 1 {
264                    return Err(crate::anyhow!(
265                        "logit_bias contains a word seperated by whitespace. Given word: {}",
266                        word_maybe
267                    ));
268                } else {
269                    words_maybe.remove(0)
270                };
271                let token_ids = tokenizer.tokenize(&word);
272                for id in token_ids {
273                    if id == tokenizer.white_space_token_id {
274                        panic!(
275                            "logit_bias contains a whitespace token. Given word: {}",
276                            word
277                        )
278                    }
279                    token_logit_bias.insert(id, *bias);
280                }
281            }
282            Ok(token_logit_bias)
283        } else {
284            Ok(HashMap::new())
285        }
286    }
287
288    fn add_word(&mut self, word: &str, bias: f32) {
289        self.words
290            .get_or_insert_with(HashMap::new)
291            .entry(word.to_owned())
292            .or_insert(bias);
293    }
294}
295
296#[derive(Clone, Default)]
297struct FromTexts {
298    pub texts: Option<HashMap<String, f32>>,
299}
300
301impl FromTexts {
302    fn is_none(&self) -> bool {
303        self.texts.is_none()
304    }
305
306    fn clear(&mut self) {
307        self.texts = None;
308    }
309
310    fn get(&self, tokenizer: &Arc<Tokenizer>) -> crate::Result<HashMap<u32, f32>> {
311        if let Some(texts) = &self.texts {
312            let mut token_logit_bias: HashMap<u32, f32> = HashMap::new();
313            for (text, bias) in texts {
314                let token_ids = tokenizer.tokenize(text);
315                for id in token_ids {
316                    if id == tokenizer.white_space_token_id {
317                        continue;
318                    }
319                    token_logit_bias.insert(id, *bias);
320                }
321            }
322            Ok(token_logit_bias)
323        } else {
324            Ok(HashMap::new())
325        }
326    }
327
328    fn add_text(&mut self, text: &str, bias: f32) {
329        self.texts
330            .get_or_insert_with(HashMap::new)
331            .entry(text.to_owned())
332            .or_insert(bias);
333    }
334}
335
336#[derive(Clone, Default)]
337pub struct OpenAILogitBias {
338    pub built_logit_bias: Option<HashMap<String, serde_json::Value>>,
339}
340
341impl OpenAILogitBias {
342    fn is_none(&self) -> bool {
343        self.built_logit_bias.is_none()
344    }
345
346    fn clear(&mut self) {
347        self.built_logit_bias = None;
348    }
349
350    fn build(&mut self, logit_bias: &HashMap<u32, f32>) {
351        let mut openai_logit_bias: HashMap<String, serde_json::Value> = HashMap::new();
352        for (token_id, value) in logit_bias {
353            openai_logit_bias.insert(
354                token_id.to_string(),
355                serde_json::Value::Number(serde_json::Number::from(value.ceil() as i32)),
356            );
357        }
358    }
359
360    fn get(&self) -> Option<HashMap<String, serde_json::Value>> {
361        self.built_logit_bias.clone()
362    }
363}
364
365pub trait LogitBiasTrait: RequestConfigTrait {
366    fn lb_mut(&mut self) -> &mut Option<LogitBias>;
367
368    fn logit_bias(&mut self) -> &mut LogitBias {
369        if self.lb_mut().is_none() {
370            *self.lb_mut() = Some(LogitBias::default());
371        }
372        self.lb_mut().as_mut().unwrap()
373    }
374
375    /// Adds a logit bias for a specific token ID. In the case you have your own tokenizer or other situations where you have token IDs.
376    ///
377    /// # Arguments
378    ///
379    /// * `token_id` - The token ID.
380    /// * `bias` - The bias value.
381    fn add_logit_bias_token_id(&mut self, token_id: u32, bias: f32) -> &mut Self {
382        self.logit_bias().add_token_id(token_id, bias);
383        self
384    }
385
386    /// Adds multiple logit biases for token IDs. In the case you have your own tokenizer or other situations where you have token IDs.
387    ///
388    /// # Arguments
389    ///
390    /// * `logit_bias` - A `HashMap` containing token IDs as keys and bias values as values.
391    fn add_logit_bias_token_ids(&mut self, logit_bias: HashMap<u32, f32>) -> &mut Self {
392        self.logit_bias().add_token_ids(logit_bias);
393        self
394    }
395
396    /// Adds a logit bias for a specific character.
397    /// Not very useful as it does not necessarily remove all instances of that character as the character may be part of other tokens.
398    ///
399    /// # Arguments
400    ///
401    /// * `char` - The character.
402    /// * `bias` - The bias value.
403    fn add_logit_bias_from_char(&mut self, char: char, bias: f32) -> &mut Self {
404        self.logit_bias().add_from_char(char, bias);
405        self
406    }
407
408    /// Adds a logit bias for a specific word. If a word is more than one token, it will be split into multiple tokens.
409    /// Errors if the word is empty or contains whitespace.
410    ///
411    /// # Arguments
412    ///
413    /// * `word` - The word.
414    /// * `bias` - The bias value.
415    fn add_logit_bias_from_word(&mut self, word: &str, bias: f32) -> &mut Self {
416        self.logit_bias().add_from_word(word, bias);
417        self
418    }
419
420    /// Adds a logit bias for a specific text. Splits the text into tokens and applies the bias to each token. It does not add the logit bias value to the whitespace token.
421    ///
422    /// # Arguments
423    ///
424    /// * `text` - The text.
425    /// * `bias` - The bias value.
426    fn add_logit_bias_from_text(&mut self, text: &str, bias: f32) -> &mut Self {
427        self.logit_bias().add_from_text(text, bias);
428        self
429    }
430
431    /// Clearss the logit bias configuration. To reuse the request object for another request. Mostly for testing.
432    fn clear_logit_bias(&mut self) -> &mut Self {
433        self.logit_bias().clear_logit_bias();
434        self
435    }
436}