1use crate::quantized::gguf_file;
2use crate::{Context, Error, Result};
3use std::collections::HashSet;
4use tokenizers::{
5 decoders::{byte_level::ByteLevel as ByteLevelDecoder, DecoderWrapper},
6 models::bpe::{Vocab, BPE},
7 normalizers::{unicode::NFC, NormalizerWrapper},
8 pre_tokenizers::{
9 byte_level::ByteLevel as ByteLevelPre,
10 sequence::Sequence,
11 split::{Split, SplitPattern},
12 PreTokenizerWrapper,
13 },
14 processors::sequence::Sequence as ProcessorSequence,
15 processors::{byte_level::ByteLevel as ByteLevelProcessor, PostProcessorWrapper},
16 tokenizer::SplitDelimiterBehavior,
17 AddedToken, Tokenizer,
18};
19
20pub trait TokenizerFromGguf: Sized {
21 fn from_gguf(ct: &gguf_file::Content) -> Result<Self>;
22}
23
24fn metadata_value<'a>(ct: &'a gguf_file::Content, key: &str) -> Result<&'a gguf_file::Value> {
25 ct.metadata
26 .get(key)
27 .with_context(|| format!("missing GGUF metadata key `{key}`"))
28}
29
30fn gguf_value_to_u32(v: &gguf_file::Value) -> Result<u32> {
31 use gguf_file::Value::*;
32 match v {
33 U8(v) => Ok(*v as u32),
34 I8(v) => Ok(*v as u32),
35 U16(v) => Ok(*v as u32),
36 I16(v) => Ok(*v as u32),
37 U32(v) => Ok(*v),
38 I32(v) => Ok(*v as u32),
39 U64(v) => Ok(*v as u32),
40 I64(v) => Ok(*v as u32),
41 _ => crate::bail!("expected numeric value for token type/id, got {v:?}"),
42 }
43}
44
45fn value_to_string_array(v: &gguf_file::Value, name: &str) -> Result<Vec<String>> {
46 let arr = v
47 .to_vec()
48 .with_context(|| format!("`{name}` is not an array"))?;
49 arr.iter()
50 .map(|v| {
51 v.to_string()
52 .map(|s| s.to_string())
53 .with_context(|| format!("`{name}` element is not a string: {v:?}"))
54 })
55 .collect()
56}
57
58fn merges_from_value(v: &gguf_file::Value) -> Result<Vec<(String, String)>> {
59 value_to_string_array(v, "tokenizer.ggml.merges")?
60 .into_iter()
61 .map(|m| {
62 m.split_once(' ')
63 .map(|(a, b)| (a.to_string(), b.to_string()))
64 .ok_or_else(|| Error::msg(format!("invalid merge entry `{m}`")))
65 })
66 .collect()
67}
68
69struct Pipeline {
70 normalizer: Option<NormalizerWrapper>,
71 pretokenizer: Option<PreTokenizerWrapper>,
72 decoder: Option<DecoderWrapper>,
73 post_processor: Option<PostProcessorWrapper>,
74}
75
76impl Pipeline {
77 fn apply(self, tokenizer: &mut Tokenizer) {
78 if let Some(norm) = self.normalizer {
79 tokenizer.with_normalizer(Some(norm));
80 }
81 if let Some(pt) = self.pretokenizer {
82 tokenizer.with_pre_tokenizer(Some(pt));
83 }
84 if let Some(dec) = self.decoder {
85 tokenizer.with_decoder(Some(dec));
86 }
87 if let Some(pp) = self.post_processor {
88 tokenizer.with_post_processor(Some(pp));
89 }
90 }
91}
92
93fn pre_tokenizer_sequence(regex: &str, byte_level: ByteLevelPre) -> Result<PreTokenizerWrapper> {
94 let split = Split::new(
95 SplitPattern::Regex(regex.to_string()),
96 SplitDelimiterBehavior::Isolated,
97 false,
98 )
99 .map_err(Error::wrap)?;
100 Ok(Sequence::new(vec![split.into(), byte_level.into()]).into())
101}
102
103fn pipeline_from_pre(pre: &str) -> Result<Pipeline> {
104 const REGEX_QWEN2: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
105 const REGEX_LLAMA3: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
106
107 Ok(match pre {
108 "qwen2" => Pipeline {
110 normalizer: Some(NFC.into()),
111 pretokenizer: Some(pre_tokenizer_sequence(
112 REGEX_QWEN2,
113 ByteLevelPre::new(false, false, false),
114 )?),
115 decoder: Some(ByteLevelDecoder::new(false, false, false).into()),
116 post_processor: Some(ByteLevelProcessor::new(false, false, false).into()),
117 },
118 "smaug-bpe" | "lfm2" | "llama3" => Pipeline {
120 normalizer: None,
121 pretokenizer: Some(pre_tokenizer_sequence(
122 REGEX_LLAMA3,
123 ByteLevelPre::new(false, true, false),
124 )?),
125 decoder: Some(ByteLevelDecoder::new(true, true, true).into()),
126 post_processor: Some(ByteLevelProcessor::new(true, false, true).into()),
127 },
128 _ => Pipeline {
130 normalizer: None,
131 pretokenizer: Some(ByteLevelPre::default().into()),
132 decoder: Some(ByteLevelDecoder::default().into()),
133 post_processor: Some(ByteLevelProcessor::default().into()),
134 },
135 })
136}
137
138fn template_processor(
139 tokens: &[String],
140 bos_id: Option<u32>,
141 eos_id: Option<u32>,
142 add_bos: bool,
143 add_eos: bool,
144) -> Option<PostProcessorWrapper> {
145 if (!add_bos && !add_eos) || tokens.is_empty() {
146 return None;
147 }
148
149 let bos = bos_id.and_then(|id| tokens.get(id as usize)).cloned();
150 let eos = eos_id.and_then(|id| tokens.get(id as usize)).cloned();
151
152 let mut specials = Vec::new();
153 if add_bos {
154 let bos_id = bos_id?;
155 let bos_tok = bos.clone()?;
156 specials.push((bos_tok.clone(), bos_id));
157 }
158 if add_eos {
159 let eos_id = eos_id?;
160 let eos_tok = eos.clone()?;
161 specials.push((eos_tok.clone(), eos_id));
162 }
163
164 let mut single = Vec::new();
165 if add_bos {
166 single.push(bos.clone()?);
167 }
168 single.push("$0".to_string());
169 if add_eos {
170 single.push(eos.clone()?);
171 }
172
173 let mut pair = Vec::new();
174 if add_bos {
175 pair.push(format!("{}:0", bos.clone()?));
176 }
177 pair.push("$A:0".to_string());
178 if add_eos {
179 pair.push(format!("{}:0", eos.clone()?));
180 }
181 if add_bos {
182 pair.push(format!("{}:1", bos.clone()?));
183 }
184 pair.push("$B:1".to_string());
185 if add_eos {
186 pair.push(format!("{}:1", eos.clone()?));
187 }
188
189 let proc = tokenizers::processors::template::TemplateProcessing::builder()
190 .try_single(single)
191 .ok()?
192 .try_pair(pair)
193 .ok()?
194 .special_tokens(specials)
195 .build()
196 .ok()?;
197
198 Some(PostProcessorWrapper::Template(proc))
199}
200
201impl TokenizerFromGguf for Tokenizer {
202 fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
203 let model_kind = metadata_value(ct, "tokenizer.ggml.model")?
204 .to_string()?
205 .to_lowercase();
206 if model_kind != "gpt2" {
207 crate::bail!("unsupported tokenizer model `{model_kind}`");
208 }
209
210 let tokens = value_to_string_array(
211 metadata_value(ct, "tokenizer.ggml.tokens")?,
212 "tokenizer.ggml.tokens",
213 )?;
214 let vocab: Vocab = tokens
215 .iter()
216 .enumerate()
217 .map(|(i, t)| (t.clone(), i as u32))
218 .collect();
219 let merges = merges_from_value(metadata_value(ct, "tokenizer.ggml.merges")?)?;
220
221 let mut builder = BPE::builder().vocab_and_merges(vocab, merges);
222
223 if let Ok(val) = metadata_value(ct, "tokenizer.ggml.unk_token_id") {
224 let token_id = gguf_value_to_u32(val)?;
225 if let Some(token) = tokens.get(token_id as usize) {
226 builder = builder.unk_token(token.clone());
227 }
228 }
229
230 if let Ok(val) = metadata_value(ct, "tokenizer.ggml.byte_fallback") {
231 builder = builder.byte_fallback(val.to_bool()?);
232 }
233
234 if let Ok(val) = metadata_value(ct, "tokenizer.ggml.ignore_merges") {
235 builder = builder.ignore_merges(val.to_bool()?);
236 }
237
238 let bpe = builder.build().map_err(Error::wrap)?;
239 let mut tokenizer = Tokenizer::new(bpe);
240
241 let pre = metadata_value(ct, "tokenizer.ggml.pre")
242 .and_then(|v| v.to_string())
243 .map(|s| s.to_string())
244 .unwrap_or_else(|_| "gpt2".to_string());
245 let pipeline = pipeline_from_pre(pre.as_str())?;
246 let post_processor_base = pipeline.post_processor.clone();
247
248 let add_bos = metadata_value(ct, "tokenizer.ggml.add_bos_token")
249 .and_then(|v| v.to_bool())
250 .unwrap_or(false);
251 let add_eos = metadata_value(ct, "tokenizer.ggml.add_eos_token")
252 .and_then(|v| v.to_bool())
253 .unwrap_or(false);
254 let bos_id = metadata_value(ct, "tokenizer.ggml.bos_token_id")
255 .and_then(gguf_value_to_u32)
256 .ok();
257 let eos_id = metadata_value(ct, "tokenizer.ggml.eos_token_id")
258 .and_then(gguf_value_to_u32)
259 .ok();
260
261 pipeline.apply(&mut tokenizer);
262
263 let template_pp = template_processor(&tokens, bos_id, eos_id, add_bos, add_eos);
265 if template_pp.is_some() || post_processor_base.is_some() {
266 let mut steps = Vec::new();
267 if let Some(pp) = post_processor_base {
268 steps.push(pp);
269 }
270 if let Some(tp) = template_pp {
271 steps.push(tp);
272 }
273 let pp = if steps.len() == 1 {
274 steps.pop().unwrap()
275 } else {
276 ProcessorSequence::new(steps).into()
277 };
278 tokenizer.with_post_processor(Some(pp));
279 }
280
281 if let Ok(gguf_file::Value::Array(arr)) = metadata_value(ct, "tokenizer.ggml.token_type") {
283 let mut specials = Vec::new();
284 for (idx, v) in arr.iter().enumerate() {
285 let ty = gguf_value_to_u32(v)?;
286 let is_special = matches!(ty, 2..=5);
288 if is_special {
289 if let Some(tok) = tokens.get(idx) {
290 specials.push(AddedToken::from(tok.clone(), true));
291 }
292 }
293 }
294 if !specials.is_empty() {
295 tokenizer.add_special_tokens(&specials);
296 }
297 }
298
299 let mut explicit_specials = HashSet::new();
300 for key in [
301 "tokenizer.ggml.bos_token_id",
302 "tokenizer.ggml.eos_token_id",
303 "tokenizer.ggml.pad_token_id",
304 "tokenizer.ggml.sep_token_id",
305 "tokenizer.ggml.unk_token_id",
306 ] {
307 if let Ok(val) = metadata_value(ct, key) {
308 explicit_specials.insert(gguf_value_to_u32(val)?);
309 }
310 }
311 if !explicit_specials.is_empty() {
312 let specials: Vec<_> = explicit_specials
313 .into_iter()
314 .filter_map(|id| tokens.get(id as usize))
315 .map(|tok| AddedToken::from(tok.clone(), true))
316 .collect();
317 if !specials.is_empty() {
318 tokenizer.add_special_tokens(&specials);
319 }
320 }
321
322 Ok(tokenizer)
323 }
324}