Skip to main content

candle_core/quantized/
tokenizer.rs

1use crate::quantized::gguf_file;
2use crate::{Context, Error, Result};
3use std::collections::HashSet;
4use tokenizers::{
5    decoders::{byte_level::ByteLevel as ByteLevelDecoder, DecoderWrapper},
6    models::bpe::{Vocab, BPE},
7    normalizers::{unicode::NFC, NormalizerWrapper},
8    pre_tokenizers::{
9        byte_level::ByteLevel as ByteLevelPre,
10        sequence::Sequence,
11        split::{Split, SplitPattern},
12        PreTokenizerWrapper,
13    },
14    processors::sequence::Sequence as ProcessorSequence,
15    processors::{byte_level::ByteLevel as ByteLevelProcessor, PostProcessorWrapper},
16    tokenizer::SplitDelimiterBehavior,
17    AddedToken, Tokenizer,
18};
19
20pub trait TokenizerFromGguf: Sized {
21    fn from_gguf(ct: &gguf_file::Content) -> Result<Self>;
22}
23
24fn metadata_value<'a>(ct: &'a gguf_file::Content, key: &str) -> Result<&'a gguf_file::Value> {
25    ct.metadata
26        .get(key)
27        .with_context(|| format!("missing GGUF metadata key `{key}`"))
28}
29
30fn gguf_value_to_u32(v: &gguf_file::Value) -> Result<u32> {
31    use gguf_file::Value::*;
32    match v {
33        U8(v) => Ok(*v as u32),
34        I8(v) => Ok(*v as u32),
35        U16(v) => Ok(*v as u32),
36        I16(v) => Ok(*v as u32),
37        U32(v) => Ok(*v),
38        I32(v) => Ok(*v as u32),
39        U64(v) => Ok(*v as u32),
40        I64(v) => Ok(*v as u32),
41        _ => crate::bail!("expected numeric value for token type/id, got {v:?}"),
42    }
43}
44
45fn value_to_string_array(v: &gguf_file::Value, name: &str) -> Result<Vec<String>> {
46    let arr = v
47        .to_vec()
48        .with_context(|| format!("`{name}` is not an array"))?;
49    arr.iter()
50        .map(|v| {
51            v.to_string()
52                .map(|s| s.to_string())
53                .with_context(|| format!("`{name}` element is not a string: {v:?}"))
54        })
55        .collect()
56}
57
58fn merges_from_value(v: &gguf_file::Value) -> Result<Vec<(String, String)>> {
59    value_to_string_array(v, "tokenizer.ggml.merges")?
60        .into_iter()
61        .map(|m| {
62            m.split_once(' ')
63                .map(|(a, b)| (a.to_string(), b.to_string()))
64                .ok_or_else(|| Error::msg(format!("invalid merge entry `{m}`")))
65        })
66        .collect()
67}
68
69struct Pipeline {
70    normalizer: Option<NormalizerWrapper>,
71    pretokenizer: Option<PreTokenizerWrapper>,
72    decoder: Option<DecoderWrapper>,
73    post_processor: Option<PostProcessorWrapper>,
74}
75
76impl Pipeline {
77    fn apply(self, tokenizer: &mut Tokenizer) {
78        if let Some(norm) = self.normalizer {
79            tokenizer.with_normalizer(Some(norm));
80        }
81        if let Some(pt) = self.pretokenizer {
82            tokenizer.with_pre_tokenizer(Some(pt));
83        }
84        if let Some(dec) = self.decoder {
85            tokenizer.with_decoder(Some(dec));
86        }
87        if let Some(pp) = self.post_processor {
88            tokenizer.with_post_processor(Some(pp));
89        }
90    }
91}
92
93fn pre_tokenizer_sequence(regex: &str, byte_level: ByteLevelPre) -> Result<PreTokenizerWrapper> {
94    let split = Split::new(
95        SplitPattern::Regex(regex.to_string()),
96        SplitDelimiterBehavior::Isolated,
97        false,
98    )
99    .map_err(Error::wrap)?;
100    Ok(Sequence::new(vec![split.into(), byte_level.into()]).into())
101}
102
103fn pipeline_from_pre(pre: &str) -> Result<Pipeline> {
104    const REGEX_QWEN2: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
105    const REGEX_LLAMA3: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
106
107    Ok(match pre {
108        // Matches Qwen2 tokenizer.json settings
109        "qwen2" => Pipeline {
110            normalizer: Some(NFC.into()),
111            pretokenizer: Some(pre_tokenizer_sequence(
112                REGEX_QWEN2,
113                ByteLevelPre::new(false, false, false),
114            )?),
115            decoder: Some(ByteLevelDecoder::new(false, false, false).into()),
116            post_processor: Some(ByteLevelProcessor::new(false, false, false).into()),
117        },
118        // Matches Smaug/Llama3 style byte-level BPE
119        "smaug-bpe" | "lfm2" | "llama3" => Pipeline {
120            normalizer: None,
121            pretokenizer: Some(pre_tokenizer_sequence(
122                REGEX_LLAMA3,
123                ByteLevelPre::new(false, true, false),
124            )?),
125            decoder: Some(ByteLevelDecoder::new(true, true, true).into()),
126            post_processor: Some(ByteLevelProcessor::new(true, false, true).into()),
127        },
128        // Default GPT-2 style BPE
129        _ => Pipeline {
130            normalizer: None,
131            pretokenizer: Some(ByteLevelPre::default().into()),
132            decoder: Some(ByteLevelDecoder::default().into()),
133            post_processor: Some(ByteLevelProcessor::default().into()),
134        },
135    })
136}
137
138fn template_processor(
139    tokens: &[String],
140    bos_id: Option<u32>,
141    eos_id: Option<u32>,
142    add_bos: bool,
143    add_eos: bool,
144) -> Option<PostProcessorWrapper> {
145    if (!add_bos && !add_eos) || tokens.is_empty() {
146        return None;
147    }
148
149    let bos = bos_id.and_then(|id| tokens.get(id as usize)).cloned();
150    let eos = eos_id.and_then(|id| tokens.get(id as usize)).cloned();
151
152    let mut specials = Vec::new();
153    if add_bos {
154        let bos_id = bos_id?;
155        let bos_tok = bos.clone()?;
156        specials.push((bos_tok.clone(), bos_id));
157    }
158    if add_eos {
159        let eos_id = eos_id?;
160        let eos_tok = eos.clone()?;
161        specials.push((eos_tok.clone(), eos_id));
162    }
163
164    let mut single = Vec::new();
165    if add_bos {
166        single.push(bos.clone()?);
167    }
168    single.push("$0".to_string());
169    if add_eos {
170        single.push(eos.clone()?);
171    }
172
173    let mut pair = Vec::new();
174    if add_bos {
175        pair.push(format!("{}:0", bos.clone()?));
176    }
177    pair.push("$A:0".to_string());
178    if add_eos {
179        pair.push(format!("{}:0", eos.clone()?));
180    }
181    if add_bos {
182        pair.push(format!("{}:1", bos.clone()?));
183    }
184    pair.push("$B:1".to_string());
185    if add_eos {
186        pair.push(format!("{}:1", eos.clone()?));
187    }
188
189    let proc = tokenizers::processors::template::TemplateProcessing::builder()
190        .try_single(single)
191        .ok()?
192        .try_pair(pair)
193        .ok()?
194        .special_tokens(specials)
195        .build()
196        .ok()?;
197
198    Some(PostProcessorWrapper::Template(proc))
199}
200
201impl TokenizerFromGguf for Tokenizer {
202    fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
203        let model_kind = metadata_value(ct, "tokenizer.ggml.model")?
204            .to_string()?
205            .to_lowercase();
206        if model_kind != "gpt2" {
207            crate::bail!("unsupported tokenizer model `{model_kind}`");
208        }
209
210        let tokens = value_to_string_array(
211            metadata_value(ct, "tokenizer.ggml.tokens")?,
212            "tokenizer.ggml.tokens",
213        )?;
214        let vocab: Vocab = tokens
215            .iter()
216            .enumerate()
217            .map(|(i, t)| (t.clone(), i as u32))
218            .collect();
219        let merges = merges_from_value(metadata_value(ct, "tokenizer.ggml.merges")?)?;
220
221        let mut builder = BPE::builder().vocab_and_merges(vocab, merges);
222
223        if let Ok(val) = metadata_value(ct, "tokenizer.ggml.unk_token_id") {
224            let token_id = gguf_value_to_u32(val)?;
225            if let Some(token) = tokens.get(token_id as usize) {
226                builder = builder.unk_token(token.clone());
227            }
228        }
229
230        if let Ok(val) = metadata_value(ct, "tokenizer.ggml.byte_fallback") {
231            builder = builder.byte_fallback(val.to_bool()?);
232        }
233
234        if let Ok(val) = metadata_value(ct, "tokenizer.ggml.ignore_merges") {
235            builder = builder.ignore_merges(val.to_bool()?);
236        }
237
238        let bpe = builder.build().map_err(Error::wrap)?;
239        let mut tokenizer = Tokenizer::new(bpe);
240
241        let pre = metadata_value(ct, "tokenizer.ggml.pre")
242            .and_then(|v| v.to_string())
243            .map(|s| s.to_string())
244            .unwrap_or_else(|_| "gpt2".to_string());
245        let pipeline = pipeline_from_pre(pre.as_str())?;
246        let post_processor_base = pipeline.post_processor.clone();
247
248        let add_bos = metadata_value(ct, "tokenizer.ggml.add_bos_token")
249            .and_then(|v| v.to_bool())
250            .unwrap_or(false);
251        let add_eos = metadata_value(ct, "tokenizer.ggml.add_eos_token")
252            .and_then(|v| v.to_bool())
253            .unwrap_or(false);
254        let bos_id = metadata_value(ct, "tokenizer.ggml.bos_token_id")
255            .and_then(gguf_value_to_u32)
256            .ok();
257        let eos_id = metadata_value(ct, "tokenizer.ggml.eos_token_id")
258            .and_then(gguf_value_to_u32)
259            .ok();
260
261        pipeline.apply(&mut tokenizer);
262
263        // Compose existing post-processor with a template-based one if needed
264        let template_pp = template_processor(&tokens, bos_id, eos_id, add_bos, add_eos);
265        if template_pp.is_some() || post_processor_base.is_some() {
266            let mut steps = Vec::new();
267            if let Some(pp) = post_processor_base {
268                steps.push(pp);
269            }
270            if let Some(tp) = template_pp {
271                steps.push(tp);
272            }
273            let pp = if steps.len() == 1 {
274                steps.pop().unwrap()
275            } else {
276                ProcessorSequence::new(steps).into()
277            };
278            tokenizer.with_post_processor(Some(pp));
279        }
280
281        // Mark special tokens so decode(skip_special_tokens = true) behaves as expected
282        if let Ok(gguf_file::Value::Array(arr)) = metadata_value(ct, "tokenizer.ggml.token_type") {
283            let mut specials = Vec::new();
284            for (idx, v) in arr.iter().enumerate() {
285                let ty = gguf_value_to_u32(v)?;
286                // Aligns with llama_token_type: treat non-normal/non-byte tokens as special.
287                let is_special = matches!(ty, 2..=5);
288                if is_special {
289                    if let Some(tok) = tokens.get(idx) {
290                        specials.push(AddedToken::from(tok.clone(), true));
291                    }
292                }
293            }
294            if !specials.is_empty() {
295                tokenizer.add_special_tokens(&specials);
296            }
297        }
298
299        let mut explicit_specials = HashSet::new();
300        for key in [
301            "tokenizer.ggml.bos_token_id",
302            "tokenizer.ggml.eos_token_id",
303            "tokenizer.ggml.pad_token_id",
304            "tokenizer.ggml.sep_token_id",
305            "tokenizer.ggml.unk_token_id",
306        ] {
307            if let Ok(val) = metadata_value(ct, key) {
308                explicit_specials.insert(gguf_value_to_u32(val)?);
309            }
310        }
311        if !explicit_specials.is_empty() {
312            let specials: Vec<_> = explicit_specials
313                .into_iter()
314                .filter_map(|id| tokens.get(id as usize))
315                .map(|tok| AddedToken::from(tok.clone(), true))
316                .collect();
317            if !specials.is_empty() {
318                tokenizer.add_special_tokens(&specials);
319            }
320        }
321
322        Ok(tokenizer)
323    }
324}