entrenar/tokenizer/
char.rs1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::config::TokenizerConfig;
8use super::error::{Result, TokenizerError};
9use super::traits::{TokenId, Tokenizer};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CharTokenizer {
14 config: TokenizerConfig,
15 vocab: HashMap<char, TokenId>,
16 id_to_char: HashMap<TokenId, char>,
17 trained: bool,
18}
19
20impl CharTokenizer {
21 pub fn new(config: TokenizerConfig) -> Self {
23 Self { config, vocab: HashMap::new(), id_to_char: HashMap::new(), trained: false }
24 }
25}
26
27impl Tokenizer for CharTokenizer {
28 fn train(&mut self, corpus: &[&str]) -> Result<()> {
29 let mut id: TokenId = 0;
30
31 let mut char_counts: HashMap<char, usize> = HashMap::new();
33 for text in corpus {
34 let processed =
35 if self.config.lowercase { text.to_lowercase() } else { text.to_string() };
36 for c in processed.chars() {
37 *char_counts.entry(c).or_insert(0) += 1;
38 }
39 }
40
41 let mut chars: Vec<_> = char_counts.into_iter().collect();
43 chars.sort_by(|a, b| b.1.cmp(&a.1));
44
45 for (c, count) in chars.into_iter().take(self.config.vocab_size) {
46 if count >= self.config.min_frequency {
47 self.vocab.insert(c, id);
48 self.id_to_char.insert(id, c);
49 id += 1;
50 }
51 }
52
53 self.trained = true;
54 Ok(())
55 }
56
57 fn encode(&self, text: &str) -> Result<Vec<TokenId>> {
58 if !self.trained {
59 return Err(TokenizerError::NotTrained);
60 }
61
62 let processed = if self.config.lowercase { text.to_lowercase() } else { text.to_string() };
63
64 let mut ids = Vec::new();
65 for c in processed.chars() {
66 if let Some(&id) = self.vocab.get(&c) {
67 ids.push(id);
68 }
69 }
71
72 Ok(ids)
73 }
74
75 fn decode(&self, ids: &[TokenId]) -> Result<String> {
76 if !self.trained {
77 return Err(TokenizerError::NotTrained);
78 }
79
80 let mut result = String::new();
81 for &id in ids {
82 if let Some(&c) = self.id_to_char.get(&id) {
83 result.push(c);
84 }
85 }
86
87 Ok(result)
88 }
89
90 fn vocab_size(&self) -> usize {
91 self.vocab.len()
92 }
93
94 fn is_trained(&self) -> bool {
95 self.trained
96 }
97
98 fn id_to_token(&self, _id: TokenId) -> Option<&str> {
99 None
101 }
102
103 fn token_to_id(&self, token: &str) -> Option<TokenId> {
104 if token.len() == 1 {
105 self.vocab
106 .get(&token.chars().next().expect("single-char token must have a char"))
107 .copied()
108 } else {
109 None
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_char_new() {
120 let config = TokenizerConfig::char();
121 let tokenizer = CharTokenizer::new(config);
122 assert!(!tokenizer.is_trained());
123 }
124
125 #[test]
126 fn test_char_train() {
127 let config = TokenizerConfig::char().with_min_frequency(1);
128 let mut tokenizer = CharTokenizer::new(config);
129
130 let corpus = vec!["hello", "world"];
131 tokenizer.train(&corpus).expect("operation should succeed");
132
133 assert!(tokenizer.is_trained());
134 assert_eq!(tokenizer.vocab_size(), 7);
136 }
137
138 #[test]
139 fn test_char_encode_decode() {
140 let config = TokenizerConfig::char().with_min_frequency(1);
141 let mut tokenizer = CharTokenizer::new(config);
142
143 let corpus = vec!["hello"];
144 tokenizer.train(&corpus).expect("operation should succeed");
145
146 let text = "hello";
147 let encoded = tokenizer.encode(text).expect("encoding should succeed");
148 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
149
150 assert_eq!(decoded, text);
151 }
152
153 #[test]
154 fn test_char_unknown_chars() {
155 let config = TokenizerConfig::char().with_min_frequency(1);
156 let mut tokenizer = CharTokenizer::new(config);
157
158 let corpus = vec!["abc"];
159 tokenizer.train(&corpus).expect("operation should succeed");
160
161 let encoded = tokenizer.encode("axbc").expect("encoding should succeed");
163 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
164
165 assert_eq!(decoded, "abc");
166 }
167
168 #[test]
169 fn test_char_lowercase() {
170 let config = TokenizerConfig::char().with_min_frequency(1).with_lowercase(true);
171 let mut tokenizer = CharTokenizer::new(config);
172
173 let corpus = vec!["Hello"];
174 tokenizer.train(&corpus).expect("operation should succeed");
175
176 let encoded = tokenizer.encode("HELLO").expect("encoding should succeed");
177 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
178
179 assert_eq!(decoded, "hello");
180 }
181}
182
183#[cfg(test)]
184mod property_tests {
185 use super::*;
186 use proptest::prelude::*;
187
188 proptest! {
189 #![proptest_config(ProptestConfig::with_cases(50))]
190
191 #[test]
192 fn prop_char_roundtrip(text in "[a-z]{1,20}") {
193 let config = TokenizerConfig::char().with_min_frequency(1);
194 let mut tokenizer = CharTokenizer::new(config);
195 tokenizer.train(&[&text]).expect("operation should succeed");
196
197 let encoded = tokenizer.encode(&text).expect("encoding should succeed");
198 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
199
200 prop_assert_eq!(decoded, text);
201 }
202
203 #[test]
204 fn prop_char_vocab_size_matches_unique_chars(text in "[a-z]{5,30}") {
205 let config = TokenizerConfig::char()
206 .with_min_frequency(1)
207 .with_vocab_size(256);
208 let mut tokenizer = CharTokenizer::new(config);
209 tokenizer.train(&[&text]).expect("operation should succeed");
210
211 let unique_chars: std::collections::HashSet<char> = text.chars().collect();
212 prop_assert_eq!(tokenizer.vocab_size(), unique_chars.len());
213 }
214 }
215}