Skip to main content

oxirs_embed/
tokenizer.rs

1//! Text tokenizer for embeddings.
2//!
3//! Provides BPE (Byte Pair Encoding) and WordPiece tokenization, vocabulary
4//! management, special token handling, bidirectional token-to-ID mapping,
5//! text encoding/decoding, max-sequence-length truncation, and batch
6//! tokenization.
7
8use std::collections::HashMap;
9
10// ---------------------------------------------------------------------------
11// Special tokens
12// ---------------------------------------------------------------------------
13
14/// Well-known special tokens used by transformer models.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum SpecialToken {
17    /// Classification token (typically index 0).
18    Cls,
19    /// Separator between sentences.
20    Sep,
21    /// Padding token.
22    Pad,
23    /// Unknown / out-of-vocabulary token.
24    Unk,
25    /// Masked token for MLM pre-training.
26    Mask,
27}
28
29impl SpecialToken {
30    /// Canonical string representation (e.g. `[CLS]`).
31    pub fn as_str(&self) -> &'static str {
32        match self {
33            SpecialToken::Cls => "[CLS]",
34            SpecialToken::Sep => "[SEP]",
35            SpecialToken::Pad => "[PAD]",
36            SpecialToken::Unk => "[UNK]",
37            SpecialToken::Mask => "[MASK]",
38        }
39    }
40
41    /// All built-in special tokens.
42    pub fn all() -> &'static [SpecialToken] {
43        &[
44            SpecialToken::Cls,
45            SpecialToken::Sep,
46            SpecialToken::Pad,
47            SpecialToken::Unk,
48            SpecialToken::Mask,
49        ]
50    }
51}
52
53// ---------------------------------------------------------------------------
54// Tokenizer mode
55// ---------------------------------------------------------------------------
56
57/// The sub-word algorithm used by the tokenizer.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum TokenizerMode {
60    /// Byte Pair Encoding with iterative merge rules.
61    Bpe,
62    /// WordPiece with `##` continuation prefix.
63    WordPiece,
64}
65
66// ---------------------------------------------------------------------------
67// BPE merge rule
68// ---------------------------------------------------------------------------
69
70/// A single BPE merge rule: pair `(left, right)` merged into `merged`.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct MergeRule {
73    pub left: String,
74    pub right: String,
75    pub merged: String,
76}
77
78// ---------------------------------------------------------------------------
79// Encode result
80// ---------------------------------------------------------------------------
81
82/// The result of encoding a piece of text.
83#[derive(Debug, Clone)]
84pub struct EncodeResult {
85    /// Token string representations.
86    pub tokens: Vec<String>,
87    /// Corresponding integer IDs.
88    pub ids: Vec<u32>,
89}
90
91// ---------------------------------------------------------------------------
92// Tokenizer
93// ---------------------------------------------------------------------------
94
95/// Configuration for building a `Tokenizer`.
96#[derive(Debug, Clone)]
97pub struct TokenizerConfig {
98    /// Sub-word algorithm.
99    pub mode: TokenizerMode,
100    /// Maximum sequence length (tokens). Encoding is truncated to this.
101    pub max_length: usize,
102    /// Whether to automatically lower-case input text before tokenization.
103    pub lowercase: bool,
104}
105
106impl Default for TokenizerConfig {
107    fn default() -> Self {
108        Self {
109            mode: TokenizerMode::Bpe,
110            max_length: 512,
111            lowercase: true,
112        }
113    }
114}
115
116/// A text tokenizer supporting BPE and WordPiece sub-word algorithms.
117pub struct Tokenizer {
118    config: TokenizerConfig,
119    /// token-string -> ID
120    token_to_id: HashMap<String, u32>,
121    /// ID -> token-string
122    id_to_token: HashMap<u32, String>,
123    /// Next ID to assign when adding a new token.
124    next_id: u32,
125    /// Ordered BPE merge rules (only used in BPE mode).
126    merge_rules: Vec<MergeRule>,
127}
128
129impl Tokenizer {
130    // ── Construction ─────────────────────────────────────────────────────
131
132    /// Create a new tokenizer with the given configuration.
133    ///
134    /// Special tokens (`[CLS]`, `[SEP]`, `[PAD]`, `[UNK]`, `[MASK]`) are
135    /// registered automatically.
136    pub fn new(config: TokenizerConfig) -> Self {
137        let mut tok = Self {
138            config,
139            token_to_id: HashMap::new(),
140            id_to_token: HashMap::new(),
141            next_id: 0,
142            merge_rules: Vec::new(),
143        };
144        // Register all special tokens up-front.
145        for st in SpecialToken::all() {
146            tok.add_token(st.as_str());
147        }
148        tok
149    }
150
151    /// Build a default BPE tokenizer.
152    pub fn bpe() -> Self {
153        Self::new(TokenizerConfig {
154            mode: TokenizerMode::Bpe,
155            ..TokenizerConfig::default()
156        })
157    }
158
159    /// Build a default WordPiece tokenizer.
160    pub fn wordpiece() -> Self {
161        Self::new(TokenizerConfig {
162            mode: TokenizerMode::WordPiece,
163            ..TokenizerConfig::default()
164        })
165    }
166
167    // ── Vocabulary management ────────────────────────────────────────────
168
169    /// Add a token to the vocabulary.  Returns its ID.
170    ///
171    /// If the token already exists, the existing ID is returned.
172    pub fn add_token(&mut self, token: &str) -> u32 {
173        if let Some(&id) = self.token_to_id.get(token) {
174            return id;
175        }
176        let id = self.next_id;
177        self.next_id += 1;
178        self.token_to_id.insert(token.to_string(), id);
179        self.id_to_token.insert(id, token.to_string());
180        id
181    }
182
183    /// Remove a token from the vocabulary.  Returns `true` if it existed.
184    ///
185    /// Special tokens cannot be removed; in that case `false` is returned.
186    pub fn remove_token(&mut self, token: &str) -> bool {
187        // Guard special tokens.
188        for st in SpecialToken::all() {
189            if st.as_str() == token {
190                return false;
191            }
192        }
193        if let Some(id) = self.token_to_id.remove(token) {
194            self.id_to_token.remove(&id);
195            return true;
196        }
197        false
198    }
199
200    /// Current vocabulary size (including special tokens).
201    pub fn vocab_size(&self) -> usize {
202        self.token_to_id.len()
203    }
204
205    /// Whether a token is in the vocabulary.
206    pub fn contains_token(&self, token: &str) -> bool {
207        self.token_to_id.contains_key(token)
208    }
209
210    // ── BPE merge rules ──────────────────────────────────────────────────
211
212    /// Add a BPE merge rule.  The merged token is also added to the vocab.
213    pub fn add_merge_rule(&mut self, left: &str, right: &str) {
214        let merged = format!("{left}{right}");
215        self.add_token(&merged);
216        self.merge_rules.push(MergeRule {
217            left: left.to_string(),
218            right: right.to_string(),
219            merged,
220        });
221    }
222
223    /// Number of registered merge rules.
224    pub fn merge_rule_count(&self) -> usize {
225        self.merge_rules.len()
226    }
227
228    // ── Token ↔ ID mapping ──────────────────────────────────────────────
229
230    /// Look up the ID of a token.
231    pub fn token_to_id(&self, token: &str) -> Option<u32> {
232        self.token_to_id.get(token).copied()
233    }
234
235    /// Look up the token string for an ID.
236    pub fn id_to_token(&self, id: u32) -> Option<&str> {
237        self.id_to_token.get(&id).map(String::as_str)
238    }
239
240    /// ID of the `[UNK]` token.
241    pub fn unk_id(&self) -> u32 {
242        self.token_to_id
243            .get(SpecialToken::Unk.as_str())
244            .copied()
245            .unwrap_or(0)
246    }
247
248    /// ID of the `[CLS]` token.
249    pub fn cls_id(&self) -> u32 {
250        self.token_to_id
251            .get(SpecialToken::Cls.as_str())
252            .copied()
253            .unwrap_or(0)
254    }
255
256    /// ID of the `[SEP]` token.
257    pub fn sep_id(&self) -> u32 {
258        self.token_to_id
259            .get(SpecialToken::Sep.as_str())
260            .copied()
261            .unwrap_or(0)
262    }
263
264    /// ID of the `[PAD]` token.
265    pub fn pad_id(&self) -> u32 {
266        self.token_to_id
267            .get(SpecialToken::Pad.as_str())
268            .copied()
269            .unwrap_or(0)
270    }
271
272    // ── Encoding ─────────────────────────────────────────────────────────
273
274    /// Encode a text string into token IDs.
275    ///
276    /// The output is truncated to `config.max_length`.
277    pub fn encode(&self, text: &str) -> EncodeResult {
278        let text = if self.config.lowercase {
279            text.to_lowercase()
280        } else {
281            text.to_string()
282        };
283
284        let sub_tokens = match &self.config.mode {
285            TokenizerMode::Bpe => self.bpe_tokenize(&text),
286            TokenizerMode::WordPiece => self.wordpiece_tokenize(&text),
287        };
288
289        let max = self.config.max_length;
290        let truncated: Vec<String> = sub_tokens.into_iter().take(max).collect();
291        let ids: Vec<u32> = truncated
292            .iter()
293            .map(|t| {
294                self.token_to_id
295                    .get(t.as_str())
296                    .copied()
297                    .unwrap_or_else(|| self.unk_id())
298            })
299            .collect();
300
301        EncodeResult {
302            tokens: truncated,
303            ids,
304        }
305    }
306
307    /// Decode a sequence of token IDs back into a string.
308    ///
309    /// WordPiece continuation tokens (`##…`) are merged back without spaces.
310    pub fn decode(&self, ids: &[u32]) -> String {
311        let mut parts: Vec<String> = Vec::with_capacity(ids.len());
312        for &id in ids {
313            if let Some(tok) = self.id_to_token.get(&id) {
314                // Skip special tokens in decoded output.
315                let is_special = SpecialToken::all().iter().any(|st| st.as_str() == tok);
316                if is_special {
317                    continue;
318                }
319                parts.push(tok.clone());
320            }
321        }
322        self.merge_subwords(&parts)
323    }
324
325    /// Encode a batch of texts.
326    pub fn encode_batch(&self, texts: &[&str]) -> Vec<EncodeResult> {
327        texts.iter().map(|t| self.encode(t)).collect()
328    }
329
330    // ── Sub-word merging (decode helper) ─────────────────────────────────
331
332    /// Merge sub-word tokens back into words.
333    ///
334    /// WordPiece continuations (`##xyz`) are concatenated to the preceding
335    /// token without a space.  BPE tokens are simply joined with spaces.
336    fn merge_subwords(&self, tokens: &[String]) -> String {
337        if tokens.is_empty() {
338            return String::new();
339        }
340
341        match &self.config.mode {
342            TokenizerMode::WordPiece => {
343                let mut result = String::new();
344                for tok in tokens {
345                    if let Some(suffix) = tok.strip_prefix("##") {
346                        result.push_str(suffix);
347                    } else {
348                        if !result.is_empty() {
349                            result.push(' ');
350                        }
351                        result.push_str(tok);
352                    }
353                }
354                result
355            }
356            TokenizerMode::Bpe => tokens.join(" "),
357        }
358    }
359
360    // ── BPE tokenization ─────────────────────────────────────────────────
361
362    /// Tokenize `text` using BPE merge rules.
363    fn bpe_tokenize(&self, text: &str) -> Vec<String> {
364        let words: Vec<&str> = text.split_whitespace().collect();
365        let mut all_tokens: Vec<String> = Vec::new();
366
367        for word in words {
368            // Start with individual characters.
369            let mut symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
370
371            // Apply merge rules in priority order.
372            for rule in &self.merge_rules {
373                symbols = Self::apply_merge(&symbols, &rule.left, &rule.right, &rule.merged);
374            }
375
376            // Map to vocab; fall back to [UNK] per symbol.
377            for sym in symbols {
378                if self.token_to_id.contains_key(&sym) {
379                    all_tokens.push(sym);
380                } else {
381                    all_tokens.push(SpecialToken::Unk.as_str().to_string());
382                }
383            }
384        }
385
386        all_tokens
387    }
388
389    /// Apply one merge rule to a symbol sequence.
390    fn apply_merge(symbols: &[String], left: &str, right: &str, merged: &str) -> Vec<String> {
391        let mut result: Vec<String> = Vec::with_capacity(symbols.len());
392        let mut i = 0;
393        while i < symbols.len() {
394            if i + 1 < symbols.len() && symbols[i] == left && symbols[i + 1] == right {
395                result.push(merged.to_string());
396                i += 2;
397            } else {
398                result.push(symbols[i].clone());
399                i += 1;
400            }
401        }
402        result
403    }
404
405    // ── WordPiece tokenization ───────────────────────────────────────────
406
407    /// Tokenize `text` using greedy longest-match WordPiece.
408    fn wordpiece_tokenize(&self, text: &str) -> Vec<String> {
409        let words: Vec<&str> = text.split_whitespace().collect();
410        let mut all_tokens: Vec<String> = Vec::new();
411
412        for word in words {
413            let chars: Vec<char> = word.chars().collect();
414            let n = chars.len();
415            let mut start = 0;
416
417            while start < n {
418                let mut end = n;
419                let mut found = false;
420
421                while start < end {
422                    let sub: String = chars[start..end].iter().collect();
423                    let candidate = if start == 0 {
424                        sub.clone()
425                    } else {
426                        format!("##{sub}")
427                    };
428
429                    if self.token_to_id.contains_key(&candidate) {
430                        all_tokens.push(candidate);
431                        start = end;
432                        found = true;
433                        break;
434                    }
435                    end -= 1;
436                }
437
438                if !found {
439                    // Single character not in vocab → [UNK].
440                    all_tokens.push(SpecialToken::Unk.as_str().to_string());
441                    start += 1;
442                }
443            }
444        }
445
446        all_tokens
447    }
448
449    // ── Config access ────────────────────────────────────────────────────
450
451    /// Maximum sequence length.
452    pub fn max_length(&self) -> usize {
453        self.config.max_length
454    }
455
456    /// Active tokenization mode.
457    pub fn mode(&self) -> &TokenizerMode {
458        &self.config.mode
459    }
460
461    /// Whether input is lowercased before tokenization.
462    pub fn is_lowercase(&self) -> bool {
463        self.config.lowercase
464    }
465}
466
467// ---------------------------------------------------------------------------
468// Tests
469// ---------------------------------------------------------------------------
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    fn bpe_tokenizer() -> Tokenizer {
476        Tokenizer::bpe()
477    }
478
479    fn wp_tokenizer() -> Tokenizer {
480        Tokenizer::wordpiece()
481    }
482
483    // ── SpecialToken ─────────────────────────────────────────────────────
484
485    #[test]
486    fn test_special_token_cls_str() {
487        assert_eq!(SpecialToken::Cls.as_str(), "[CLS]");
488    }
489
490    #[test]
491    fn test_special_token_sep_str() {
492        assert_eq!(SpecialToken::Sep.as_str(), "[SEP]");
493    }
494
495    #[test]
496    fn test_special_token_pad_str() {
497        assert_eq!(SpecialToken::Pad.as_str(), "[PAD]");
498    }
499
500    #[test]
501    fn test_special_token_unk_str() {
502        assert_eq!(SpecialToken::Unk.as_str(), "[UNK]");
503    }
504
505    #[test]
506    fn test_special_token_mask_str() {
507        assert_eq!(SpecialToken::Mask.as_str(), "[MASK]");
508    }
509
510    #[test]
511    fn test_special_token_all_count() {
512        assert_eq!(SpecialToken::all().len(), 5);
513    }
514
515    // ── Tokenizer construction ───────────────────────────────────────────
516
517    #[test]
518    fn test_new_bpe_has_special_tokens() {
519        let tok = bpe_tokenizer();
520        assert!(tok.contains_token("[CLS]"));
521        assert!(tok.contains_token("[SEP]"));
522        assert!(tok.contains_token("[PAD]"));
523        assert!(tok.contains_token("[UNK]"));
524        assert!(tok.contains_token("[MASK]"));
525    }
526
527    #[test]
528    fn test_new_bpe_vocab_size() {
529        let tok = bpe_tokenizer();
530        assert_eq!(tok.vocab_size(), 5); // 5 special tokens
531    }
532
533    #[test]
534    fn test_new_wordpiece_mode() {
535        let tok = wp_tokenizer();
536        assert_eq!(*tok.mode(), TokenizerMode::WordPiece);
537    }
538
539    #[test]
540    fn test_bpe_mode() {
541        let tok = bpe_tokenizer();
542        assert_eq!(*tok.mode(), TokenizerMode::Bpe);
543    }
544
545    // ── Vocabulary management ────────────────────────────────────────────
546
547    #[test]
548    fn test_add_token_returns_new_id() {
549        let mut tok = bpe_tokenizer();
550        let id1 = tok.add_token("hello");
551        let id2 = tok.add_token("world");
552        assert_ne!(id1, id2);
553    }
554
555    #[test]
556    fn test_add_token_idempotent() {
557        let mut tok = bpe_tokenizer();
558        let id1 = tok.add_token("hello");
559        let id2 = tok.add_token("hello");
560        assert_eq!(id1, id2);
561        // vocab should not grow
562        assert_eq!(tok.vocab_size(), 6); // 5 special + 1
563    }
564
565    #[test]
566    fn test_remove_token_normal() {
567        let mut tok = bpe_tokenizer();
568        tok.add_token("temp");
569        assert!(tok.contains_token("temp"));
570        assert!(tok.remove_token("temp"));
571        assert!(!tok.contains_token("temp"));
572    }
573
574    #[test]
575    fn test_remove_special_token_prevented() {
576        let mut tok = bpe_tokenizer();
577        assert!(!tok.remove_token("[CLS]"));
578        assert!(tok.contains_token("[CLS]"));
579    }
580
581    #[test]
582    fn test_remove_nonexistent_returns_false() {
583        let mut tok = bpe_tokenizer();
584        assert!(!tok.remove_token("nonexistent"));
585    }
586
587    #[test]
588    fn test_vocab_size_grows() {
589        let mut tok = bpe_tokenizer();
590        assert_eq!(tok.vocab_size(), 5);
591        tok.add_token("a");
592        tok.add_token("b");
593        assert_eq!(tok.vocab_size(), 7);
594    }
595
596    // ── Token ↔ ID mapping ──────────────────────────────────────────────
597
598    #[test]
599    fn test_token_to_id_roundtrip() {
600        let mut tok = bpe_tokenizer();
601        let id = tok.add_token("cat");
602        assert_eq!(tok.token_to_id("cat"), Some(id));
603        assert_eq!(tok.id_to_token(id), Some("cat"));
604    }
605
606    #[test]
607    fn test_token_to_id_missing() {
608        let tok = bpe_tokenizer();
609        assert_eq!(tok.token_to_id("missing"), None);
610    }
611
612    #[test]
613    fn test_id_to_token_missing() {
614        let tok = bpe_tokenizer();
615        assert_eq!(tok.id_to_token(9999), None);
616    }
617
618    #[test]
619    fn test_unk_id() {
620        let tok = bpe_tokenizer();
621        let unk = tok.unk_id();
622        assert_eq!(tok.id_to_token(unk), Some("[UNK]"));
623    }
624
625    #[test]
626    fn test_cls_id() {
627        let tok = bpe_tokenizer();
628        let cls = tok.cls_id();
629        assert_eq!(tok.id_to_token(cls), Some("[CLS]"));
630    }
631
632    #[test]
633    fn test_sep_id() {
634        let tok = bpe_tokenizer();
635        let sep = tok.sep_id();
636        assert_eq!(tok.id_to_token(sep), Some("[SEP]"));
637    }
638
639    #[test]
640    fn test_pad_id() {
641        let tok = bpe_tokenizer();
642        let pad = tok.pad_id();
643        assert_eq!(tok.id_to_token(pad), Some("[PAD]"));
644    }
645
646    // ── BPE merge rules ──────────────────────────────────────────────────
647
648    #[test]
649    fn test_add_merge_rule_creates_merged_token() {
650        let mut tok = bpe_tokenizer();
651        tok.add_token("h");
652        tok.add_token("e");
653        tok.add_merge_rule("h", "e");
654        assert!(tok.contains_token("he"));
655        assert_eq!(tok.merge_rule_count(), 1);
656    }
657
658    #[test]
659    fn test_bpe_merge_rules_applied_in_order() {
660        let mut tok = bpe_tokenizer();
661        // Build vocab: individual chars + merges
662        tok.add_token("h");
663        tok.add_token("e");
664        tok.add_token("l");
665        tok.add_token("o");
666        tok.add_merge_rule("h", "e"); // he
667        tok.add_merge_rule("l", "o"); // lo
668        tok.add_merge_rule("he", "l"); // hel
669        tok.add_merge_rule("hel", "lo"); // hello
670
671        let result = tok.encode("hello");
672        assert!(result.tokens.contains(&"hello".to_string()));
673    }
674
675    // ── BPE encoding ─────────────────────────────────────────────────────
676
677    #[test]
678    fn test_bpe_encode_unknown_chars() {
679        let tok = bpe_tokenizer();
680        // No char tokens registered → everything maps to [UNK]
681        let result = tok.encode("xyz");
682        assert!(result.ids.iter().all(|&id| id == tok.unk_id()));
683    }
684
685    #[test]
686    fn test_bpe_encode_single_char_tokens() {
687        let mut tok = bpe_tokenizer();
688        tok.add_token("a");
689        tok.add_token("b");
690        let result = tok.encode("ab");
691        assert_eq!(result.tokens, vec!["a", "b"]);
692    }
693
694    #[test]
695    fn test_bpe_encode_multiple_words() {
696        let mut tok = bpe_tokenizer();
697        tok.add_token("h");
698        tok.add_token("i");
699        let result = tok.encode("hi hi");
700        assert_eq!(result.tokens.len(), 4); // h i h i
701    }
702
703    // ── WordPiece encoding ───────────────────────────────────────────────
704
705    #[test]
706    fn test_wordpiece_full_word_match() {
707        let mut tok = wp_tokenizer();
708        tok.add_token("hello");
709        let result = tok.encode("hello");
710        assert_eq!(result.tokens, vec!["hello"]);
711    }
712
713    #[test]
714    fn test_wordpiece_continuation_tokens() {
715        let mut tok = wp_tokenizer();
716        tok.add_token("un");
717        tok.add_token("##believ");
718        tok.add_token("##able");
719        let result = tok.encode("unbelievable");
720        assert_eq!(result.tokens, vec!["un", "##believ", "##able"]);
721    }
722
723    #[test]
724    fn test_wordpiece_unknown_fallback() {
725        let tok = wp_tokenizer();
726        let result = tok.encode("xyz");
727        // Each unknown character becomes [UNK]
728        assert!(result.ids.iter().all(|&id| id == tok.unk_id()));
729    }
730
731    #[test]
732    fn test_wordpiece_multiple_words() {
733        let mut tok = wp_tokenizer();
734        tok.add_token("hello");
735        tok.add_token("world");
736        let result = tok.encode("hello world");
737        assert_eq!(result.tokens, vec!["hello", "world"]);
738    }
739
740    // ── Decoding ─────────────────────────────────────────────────────────
741
742    #[test]
743    fn test_bpe_decode_simple() {
744        let mut tok = bpe_tokenizer();
745        let id_a = tok.add_token("hello");
746        let id_b = tok.add_token("world");
747        let decoded = tok.decode(&[id_a, id_b]);
748        assert_eq!(decoded, "hello world");
749    }
750
751    #[test]
752    fn test_wordpiece_decode_merges_continuations() {
753        let mut tok = wp_tokenizer();
754        let id_un = tok.add_token("un");
755        let id_do = tok.add_token("##do");
756        let decoded = tok.decode(&[id_un, id_do]);
757        assert_eq!(decoded, "undo");
758    }
759
760    #[test]
761    fn test_decode_skips_special_tokens() {
762        let tok = bpe_tokenizer();
763        let cls = tok.cls_id();
764        let sep = tok.sep_id();
765        let decoded = tok.decode(&[cls, sep]);
766        assert_eq!(decoded, "");
767    }
768
769    #[test]
770    fn test_decode_empty() {
771        let tok = bpe_tokenizer();
772        assert_eq!(tok.decode(&[]), "");
773    }
774
775    // ── Max-length truncation ────────────────────────────────────────────
776
777    #[test]
778    fn test_truncation_at_max_length() {
779        let mut tok = Tokenizer::new(TokenizerConfig {
780            mode: TokenizerMode::Bpe,
781            max_length: 3,
782            lowercase: true,
783        });
784        tok.add_token("a");
785        tok.add_token("b");
786        tok.add_token("c");
787        tok.add_token("d");
788        let result = tok.encode("a b c d");
789        assert_eq!(result.tokens.len(), 3);
790        assert_eq!(result.ids.len(), 3);
791    }
792
793    #[test]
794    fn test_truncation_shorter_text_unaffected() {
795        let mut tok = Tokenizer::new(TokenizerConfig {
796            mode: TokenizerMode::Bpe,
797            max_length: 100,
798            lowercase: true,
799        });
800        tok.add_token("x");
801        let result = tok.encode("x");
802        assert_eq!(result.tokens.len(), 1);
803    }
804
805    // ── Batch tokenization ───────────────────────────────────────────────
806
807    #[test]
808    fn test_encode_batch_count() {
809        let mut tok = bpe_tokenizer();
810        tok.add_token("a");
811        let results = tok.encode_batch(&["a", "a a", "a a a"]);
812        assert_eq!(results.len(), 3);
813    }
814
815    #[test]
816    fn test_encode_batch_independent() {
817        let mut tok = bpe_tokenizer();
818        tok.add_token("x");
819        tok.add_token("y");
820        let results = tok.encode_batch(&["x", "y"]);
821        assert_ne!(results[0].ids, results[1].ids);
822    }
823
824    #[test]
825    fn test_encode_batch_empty() {
826        let tok = bpe_tokenizer();
827        let results = tok.encode_batch(&[]);
828        assert!(results.is_empty());
829    }
830
831    // ── Lowercase handling ───────────────────────────────────────────────
832
833    #[test]
834    fn test_lowercase_enabled() {
835        let mut tok = Tokenizer::new(TokenizerConfig {
836            mode: TokenizerMode::Bpe,
837            max_length: 512,
838            lowercase: true,
839        });
840        tok.add_token("hello");
841        // "h", "e", ... are individual chars, but "hello" is known as full word in BPE only
842        // after merges. Test that uppercase is lowered:
843        let r1 = tok.encode("HELLO");
844        let r2 = tok.encode("hello");
845        assert_eq!(r1.ids, r2.ids);
846    }
847
848    #[test]
849    fn test_lowercase_disabled() {
850        let mut tok = Tokenizer::new(TokenizerConfig {
851            mode: TokenizerMode::Bpe,
852            max_length: 512,
853            lowercase: false,
854        });
855        tok.add_token("A");
856        tok.add_token("a");
857        let r1 = tok.encode("A");
858        let r2 = tok.encode("a");
859        assert_ne!(r1.ids, r2.ids);
860    }
861
862    // ── Config accessors ─────────────────────────────────────────────────
863
864    #[test]
865    fn test_max_length_accessor() {
866        let tok = bpe_tokenizer();
867        assert_eq!(tok.max_length(), 512);
868    }
869
870    #[test]
871    fn test_is_lowercase_accessor() {
872        let tok = bpe_tokenizer();
873        assert!(tok.is_lowercase());
874    }
875
876    // ── Edge cases ───────────────────────────────────────────────────────
877
878    #[test]
879    fn test_encode_empty_string() {
880        let tok = bpe_tokenizer();
881        let result = tok.encode("");
882        assert!(result.tokens.is_empty());
883        assert!(result.ids.is_empty());
884    }
885
886    #[test]
887    fn test_encode_whitespace_only() {
888        let tok = bpe_tokenizer();
889        let result = tok.encode("   ");
890        assert!(result.tokens.is_empty());
891    }
892
893    #[test]
894    fn test_wordpiece_greedy_longest_match() {
895        let mut tok = wp_tokenizer();
896        tok.add_token("play");
897        tok.add_token("##ing");
898        tok.add_token("##i");
899        tok.add_token("##n");
900        tok.add_token("##g");
901        let result = tok.encode("playing");
902        // Should prefer "##ing" over "##i" + "##n" + "##g"
903        assert_eq!(result.tokens, vec!["play", "##ing"]);
904    }
905
906    #[test]
907    fn test_merge_rule_struct_fields() {
908        let rule = MergeRule {
909            left: "a".to_string(),
910            right: "b".to_string(),
911            merged: "ab".to_string(),
912        };
913        assert_eq!(rule.left, "a");
914        assert_eq!(rule.right, "b");
915        assert_eq!(rule.merged, "ab");
916    }
917
918    #[test]
919    fn test_encode_result_tokens_and_ids_same_length() {
920        let mut tok = bpe_tokenizer();
921        tok.add_token("t");
922        tok.add_token("e");
923        tok.add_token("s");
924        let result = tok.encode("test");
925        assert_eq!(result.tokens.len(), result.ids.len());
926    }
927
928    #[test]
929    fn test_tokenizer_config_default() {
930        let cfg = TokenizerConfig::default();
931        assert_eq!(cfg.mode, TokenizerMode::Bpe);
932        assert_eq!(cfg.max_length, 512);
933        assert!(cfg.lowercase);
934    }
935}