alith_models/
tokenizer.rs1use super::local_model::hf_loader::{HfTokenTrait, HuggingFaceLoader};
2use alith_prompt::PromptTokenizer;
3use anyhow::{Result, anyhow};
4use std::{
5 fmt,
6 path::{Path, PathBuf},
7};
8use tiktoken_rs::{CoreBPE, get_bpe_from_model};
9use tokenizers::Tokenizer as HFTokenizer;
10
11pub enum TokenizerBackend {
12 HuggingFace(Box<HFTokenizer>),
13 Tiktoken(Box<CoreBPE>),
14}
15
16impl fmt::Debug for TokenizerBackend {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 TokenizerBackend::HuggingFace(_) => {
20 write!(f, "TokenizerBackend::HuggingFacesTokenizer")
21 }
22 TokenizerBackend::Tiktoken(_) => {
23 write!(f, "TokenizerBackend::Tiktoken")
24 }
25 }
26 }
27}
28
29#[derive(Debug)]
30pub struct Tokenizer {
31 pub tokenizer: TokenizerBackend,
32 pub tokenizer_path: Option<PathBuf>,
33 pub with_special_tokens: bool,
34 pub white_space_token_id: u32,
35}
36
37impl Tokenizer {
38 pub fn new_tiktoken<S: AsRef<str>>(model_id: S) -> Result<Self> {
39 let tokenizer = get_bpe_from_model(model_id.as_ref())?;
40 let white_space_token_id = tokenizer.encode_ordinary(" ").remove(0);
41 Ok(Self {
42 tokenizer: TokenizerBackend::Tiktoken(Box::new(tokenizer)),
43 tokenizer_path: None,
44 with_special_tokens: false,
45 white_space_token_id,
46 })
47 }
48
49 pub fn new_from_tokenizer(tokenizer: HFTokenizer) -> Result<Self> {
50 let white_space_token_id = tokenizer.encode(" ", false).unwrap().get_ids()[0];
51 Ok(Self {
52 tokenizer: TokenizerBackend::HuggingFace(Box::new(tokenizer)),
53 tokenizer_path: None,
54 with_special_tokens: false,
55 white_space_token_id,
56 })
57 }
58
59 pub fn new_from_tokenizer_json<P: AsRef<Path>>(local_path: P) -> Result<Self> {
60 let path = local_path.as_ref().to_path_buf().clone();
61 let tokenizer = HFTokenizer::from_file(local_path).map_err(|e| anyhow!(e))?;
62 let white_space_token_id = tokenizer.encode(" ", false).unwrap().get_ids()[0];
63 Ok(Self {
64 tokenizer: TokenizerBackend::HuggingFace(Box::new(tokenizer)),
65 tokenizer_path: Some(path),
66 with_special_tokens: false,
67 white_space_token_id,
68 })
69 }
70
71 pub fn new_from_hf_repo<S: AsRef<str>>(hf_token: Option<S>, repo_id: S) -> Result<Self> {
72 let mut api: HuggingFaceLoader = HuggingFaceLoader::new();
73 if let Some(hf_token) = hf_token {
74 *api.hf_token_mut() = Some(hf_token.as_ref().to_owned());
75 }
76
77 let local_path = api.load_file("tokenizer.json", repo_id.as_ref())?;
78 Tokenizer::new_from_tokenizer_json(&local_path)
79 }
80
81 #[inline]
82 pub fn tokenize<T: AsRef<str>>(&self, str: T) -> Vec<u32> {
83 self.encode(str.as_ref())
84 }
85
86 #[inline]
87 pub fn detokenize_one(&self, token: u32) -> Result<String> {
88 self.decode(&[token])
89 }
90
91 #[inline]
92 pub fn detokenize_many(&self, tokens: &[u32]) -> Result<String> {
93 self.decode(tokens)
94 }
95
96 #[inline]
97 pub fn count_tokens(&self, str: &str) -> u32 {
98 self.tokenize(str).len() as u32
99 }
100
101 pub fn try_from_single_token_id(&self, try_from_single_token_id: u32) -> Result<String> {
102 let detokenize_response = self.detokenize_one(try_from_single_token_id)?;
103 let mut strings_maybe: Vec<String> = detokenize_response
104 .split_ascii_whitespace()
105 .map(|s| s.to_string())
106 .collect();
107 match strings_maybe.len() {
108 0 => Err(anyhow!(
109 "token_id is empty for try_from_single_token_id: {}",
110 try_from_single_token_id
111 )),
112 1 => Ok(strings_maybe.remove(0)),
113 n => Err(anyhow!(
114 "Found more than one token ({n} total) in try_from_single_token_id: {}",
115 try_from_single_token_id
116 )),
117 }
118 }
119
120 pub fn try_into_single_token(&self, try_into_single_token: &str) -> Result<u32> {
121 let mut tokens = self.tokenize(try_into_single_token);
122 match tokens.len() {
123 0 => Err(anyhow!("No token found in text: {}", try_into_single_token)),
124 1 => Ok(tokens.remove(0)),
125 n => Err(anyhow!(
126 "Found more than one token ({n} total) in text: {}",
127 try_into_single_token
128 )),
129 }
130 }
131
132 pub fn create_text_window(&self, text: &str, target_token_size: u32) -> String {
144 let tokens = self.tokenize(text);
145 if tokens.len() <= target_token_size as usize {
146 return text.to_string();
147 }
148
149 let start_token_index = (tokens.len() - target_token_size as usize) / 2;
150 let end_token_index = start_token_index + target_token_size as usize;
151
152 let preserved_tokens = &tokens[start_token_index..end_token_index];
153 self.detokenize_many(preserved_tokens).unwrap()
154 }
155
156 pub fn create_text_range(
168 &self,
169 text: &str,
170 start_token_index: u32,
171 end_token_index: u32,
172 ) -> String {
173 let tokens = self.tokenize(text);
174 let end_token_index = if tokens.len() <= end_token_index as usize {
175 tokens.len()
176 } else {
177 end_token_index as usize
178 };
179
180 let preserved_tokens = &tokens[start_token_index as usize..end_token_index];
181 self.detokenize_many(preserved_tokens).unwrap()
182 }
183
184 fn encode_tiktoken(&self, tokenizer: &CoreBPE, str: &str) -> Vec<u32> {
185 if self.with_special_tokens {
186 tokenizer.encode_with_special_tokens(str)
187 } else {
188 tokenizer.encode_ordinary(str)
189 }
190 }
191
192 fn encode_hf(&self, tokenizer: &HFTokenizer, str: &str) -> Vec<u32> {
193 let tokens = if self.with_special_tokens {
194 tokenizer.encode(str, true)
195 } else {
196 tokenizer.encode(str, false)
197 };
198 tokens.unwrap().get_ids().to_vec()
199 }
200
201 #[inline]
202 fn encode(&self, str: &str) -> Vec<u32> {
203 match &self.tokenizer {
204 TokenizerBackend::HuggingFace(tokenizer) => self.encode_hf(tokenizer, str),
205 TokenizerBackend::Tiktoken(tokenizer) => self.encode_tiktoken(tokenizer, str),
206 }
207 }
208
209 #[inline]
210 fn decode_tiktoken(&self, tokenizer: &CoreBPE, tokens: &[u32]) -> Result<String> {
211 tokenizer.decode(tokens.to_owned()).map_err(|e| anyhow!(e))
212 }
213
214 #[inline]
215 fn decode_hf(&self, tokenizer: &HFTokenizer, tokens: &[u32]) -> Result<String> {
216 tokenizer.decode(tokens, true).map_err(|e| anyhow!(e))
217 }
218
219 #[inline]
220 fn decode(&self, tokens: &[u32]) -> Result<String> {
221 match &self.tokenizer {
222 TokenizerBackend::HuggingFace(tokenizer) => self.decode_hf(tokenizer, tokens),
223 TokenizerBackend::Tiktoken(tokenizer) => self.decode_tiktoken(tokenizer, tokens),
224 }
225 }
226}
227
228impl PromptTokenizer for Tokenizer {
229 #[inline]
230 fn tokenize(&self, input: &str) -> Vec<u32> {
231 self.tokenize(input)
232 }
233
234 #[inline]
235 fn count_tokens(&self, str: &str) -> u32 {
236 self.count_tokens(str)
237 }
238}