llm_base/
vocabulary.rs

1use std::{collections::HashMap, str::FromStr};
2
3use crate::InferenceError;
4
5/// The identifier of a token in a vocabulary.
6pub type TokenId = i32;
7pub(crate) type Token = Vec<u8>;
8pub(crate) type TokenScore = f32;
9
10/// The vocabulary used by a model.
11#[derive(Debug, Clone, Default)]
12pub struct Vocabulary {
13    /// Maps every integer (index) token ID to its corresponding token.
14    pub id_to_token: Vec<Token>,
15
16    /// Maps every integer (index) token ID to corresponding score.
17    pub id_to_token_score: Vec<TokenScore>,
18
19    // todo: use a radix tree
20    /// Maps a token to a token ID.
21    pub token_to_id: HashMap<Token, TokenId>,
22
23    /// The longest token in this vocabulary.
24    pub max_token_length: usize,
25}
26
27impl Vocabulary {
28    /// Add a token to the vocabulary.
29    ///
30    /// The token added must have `id` directly after the last token in the vocabulary.
31    ///
32    /// # Panics
33    /// - This function can panic if `id` does not correspond to the next token in the vocabulary.
34    ///   That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`.
35    pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
36        // These are loader invariants. If this is broken, then the loader is broken and this is a bug,
37        // not an issue with the model itself.
38        assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
39        if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize {
40            let expected_id = self.id_to_token.len() as TokenId;
41            panic!("the id of token added should be {expected_id}; is {id}");
42        }
43
44        self.max_token_length = self.max_token_length.max(content.len());
45        self.id_to_token.push(content.clone());
46        self.id_to_token_score.push(score);
47        self.token_to_id.insert(content, id);
48    }
49
50    /// Converts a token index to the token it represents in this vocabulary.
51    pub fn token(&self, idx: usize) -> &[u8] {
52        &self.id_to_token[idx]
53    }
54
55    // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
56    /// Tokenize a `text` with this vocabulary.
57    ///
58    /// `bos` controls whether a beginning-of-string token should be inserted.
59    pub fn tokenize<'a>(
60        &'a self,
61        text: &str,
62        bos: bool,
63    ) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
64        let len = text.len();
65
66        let mut score = vec![0usize; len + 1];
67        let mut prev = vec![TokenId::default(); len + 1];
68
69        for i in 0..len {
70            let max_len = (len - i).min(self.max_token_length);
71            for sub_len in 1..=max_len {
72                let sub = &text.as_bytes()[i..i + sub_len];
73                let token = self.token_to_id.get(sub);
74
75                if let Some(token) = token {
76                    let token_score = sub.len() * sub.len();
77                    let local_score = score[i] + token_score;
78                    let next = i + sub_len;
79
80                    if score[next] < local_score {
81                        score[next] = local_score;
82                        prev[next] = *token;
83                    }
84                }
85            }
86        }
87
88        // Backward pass
89        let mut res = vec![];
90        let mut i = len;
91        while i > 0 {
92            let token_id = prev[i];
93            if token_id == 0 {
94                return Err(InferenceError::TokenizationFailed);
95            }
96            let token = self.id_to_token[token_id as usize].as_slice();
97            res.push((token, token_id));
98            i -= token.len();
99        }
100
101        if bos {
102            // TODO: replace with vocab.bos
103            res.push((&[], 1));
104        }
105
106        // Pieces are in reverse order so correct that
107        res.reverse();
108
109        Ok(res)
110    }
111}
112
113#[derive(Default, Clone, Debug, PartialEq)]
114/// A list of tokens to bias during the process of inferencing.
115///
116/// When a biased token is encountered, the bias will be used
117/// instead of the inferred logit during the sampling process.
118///
119/// This can be used to disable the generation of responses
120/// with specific tokens by setting their corresponding bias
121/// to -1.0.
122pub struct TokenBias(Vec<(TokenId, f32)>);
123
124impl TokenBias {
125    /// Create a [TokenBias] from an existing `Vec`.
126    pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
127        v.sort_by_cached_key(|(tid, _)| *tid);
128        v.dedup_by_key(|(tid, _)| *tid);
129        Self(v)
130    }
131
132    /// Retrieves the bias for a given token, if available.
133    pub fn get(&self, tid: TokenId) -> Option<f32> {
134        self.0
135            .binary_search_by_key(&tid, |(tid, _)| *tid)
136            .map(|idx| self.0[idx].1)
137            .ok()
138    }
139}
140
141impl FromStr for TokenBias {
142    type Err = String;
143
144    /// A comma separated list of token biases. The list should be in the format
145    /// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
146    /// floating point number.
147    /// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
148    /// (start of document) and 2 (end of document) to -1.0 which effectively
149    /// disables the model from generating responses containing those token IDs.
150    fn from_str(s: &str) -> Result<Self, Self::Err> {
151        let x = s
152            .split(',')
153            .map(|kv| {
154                let (k, v) = kv
155                    .trim()
156                    .split_once('=')
157                    .ok_or_else(|| "Missing '=' in bias item".to_owned())?;
158                let tid: TokenId = k
159                    .trim()
160                    .parse()
161                    .map_err(|e: std::num::ParseIntError| e.to_string())?;
162                let bias: f32 = v
163                    .trim()
164                    .parse()
165                    .map_err(|e: std::num::ParseFloatError| e.to_string())?;
166                Result::<_, String>::Ok((tid, bias))
167            })
168            .collect::<Result<_, _>>()?;
169        Ok(TokenBias::new(x))
170    }
171}
172
173impl std::fmt::Display for TokenBias {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        write!(f, "{:?}", self.0)
176    }
177}