Skip to main content

llama_rs/tokenizer/
mod.rs

1//! Tokenizer implementations for text encoding/decoding
2//!
3//! This module provides tokenizer implementations loaded from GGUF metadata.
4//! Supports BPE (Byte Pair Encoding) and SentencePiece tokenizers.
5
6use std::collections::HashMap;
7
8use unicode_normalization::UnicodeNormalization;
9
10use crate::gguf::{GgufFile, MetadataValue};
11
12/// Tokenizer type
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TokenizerType {
15    /// Byte Pair Encoding
16    BPE,
17    /// SentencePiece (Unigram)
18    SentencePiece,
19    /// WordPiece
20    WordPiece,
21    /// Unknown type
22    Unknown,
23}
24
25impl TokenizerType {
26    /// Parse tokenizer type from GGUF metadata
27    pub fn from_gguf_str(s: &str) -> Self {
28        match s.to_lowercase().as_str() {
29            "llama" | "bpe" => Self::BPE,
30            "gpt2" => Self::BPE,
31            "sentencepiece" | "spm" => Self::SentencePiece,
32            "wordpiece" | "bert" => Self::WordPiece,
33            _ => Self::Unknown,
34        }
35    }
36}
37
38/// Token type classification
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum TokenType {
41    /// Normal token
42    #[default]
43    Normal,
44    /// Control token (special)
45    Control,
46    /// Byte fallback token
47    Byte,
48    /// Unknown token
49    Unknown,
50}
51
52/// Text normalizer applied before tokenization (from HuggingFace tokenizer.json)
53#[derive(Debug, Clone)]
54pub enum Normalizer {
55    NFC,
56    NFKC,
57    NFD,
58    NFKD,
59    Lowercase,
60    Strip { left: bool, right: bool },
61    Prepend(String),
62    Replace { pattern: String, content: String },
63    StripAccents,
64    Sequence(Vec<Normalizer>),
65}
66
67impl Normalizer {
68    fn apply(&self, text: &str) -> String {
69        match self {
70            Self::NFC => text.nfc().collect(),
71            Self::NFKC => text.nfkc().collect(),
72            Self::NFD => text.nfd().collect(),
73            Self::NFKD => text.nfkd().collect(),
74            Self::Lowercase => text.to_lowercase(),
75            Self::Strip { left, right } => {
76                let s = if *left { text.trim_start() } else { text };
77                if *right { s.trim_end().to_string() } else { s.to_string() }
78            }
79            Self::Prepend(prefix) => format!("{}{}", prefix, text),
80            Self::Replace { pattern, content } => text.replace(pattern.as_str(), content.as_str()),
81            Self::StripAccents => {
82                text.nfkd()
83                    .filter(|c| !unicode_normalization::char::is_combining_mark(*c))
84                    .collect()
85            }
86            Self::Sequence(normalizers) => {
87                let mut result = text.to_string();
88                for n in normalizers {
89                    result = n.apply(&result);
90                }
91                result
92            }
93        }
94    }
95}
96
97/// Pre-tokenizer splits text into segments before model encoding
98#[derive(Debug, Clone)]
99pub enum PreTokenizer {
100    /// GPT-2 style byte-level splitting
101    ByteLevel { add_prefix_space: bool },
102    /// Split on whitespace boundaries
103    Whitespace,
104    /// SentencePiece metaspace handling
105    Metaspace { replacement: char, add_prefix_space: bool },
106    /// Split around punctuation characters
107    Punctuation,
108    /// Split digits
109    Digits { individual_digits: bool },
110    /// Chain multiple pre-tokenizers
111    Sequence(Vec<PreTokenizer>),
112}
113
114impl PreTokenizer {
115    fn apply(&self, text: &str) -> Vec<String> {
116        match self {
117            Self::ByteLevel { add_prefix_space } => {
118                let text = if *add_prefix_space && !text.starts_with(' ') {
119                    format!(" {}", text)
120                } else {
121                    text.to_string()
122                };
123                let mut tokens = Vec::new();
124                let mut current = String::new();
125                for ch in text.chars() {
126                    if ch == ' ' && !current.is_empty() {
127                        tokens.push(std::mem::take(&mut current));
128                    }
129                    current.push(ch);
130                }
131                if !current.is_empty() {
132                    tokens.push(current);
133                }
134                tokens
135            }
136            Self::Whitespace => {
137                text.split_whitespace().map(|s| s.to_string()).collect()
138            }
139            Self::Metaspace { replacement, add_prefix_space } => {
140                let text = if *add_prefix_space && !text.starts_with(' ') {
141                    format!(" {}", text)
142                } else {
143                    text.to_string()
144                };
145                text.split(' ')
146                    .enumerate()
147                    .map(|(i, s)| {
148                        if i == 0 && s.is_empty() {
149                            replacement.to_string()
150                        } else if i > 0 {
151                            format!("{}{}", replacement, s)
152                        } else {
153                            s.to_string()
154                        }
155                    })
156                    .filter(|s| !s.is_empty())
157                    .collect()
158            }
159            Self::Punctuation => {
160                let mut result = Vec::new();
161                let mut current = String::new();
162                for ch in text.chars() {
163                    if ch.is_ascii_punctuation() {
164                        if !current.is_empty() {
165                            result.push(std::mem::take(&mut current));
166                        }
167                        result.push(ch.to_string());
168                    } else {
169                        current.push(ch);
170                    }
171                }
172                if !current.is_empty() {
173                    result.push(current);
174                }
175                result
176            }
177            Self::Digits { individual_digits } => {
178                if !*individual_digits {
179                    return vec![text.to_string()];
180                }
181                let mut result = Vec::new();
182                let mut current = String::new();
183                for ch in text.chars() {
184                    if ch.is_ascii_digit() {
185                        if !current.is_empty() {
186                            result.push(std::mem::take(&mut current));
187                        }
188                        result.push(ch.to_string());
189                    } else {
190                        current.push(ch);
191                    }
192                }
193                if !current.is_empty() {
194                    result.push(current);
195                }
196                result
197            }
198            Self::Sequence(pre_tokenizers) => {
199                let mut segments = vec![text.to_string()];
200                for pt in pre_tokenizers {
201                    let mut next = Vec::new();
202                    for seg in &segments {
203                        next.extend(pt.apply(seg));
204                    }
205                    segments = next;
206                }
207                segments
208            }
209        }
210    }
211}
212
213/// Element in a template processing sequence
214#[derive(Debug, Clone)]
215pub enum TemplateElement {
216    SpecialToken { id: String, token_id: u32 },
217    Sequence { type_id: u32 },
218}
219
220/// Post-processor adds special tokens after encoding
221#[derive(Debug, Clone)]
222pub enum PostProcessor {
223    TemplateProcessing {
224        single: Vec<TemplateElement>,
225        pair: Vec<TemplateElement>,
226    },
227    ByteLevel { trim_offsets: bool },
228}
229
230/// Special token IDs
231#[derive(Debug, Clone)]
232pub struct SpecialTokens {
233    /// Beginning of sequence token
234    pub bos_token_id: u32,
235    /// End of sequence token
236    pub eos_token_id: u32,
237    /// Padding token (optional)
238    pub pad_token_id: Option<u32>,
239    /// Unknown token (optional)
240    pub unk_token_id: Option<u32>,
241}
242
243impl Default for SpecialTokens {
244    fn default() -> Self {
245        Self {
246            bos_token_id: 1,
247            eos_token_id: 2,
248            pad_token_id: None,
249            unk_token_id: Some(0),
250        }
251    }
252}
253
254/// Tokenizer error
255#[derive(thiserror::Error, Debug)]
256pub enum TokenizerError {
257    #[error("Missing tokenizer data in GGUF: {0}")]
258    MissingData(String),
259
260    #[error("Invalid token: {0}")]
261    InvalidToken(String),
262
263    #[error("Encoding error: {0}")]
264    EncodingError(String),
265}
266
267pub type TokenizerResult<T> = Result<T, TokenizerError>;
268
269/// Extract the longest valid UTF-8 prefix from `buf`, draining those bytes.
270/// Bytes that form incomplete trailing sequences are left in `buf` for the
271/// next call. At most 3 trailing bytes can remain (start of a 2–4 byte seq).
272fn flush_valid_utf8(buf: &mut Vec<u8>) -> String {
273    if buf.is_empty() {
274        return String::new();
275    }
276
277    // Find the longest prefix that is valid UTF-8.
278    // An incomplete trailing multi-byte sequence has at most 3 leading bytes.
279    let valid_up_to = match std::str::from_utf8(buf) {
280        Ok(_) => {
281            let s = String::from_utf8(std::mem::take(buf)).unwrap();
282            return s;
283        }
284        Err(e) => e.valid_up_to(),
285    };
286
287    if valid_up_to == 0 {
288        // Check if the entire buffer is a partial multi-byte start (≤3 bytes)
289        // that could become valid with more bytes
290        if buf.len() <= 3 && buf[0] >= 0x80 {
291            return String::new();
292        }
293        // Otherwise, the first byte is truly invalid — emit replacement and skip it
294        buf.remove(0);
295        return String::from("\u{FFFD}");
296    }
297
298    let text = String::from_utf8(buf[..valid_up_to].to_vec()).unwrap();
299    *buf = buf[valid_up_to..].to_vec();
300    text
301}
302
303/// Build both directions of the GPT-2 byte ↔ unicode mapping.
304///
305/// GPT-2 BPE maps every byte (0-255) to a printable Unicode character so that
306/// token strings are always valid Unicode. Printable ASCII and certain Latin-1
307/// bytes map to themselves; the remaining 68 bytes map to U+0100..U+0143.
308fn build_gpt2_mappings() -> (HashMap<char, u8>, [char; 256]) {
309    let mut byte_to_unicode = ['\0'; 256];
310
311    let mut direct: Vec<u8> = Vec::new();
312    direct.extend(33u8..=126);
313    direct.extend(161u8..=172);
314    direct.extend(174u8..=255);
315
316    for &b in &direct {
317        byte_to_unicode[b as usize] = char::from(b);
318    }
319
320    let mut n: u32 = 0;
321    for b in 0u16..=255 {
322        if !direct.contains(&(b as u8)) {
323            byte_to_unicode[b as usize] = char::from_u32(256 + n).unwrap();
324            n += 1;
325        }
326    }
327
328    let unicode_to_byte: HashMap<char, u8> = byte_to_unicode
329        .iter()
330        .enumerate()
331        .map(|(b, &c)| (c, b as u8))
332        .collect();
333
334    (unicode_to_byte, byte_to_unicode)
335}
336
337/// A segment of text that has been split around special/control tokens.
338#[derive(Debug, Clone)]
339enum TextSegment {
340    /// Regular text to be encoded with BPE/SentencePiece
341    Text(String),
342    /// A control/special token that maps directly to a token ID
343    SpecialToken(u32),
344}
345
346/// Tokenizer loaded from GGUF metadata or HuggingFace tokenizer.json
347#[derive(Debug)]
348pub struct Tokenizer {
349    /// Token vocabulary (token string -> token id)
350    token_to_id: HashMap<String, u32>,
351    /// Reverse vocabulary (token id -> token string)
352    id_to_token: Vec<String>,
353    /// Token scores (log probabilities for Unigram models)
354    scores: Vec<f32>,
355    /// Merge pairs for BPE with priority (lower = merge first)
356    /// Maps (token1_id, token2_id) -> (merged_token_id, priority)
357    merges: HashMap<(u32, u32), (u32, usize)>,
358    /// Special tokens
359    pub special_tokens: SpecialTokens,
360    /// Tokenizer type
361    pub tokenizer_type: TokenizerType,
362    /// Vocabulary size
363    pub vocab_size: usize,
364    /// Token types (for distinguishing normal, control, byte tokens)
365    token_types: Vec<TokenType>,
366    /// GPT-2 unicode-to-byte reverse mapping (only for GPT-2 tokenizers)
367    gpt2_unicode_to_byte: Option<HashMap<char, u8>>,
368    /// GPT-2 byte-to-unicode forward mapping for encoding (only for GPT-2 tokenizers)
369    gpt2_byte_to_unicode: Option<[char; 256]>,
370    /// HF normalizer pipeline component
371    normalizer: Option<Normalizer>,
372    /// HF pre-tokenizer pipeline component
373    pre_tokenizer: Option<PreTokenizer>,
374    /// HF post-processor pipeline component
375    post_processor: Option<PostProcessor>,
376    /// WordPiece continuation prefix (default "##")
377    wordpiece_prefix: String,
378    /// Control token strings sorted by length (longest first) for greedy matching
379    control_token_strings: Vec<(String, u32)>,
380    /// Whether the GGUF explicitly defined a BOS token ID
381    pub has_explicit_bos: bool,
382    /// Whether to add a space prefix before BPE encoding (default: true for SentencePiece)
383    pub add_space_prefix: bool,
384}
385
386impl Tokenizer {
387    /// Load tokenizer from GGUF file
388    pub fn from_gguf(gguf: &GgufFile) -> TokenizerResult<Self> {
389        // Get tokenizer type
390        let model_str = gguf
391            .data
392            .get_string("tokenizer.ggml.model")
393            .unwrap_or("bpe");
394        let tokenizer_type = TokenizerType::from_gguf_str(model_str);
395
396        // GPT-2 style tokenizers use byte-level BPE with a unicode mapping
397        let uses_gpt2_bytes = model_str == "gpt2"
398            || gguf
399                .data
400                .get_string("tokenizer.ggml.pre")
401                .is_some_and(|p| {
402                    matches!(
403                        p,
404                        "qwen2" | "gpt-2" | "gpt2" | "starcoder" | "deepseek-llm" | "deepseek-coder"
405                    )
406                });
407
408        // Load vocabulary
409        let tokens = Self::load_tokens(gguf)?;
410        let vocab_size = tokens.len();
411
412        // Build token mappings
413        let mut token_to_id = HashMap::with_capacity(vocab_size);
414        let mut id_to_token = Vec::with_capacity(vocab_size);
415
416        for (id, token) in tokens.into_iter().enumerate() {
417            token_to_id.insert(token.clone(), id as u32);
418            id_to_token.push(token);
419        }
420
421        // Load scores if available
422        let scores = Self::load_scores(gguf, vocab_size);
423
424        // Load token types
425        let token_types = Self::load_token_types(gguf, vocab_size);
426
427        // Load merges for BPE
428        let merges = Self::load_merges(gguf, &token_to_id);
429
430        // Load special tokens
431        let special_tokens = Self::load_special_tokens(gguf);
432
433        let (gpt2_unicode_to_byte, gpt2_byte_to_unicode) = if uses_gpt2_bytes {
434            let (u2b, b2u) = build_gpt2_mappings();
435            (Some(u2b), Some(b2u))
436        } else {
437            (None, None)
438        };
439
440        let has_explicit_bos = gguf.data.get_u32("tokenizer.ggml.bos_token_id").is_some();
441        let add_space_prefix = gguf
442            .data
443            .get_bool("tokenizer.ggml.add_space_prefix")
444            .unwrap_or(true);
445
446        let mut control_token_strings: Vec<(String, u32)> = token_types
447            .iter()
448            .enumerate()
449            .filter(|(_, tt)| **tt == TokenType::Control)
450            .filter_map(|(id, _)| {
451                let s = &id_to_token[id];
452                if !s.is_empty() {
453                    Some((s.clone(), id as u32))
454                } else {
455                    None
456                }
457            })
458            .collect();
459        control_token_strings.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
460
461        Ok(Self {
462            token_to_id,
463            id_to_token,
464            scores,
465            merges,
466            special_tokens,
467            tokenizer_type,
468            vocab_size,
469            token_types,
470            gpt2_unicode_to_byte,
471            gpt2_byte_to_unicode,
472            normalizer: None,
473            pre_tokenizer: None,
474            post_processor: None,
475            wordpiece_prefix: "##".to_string(),
476            control_token_strings,
477            has_explicit_bos,
478            add_space_prefix,
479        })
480    }
481
482    /// Load tokens from GGUF
483    fn load_tokens(gguf: &GgufFile) -> TokenizerResult<Vec<String>> {
484        let tokens_value = gguf
485            .data
486            .metadata
487            .get("tokenizer.ggml.tokens")
488            .ok_or_else(|| TokenizerError::MissingData("tokenizer.ggml.tokens".into()))?;
489
490        match tokens_value {
491            MetadataValue::Array(arr) => {
492                let mut tokens = Vec::with_capacity(arr.values.len());
493                for value in &arr.values {
494                    match value {
495                        MetadataValue::String(s) => tokens.push(s.clone()),
496                        _ => {
497                            return Err(TokenizerError::MissingData(
498                                "Expected string tokens".into(),
499                            ));
500                        }
501                    }
502                }
503                Ok(tokens)
504            }
505            _ => Err(TokenizerError::MissingData("Expected token array".into())),
506        }
507    }
508
509    /// Load token scores from GGUF
510    fn load_scores(gguf: &GgufFile, vocab_size: usize) -> Vec<f32> {
511        let scores_value = gguf.data.metadata.get("tokenizer.ggml.scores");
512
513        match scores_value {
514            Some(MetadataValue::Array(arr)) => {
515                let mut scores = Vec::with_capacity(arr.values.len());
516                for value in &arr.values {
517                    match value {
518                        MetadataValue::Float32(f) => scores.push(*f),
519                        _ => scores.push(0.0),
520                    }
521                }
522                scores
523            }
524            _ => vec![0.0; vocab_size],
525        }
526    }
527
528    /// Load token types from GGUF
529    fn load_token_types(gguf: &GgufFile, vocab_size: usize) -> Vec<TokenType> {
530        let types_value = gguf.data.metadata.get("tokenizer.ggml.token_type");
531
532        match types_value {
533            Some(MetadataValue::Array(arr)) => {
534                let mut types = Vec::with_capacity(arr.values.len());
535                for value in &arr.values {
536                    let token_type = match value {
537                        MetadataValue::Int32(t) => match *t {
538                            1 => TokenType::Normal,
539                            2 => TokenType::Unknown,
540                            3 => TokenType::Control,
541                            6 => TokenType::Byte,
542                            _ => TokenType::Normal,
543                        },
544                        _ => TokenType::Normal,
545                    };
546                    types.push(token_type);
547                }
548                types
549            }
550            _ => vec![TokenType::Normal; vocab_size],
551        }
552    }
553
554    /// Load BPE merges from GGUF with priority ordering
555    fn load_merges(
556        gguf: &GgufFile,
557        token_to_id: &HashMap<String, u32>,
558    ) -> HashMap<(u32, u32), (u32, usize)> {
559        let mut merges = HashMap::new();
560
561        let merges_value = gguf.data.metadata.get("tokenizer.ggml.merges");
562
563        if let Some(MetadataValue::Array(arr)) = merges_value {
564            for (priority, value) in arr.values.iter().enumerate() {
565                if let MetadataValue::String(merge_str) = value {
566                    // Parse merge: "token1 token2"
567                    let parts: Vec<&str> = merge_str.split(' ').collect();
568                    if parts.len() == 2
569                        && let (Some(&id1), Some(&id2)) =
570                            (token_to_id.get(parts[0]), token_to_id.get(parts[1]))
571                    {
572                        // The merged result is typically the concatenation
573                        let merged = format!("{}{}", parts[0], parts[1]);
574                        if let Some(&merged_id) = token_to_id.get(&merged) {
575                            merges.insert((id1, id2), (merged_id, priority));
576                        }
577                    }
578                }
579            }
580        }
581
582        merges
583    }
584
585    /// Load special tokens from GGUF
586    fn load_special_tokens(gguf: &GgufFile) -> SpecialTokens {
587        SpecialTokens {
588            bos_token_id: gguf
589                .data
590                .get_u32("tokenizer.ggml.bos_token_id")
591                .unwrap_or(1),
592            eos_token_id: gguf
593                .data
594                .get_u32("tokenizer.ggml.eos_token_id")
595                .unwrap_or(2),
596            pad_token_id: gguf.data.get_u32("tokenizer.ggml.padding_token_id"),
597            unk_token_id: gguf.data.get_u32("tokenizer.ggml.unknown_token_id"),
598        }
599    }
600
601    /// Split text around control/special token strings.
602    ///
603    /// Scans `text` for any control token literal (e.g. `<|im_start|>`) and
604    /// splits it into alternating Text / SpecialToken segments. Uses greedy
605    /// longest-match so longer control tokens take priority.
606    fn split_with_special_tokens(&self, text: &str) -> Vec<TextSegment> {
607        if self.control_token_strings.is_empty() {
608            return vec![TextSegment::Text(text.to_string())];
609        }
610
611        let mut segments = Vec::new();
612        let mut remaining = text;
613
614        while !remaining.is_empty() {
615            let mut earliest_pos = remaining.len();
616            let mut matched_len = 0;
617            let mut matched_id = 0u32;
618
619            for (tok_str, tok_id) in &self.control_token_strings {
620                if let Some(pos) = remaining.find(tok_str.as_str()) {
621                    if pos < earliest_pos
622                        || (pos == earliest_pos && tok_str.len() > matched_len)
623                    {
624                        earliest_pos = pos;
625                        matched_len = tok_str.len();
626                        matched_id = *tok_id;
627                    }
628                }
629            }
630
631            if matched_len == 0 {
632                segments.push(TextSegment::Text(remaining.to_string()));
633                break;
634            }
635
636            if earliest_pos > 0 {
637                segments.push(TextSegment::Text(remaining[..earliest_pos].to_string()));
638            }
639            segments.push(TextSegment::SpecialToken(matched_id));
640            remaining = &remaining[earliest_pos + matched_len..];
641        }
642
643        segments
644    }
645
646    /// Encode a plain text segment (no special tokens) using the appropriate algorithm.
647    fn encode_text_segment(&self, text: &str) -> TokenizerResult<Vec<u32>> {
648        if text.is_empty() {
649            return Ok(vec![]);
650        }
651        if self.normalizer.is_some() || self.pre_tokenizer.is_some() {
652            let normalized = match &self.normalizer {
653                Some(n) => n.apply(text),
654                None => text.to_string(),
655            };
656            let pre_tokens = match &self.pre_tokenizer {
657                Some(pt) => pt.apply(&normalized),
658                None => vec![normalized],
659            };
660            let mut tokens = Vec::new();
661            for pre_token in &pre_tokens {
662                if pre_token.is_empty() {
663                    continue;
664                }
665                match self.tokenizer_type {
666                    TokenizerType::SentencePiece => {
667                        tokens.extend(self.encode_unigram(pre_token)?);
668                    }
669                    TokenizerType::WordPiece => {
670                        tokens.extend(self.encode_wordpiece(pre_token)?);
671                    }
672                    _ => {
673                        tokens.extend(self.encode_bpe_pretokenized(pre_token)?);
674                    }
675                }
676            }
677            Ok(tokens)
678        } else if !self.merges.is_empty() {
679            self.encode_bpe(text)
680        } else {
681            self.encode_sentencepiece(text)
682        }
683    }
684
685    /// Encode text to token IDs
686    pub fn encode(&self, text: &str, add_bos: bool) -> TokenizerResult<Vec<u32>> {
687        let mut tokens = Vec::new();
688
689        if add_bos {
690            tokens.push(self.special_tokens.bos_token_id);
691        }
692
693        let segments = self.split_with_special_tokens(text);
694        for segment in segments {
695            match segment {
696                TextSegment::Text(t) => {
697                    tokens.extend(self.encode_text_segment(&t)?);
698                }
699                TextSegment::SpecialToken(id) => {
700                    tokens.push(id);
701                }
702            }
703        }
704
705        if !add_bos {
706            if let Some(PostProcessor::TemplateProcessing { ref single, .. }) = self.post_processor {
707                let mut processed = Vec::new();
708                for elem in single {
709                    match elem {
710                        TemplateElement::SpecialToken { token_id, .. } => {
711                            processed.push(*token_id);
712                        }
713                        TemplateElement::Sequence { .. } => {
714                            processed.extend(&tokens);
715                        }
716                    }
717                }
718                return Ok(processed);
719            }
720        }
721
722        Ok(tokens)
723    }
724
725    /// SentencePiece encoding using greedy longest-match algorithm
726    fn encode_sentencepiece(&self, text: &str) -> TokenizerResult<Vec<u32>> {
727        let mut result = Vec::new();
728
729        // Add space prefix for LLaMA-style tokenizers (when configured)
730        let text_with_prefix = if self.add_space_prefix {
731            format!(" {}", text)
732        } else {
733            text.to_string()
734        };
735        let chars: Vec<char> = text_with_prefix.chars().collect();
736        let mut pos = 0;
737
738        while pos < chars.len() {
739            let mut best_len = 0;
740            let mut best_id = None;
741
742            // Try to find the longest matching token starting at current position
743            // Try lengths from longest to shortest for efficiency
744            for end in (pos + 1..=chars.len()).rev() {
745                let substr: String = chars[pos..end].iter().collect();
746
747                // Try with SentencePiece space marker
748                let spm_str = substr.replace(' ', "▁");
749                if let Some(&id) = self.token_to_id.get(&spm_str) {
750                    best_len = end - pos;
751                    best_id = Some(id);
752                    break; // Found longest match
753                }
754
755                // Try original string
756                if let Some(&id) = self.token_to_id.get(&substr) {
757                    best_len = end - pos;
758                    best_id = Some(id);
759                    break; // Found longest match
760                }
761            }
762
763            if let Some(id) = best_id {
764                result.push(id);
765                pos += best_len;
766            } else {
767                // Fallback: try single character with byte fallback
768                let ch = chars[pos];
769                let ch_str = ch.to_string();
770
771                // Try as SentencePiece space
772                if ch == ' '
773                    && let Some(&id) = self.token_to_id.get("▁")
774                {
775                    result.push(id);
776                    pos += 1;
777                    continue;
778                }
779
780                // Try as regular character
781                if let Some(&id) = self.token_to_id.get(&ch_str) {
782                    result.push(id);
783                    pos += 1;
784                    continue;
785                }
786
787                // Byte-level fallback
788                for byte in ch_str.as_bytes() {
789                    let byte_token = format!("<0x{:02X}>", byte);
790                    if let Some(&id) = self.token_to_id.get(&byte_token) {
791                        result.push(id);
792                    } else if let Some(unk_id) = self.special_tokens.unk_token_id {
793                        result.push(unk_id);
794                    }
795                }
796                pos += 1;
797            }
798        }
799
800        Ok(result)
801    }
802
803    /// BPE encoding algorithm
804    fn encode_bpe(&self, text: &str) -> TokenizerResult<Vec<u32>> {
805        if self.gpt2_byte_to_unicode.is_some() {
806            return self.encode_bpe_gpt2(text);
807        }
808
809        let mut result = Vec::new();
810
811        let text_with_prefix = if self.add_space_prefix && !text.starts_with(' ') && !text.is_empty() {
812            format!(" {}", text)
813        } else {
814            text.to_string()
815        };
816
817        for segment in self.split_into_segments(&text_with_prefix) {
818            if segment.is_empty() {
819                continue;
820            }
821
822            if let Some(&id) = self.token_to_id.get(&segment) {
823                result.push(id);
824                continue;
825            }
826
827            let mut tokens = self.text_to_initial_tokens(&segment)?;
828            self.apply_bpe_merges(&mut tokens);
829            result.extend(tokens);
830        }
831
832        Ok(result)
833    }
834
835    /// GPT-2 byte-level BPE encoding.
836    ///
837    /// Converts input bytes through the GPT-2 byte→unicode mapping, splits
838    /// into pretokenized segments, and applies BPE merges on each.
839    fn encode_bpe_gpt2(&self, text: &str) -> TokenizerResult<Vec<u32>> {
840        let b2u = self.gpt2_byte_to_unicode.as_ref().unwrap();
841        let mut result = Vec::new();
842
843        for segment in Self::gpt2_pretokenize(text) {
844            if segment.is_empty() {
845                continue;
846            }
847
848            let mapped: String = segment.as_bytes().iter().map(|&b| b2u[b as usize]).collect();
849
850            if let Some(&id) = self.token_to_id.get(&mapped) {
851                result.push(id);
852                continue;
853            }
854
855            let mut tokens: Vec<u32> = Vec::with_capacity(mapped.len());
856            for ch in mapped.chars() {
857                let ch_str = ch.to_string();
858                if let Some(&id) = self.token_to_id.get(&ch_str) {
859                    tokens.push(id);
860                } else if let Some(unk_id) = self.special_tokens.unk_token_id {
861                    tokens.push(unk_id);
862                }
863            }
864
865            self.apply_bpe_merges(&mut tokens);
866            result.extend(tokens);
867        }
868
869        Ok(result)
870    }
871
872    /// Simple GPT-2 pretokenization: split text into chunks at word boundaries.
873    ///
874    /// Spaces attach to the following word. Newlines and other control
875    /// characters are their own chunks. Runs of letters, runs of digits
876    /// (up to 3), and individual punctuation are separate chunks.
877    fn gpt2_pretokenize(text: &str) -> Vec<String> {
878        let mut chunks = Vec::new();
879        let chars: Vec<char> = text.chars().collect();
880        let mut i = 0;
881
882        while i < chars.len() {
883            let ch = chars[i];
884
885            if ch == ' ' {
886                let mut chunk = String::new();
887                chunk.push(ch);
888                i += 1;
889                if i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
890                    while i < chars.len()
891                        && !chars[i].is_whitespace()
892                        && (chars[i].is_alphanumeric() || chars[i] == '_')
893                    {
894                        chunk.push(chars[i]);
895                        i += 1;
896                    }
897                }
898                chunks.push(chunk);
899            } else if ch == '\n' || ch == '\r' || ch == '\t' {
900                let mut chunk = String::new();
901                while i < chars.len()
902                    && (chars[i] == '\n' || chars[i] == '\r' || chars[i] == '\t')
903                {
904                    chunk.push(chars[i]);
905                    i += 1;
906                }
907                chunks.push(chunk);
908            } else if ch.is_alphabetic() || ch == '_' {
909                let mut chunk = String::new();
910                while i < chars.len() && (chars[i].is_alphabetic() || chars[i] == '_') {
911                    chunk.push(chars[i]);
912                    i += 1;
913                }
914                chunks.push(chunk);
915            } else if ch.is_ascii_digit() {
916                let mut chunk = String::new();
917                let mut count = 0;
918                while i < chars.len() && chars[i].is_ascii_digit() && count < 3 {
919                    chunk.push(chars[i]);
920                    i += 1;
921                    count += 1;
922                }
923                chunks.push(chunk);
924            } else {
925                chunks.push(ch.to_string());
926                i += 1;
927            }
928        }
929
930        chunks
931    }
932
933    /// Apply BPE merges iteratively until no more merges are possible.
934    fn apply_bpe_merges(&self, tokens: &mut Vec<u32>) {
935        loop {
936            if tokens.len() < 2 {
937                break;
938            }
939
940            let mut best_merge: Option<(usize, u32, usize)> = None;
941
942            for i in 0..tokens.len() - 1 {
943                let pair = (tokens[i], tokens[i + 1]);
944                if let Some(&(merged_id, priority)) = self.merges.get(&pair)
945                    && (best_merge.is_none() || priority < best_merge.unwrap().2)
946                {
947                    best_merge = Some((i, merged_id, priority));
948                }
949            }
950
951            match best_merge {
952                Some((pos, merged_id, _)) => {
953                    tokens[pos] = merged_id;
954                    tokens.remove(pos + 1);
955                }
956                None => break,
957            }
958        }
959    }
960
961    /// Split text into segments for non-GPT-2 BPE processing
962    /// Split text into segments for SentencePiece-style BPE.
963    ///
964    /// Spaces are prepended to the FOLLOWING word (not appended to the
965    /// preceding one), matching the SentencePiece convention where "▁"
966    /// marks the start of a new word.
967    fn split_into_segments(&self, text: &str) -> Vec<String> {
968        let mut segments = Vec::new();
969        let mut current = String::new();
970
971        for ch in text.chars() {
972            // When we hit a space/whitespace: if there's accumulated text,
973            // push it as a segment. Then start a new segment with the space.
974            if ch.is_whitespace() {
975                if !current.is_empty() {
976                    segments.push(current.clone());
977                    current.clear();
978                }
979                current.push(ch);
980            } else if ch.is_ascii_punctuation() {
981                // Punctuation: flush current, push punctuation as its own segment
982                if !current.is_empty() {
983                    segments.push(current.clone());
984                    current.clear();
985                }
986                segments.push(ch.to_string());
987            } else {
988                current.push(ch);
989            }
990        }
991
992        if !current.is_empty() {
993            segments.push(current);
994        }
995
996        segments
997    }
998
999    /// Convert text segment to initial token sequence (non-GPT-2 path)
1000    fn text_to_initial_tokens(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1001        let mut tokens = Vec::new();
1002
1003        for ch in text.chars() {
1004            let ch_str = ch.to_string();
1005
1006            if let Some(&id) = self.token_to_id.get(&ch_str) {
1007                tokens.push(id);
1008                continue;
1009            }
1010
1011            if ch == ' '
1012                && let Some(&id) = self.token_to_id.get("▁")
1013            {
1014                tokens.push(id);
1015                continue;
1016            }
1017
1018            for byte in ch_str.as_bytes() {
1019                let byte_token = format!("<0x{:02X}>", byte);
1020                if let Some(&id) = self.token_to_id.get(&byte_token) {
1021                    tokens.push(id);
1022                } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1023                    tokens.push(unk_id);
1024                }
1025            }
1026        }
1027
1028        Ok(tokens)
1029    }
1030
1031    /// Fallback encoding (character/byte level) - kept for potential future use
1032    #[allow(dead_code)]
1033    fn encode_fallback(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1034        let mut tokens = Vec::new();
1035
1036        for ch in text.chars() {
1037            let ch_str = ch.to_string();
1038            if let Some(&id) = self.token_to_id.get(&ch_str) {
1039                tokens.push(id);
1040            } else {
1041                // Try byte fallback
1042                for byte in ch_str.as_bytes() {
1043                    let byte_token = format!("<0x{:02X}>", byte);
1044                    if let Some(&id) = self.token_to_id.get(&byte_token) {
1045                        tokens.push(id);
1046                    } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1047                        tokens.push(unk_id);
1048                    }
1049                }
1050            }
1051        }
1052
1053        Ok(tokens)
1054    }
1055
1056    /// Unigram encoding using Viterbi algorithm (for HuggingFace Unigram models)
1057    fn encode_unigram(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1058        if text.is_empty() {
1059            return Ok(vec![]);
1060        }
1061
1062        let char_boundaries: Vec<usize> = text
1063            .char_indices()
1064            .map(|(i, _)| i)
1065            .chain(std::iter::once(text.len()))
1066            .collect();
1067        let n = char_boundaries.len() - 1;
1068
1069        const NEG_INF: f64 = -1e18;
1070        let mut best_score = vec![NEG_INF; n + 1];
1071        let mut best_path: Vec<Option<(u32, usize)>> = vec![None; n + 1];
1072        best_score[0] = 0.0;
1073
1074        let max_token_chars = 128;
1075
1076        for end in 1..=n {
1077            let end_byte = char_boundaries[end];
1078            let min_start = end.saturating_sub(max_token_chars);
1079
1080            for start in (min_start..end).rev() {
1081                if best_score[start] <= NEG_INF {
1082                    continue;
1083                }
1084                let start_byte = char_boundaries[start];
1085                let substr = &text[start_byte..end_byte];
1086
1087                if let Some(&id) = self.token_to_id.get(substr) {
1088                    let score = *self.scores.get(id as usize).unwrap_or(&0.0) as f64;
1089                    let candidate = best_score[start] + score;
1090                    if candidate > best_score[end] {
1091                        best_score[end] = candidate;
1092                        best_path[end] = Some((id, start));
1093                    }
1094                }
1095            }
1096
1097            // Single-char byte fallback if no token found
1098            if best_path[end].is_none() && best_score[end - 1] > NEG_INF {
1099                let start_byte = char_boundaries[end - 1];
1100                let end_byte_val = char_boundaries[end];
1101                let ch_str = &text[start_byte..end_byte_val];
1102
1103                if let Some(&id) = self.token_to_id.get(ch_str) {
1104                    let score = *self.scores.get(id as usize).unwrap_or(&-10.0) as f64;
1105                    best_score[end] = best_score[end - 1] + score;
1106                    best_path[end] = Some((id, end - 1));
1107                } else {
1108                    // Try byte-level fallback tokens
1109                    for byte in ch_str.as_bytes() {
1110                        let byte_token = format!("<0x{:02X}>", byte);
1111                        if let Some(&id) = self.token_to_id.get(&byte_token) {
1112                            let score = *self.scores.get(id as usize).unwrap_or(&-10.0) as f64;
1113                            let candidate = best_score[end - 1] + score;
1114                            if candidate > best_score[end] {
1115                                best_score[end] = candidate;
1116                                best_path[end] = Some((id, end - 1));
1117                            }
1118                        }
1119                    }
1120                }
1121            }
1122        }
1123
1124        if best_score[n] <= NEG_INF {
1125            return self.encode_unigram_fallback(text);
1126        }
1127
1128        let mut result = Vec::new();
1129        let mut pos = n;
1130        while pos > 0 {
1131            if let Some((token_id, start)) = best_path[pos] {
1132                result.push(token_id);
1133                pos = start;
1134            } else {
1135                break;
1136            }
1137        }
1138        result.reverse();
1139        Ok(result)
1140    }
1141
1142    /// Fallback for Unigram when Viterbi cannot find a complete path
1143    fn encode_unigram_fallback(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1144        let mut result = Vec::new();
1145        for ch in text.chars() {
1146            let ch_str = ch.to_string();
1147            if let Some(&id) = self.token_to_id.get(&ch_str) {
1148                result.push(id);
1149            } else {
1150                for byte in ch_str.as_bytes() {
1151                    let byte_token = format!("<0x{:02X}>", byte);
1152                    if let Some(&id) = self.token_to_id.get(&byte_token) {
1153                        result.push(id);
1154                    } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1155                        result.push(unk_id);
1156                    }
1157                }
1158            }
1159        }
1160        Ok(result)
1161    }
1162
1163    /// WordPiece encoding using greedy longest-match with continuation prefix
1164    fn encode_wordpiece(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1165        if text.is_empty() {
1166            return Ok(vec![]);
1167        }
1168
1169        let mut result = Vec::new();
1170        let chars: Vec<char> = text.chars().collect();
1171
1172        // WordPiece operates on whitespace-split words
1173        let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
1174        let words = if words.is_empty() {
1175            vec![text.to_string()]
1176        } else {
1177            words
1178        };
1179
1180        for word in &words {
1181            let word_chars: Vec<char> = word.chars().collect();
1182            if word_chars.len() > 200 {
1183                if let Some(unk_id) = self.special_tokens.unk_token_id {
1184                    result.push(unk_id);
1185                }
1186                continue;
1187            }
1188
1189            let mut start = 0;
1190            let mut is_first_subword = true;
1191
1192            while start < word_chars.len() {
1193                let mut end = word_chars.len();
1194                let mut found = false;
1195
1196                while start < end {
1197                    let substr: String = word_chars[start..end].iter().collect();
1198                    let candidate = if is_first_subword {
1199                        substr.clone()
1200                    } else {
1201                        format!("{}{}", self.wordpiece_prefix, substr)
1202                    };
1203
1204                    if let Some(&id) = self.token_to_id.get(&candidate) {
1205                        result.push(id);
1206                        found = true;
1207                        break;
1208                    }
1209                    end -= 1;
1210                }
1211
1212                if !found {
1213                    if let Some(unk_id) = self.special_tokens.unk_token_id {
1214                        result.push(unk_id);
1215                    }
1216                    break;
1217                }
1218
1219                start = end;
1220                is_first_subword = false;
1221            }
1222        }
1223
1224        let _ = chars; // suppress unused warning from pre-existing binding
1225        Ok(result)
1226    }
1227
1228    /// BPE encoding for a pre-tokenized segment (no further splitting)
1229    fn encode_bpe_pretokenized(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1230        if let Some(&id) = self.token_to_id.get(text) {
1231            return Ok(vec![id]);
1232        }
1233
1234        let mut tokens = self.text_to_initial_tokens(text)?;
1235        self.apply_bpe_merges(&mut tokens);
1236        Ok(tokens)
1237    }
1238
1239    /// Decode token IDs to text
1240    pub fn decode(&self, tokens: &[u32]) -> TokenizerResult<String> {
1241        if let Some(ref u2b) = self.gpt2_unicode_to_byte {
1242            return self.decode_gpt2(tokens, u2b);
1243        }
1244        self.decode_sentencepiece(tokens)
1245    }
1246
1247    /// Decode for GPT-2 byte-level BPE tokenizers (Qwen, StarCoder, DeepSeek, etc.)
1248    ///
1249    /// Each character in a GPT-2 token string represents one byte via the byte_to_unicode
1250    /// mapping. Decoding reverses the mapping to recover the raw byte sequence, then
1251    /// interprets those bytes as UTF-8.
1252    fn decode_gpt2(
1253        &self,
1254        tokens: &[u32],
1255        unicode_to_byte: &HashMap<char, u8>,
1256    ) -> TokenizerResult<String> {
1257        let mut raw_bytes: Vec<u8> = Vec::new();
1258
1259        for &token_id in tokens {
1260            if self.is_special_token(token_id) {
1261                continue;
1262            }
1263
1264            let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1265                TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1266            })?;
1267
1268            // Skip control tokens that render as literal text (e.g. <|im_end|>)
1269            if self.get_token_type(token_id) == TokenType::Control {
1270                continue;
1271            }
1272
1273            // Handle <0x??> byte fallback tokens
1274            if token_str.starts_with("<0x")
1275                && token_str.ends_with('>')
1276                && token_str.len() == 6
1277                && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1278            {
1279                raw_bytes.push(byte);
1280                continue;
1281            }
1282
1283            // Map each character through the GPT-2 unicode→byte table
1284            for ch in token_str.chars() {
1285                if let Some(&b) = unicode_to_byte.get(&ch) {
1286                    raw_bytes.push(b);
1287                } else {
1288                    // Character not in the GPT-2 table — encode its UTF-8 bytes directly
1289                    let mut buf = [0u8; 4];
1290                    let encoded = ch.encode_utf8(&mut buf);
1291                    raw_bytes.extend_from_slice(encoded.as_bytes());
1292                }
1293            }
1294        }
1295
1296        Ok(String::from_utf8_lossy(&raw_bytes).into_owned())
1297    }
1298
1299    /// Decode for SentencePiece-style tokenizers (LLaMA, etc.)
1300    fn decode_sentencepiece(&self, tokens: &[u32]) -> TokenizerResult<String> {
1301        let mut text = String::new();
1302        let mut byte_buffer: Vec<u8> = Vec::new();
1303
1304        for &token_id in tokens {
1305            if self.is_special_token(token_id) {
1306                continue;
1307            }
1308
1309            if self.get_token_type(token_id) == TokenType::Control {
1310                continue;
1311            }
1312
1313            let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1314                TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1315            })?;
1316
1317            // Handle byte tokens — collect into buffer for proper UTF-8 decoding
1318            if token_str.starts_with("<0x")
1319                && token_str.ends_with('>')
1320                && token_str.len() == 6
1321                && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1322            {
1323                byte_buffer.push(byte);
1324                continue;
1325            }
1326
1327            // Flush byte buffer before adding text
1328            if !byte_buffer.is_empty() {
1329                text.push_str(&String::from_utf8_lossy(&byte_buffer));
1330                byte_buffer.clear();
1331            }
1332
1333            // SentencePiece uses ▁ for leading spaces
1334            text.push_str(&token_str.replace('▁', " "));
1335        }
1336
1337        // Flush remaining bytes
1338        if !byte_buffer.is_empty() {
1339            text.push_str(&String::from_utf8_lossy(&byte_buffer));
1340        }
1341
1342        Ok(text)
1343    }
1344
1345    /// Decode a single token to string
1346    pub fn decode_token(&self, token_id: u32) -> TokenizerResult<String> {
1347        self.decode(&[token_id])
1348    }
1349
1350    /// Decode a single token in streaming mode, handling incomplete UTF-8 sequences.
1351    ///
1352    /// For GPT-2 byte-level tokenizers, multi-byte UTF-8 characters (like emoji)
1353    /// may be split across multiple tokens. This method accumulates raw bytes in
1354    /// `pending` and only returns text once complete UTF-8 code points are formed.
1355    pub fn decode_token_streaming(
1356        &self,
1357        token_id: u32,
1358        pending: &mut Vec<u8>,
1359    ) -> TokenizerResult<String> {
1360        if self.is_special_token(token_id) || self.get_token_type(token_id) == TokenType::Control {
1361            // Flush any pending bytes before emitting nothing
1362            let flushed = flush_valid_utf8(pending);
1363            return Ok(flushed);
1364        }
1365
1366        let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1367            TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1368        })?;
1369
1370        // Handle <0x??> byte fallback tokens
1371        if token_str.starts_with("<0x")
1372            && token_str.ends_with('>')
1373            && token_str.len() == 6
1374            && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1375        {
1376            pending.push(byte);
1377            return Ok(flush_valid_utf8(pending));
1378        }
1379
1380        if let Some(ref u2b) = self.gpt2_unicode_to_byte {
1381            // GPT-2: each char maps to a byte
1382            for ch in token_str.chars() {
1383                if let Some(&b) = u2b.get(&ch) {
1384                    pending.push(b);
1385                } else {
1386                    let mut buf = [0u8; 4];
1387                    let encoded = ch.encode_utf8(&mut buf);
1388                    pending.extend_from_slice(encoded.as_bytes());
1389                }
1390            }
1391            Ok(flush_valid_utf8(pending))
1392        } else {
1393            // SentencePiece: flush pending, then return token text
1394            let mut result = flush_valid_utf8(pending);
1395            result.push_str(&token_str.replace('▁', " "));
1396            Ok(result)
1397        }
1398    }
1399
1400    /// Get token string by ID
1401    pub fn get_token(&self, id: u32) -> Option<&str> {
1402        self.id_to_token.get(id as usize).map(|s| s.as_str())
1403    }
1404
1405    /// Get token ID by string
1406    pub fn get_token_id(&self, token: &str) -> Option<u32> {
1407        self.token_to_id.get(token).copied()
1408    }
1409
1410    /// Get token type
1411    pub fn get_token_type(&self, id: u32) -> TokenType {
1412        self.token_types
1413            .get(id as usize)
1414            .copied()
1415            .unwrap_or(TokenType::Normal)
1416    }
1417
1418    /// Check if a token is a special token
1419    pub fn is_special_token(&self, id: u32) -> bool {
1420        id == self.special_tokens.bos_token_id
1421            || id == self.special_tokens.eos_token_id
1422            || self.special_tokens.pad_token_id == Some(id)
1423            || self.special_tokens.unk_token_id == Some(id)
1424    }
1425
1426    /// Load tokenizer from a HuggingFace `tokenizer.json` file
1427    ///
1428    /// This parses the JSON format used by HuggingFace tokenizers (the `tokenizers` library).
1429    /// Supports BPE models which cover LLaMA, Mistral, Qwen, and most modern LLMs.
1430    pub fn from_hf_json(path: impl AsRef<std::path::Path>) -> TokenizerResult<Self> {
1431        let path = path.as_ref();
1432        let data = std::fs::read_to_string(path)
1433            .map_err(|e| TokenizerError::MissingData(format!("{}: {}", path.display(), e)))?;
1434
1435        Self::from_hf_json_str(&data)
1436    }
1437
1438    /// Parse tokenizer from a HuggingFace tokenizer.json string
1439    pub fn from_hf_json_str(json: &str) -> TokenizerResult<Self> {
1440        let root: serde_json::Value = serde_json::from_str(json)
1441            .map_err(|e| TokenizerError::EncodingError(format!("Invalid tokenizer.json: {}", e)))?;
1442
1443        let model = root
1444            .get("model")
1445            .ok_or_else(|| TokenizerError::MissingData("model section in tokenizer.json".into()))?;
1446
1447        let model_type = model.get("type").and_then(|v| v.as_str()).unwrap_or("BPE");
1448        let tokenizer_type = match model_type {
1449            "BPE" => TokenizerType::BPE,
1450            "Unigram" => TokenizerType::SentencePiece,
1451            "WordPiece" => TokenizerType::WordPiece,
1452            _ => TokenizerType::Unknown,
1453        };
1454
1455        let mut token_to_id = HashMap::new();
1456        let mut id_to_token: Vec<String>;
1457        let mut scores: Vec<f32>;
1458        let mut merges = HashMap::new();
1459        let mut wordpiece_prefix = "##".to_string();
1460        let mut model_unk_token: Option<String> = None;
1461
1462        match tokenizer_type {
1463            TokenizerType::SentencePiece => {
1464                // Unigram: vocab is [[token, score], ...]
1465                let vocab_arr = model
1466                    .get("vocab")
1467                    .and_then(|v| v.as_array())
1468                    .ok_or_else(|| {
1469                        TokenizerError::MissingData("Unigram vocab array".into())
1470                    })?;
1471
1472                id_to_token = Vec::with_capacity(vocab_arr.len());
1473                scores = Vec::with_capacity(vocab_arr.len());
1474
1475                for (id, entry) in vocab_arr.iter().enumerate() {
1476                    let arr = entry.as_array().ok_or_else(|| {
1477                        TokenizerError::MissingData(format!(
1478                            "Unigram vocab entry {} not an array",
1479                            id
1480                        ))
1481                    })?;
1482                    let token = arr
1483                        .first()
1484                        .and_then(|v| v.as_str())
1485                        .ok_or_else(|| {
1486                            TokenizerError::MissingData(format!(
1487                                "Unigram vocab entry {} missing token",
1488                                id
1489                            ))
1490                        })?;
1491                    let score = arr
1492                        .get(1)
1493                        .and_then(|v| v.as_f64())
1494                        .unwrap_or(0.0) as f32;
1495
1496                    token_to_id.insert(token.to_string(), id as u32);
1497                    id_to_token.push(token.to_string());
1498                    scores.push(score);
1499                }
1500
1501                if let Some(unk_id) = model.get("unk_id").and_then(|v| v.as_u64()) {
1502                    model_unk_token = id_to_token.get(unk_id as usize).cloned();
1503                }
1504            }
1505            TokenizerType::WordPiece => {
1506                // WordPiece: vocab is { token: id, ... }
1507                let vocab_obj = model
1508                    .get("vocab")
1509                    .and_then(|v| v.as_object())
1510                    .ok_or_else(|| {
1511                        TokenizerError::MissingData("WordPiece vocab object".into())
1512                    })?;
1513
1514                let vocab_size = vocab_obj.len();
1515                id_to_token = vec![String::new(); vocab_size];
1516
1517                for (token, id_val) in vocab_obj {
1518                    let id = id_val.as_u64().ok_or_else(|| {
1519                        TokenizerError::MissingData(format!("Invalid vocab ID for '{}'", token))
1520                    })? as u32;
1521                    token_to_id.insert(token.clone(), id);
1522                    if (id as usize) < id_to_token.len() {
1523                        id_to_token[id as usize] = token.clone();
1524                    }
1525                }
1526
1527                if let Some(prefix) = model
1528                    .get("continuing_subword_prefix")
1529                    .and_then(|v| v.as_str())
1530                {
1531                    wordpiece_prefix = prefix.to_string();
1532                }
1533                if let Some(unk) = model.get("unk_token").and_then(|v| v.as_str()) {
1534                    model_unk_token = Some(unk.to_string());
1535                }
1536
1537                scores = vec![0.0; id_to_token.len()];
1538            }
1539            _ => {
1540                // BPE: vocab is { token: id, ... }
1541                let vocab_obj = model
1542                    .get("vocab")
1543                    .and_then(|v| v.as_object())
1544                    .ok_or_else(|| {
1545                        TokenizerError::MissingData("BPE vocab object".into())
1546                    })?;
1547
1548                let vocab_size = vocab_obj.len();
1549                id_to_token = vec![String::new(); vocab_size];
1550
1551                for (token, id_val) in vocab_obj {
1552                    let id = id_val.as_u64().ok_or_else(|| {
1553                        TokenizerError::MissingData(format!("Invalid vocab ID for '{}'", token))
1554                    })? as u32;
1555                    token_to_id.insert(token.clone(), id);
1556                    if (id as usize) < id_to_token.len() {
1557                        id_to_token[id as usize] = token.clone();
1558                    }
1559                }
1560
1561                if let Some(merges_arr) = model.get("merges").and_then(|v| v.as_array()) {
1562                    for (priority, merge_val) in merges_arr.iter().enumerate() {
1563                        // Support two merge formats:
1564                        // 1. String: "▁ t" (space-separated pair)
1565                        // 2. Array: ["▁", "t"] (pair of strings, used by Gemma 4)
1566                        let (part0, part1) = if let Some(merge_str) = merge_val.as_str() {
1567                            let parts: Vec<&str> = merge_str.split(' ').collect();
1568                            if parts.len() == 2 {
1569                                (parts[0].to_string(), parts[1].to_string())
1570                            } else {
1571                                continue;
1572                            }
1573                        } else if let Some(arr) = merge_val.as_array() {
1574                            if arr.len() == 2 {
1575                                if let (Some(a), Some(b)) = (arr[0].as_str(), arr[1].as_str()) {
1576                                    (a.to_string(), b.to_string())
1577                                } else {
1578                                    continue;
1579                                }
1580                            } else {
1581                                continue;
1582                            }
1583                        } else {
1584                            continue;
1585                        };
1586
1587                        if let (Some(&id1), Some(&id2)) =
1588                            (token_to_id.get(&part0), token_to_id.get(&part1))
1589                        {
1590                            let merged = format!("{}{}", part0, part1);
1591                            if let Some(&merged_id) = token_to_id.get(&merged) {
1592                                merges.insert((id1, id2), (merged_id, priority));
1593                            }
1594                        }
1595                    }
1596                }
1597
1598                scores = vec![0.0; id_to_token.len()];
1599            }
1600        }
1601
1602        let vocab_size = id_to_token.len();
1603
1604        // Parse added_tokens and detect special token roles
1605        let mut bos_token_id: Option<u32> = None;
1606        let mut eos_token_id: Option<u32> = None;
1607        let mut pad_token_id: Option<u32> = None;
1608        let mut unk_token_id: Option<u32> = None;
1609
1610        if let Some(added_tokens) = root.get("added_tokens").and_then(|v| v.as_array()) {
1611            for token_obj in added_tokens {
1612                let content = token_obj
1613                    .get("content")
1614                    .and_then(|v| v.as_str())
1615                    .unwrap_or("");
1616                let id = token_obj
1617                    .get("id")
1618                    .and_then(|v| v.as_u64())
1619                    .map(|v| v as u32);
1620                let special = token_obj
1621                    .get("special")
1622                    .and_then(|v| v.as_bool())
1623                    .unwrap_or(false);
1624
1625                if let Some(id) = id {
1626                    token_to_id.insert(content.to_string(), id);
1627                    if (id as usize) < id_to_token.len() {
1628                        id_to_token[id as usize] = content.to_string();
1629                    }
1630
1631                    if special {
1632                        let content_lower = content.to_lowercase();
1633                        if content_lower.contains("bos")
1634                            || content == "<s>"
1635                            || content == "<|begin_of_text|>"
1636                            || content == "<|startoftext|>"
1637                        {
1638                            bos_token_id = Some(id);
1639                        }
1640                        if content_lower.contains("eos")
1641                            || content == "</s>"
1642                            || content == "<|end_of_text|>"
1643                            || content == "<|endoftext|>"
1644                            || content == "<|eot_id|>"
1645                        {
1646                            eos_token_id = Some(id);
1647                        }
1648                        if content_lower.contains("pad") || content == "<pad>" {
1649                            pad_token_id = Some(id);
1650                        }
1651                        if content_lower.contains("unk") || content == "<unk>" {
1652                            unk_token_id = Some(id);
1653                        }
1654                    }
1655                }
1656            }
1657        }
1658
1659        // Resolve unk from model section if not found in added_tokens
1660        if unk_token_id.is_none() {
1661            if let Some(ref unk_str) = model_unk_token {
1662                unk_token_id = token_to_id.get(unk_str).copied();
1663            }
1664        }
1665
1666        // Check post_processor for special token IDs
1667        if let Some(post_proc) = root.get("post_processor") {
1668            if let Some(special_tokens_map) = post_proc.get("special_tokens") {
1669                if let Some(bos_obj) = special_tokens_map
1670                    .get("<s>")
1671                    .or_else(|| special_tokens_map.get("<|begin_of_text|>"))
1672                    && let Some(ids) = bos_obj.get("ids").and_then(|v| v.as_array())
1673                    && let Some(id) = ids.first().and_then(|v| v.as_u64())
1674                {
1675                    bos_token_id = bos_token_id.or(Some(id as u32));
1676                }
1677                if let Some(eos_obj) = special_tokens_map
1678                    .get("</s>")
1679                    .or_else(|| special_tokens_map.get("<|end_of_text|>"))
1680                    && let Some(ids) = eos_obj.get("ids").and_then(|v| v.as_array())
1681                    && let Some(id) = ids.first().and_then(|v| v.as_u64())
1682                {
1683                    eos_token_id = eos_token_id.or(Some(id as u32));
1684                }
1685            }
1686        }
1687
1688        let special_tokens = SpecialTokens {
1689            bos_token_id: bos_token_id.unwrap_or(1),
1690            eos_token_id: eos_token_id.unwrap_or(2),
1691            pad_token_id,
1692            unk_token_id,
1693        };
1694
1695        // Build token types
1696        let mut token_types = vec![TokenType::Normal; vocab_size];
1697        for &id in [special_tokens.bos_token_id, special_tokens.eos_token_id].iter() {
1698            if (id as usize) < token_types.len() {
1699                token_types[id as usize] = TokenType::Control;
1700            }
1701        }
1702        if let Some(pad_id) = special_tokens.pad_token_id
1703            && (pad_id as usize) < token_types.len()
1704        {
1705            token_types[pad_id as usize] = TokenType::Control;
1706        }
1707        if let Some(unk_id) = special_tokens.unk_token_id
1708            && (unk_id as usize) < token_types.len()
1709        {
1710            token_types[unk_id as usize] = TokenType::Control;
1711        }
1712        for (token, &id) in &token_to_id {
1713            if token.starts_with("<0x")
1714                && token.ends_with('>')
1715                && token.len() == 6
1716                && (id as usize) < token_types.len()
1717            {
1718                token_types[id as usize] = TokenType::Byte;
1719            }
1720        }
1721
1722        // Detect GPT-2 byte-level BPE
1723        let uses_byte_level = root
1724            .get("pre_tokenizer")
1725            .and_then(|v| v.get("type").or_else(|| {
1726                // Handle Sequence pre-tokenizer containing ByteLevel
1727                v.get("pretokenizers").and_then(|arr| {
1728                    arr.as_array().and_then(|a| {
1729                        a.iter().find_map(|pt| {
1730                            pt.get("type").filter(|t| t.as_str() == Some("ByteLevel"))
1731                        })
1732                    })
1733                })
1734            }))
1735            .and_then(|v| v.as_str())
1736            .is_some_and(|t| t == "ByteLevel");
1737
1738        let (gpt2_unicode_to_byte, gpt2_byte_to_unicode) = if tokenizer_type == TokenizerType::BPE && uses_byte_level {
1739            let (u2b, b2u) = build_gpt2_mappings();
1740            (Some(u2b), Some(b2u))
1741        } else {
1742            (None, None)
1743        };
1744
1745        // Parse HF pipeline components
1746        let normalizer = root.get("normalizer")
1747            .and_then(|v| if v.is_null() { None } else { Self::parse_normalizer(v) });
1748        let pre_tokenizer = root.get("pre_tokenizer")
1749            .and_then(|v| if v.is_null() { None } else { Self::parse_pre_tokenizer(v) });
1750        let post_processor = root.get("post_processor")
1751            .and_then(|v| if v.is_null() { None } else { Self::parse_post_processor(v, &token_to_id) });
1752
1753        let mut control_token_strings: Vec<(String, u32)> = token_types
1754            .iter()
1755            .enumerate()
1756            .filter(|(_, tt)| **tt == TokenType::Control)
1757            .filter_map(|(id, _)| {
1758                let s = &id_to_token[id];
1759                if !s.is_empty() {
1760                    Some((s.clone(), id as u32))
1761                } else {
1762                    None
1763                }
1764            })
1765            .collect();
1766        control_token_strings.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
1767
1768        Ok(Self {
1769            token_to_id,
1770            id_to_token,
1771            scores,
1772            merges,
1773            special_tokens,
1774            tokenizer_type,
1775            vocab_size,
1776            token_types,
1777            gpt2_unicode_to_byte,
1778            gpt2_byte_to_unicode,
1779            normalizer,
1780            pre_tokenizer,
1781            post_processor,
1782            wordpiece_prefix,
1783            control_token_strings,
1784            has_explicit_bos: bos_token_id.is_some(),
1785            add_space_prefix: true, // JSON tokenizers typically use pre-tokenizer to handle this
1786        })
1787    }
1788
1789    fn parse_normalizer(value: &serde_json::Value) -> Option<Normalizer> {
1790        let type_str = value.get("type")?.as_str()?;
1791        match type_str {
1792            "NFC" => Some(Normalizer::NFC),
1793            "NFKC" => Some(Normalizer::NFKC),
1794            "NFD" => Some(Normalizer::NFD),
1795            "NFKD" => Some(Normalizer::NFKD),
1796            "Lowercase" => Some(Normalizer::Lowercase),
1797            "Strip" => {
1798                let left = value.get("strip_left").and_then(|v| v.as_bool()).unwrap_or(true);
1799                let right = value.get("strip_right").and_then(|v| v.as_bool()).unwrap_or(true);
1800                Some(Normalizer::Strip { left, right })
1801            }
1802            "Prepend" => {
1803                let prepend = value.get("prepend").and_then(|v| v.as_str()).unwrap_or("▁");
1804                Some(Normalizer::Prepend(prepend.to_string()))
1805            }
1806            "Replace" => {
1807                let pattern = value
1808                    .get("pattern")
1809                    .and_then(|v| v.get("String").and_then(|s| s.as_str()))
1810                    .unwrap_or("");
1811                let content = value.get("content").and_then(|v| v.as_str()).unwrap_or("");
1812                Some(Normalizer::Replace {
1813                    pattern: pattern.to_string(),
1814                    content: content.to_string(),
1815                })
1816            }
1817            "StripAccents" => Some(Normalizer::StripAccents),
1818            "Sequence" => {
1819                let normalizers = value.get("normalizers")?.as_array()?;
1820                let parsed: Vec<Normalizer> = normalizers
1821                    .iter()
1822                    .filter_map(|v| Self::parse_normalizer(v))
1823                    .collect();
1824                if parsed.is_empty() {
1825                    None
1826                } else {
1827                    Some(Normalizer::Sequence(parsed))
1828                }
1829            }
1830            "BertNormalizer" => {
1831                let mut seq = Vec::new();
1832                if value
1833                    .get("lowercase")
1834                    .and_then(|v| v.as_bool())
1835                    .unwrap_or(true)
1836                {
1837                    seq.push(Normalizer::Lowercase);
1838                }
1839                if value
1840                    .get("strip_accents")
1841                    .and_then(|v| v.as_bool())
1842                    .unwrap_or(false)
1843                {
1844                    seq.push(Normalizer::StripAccents);
1845                }
1846                match seq.len() {
1847                    0 => None,
1848                    1 => Some(seq.remove(0)),
1849                    _ => Some(Normalizer::Sequence(seq)),
1850                }
1851            }
1852            "Precompiled" => Some(Normalizer::NFC),
1853            _ => None,
1854        }
1855    }
1856
1857    fn parse_pre_tokenizer(value: &serde_json::Value) -> Option<PreTokenizer> {
1858        let type_str = value.get("type")?.as_str()?;
1859        match type_str {
1860            "ByteLevel" => {
1861                let add_prefix_space = value
1862                    .get("add_prefix_space")
1863                    .and_then(|v| v.as_bool())
1864                    .unwrap_or(true);
1865                Some(PreTokenizer::ByteLevel { add_prefix_space })
1866            }
1867            "Whitespace" | "WhitespaceSplit" => Some(PreTokenizer::Whitespace),
1868            "Metaspace" => {
1869                let replacement = value
1870                    .get("replacement")
1871                    .and_then(|v| v.as_str())
1872                    .and_then(|s| s.chars().next())
1873                    .unwrap_or('▁');
1874                let add_prefix_space = value
1875                    .get("add_prefix_space")
1876                    .and_then(|v| v.as_bool())
1877                    .unwrap_or(true);
1878                Some(PreTokenizer::Metaspace {
1879                    replacement,
1880                    add_prefix_space,
1881                })
1882            }
1883            "Punctuation" | "BertPreTokenizer" => Some(PreTokenizer::Punctuation),
1884            "Digits" => {
1885                let individual_digits = value
1886                    .get("individual_digits")
1887                    .and_then(|v| v.as_bool())
1888                    .unwrap_or(false);
1889                Some(PreTokenizer::Digits { individual_digits })
1890            }
1891            "Sequence" => {
1892                let pretokenizers = value.get("pretokenizers")?.as_array()?;
1893                let parsed: Vec<PreTokenizer> = pretokenizers
1894                    .iter()
1895                    .filter_map(|v| Self::parse_pre_tokenizer(v))
1896                    .collect();
1897                if parsed.is_empty() {
1898                    None
1899                } else {
1900                    Some(PreTokenizer::Sequence(parsed))
1901                }
1902            }
1903            _ => None,
1904        }
1905    }
1906
1907    fn parse_post_processor(
1908        value: &serde_json::Value,
1909        token_to_id: &HashMap<String, u32>,
1910    ) -> Option<PostProcessor> {
1911        let type_str = value.get("type")?.as_str()?;
1912        match type_str {
1913            "TemplateProcessing" => {
1914                let parse_template = |arr: &[serde_json::Value]| -> Vec<TemplateElement> {
1915                    arr.iter()
1916                        .filter_map(|item| {
1917                            if let Some(special) = item.get("SpecialToken") {
1918                                let id_str = special.get("id")?.as_str()?;
1919                                let token_id = token_to_id.get(id_str).copied()?;
1920                                Some(TemplateElement::SpecialToken {
1921                                    id: id_str.to_string(),
1922                                    token_id,
1923                                })
1924                            } else if item.get("Sequence").is_some() {
1925                                let type_id = item
1926                                    .get("Sequence")
1927                                    .and_then(|s| s.get("id"))
1928                                    .and_then(|v| v.as_u64())
1929                                    .unwrap_or(0) as u32;
1930                                Some(TemplateElement::Sequence { type_id })
1931                            } else {
1932                                None
1933                            }
1934                        })
1935                        .collect()
1936                };
1937
1938                let single = value
1939                    .get("single")
1940                    .and_then(|v| v.as_array())
1941                    .map(|a| parse_template(a))
1942                    .unwrap_or_default();
1943                let pair = value
1944                    .get("pair")
1945                    .and_then(|v| v.as_array())
1946                    .map(|a| parse_template(a))
1947                    .unwrap_or_default();
1948
1949                Some(PostProcessor::TemplateProcessing { single, pair })
1950            }
1951            "ByteLevel" => {
1952                let trim_offsets = value
1953                    .get("trim_offsets")
1954                    .and_then(|v| v.as_bool())
1955                    .unwrap_or(true);
1956                Some(PostProcessor::ByteLevel { trim_offsets })
1957            }
1958            "BertProcessing" => {
1959                let mut single = Vec::new();
1960                let mut pair = Vec::new();
1961
1962                if let Some(cls) = value.get("cls").and_then(|v| v.as_array()) {
1963                    if let (Some(token), Some(id)) = (
1964                        cls.first().and_then(|v| v.as_str()),
1965                        cls.get(1).and_then(|v| v.as_u64()),
1966                    ) {
1967                        let elem = TemplateElement::SpecialToken {
1968                            id: token.to_string(),
1969                            token_id: id as u32,
1970                        };
1971                        single.push(elem.clone());
1972                        pair.push(elem);
1973                    }
1974                }
1975
1976                single.push(TemplateElement::Sequence { type_id: 0 });
1977                pair.push(TemplateElement::Sequence { type_id: 0 });
1978
1979                if let Some(sep) = value.get("sep").and_then(|v| v.as_array()) {
1980                    if let (Some(token), Some(id)) = (
1981                        sep.first().and_then(|v| v.as_str()),
1982                        sep.get(1).and_then(|v| v.as_u64()),
1983                    ) {
1984                        let elem = TemplateElement::SpecialToken {
1985                            id: token.to_string(),
1986                            token_id: id as u32,
1987                        };
1988                        single.push(elem.clone());
1989                        pair.push(elem.clone());
1990                        pair.push(TemplateElement::Sequence { type_id: 1 });
1991                        pair.push(elem);
1992                    }
1993                }
1994
1995                Some(PostProcessor::TemplateProcessing { single, pair })
1996            }
1997            _ => None,
1998        }
1999    }
2000}
2001
2002#[cfg(test)]
2003mod tests {
2004    use super::*;
2005
2006    #[test]
2007    fn test_tokenizer_type_parsing() {
2008        assert_eq!(TokenizerType::from_gguf_str("llama"), TokenizerType::BPE);
2009        assert_eq!(TokenizerType::from_gguf_str("bpe"), TokenizerType::BPE);
2010        assert_eq!(
2011            TokenizerType::from_gguf_str("sentencepiece"),
2012            TokenizerType::SentencePiece
2013        );
2014    }
2015
2016    #[test]
2017    fn test_special_tokens_default() {
2018        let special = SpecialTokens::default();
2019        assert_eq!(special.bos_token_id, 1);
2020        assert_eq!(special.eos_token_id, 2);
2021    }
2022
2023    #[test]
2024    fn test_gpt2_unicode_to_byte_table() {
2025        let (table, _) = build_gpt2_mappings();
2026        assert_eq!(table.len(), 256);
2027
2028        // Printable ASCII maps to itself
2029        assert_eq!(table[&'!'], b'!');
2030        assert_eq!(table[&'A'], b'A');
2031        assert_eq!(table[&'~'], b'~');
2032
2033        // GPT-2 special chars map to their byte values
2034        assert_eq!(table[&'Ġ'], b' '); // U+0120 → 0x20 (space)
2035        assert_eq!(table[&'Ċ'], b'\n'); // U+010A → 0x0A (newline)
2036        assert_eq!(table[&'ĉ'], b'\t'); // U+0109 → 0x09 (tab)
2037
2038        // Latin-1 supplement bytes map to themselves
2039        assert_eq!(table[&'¡'], 0xA1);
2040        assert_eq!(table[&'®'], 0xAE);
2041        assert_eq!(table[&'ÿ'], 0xFF);
2042    }
2043
2044    #[test]
2045    fn test_gpt2_decode_space_and_emoji() {
2046        let (table, _) = build_gpt2_mappings();
2047
2048        // "ĠHello" should decode to " Hello"
2049        let bytes: Vec<u8> = "ĠHello".chars().map(|c| table[&c]).collect();
2050        assert_eq!(String::from_utf8(bytes).unwrap(), " Hello");
2051
2052        // "ðŁĺĬ" is the GPT-2 encoding of 😊 (U+1F60A, UTF-8: F0 9F 98 8A)
2053        let bytes: Vec<u8> = "ðŁĺĬ".chars().map(|c| table[&c]).collect();
2054        let decoded = String::from_utf8(bytes).unwrap();
2055        assert_eq!(decoded, "😊");
2056    }
2057
2058    #[test]
2059    fn test_normalizer_nfc() {
2060        let norm = Normalizer::NFC;
2061        // U+00E9 (é precomposed) vs U+0065 + U+0301 (e + combining acute)
2062        let decomposed = "e\u{0301}";
2063        let result = norm.apply(decomposed);
2064        assert_eq!(result, "\u{00E9}");
2065    }
2066
2067    #[test]
2068    fn test_normalizer_lowercase() {
2069        let norm = Normalizer::Lowercase;
2070        assert_eq!(norm.apply("HELLO World"), "hello world");
2071    }
2072
2073    #[test]
2074    fn test_normalizer_strip_accents() {
2075        let norm = Normalizer::StripAccents;
2076        assert_eq!(norm.apply("café"), "cafe");
2077        assert_eq!(norm.apply("naïve"), "naive");
2078    }
2079
2080    #[test]
2081    fn test_normalizer_sequence() {
2082        let norm = Normalizer::Sequence(vec![
2083            Normalizer::NFKC,
2084            Normalizer::Lowercase,
2085        ]);
2086        assert_eq!(norm.apply("HÉLLO"), "héllo");
2087    }
2088
2089    #[test]
2090    fn test_normalizer_replace() {
2091        let norm = Normalizer::Replace {
2092            pattern: " ".to_string(),
2093            content: "▁".to_string(),
2094        };
2095        assert_eq!(norm.apply("hello world"), "hello▁world");
2096    }
2097
2098    #[test]
2099    fn test_pre_tokenizer_whitespace() {
2100        let pt = PreTokenizer::Whitespace;
2101        assert_eq!(pt.apply("Hello world  test"), vec!["Hello", "world", "test"]);
2102    }
2103
2104    #[test]
2105    fn test_pre_tokenizer_byte_level() {
2106        let pt = PreTokenizer::ByteLevel { add_prefix_space: true };
2107        let result = pt.apply("Hello world");
2108        assert_eq!(result, vec![" Hello", " world"]);
2109
2110        let pt_no_space = PreTokenizer::ByteLevel { add_prefix_space: false };
2111        let result = pt_no_space.apply("Hello world");
2112        assert_eq!(result, vec!["Hello", " world"]);
2113    }
2114
2115    #[test]
2116    fn test_pre_tokenizer_punctuation() {
2117        let pt = PreTokenizer::Punctuation;
2118        let result = pt.apply("Hello, world!");
2119        assert_eq!(result, vec!["Hello", ",", " world", "!"]);
2120    }
2121
2122    #[test]
2123    fn test_pre_tokenizer_digits() {
2124        let pt = PreTokenizer::Digits { individual_digits: true };
2125        let result = pt.apply("abc123def");
2126        assert_eq!(result, vec!["abc", "1", "2", "3", "def"]);
2127    }
2128
2129    #[test]
2130    fn test_pre_tokenizer_sequence() {
2131        let pt = PreTokenizer::Sequence(vec![
2132            PreTokenizer::Whitespace,
2133            PreTokenizer::Punctuation,
2134        ]);
2135        let result = pt.apply("Hello, world!");
2136        assert_eq!(result, vec!["Hello", ",", "world", "!"]);
2137    }
2138
2139    #[test]
2140    fn test_unigram_from_hf_json() {
2141        let json = r#"{
2142            "model": {
2143                "type": "Unigram",
2144                "unk_id": 0,
2145                "vocab": [
2146                    ["<unk>", 0.0],
2147                    ["▁", -1.0],
2148                    ["▁the", -2.0],
2149                    ["▁a", -2.5],
2150                    ["h", -3.0],
2151                    ["e", -3.0],
2152                    ["l", -3.0],
2153                    ["o", -3.0],
2154                    ["he", -2.0],
2155                    ["llo", -2.5]
2156                ]
2157            },
2158            "pre_tokenizer": {
2159                "type": "Metaspace",
2160                "replacement": "▁",
2161                "add_prefix_space": true
2162            },
2163            "added_tokens": [
2164                {"id": 0, "content": "<unk>", "special": true}
2165            ]
2166        }"#;
2167
2168        let tok = Tokenizer::from_hf_json_str(json).unwrap();
2169        assert_eq!(tok.tokenizer_type, TokenizerType::SentencePiece);
2170        assert_eq!(tok.vocab_size, 10);
2171        assert!(tok.scores.iter().any(|&s| s != 0.0));
2172    }
2173
2174    #[test]
2175    fn test_wordpiece_from_hf_json() {
2176        let json = r###"{
2177            "model": {
2178                "type": "WordPiece",
2179                "unk_token": "[UNK]",
2180                "continuing_subword_prefix": "##",
2181                "vocab": {
2182                    "[UNK]": 0,
2183                    "[CLS]": 1,
2184                    "[SEP]": 2,
2185                    "hello": 3,
2186                    "world": 4,
2187                    "he": 5,
2188                    "##llo": 6,
2189                    "wo": 7,
2190                    "##rld": 8
2191                }
2192            },
2193            "normalizer": {
2194                "type": "BertNormalizer",
2195                "lowercase": true,
2196                "strip_accents": false
2197            },
2198            "pre_tokenizer": {
2199                "type": "BertPreTokenizer"
2200            },
2201            "added_tokens": [
2202                {"id": 0, "content": "[UNK]", "special": true},
2203                {"id": 1, "content": "[CLS]", "special": true},
2204                {"id": 2, "content": "[SEP]", "special": true}
2205            ]
2206        }"###;
2207
2208        let tok = Tokenizer::from_hf_json_str(json).unwrap();
2209        assert_eq!(tok.tokenizer_type, TokenizerType::WordPiece);
2210        assert_eq!(tok.wordpiece_prefix, "##");
2211
2212        // "hello" should encode to [3] (direct match)
2213        let tokens = tok.encode("hello", false).unwrap();
2214        assert_eq!(tokens, vec![3]);
2215
2216        // "hello world" should encode to [3, 4] (both direct matches after whitespace split)
2217        let tokens = tok.encode("hello world", false).unwrap();
2218        assert_eq!(tokens, vec![3, 4]);
2219    }
2220
2221    #[test]
2222    fn test_wordpiece_subword_splitting() {
2223        let json = r###"{
2224            "model": {
2225                "type": "WordPiece",
2226                "unk_token": "[UNK]",
2227                "continuing_subword_prefix": "##",
2228                "vocab": {
2229                    "[UNK]": 0,
2230                    "[BOS]": 1,
2231                    "[EOS]": 2,
2232                    "un": 3,
2233                    "##know": 4,
2234                    "##n": 5,
2235                    "unknown": 6,
2236                    "the": 7,
2237                    "##s": 8
2238                }
2239            },
2240            "pre_tokenizer": { "type": "Whitespace" },
2241            "added_tokens": [
2242                {"id": 0, "content": "[UNK]", "special": true},
2243                {"id": 1, "content": "[BOS]", "special": true},
2244                {"id": 2, "content": "[EOS]", "special": true}
2245            ]
2246        }"###;
2247
2248        let tok = Tokenizer::from_hf_json_str(json).unwrap();
2249
2250        // "unknown" is a direct vocabulary match
2251        let tokens = tok.encode("unknown", false).unwrap();
2252        assert_eq!(tokens, vec![6]);
2253
2254        // "the" should encode to [7]
2255        let tokens = tok.encode("the", false).unwrap();
2256        assert_eq!(tokens, vec![7]);
2257
2258        // "thes" should split to "the" + "##s"
2259        let tokens = tok.encode("thes", false).unwrap();
2260        assert_eq!(tokens, vec![7, 8]);
2261    }
2262
2263    #[test]
2264    fn test_unigram_viterbi_encoding() {
2265        let json = r#"{
2266            "model": {
2267                "type": "Unigram",
2268                "unk_id": 0,
2269                "vocab": [
2270                    ["<unk>", 0.0],
2271                    ["<s>", 0.0],
2272                    ["</s>", 0.0],
2273                    ["a", -1.0],
2274                    ["b", -1.0],
2275                    ["c", -1.0],
2276                    ["ab", -0.5],
2277                    ["bc", -0.5],
2278                    ["abc", -0.1]
2279                ]
2280            },
2281            "pre_tokenizer": { "type": "Whitespace" },
2282            "added_tokens": [
2283                {"id": 0, "content": "<unk>", "special": true},
2284                {"id": 1, "content": "<s>", "special": true},
2285                {"id": 2, "content": "</s>", "special": true}
2286            ]
2287        }"#;
2288
2289        let tok = Tokenizer::from_hf_json_str(json).unwrap();
2290
2291        // "abc" should prefer the single token [abc] (score -0.1) over
2292        // [a,bc] (score -1.5) or [ab,c] (score -1.5) or [a,b,c] (score -3.0)
2293        let tokens = tok.encode("abc", false).unwrap();
2294        assert_eq!(tokens, vec![8]); // id 8 = "abc"
2295    }
2296
2297    #[test]
2298    fn test_bpe_with_pipeline() {
2299        let json = r#"{
2300            "model": {
2301                "type": "BPE",
2302                "vocab": {
2303                    "<s>": 0,
2304                    "</s>": 1,
2305                    "h": 2,
2306                    "e": 3,
2307                    "l": 4,
2308                    "o": 5,
2309                    "he": 6,
2310                    "ll": 7,
2311                    "hell": 8,
2312                    "hello": 9,
2313                    " ": 10
2314                },
2315                "merges": [
2316                    "h e",
2317                    "l l",
2318                    "he ll",
2319                    "hell o"
2320                ]
2321            },
2322            "pre_tokenizer": {
2323                "type": "ByteLevel",
2324                "add_prefix_space": false
2325            },
2326            "added_tokens": [
2327                {"id": 0, "content": "<s>", "special": true},
2328                {"id": 1, "content": "</s>", "special": true}
2329            ]
2330        }"#;
2331
2332        let tok = Tokenizer::from_hf_json_str(json).unwrap();
2333        assert_eq!(tok.tokenizer_type, TokenizerType::BPE);
2334        assert!(tok.pre_tokenizer.is_some());
2335
2336        // Should encode "hello" -> merge h+e=he, l+l=ll, he+ll=hell, hell+o=hello -> [9]
2337        let tokens = tok.encode("hello", false).unwrap();
2338        assert_eq!(tokens, vec![9]);
2339    }
2340
2341    #[test]
2342    fn test_parse_normalizer_types() {
2343        let nfc: serde_json::Value = serde_json::from_str(r#"{"type": "NFC"}"#).unwrap();
2344        let result = Tokenizer::parse_normalizer(&nfc);
2345        assert!(matches!(result, Some(Normalizer::NFC)));
2346
2347        let bert: serde_json::Value = serde_json::from_str(
2348            r#"{"type": "BertNormalizer", "lowercase": true, "strip_accents": true}"#,
2349        )
2350        .unwrap();
2351        let result = Tokenizer::parse_normalizer(&bert);
2352        assert!(matches!(result, Some(Normalizer::Sequence(_))));
2353
2354        let seq: serde_json::Value = serde_json::from_str(
2355            r#"{"type": "Sequence", "normalizers": [{"type": "NFC"}, {"type": "Lowercase"}]}"#,
2356        )
2357        .unwrap();
2358        let result = Tokenizer::parse_normalizer(&seq);
2359        assert!(matches!(result, Some(Normalizer::Sequence(_))));
2360    }
2361
2362    #[test]
2363    fn test_parse_pre_tokenizer_types() {
2364        let bl: serde_json::Value =
2365            serde_json::from_str(r#"{"type": "ByteLevel", "add_prefix_space": false}"#).unwrap();
2366        let result = Tokenizer::parse_pre_tokenizer(&bl);
2367        assert!(matches!(
2368            result,
2369            Some(PreTokenizer::ByteLevel { add_prefix_space: false })
2370        ));
2371
2372        let meta: serde_json::Value = serde_json::from_str(
2373            r#"{"type": "Metaspace", "replacement": "▁", "add_prefix_space": true}"#,
2374        )
2375        .unwrap();
2376        let result = Tokenizer::parse_pre_tokenizer(&meta);
2377        assert!(matches!(
2378            result,
2379            Some(PreTokenizer::Metaspace { add_prefix_space: true, .. })
2380        ));
2381
2382        let seq: serde_json::Value = serde_json::from_str(
2383            r#"{"type": "Sequence", "pretokenizers": [{"type": "Whitespace"}, {"type": "Punctuation"}]}"#,
2384        )
2385        .unwrap();
2386        let result = Tokenizer::parse_pre_tokenizer(&seq);
2387        assert!(matches!(result, Some(PreTokenizer::Sequence(_))));
2388    }
2389}