alith_models/
tokenizer.rs

1use super::local_model::hf_loader::{HfTokenTrait, HuggingFaceLoader};
2use alith_prompt::PromptTokenizer;
3use anyhow::{Result, anyhow};
4use std::{
5    fmt,
6    path::{Path, PathBuf},
7};
8use tiktoken_rs::{CoreBPE, get_bpe_from_model};
9use tokenizers::Tokenizer as HFTokenizer;
10
11pub enum TokenizerBackend {
12    HuggingFace(Box<HFTokenizer>),
13    Tiktoken(Box<CoreBPE>),
14}
15
16impl fmt::Debug for TokenizerBackend {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            TokenizerBackend::HuggingFace(_) => {
20                write!(f, "TokenizerBackend::HuggingFacesTokenizer")
21            }
22            TokenizerBackend::Tiktoken(_) => {
23                write!(f, "TokenizerBackend::Tiktoken")
24            }
25        }
26    }
27}
28
29#[derive(Debug)]
30pub struct Tokenizer {
31    pub tokenizer: TokenizerBackend,
32    pub tokenizer_path: Option<PathBuf>,
33    pub with_special_tokens: bool,
34    pub white_space_token_id: u32,
35}
36
37impl Tokenizer {
38    pub fn new_tiktoken<S: AsRef<str>>(model_id: S) -> Result<Self> {
39        let tokenizer = get_bpe_from_model(model_id.as_ref())?;
40        let white_space_token_id = tokenizer.encode_ordinary(" ").remove(0);
41        Ok(Self {
42            tokenizer: TokenizerBackend::Tiktoken(Box::new(tokenizer)),
43            tokenizer_path: None,
44            with_special_tokens: false,
45            white_space_token_id,
46        })
47    }
48
49    pub fn new_from_tokenizer(tokenizer: HFTokenizer) -> Result<Self> {
50        let white_space_token_id = tokenizer.encode(" ", false).unwrap().get_ids()[0];
51        Ok(Self {
52            tokenizer: TokenizerBackend::HuggingFace(Box::new(tokenizer)),
53            tokenizer_path: None,
54            with_special_tokens: false,
55            white_space_token_id,
56        })
57    }
58
59    pub fn new_from_tokenizer_json<P: AsRef<Path>>(local_path: P) -> Result<Self> {
60        let path = local_path.as_ref().to_path_buf().clone();
61        let tokenizer = HFTokenizer::from_file(local_path).map_err(|e| anyhow!(e))?;
62        let white_space_token_id = tokenizer.encode(" ", false).unwrap().get_ids()[0];
63        Ok(Self {
64            tokenizer: TokenizerBackend::HuggingFace(Box::new(tokenizer)),
65            tokenizer_path: Some(path),
66            with_special_tokens: false,
67            white_space_token_id,
68        })
69    }
70
71    pub fn new_from_hf_repo<S: AsRef<str>>(hf_token: Option<S>, repo_id: S) -> Result<Self> {
72        let mut api: HuggingFaceLoader = HuggingFaceLoader::new();
73        if let Some(hf_token) = hf_token {
74            *api.hf_token_mut() = Some(hf_token.as_ref().to_owned());
75        }
76
77        let local_path = api.load_file("tokenizer.json", repo_id.as_ref())?;
78        Tokenizer::new_from_tokenizer_json(&local_path)
79    }
80
81    #[inline]
82    pub fn tokenize<T: AsRef<str>>(&self, str: T) -> Vec<u32> {
83        self.encode(str.as_ref())
84    }
85
86    #[inline]
87    pub fn detokenize_one(&self, token: u32) -> Result<String> {
88        self.decode(&[token])
89    }
90
91    #[inline]
92    pub fn detokenize_many(&self, tokens: &[u32]) -> Result<String> {
93        self.decode(tokens)
94    }
95
96    #[inline]
97    pub fn count_tokens(&self, str: &str) -> u32 {
98        self.tokenize(str).len() as u32
99    }
100
101    pub fn try_from_single_token_id(&self, try_from_single_token_id: u32) -> Result<String> {
102        let detokenize_response = self.detokenize_one(try_from_single_token_id)?;
103        let mut strings_maybe: Vec<String> = detokenize_response
104            .split_ascii_whitespace()
105            .map(|s| s.to_string())
106            .collect();
107        match strings_maybe.len() {
108            0 => Err(anyhow!(
109                "token_id is empty for try_from_single_token_id: {}",
110                try_from_single_token_id
111            )),
112            1 => Ok(strings_maybe.remove(0)),
113            n => Err(anyhow!(
114                "Found more than one token ({n} total) in try_from_single_token_id: {}",
115                try_from_single_token_id
116            )),
117        }
118    }
119
120    pub fn try_into_single_token(&self, try_into_single_token: &str) -> Result<u32> {
121        let mut tokens = self.tokenize(try_into_single_token);
122        match tokens.len() {
123            0 => Err(anyhow!("No token found in text: {}", try_into_single_token)),
124            1 => Ok(tokens.remove(0)),
125            n => Err(anyhow!(
126                "Found more than one token ({n} total) in text: {}",
127                try_into_single_token
128            )),
129        }
130    }
131
132    /// Creates a window of text normalized to the specified token size in the center of the text.
133    ///
134    /// # Arguments
135    ///
136    /// * `text` - The input text to create a window from.
137    /// * `target_token_size` - The desired number of tokens in the window.
138    ///
139    /// # Returns
140    ///
141    /// A new string that represents the normalized window of text, or the original
142    /// text if its token count is less than or equal to `target_token_size`.
143    pub fn create_text_window(&self, text: &str, target_token_size: u32) -> String {
144        let tokens = self.tokenize(text);
145        if tokens.len() <= target_token_size as usize {
146            return text.to_string();
147        }
148
149        let start_token_index = (tokens.len() - target_token_size as usize) / 2;
150        let end_token_index = start_token_index + target_token_size as usize;
151
152        let preserved_tokens = &tokens[start_token_index..end_token_index];
153        self.detokenize_many(preserved_tokens).unwrap()
154    }
155
156    /// Creates a range of text from the specified start and end token indices.
157    ///
158    /// # Arguments
159    ///
160    /// * `text` - The input text to create a window from.
161    /// * `target_token_size` - The desired number of tokens in the window.
162    ///
163    /// # Returns
164    ///
165    /// A new string that represents the normalized window of text, or the original
166    /// text if its token count is less than or equal to `target_token_size`.
167    pub fn create_text_range(
168        &self,
169        text: &str,
170        start_token_index: u32,
171        end_token_index: u32,
172    ) -> String {
173        let tokens = self.tokenize(text);
174        let end_token_index = if tokens.len() <= end_token_index as usize {
175            tokens.len()
176        } else {
177            end_token_index as usize
178        };
179
180        let preserved_tokens = &tokens[start_token_index as usize..end_token_index];
181        self.detokenize_many(preserved_tokens).unwrap()
182    }
183
184    fn encode_tiktoken(&self, tokenizer: &CoreBPE, str: &str) -> Vec<u32> {
185        if self.with_special_tokens {
186            tokenizer.encode_with_special_tokens(str)
187        } else {
188            tokenizer.encode_ordinary(str)
189        }
190    }
191
192    fn encode_hf(&self, tokenizer: &HFTokenizer, str: &str) -> Vec<u32> {
193        let tokens = if self.with_special_tokens {
194            tokenizer.encode(str, true)
195        } else {
196            tokenizer.encode(str, false)
197        };
198        tokens.unwrap().get_ids().to_vec()
199    }
200
201    #[inline]
202    fn encode(&self, str: &str) -> Vec<u32> {
203        match &self.tokenizer {
204            TokenizerBackend::HuggingFace(tokenizer) => self.encode_hf(tokenizer, str),
205            TokenizerBackend::Tiktoken(tokenizer) => self.encode_tiktoken(tokenizer, str),
206        }
207    }
208
209    #[inline]
210    fn decode_tiktoken(&self, tokenizer: &CoreBPE, tokens: &[u32]) -> Result<String> {
211        tokenizer.decode(tokens.to_owned()).map_err(|e| anyhow!(e))
212    }
213
214    #[inline]
215    fn decode_hf(&self, tokenizer: &HFTokenizer, tokens: &[u32]) -> Result<String> {
216        tokenizer.decode(tokens, true).map_err(|e| anyhow!(e))
217    }
218
219    #[inline]
220    fn decode(&self, tokens: &[u32]) -> Result<String> {
221        match &self.tokenizer {
222            TokenizerBackend::HuggingFace(tokenizer) => self.decode_hf(tokenizer, tokens),
223            TokenizerBackend::Tiktoken(tokenizer) => self.decode_tiktoken(tokenizer, tokens),
224        }
225    }
226}
227
228impl PromptTokenizer for Tokenizer {
229    #[inline]
230    fn tokenize(&self, input: &str) -> Vec<u32> {
231        self.tokenize(input)
232    }
233
234    #[inline]
235    fn count_tokens(&self, str: &str) -> u32 {
236        self.count_tokens(str)
237    }
238}