1use std::collections::HashMap;
13
14#[derive(Clone)]
16pub struct CharTokenizer {
17 pub char_to_idx: Vec<(char, usize)>,
18 pub idx_to_char: Vec<char>,
19 pub vocab_size: usize,
20}
21
22impl CharTokenizer {
23 pub fn from_text(text: &str) -> Self {
25 let mut chars: Vec<char> = text
26 .chars()
27 .collect::<std::collections::BTreeSet<_>>()
28 .into_iter()
29 .collect();
30 chars.sort();
31 let char_to_idx: Vec<(char, usize)> = chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
32 let idx_to_char = chars;
33 let vocab_size = idx_to_char.len();
34 Self {
35 char_to_idx,
36 idx_to_char,
37 vocab_size,
38 }
39 }
40
41 pub fn encode(&self, text: &str) -> Vec<usize> {
42 text.chars()
43 .map(|c| {
44 self.char_to_idx
45 .iter()
46 .find(|&&(ch, _)| ch == c)
47 .map(|&(_, idx)| idx)
48 .unwrap_or(0)
49 })
50 .collect()
51 }
52
53 pub fn decode(&self, tokens: &[usize]) -> String {
54 tokens
55 .iter()
56 .map(|&idx| self.idx_to_char.get(idx).copied().unwrap_or('?'))
57 .collect()
58 }
59}
60
61#[derive(Clone)]
69pub struct BpeTokenizer {
70 merges: Vec<(String, String)>,
72 token_to_idx: HashMap<String, usize>,
74 idx_to_token: Vec<String>,
76 pub vocab_size: usize,
78}
79
80impl BpeTokenizer {
81 pub fn train(text: &str, target_vocab: usize) -> Self {
86 let mut base_chars: Vec<char> = text
88 .chars()
89 .collect::<std::collections::BTreeSet<_>>()
90 .into_iter()
91 .collect();
92 base_chars.sort();
93 let base_vocab_size = base_chars.len();
94
95 let mut token_to_idx: HashMap<String, usize> = HashMap::new();
97 let mut idx_to_token: Vec<String> = Vec::new();
98 for (i, &c) in base_chars.iter().enumerate() {
99 let s = c.to_string();
100 token_to_idx.insert(s.clone(), i);
101 idx_to_token.push(s);
102 }
103
104 let mut corpus_tokens: Vec<Vec<String>> = text
106 .lines()
107 .map(|line| line.chars().map(|c| c.to_string()).collect())
108 .collect();
109
110 let num_merges = target_vocab.saturating_sub(base_vocab_size);
112 let mut merges: Vec<(String, String)> = Vec::with_capacity(num_merges);
113
114 for _merge_round in 0..num_merges {
115 let mut pair_counts: HashMap<(String, String), usize> = HashMap::new();
117 for seq in &corpus_tokens {
118 for window in seq.windows(2) {
119 let pair = (window[0].clone(), window[1].clone());
120 *pair_counts.entry(pair).or_insert(0) += 1;
121 }
122 }
123
124 let best = pair_counts.into_iter().max_by_key(|&(_, count)| count);
126 let (best_pair, best_count) = match best {
127 Some((pair, count)) if count >= 2 => (pair, count),
128 _ => break, };
130 let _ = best_count;
131
132 let merged = format!("{}{}", best_pair.0, best_pair.1);
134 let new_idx = idx_to_token.len();
135 token_to_idx.insert(merged.clone(), new_idx);
136 idx_to_token.push(merged.clone());
137 merges.push(best_pair.clone());
138
139 for seq in &mut corpus_tokens {
141 let mut i = 0;
142 while i + 1 < seq.len() {
143 if seq[i] == best_pair.0 && seq[i + 1] == best_pair.1 {
144 seq[i] = merged.clone();
145 seq.remove(i + 1);
146 } else {
148 i += 1;
149 }
150 }
151 }
152 }
153
154 let vocab_size = idx_to_token.len();
155 println!(
156 "BPE: {} merges, vocab = {} (base {} + {} merges)",
157 merges.len(),
158 vocab_size,
159 base_vocab_size,
160 merges.len()
161 );
162
163 Self {
164 merges,
165 token_to_idx,
166 idx_to_token,
167 vocab_size,
168 }
169 }
170
171 pub fn encode(&self, text: &str) -> Vec<usize> {
173 let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
175
176 for (a, b) in &self.merges {
178 let merged = format!("{}{}", a, b);
179 let mut i = 0;
180 while i + 1 < tokens.len() {
181 if tokens[i] == *a && tokens[i + 1] == *b {
182 tokens[i] = merged.clone();
183 tokens.remove(i + 1);
184 } else {
185 i += 1;
186 }
187 }
188 }
189
190 tokens
192 .iter()
193 .map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
194 .collect()
195 }
196
197 pub fn decode(&self, tokens: &[usize]) -> String {
199 tokens
200 .iter()
201 .map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
202 .collect()
203 }
204}
205
206#[derive(Clone)]
219pub struct MiTokenizer {
220 merges: Vec<(String, String)>,
221 token_to_idx: HashMap<String, usize>,
222 idx_to_token: Vec<String>,
223 pub vocab_size: usize,
224}
225
226impl MiTokenizer {
227 pub fn train(text: &str, target_vocab: usize) -> Self {
230 let mut base_chars: Vec<char> = text
232 .chars()
233 .collect::<std::collections::BTreeSet<_>>()
234 .into_iter()
235 .collect();
236 base_chars.sort();
237
238 let mut token_to_idx: HashMap<String, usize> = HashMap::new();
239 let mut idx_to_token: Vec<String> = Vec::new();
240 for (i, &c) in base_chars.iter().enumerate() {
241 let s = c.to_string();
242 token_to_idx.insert(s.clone(), i);
243 idx_to_token.push(s);
244 }
245
246 let mut corpus: Vec<Vec<String>> = text
248 .lines()
249 .map(|line| line.chars().map(|c| c.to_string()).collect())
250 .collect();
251
252 let mut all_merges: Vec<(String, String)> = Vec::new();
253 let phi_threshold = (1.618033988_f64).ln(); for round in 0..8 {
257 let remaining = target_vocab.saturating_sub(idx_to_token.len());
258 if remaining == 0 {
259 break;
260 }
261
262 let mut unigram: HashMap<String, usize> = HashMap::new();
264 let mut bigram: HashMap<(String, String), usize> = HashMap::new();
265 let mut total: usize = 0;
266 for seq in &corpus {
267 total += seq.len();
268 for tok in seq {
269 *unigram.entry(tok.clone()).or_default() += 1;
270 }
271 for w in seq.windows(2) {
272 *bigram.entry((w[0].clone(), w[1].clone())).or_default() += 1;
273 }
274 }
275 if total < 2 {
276 break;
277 }
278 let total_f = total as f64;
279
280 let mut mi_pairs: Vec<((String, String), f64)> = bigram
282 .iter()
283 .filter_map(|((a, b), &count)| {
284 if count < 2 {
285 return None;
286 }
287 let p_ab = count as f64 / total_f;
288 let p_a = *unigram.get(a).unwrap_or(&1) as f64 / total_f;
289 let p_b = *unigram.get(b).unwrap_or(&1) as f64 / total_f;
290 if p_a == 0.0 || p_b == 0.0 {
291 return None;
292 }
293 let mi = (p_ab / (p_a * p_b)).ln();
294 if mi > phi_threshold {
295 Some(((a.clone(), b.clone()), mi))
296 } else {
297 None
298 }
299 })
300 .collect();
301
302 if mi_pairs.is_empty() {
303 break;
304 } mi_pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
308 let take = mi_pairs.len().min(remaining);
309 let round_merges: Vec<(String, String)> = mi_pairs[..take].iter().map(|(pair, _)| pair.clone()).collect();
310
311 if round_merges.is_empty() {
312 break;
313 }
314
315 for (a, b) in &round_merges {
317 let merged = format!("{}{}", a, b);
318 let new_idx = idx_to_token.len();
319 token_to_idx.insert(merged.clone(), new_idx);
320 idx_to_token.push(merged.clone());
321 all_merges.push((a.clone(), b.clone()));
322
323 for seq in &mut corpus {
324 let mut i = 0;
325 while i + 1 < seq.len() {
326 if seq[i] == *a && seq[i + 1] == *b {
327 seq[i] = merged.clone();
328 seq.remove(i + 1);
329 } else {
330 i += 1;
331 }
332 }
333 }
334 }
335
336 println!(
337 "MI round {}: {} merges (MI > ln(φ)={:.3}), vocab = {}",
338 round,
339 round_merges.len(),
340 phi_threshold,
341 idx_to_token.len()
342 );
343 }
344
345 let vocab_size = idx_to_token.len();
346 println!(
347 "MI tokenizer: {} total merges, vocab = {}",
348 all_merges.len(),
349 vocab_size
350 );
351
352 Self {
353 merges: all_merges,
354 token_to_idx,
355 idx_to_token,
356 vocab_size,
357 }
358 }
359
360 pub fn encode(&self, text: &str) -> Vec<usize> {
361 let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
362 for (a, b) in &self.merges {
363 let merged = format!("{}{}", a, b);
364 let mut i = 0;
365 while i + 1 < tokens.len() {
366 if tokens[i] == *a && tokens[i + 1] == *b {
367 tokens[i] = merged.clone();
368 tokens.remove(i + 1);
369 } else {
370 i += 1;
371 }
372 }
373 }
374 tokens
375 .iter()
376 .map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
377 .collect()
378 }
379
380 pub fn decode(&self, tokens: &[usize]) -> String {
381 tokens
382 .iter()
383 .map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
384 .collect()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn char_roundtrip() {
394 let text = "hello world";
395 let tok = CharTokenizer::from_text(text);
396 let encoded = tok.encode(text);
397 let decoded = tok.decode(&encoded);
398 assert_eq!(decoded, text);
399 }
400
401 #[test]
402 fn char_vocab_size_correct() {
403 let tok = CharTokenizer::from_text("abcabc");
404 assert_eq!(tok.vocab_size, 3);
405 }
406
407 #[test]
408 fn bpe_trains_and_encodes() {
409 let text = "abababab cdcdcdcd abababab";
410 let bpe = BpeTokenizer::train(text, 20);
411 assert!(bpe.vocab_size > 6, "BPE should have merged some pairs");
413 let encoded = bpe.encode("abab");
414 let decoded = bpe.decode(&encoded);
415 assert_eq!(decoded, "abab");
416 }
417
418 #[test]
419 fn bpe_roundtrip() {
420 let text = "the cat sat on the mat the cat sat on the mat";
421 let bpe = BpeTokenizer::train(text, 30);
422 let encoded = bpe.encode(text);
423 let decoded = bpe.decode(&encoded);
424 assert_eq!(decoded, text);
425 }
426
427 #[test]
428 fn bpe_compression() {
429 let text = "aaaa bbbb aaaa bbbb aaaa bbbb";
430 let bpe = BpeTokenizer::train(text, 20);
431 let char_len = text.len();
432 let bpe_len = bpe.encode(text).len();
433 assert!(bpe_len < char_len, "BPE should compress: {} < {}", bpe_len, char_len);
434 }
435
436 #[test]
437 fn mi_roundtrip() {
438 let text = "the cat sat on the mat the cat sat on the mat";
439 let mi = MiTokenizer::train(text, 30);
440 let encoded = mi.encode(text);
441 let decoded = mi.decode(&encoded);
442 assert_eq!(decoded, text);
443 }
444
445 #[test]
446 fn mi_compression() {
447 let text = "aaaa bbbb aaaa bbbb aaaa bbbb cccc dddd cccc dddd";
448 let mi = MiTokenizer::train(text, 30);
449 let char_len = text.len();
450 let mi_len = mi.encode(text).len();
451 assert!(mi_len < char_len, "MI should compress: {} < {}", mi_len, char_len);
452 }
453
454 #[test]
455 fn mi_merges_high_mi_pairs() {
456 let text = "the the the the the the the the the the other this that them then";
458 let mi = MiTokenizer::train(text, 50);
459 assert!(
460 mi.vocab_size > 10,
461 "MI should have merged pairs, got vocab={}",
462 mi.vocab_size
463 );
464 let encoded = mi.encode("the");
465 assert!(
467 encoded.len() < 3,
468 "\"the\" should be compressed: {} tokens",
469 encoded.len()
470 );
471 }
472}