Skip to main content

entrenar/tokenizer/
config.rs

1//! Tokenizer configuration types.
2
3use serde::{Deserialize, Serialize};
4
5/// Default vocabulary size (LLaMA/Mistral family)
6const DEFAULT_VOCAB_SIZE: usize = 32000;
7
8/// Special tokens
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SpecialTokens {
11    /// Unknown token
12    pub unk: String,
13    /// Beginning of sequence
14    pub bos: String,
15    /// End of sequence
16    pub eos: String,
17    /// Padding token
18    pub pad: String,
19    /// Mask token (for MLM)
20    pub mask: String,
21}
22
23impl Default for SpecialTokens {
24    fn default() -> Self {
25        Self {
26            unk: "<unk>".to_string(),
27            bos: "<s>".to_string(),
28            eos: "</s>".to_string(),
29            pad: "<pad>".to_string(),
30            mask: "<mask>".to_string(),
31        }
32    }
33}
34
35/// Tokenizer type
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum TokenizerType {
38    /// Byte Pair Encoding
39    BPE,
40    /// WordPiece (BERT-style)
41    WordPiece,
42    /// Character-level
43    Char,
44}
45
46/// Unicode normalization mode applied before byte-level encoding.
47///
48/// `tokenizer-bpe-v1.yaml` INV-TOK-003 mandates NFC for MODEL-2; without it,
49/// composed and decomposed variants of the same grapheme (e.g. `café`) hash
50/// to different byte sequences and the tokenizer drifts between training-
51/// time corpus preparation and inference-time input.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
53pub enum Normalization {
54    /// No normalization (backward-compat default).
55    #[default]
56    None,
57    /// Unicode NFC (canonical composition).
58    NFC,
59}
60
61/// Tokenizer configuration
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TokenizerConfig {
64    /// Target vocabulary size
65    pub vocab_size: usize,
66    /// Minimum token frequency for training
67    pub min_frequency: usize,
68    /// Special tokens
69    pub special_tokens: SpecialTokens,
70    /// Whether to lowercase input
71    pub lowercase: bool,
72    /// Tokenizer type
73    pub tokenizer_type: TokenizerType,
74    /// Unicode normalization mode
75    #[serde(default)]
76    pub normalization: Normalization,
77}
78
79impl Default for TokenizerConfig {
80    fn default() -> Self {
81        Self {
82            vocab_size: DEFAULT_VOCAB_SIZE,
83            min_frequency: 2,
84            special_tokens: SpecialTokens::default(),
85            lowercase: false,
86            tokenizer_type: TokenizerType::BPE,
87            normalization: Normalization::default(),
88        }
89    }
90}
91
92impl TokenizerConfig {
93    /// Create a BPE tokenizer config
94    pub fn bpe() -> Self {
95        Self { tokenizer_type: TokenizerType::BPE, ..Default::default() }
96    }
97
98    /// Create a WordPiece tokenizer config
99    pub fn wordpiece() -> Self {
100        Self { tokenizer_type: TokenizerType::WordPiece, ..Default::default() }
101    }
102
103    /// Create a character-level tokenizer config
104    pub fn char() -> Self {
105        Self { tokenizer_type: TokenizerType::Char, vocab_size: 256, ..Default::default() }
106    }
107
108    /// Set vocabulary size
109    pub fn with_vocab_size(mut self, size: usize) -> Self {
110        self.vocab_size = size;
111        self
112    }
113
114    /// Set minimum frequency
115    pub fn with_min_frequency(mut self, freq: usize) -> Self {
116        self.min_frequency = freq;
117        self
118    }
119
120    /// Enable lowercase preprocessing
121    pub fn with_lowercase(mut self, lowercase: bool) -> Self {
122        self.lowercase = lowercase;
123        self
124    }
125
126    /// Set the Unicode normalization mode.
127    pub fn with_normalization(mut self, normalization: Normalization) -> Self {
128        self.normalization = normalization;
129        self
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_tokenizer_config_default() {
139        let config = TokenizerConfig::default();
140        assert_eq!(config.vocab_size, 32000);
141        assert_eq!(config.tokenizer_type, TokenizerType::BPE);
142    }
143
144    #[test]
145    fn test_tokenizer_config_bpe() {
146        let config = TokenizerConfig::bpe().with_vocab_size(1000);
147        assert_eq!(config.vocab_size, 1000);
148        assert_eq!(config.tokenizer_type, TokenizerType::BPE);
149    }
150
151    #[test]
152    fn test_tokenizer_config_wordpiece() {
153        let config = TokenizerConfig::wordpiece();
154        assert_eq!(config.tokenizer_type, TokenizerType::WordPiece);
155    }
156
157    #[test]
158    fn test_tokenizer_config_char() {
159        let config = TokenizerConfig::char();
160        assert_eq!(config.tokenizer_type, TokenizerType::Char);
161        assert_eq!(config.vocab_size, 256);
162    }
163
164    #[test]
165    fn test_special_tokens_default() {
166        let special = SpecialTokens::default();
167        assert_eq!(special.unk, "<unk>");
168        assert_eq!(special.bos, "<s>");
169        assert_eq!(special.eos, "</s>");
170    }
171}