Skip to main content

dreamwell_intelligence/
tokenizer.rs

1// Tokenizer — character-level and BPE (Byte Pair Encoding).
2//
3// CharTokenizer: one token per character. Simple, deterministic.
4// BpeTokenizer: subword tokens learned from corpus via iterative pair merging.
5//   Each token carries ~4-5 characters of content. A 128-position context window
6//   sees ~500-600 characters instead of 128 — a 4-5x improvement in effective
7//   context length.
8//
9// Clean Compute: no external dependencies. BPE training is O(V × corpus_len)
10// where V = target vocab size. Encoding is O(text_len × max_token_len).
11
12use std::collections::HashMap;
13
14/// Character-level tokenizer. Maps unique characters to indices.
15#[derive(Clone)]
16pub struct CharTokenizer {
17    pub char_to_idx: Vec<(char, usize)>,
18    pub idx_to_char: Vec<char>,
19    pub vocab_size: usize,
20}
21
22impl CharTokenizer {
23    /// Build tokenizer from a text corpus. Vocabulary = unique characters.
24    pub fn from_text(text: &str) -> Self {
25        let mut chars: Vec<char> = text
26            .chars()
27            .collect::<std::collections::BTreeSet<_>>()
28            .into_iter()
29            .collect();
30        chars.sort();
31        let char_to_idx: Vec<(char, usize)> = chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
32        let idx_to_char = chars;
33        let vocab_size = idx_to_char.len();
34        Self {
35            char_to_idx,
36            idx_to_char,
37            vocab_size,
38        }
39    }
40
41    pub fn encode(&self, text: &str) -> Vec<usize> {
42        text.chars()
43            .map(|c| {
44                self.char_to_idx
45                    .iter()
46                    .find(|&&(ch, _)| ch == c)
47                    .map(|&(_, idx)| idx)
48                    .unwrap_or(0)
49            })
50            .collect()
51    }
52
53    pub fn decode(&self, tokens: &[usize]) -> String {
54        tokens
55            .iter()
56            .map(|&idx| self.idx_to_char.get(idx).copied().unwrap_or('?'))
57            .collect()
58    }
59}
60
61/// BPE (Byte Pair Encoding) tokenizer. Subword tokens learned from corpus.
62///
63/// Training: iteratively merge the most frequent adjacent pair until target
64/// vocab size is reached. Each merge creates a new token from two existing ones.
65///
66/// Encoding: greedily apply merges in learned order (longest match first).
67/// Decoding: expand each token to its character sequence.
68#[derive(Clone)]
69pub struct BpeTokenizer {
70    /// Merge rules: (token_a, token_b) → merged_token, in order learned.
71    merges: Vec<(String, String)>,
72    /// Token → index mapping.
73    token_to_idx: HashMap<String, usize>,
74    /// Index → token string.
75    idx_to_token: Vec<String>,
76    /// Vocabulary size (base chars + merges).
77    pub vocab_size: usize,
78}
79
80impl BpeTokenizer {
81    /// Train a BPE tokenizer from a text corpus.
82    ///
83    /// `target_vocab`: desired vocabulary size (base characters + merge tokens).
84    /// Typical values: 512, 1024, 2048. Larger = more semantic content per token.
85    pub fn train(text: &str, target_vocab: usize) -> Self {
86        // Step 1: collect base vocabulary (unique characters)
87        let mut base_chars: Vec<char> = text
88            .chars()
89            .collect::<std::collections::BTreeSet<_>>()
90            .into_iter()
91            .collect();
92        base_chars.sort();
93        let base_vocab_size = base_chars.len();
94
95        // Build initial token → index map (one token per character)
96        let mut token_to_idx: HashMap<String, usize> = HashMap::new();
97        let mut idx_to_token: Vec<String> = Vec::new();
98        for (i, &c) in base_chars.iter().enumerate() {
99            let s = c.to_string();
100            token_to_idx.insert(s.clone(), i);
101            idx_to_token.push(s);
102        }
103
104        // Step 2: split corpus into character-level token sequences
105        let mut corpus_tokens: Vec<Vec<String>> = text
106            .lines()
107            .map(|line| line.chars().map(|c| c.to_string()).collect())
108            .collect();
109
110        // Step 3: iteratively merge most frequent pairs
111        let num_merges = target_vocab.saturating_sub(base_vocab_size);
112        let mut merges: Vec<(String, String)> = Vec::with_capacity(num_merges);
113
114        for _merge_round in 0..num_merges {
115            // Count all adjacent pairs
116            let mut pair_counts: HashMap<(String, String), usize> = HashMap::new();
117            for seq in &corpus_tokens {
118                for window in seq.windows(2) {
119                    let pair = (window[0].clone(), window[1].clone());
120                    *pair_counts.entry(pair).or_insert(0) += 1;
121                }
122            }
123
124            // Find most frequent pair
125            let best = pair_counts.into_iter().max_by_key(|&(_, count)| count);
126            let (best_pair, best_count) = match best {
127                Some((pair, count)) if count >= 2 => (pair, count),
128                _ => break, // no pair occurs more than once — done
129            };
130            let _ = best_count;
131
132            // Create merged token
133            let merged = format!("{}{}", best_pair.0, best_pair.1);
134            let new_idx = idx_to_token.len();
135            token_to_idx.insert(merged.clone(), new_idx);
136            idx_to_token.push(merged.clone());
137            merges.push(best_pair.clone());
138
139            // Apply merge to all sequences
140            for seq in &mut corpus_tokens {
141                let mut i = 0;
142                while i + 1 < seq.len() {
143                    if seq[i] == best_pair.0 && seq[i + 1] == best_pair.1 {
144                        seq[i] = merged.clone();
145                        seq.remove(i + 1);
146                        // Don't advance i — check if the new token can merge with next
147                    } else {
148                        i += 1;
149                    }
150                }
151            }
152        }
153
154        let vocab_size = idx_to_token.len();
155        println!(
156            "BPE: {} merges, vocab = {} (base {} + {} merges)",
157            merges.len(),
158            vocab_size,
159            base_vocab_size,
160            merges.len()
161        );
162
163        Self {
164            merges,
165            token_to_idx,
166            idx_to_token,
167            vocab_size,
168        }
169    }
170
171    /// Encode text into token indices using learned BPE merges.
172    pub fn encode(&self, text: &str) -> Vec<usize> {
173        // Start with character-level tokens
174        let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
175
176        // Apply merges in learned order (greedy)
177        for (a, b) in &self.merges {
178            let merged = format!("{}{}", a, b);
179            let mut i = 0;
180            while i + 1 < tokens.len() {
181                if tokens[i] == *a && tokens[i + 1] == *b {
182                    tokens[i] = merged.clone();
183                    tokens.remove(i + 1);
184                } else {
185                    i += 1;
186                }
187            }
188        }
189
190        // Map to indices (unknown → 0)
191        tokens
192            .iter()
193            .map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
194            .collect()
195    }
196
197    /// Decode token indices back to text.
198    pub fn decode(&self, tokens: &[usize]) -> String {
199        tokens
200            .iter()
201            .map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
202            .collect()
203    }
204}
205
206/// Single-Pass Mutual Information Tokenizer (BA-37).
207///
208/// Instead of BPE's iterative pair counting (O(merges × corpus_len)),
209/// computes mutual information for all character pairs in ONE pass:
210///   MI(a,b) = ln(P(ab) / (P(a) × P(b)))
211/// Pairs with MI > ln(φ) ≈ 0.481 co-occur φ× more than chance — merge them.
212///
213/// Total cost: O(corpus_len) for counting + O(V²) for MI + O(corpus_len) per merge round.
214/// Typically 2-3 rounds on progressively shorter corpora. ~500x faster than BPE.
215///
216/// The golden ratio appears as the significance threshold: ln(φ) separates
217/// coherent pairs (signal) from independent pairs (noise).
218#[derive(Clone)]
219pub struct MiTokenizer {
220    merges: Vec<(String, String)>,
221    token_to_idx: HashMap<String, usize>,
222    idx_to_token: Vec<String>,
223    pub vocab_size: usize,
224}
225
226impl MiTokenizer {
227    /// Train the MI tokenizer from a text corpus.
228    /// `target_vocab`: desired vocabulary size. Merges stop when reached.
229    pub fn train(text: &str, target_vocab: usize) -> Self {
230        // Base vocabulary: unique characters
231        let mut base_chars: Vec<char> = text
232            .chars()
233            .collect::<std::collections::BTreeSet<_>>()
234            .into_iter()
235            .collect();
236        base_chars.sort();
237
238        let mut token_to_idx: HashMap<String, usize> = HashMap::new();
239        let mut idx_to_token: Vec<String> = Vec::new();
240        for (i, &c) in base_chars.iter().enumerate() {
241            let s = c.to_string();
242            token_to_idx.insert(s.clone(), i);
243            idx_to_token.push(s);
244        }
245
246        // Split corpus into token sequences (one per line for efficiency)
247        let mut corpus: Vec<Vec<String>> = text
248            .lines()
249            .map(|line| line.chars().map(|c| c.to_string()).collect())
250            .collect();
251
252        let mut all_merges: Vec<(String, String)> = Vec::new();
253        let phi_threshold = (1.618033988_f64).ln(); // ln(φ) ≈ 0.481
254
255        // Recursive MI merge rounds on progressively shorter corpus
256        for round in 0..8 {
257            let remaining = target_vocab.saturating_sub(idx_to_token.len());
258            if remaining == 0 {
259                break;
260            }
261
262            // Step 1: count unigrams and bigrams in ONE pass
263            let mut unigram: HashMap<String, usize> = HashMap::new();
264            let mut bigram: HashMap<(String, String), usize> = HashMap::new();
265            let mut total: usize = 0;
266            for seq in &corpus {
267                total += seq.len();
268                for tok in seq {
269                    *unigram.entry(tok.clone()).or_default() += 1;
270                }
271                for w in seq.windows(2) {
272                    *bigram.entry((w[0].clone(), w[1].clone())).or_default() += 1;
273                }
274            }
275            if total < 2 {
276                break;
277            }
278            let total_f = total as f64;
279
280            // Step 2: compute MI for each bigram, filter by ln(φ) threshold
281            let mut mi_pairs: Vec<((String, String), f64)> = bigram
282                .iter()
283                .filter_map(|((a, b), &count)| {
284                    if count < 2 {
285                        return None;
286                    }
287                    let p_ab = count as f64 / total_f;
288                    let p_a = *unigram.get(a).unwrap_or(&1) as f64 / total_f;
289                    let p_b = *unigram.get(b).unwrap_or(&1) as f64 / total_f;
290                    if p_a == 0.0 || p_b == 0.0 {
291                        return None;
292                    }
293                    let mi = (p_ab / (p_a * p_b)).ln();
294                    if mi > phi_threshold {
295                        Some(((a.clone(), b.clone()), mi))
296                    } else {
297                        None
298                    }
299                })
300                .collect();
301
302            if mi_pairs.is_empty() {
303                break;
304            } // no more significant pairs
305
306            // Step 3: sort by MI descending, take up to `remaining`
307            mi_pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
308            let take = mi_pairs.len().min(remaining);
309            let round_merges: Vec<(String, String)> = mi_pairs[..take].iter().map(|(pair, _)| pair.clone()).collect();
310
311            if round_merges.is_empty() {
312                break;
313            }
314
315            // Step 4: apply all merges to corpus (single sweep per merge)
316            for (a, b) in &round_merges {
317                let merged = format!("{}{}", a, b);
318                let new_idx = idx_to_token.len();
319                token_to_idx.insert(merged.clone(), new_idx);
320                idx_to_token.push(merged.clone());
321                all_merges.push((a.clone(), b.clone()));
322
323                for seq in &mut corpus {
324                    let mut i = 0;
325                    while i + 1 < seq.len() {
326                        if seq[i] == *a && seq[i + 1] == *b {
327                            seq[i] = merged.clone();
328                            seq.remove(i + 1);
329                        } else {
330                            i += 1;
331                        }
332                    }
333                }
334            }
335
336            println!(
337                "MI round {}: {} merges (MI > ln(φ)={:.3}), vocab = {}",
338                round,
339                round_merges.len(),
340                phi_threshold,
341                idx_to_token.len()
342            );
343        }
344
345        let vocab_size = idx_to_token.len();
346        println!(
347            "MI tokenizer: {} total merges, vocab = {}",
348            all_merges.len(),
349            vocab_size
350        );
351
352        Self {
353            merges: all_merges,
354            token_to_idx,
355            idx_to_token,
356            vocab_size,
357        }
358    }
359
360    pub fn encode(&self, text: &str) -> Vec<usize> {
361        let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
362        for (a, b) in &self.merges {
363            let merged = format!("{}{}", a, b);
364            let mut i = 0;
365            while i + 1 < tokens.len() {
366                if tokens[i] == *a && tokens[i + 1] == *b {
367                    tokens[i] = merged.clone();
368                    tokens.remove(i + 1);
369                } else {
370                    i += 1;
371                }
372            }
373        }
374        tokens
375            .iter()
376            .map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
377            .collect()
378    }
379
380    pub fn decode(&self, tokens: &[usize]) -> String {
381        tokens
382            .iter()
383            .map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
384            .collect()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn char_roundtrip() {
394        let text = "hello world";
395        let tok = CharTokenizer::from_text(text);
396        let encoded = tok.encode(text);
397        let decoded = tok.decode(&encoded);
398        assert_eq!(decoded, text);
399    }
400
401    #[test]
402    fn char_vocab_size_correct() {
403        let tok = CharTokenizer::from_text("abcabc");
404        assert_eq!(tok.vocab_size, 3);
405    }
406
407    #[test]
408    fn bpe_trains_and_encodes() {
409        let text = "abababab cdcdcdcd abababab";
410        let bpe = BpeTokenizer::train(text, 20);
411        // "ab" should be merged (appears 8 times)
412        assert!(bpe.vocab_size > 6, "BPE should have merged some pairs");
413        let encoded = bpe.encode("abab");
414        let decoded = bpe.decode(&encoded);
415        assert_eq!(decoded, "abab");
416    }
417
418    #[test]
419    fn bpe_roundtrip() {
420        let text = "the cat sat on the mat the cat sat on the mat";
421        let bpe = BpeTokenizer::train(text, 30);
422        let encoded = bpe.encode(text);
423        let decoded = bpe.decode(&encoded);
424        assert_eq!(decoded, text);
425    }
426
427    #[test]
428    fn bpe_compression() {
429        let text = "aaaa bbbb aaaa bbbb aaaa bbbb";
430        let bpe = BpeTokenizer::train(text, 20);
431        let char_len = text.len();
432        let bpe_len = bpe.encode(text).len();
433        assert!(bpe_len < char_len, "BPE should compress: {} < {}", bpe_len, char_len);
434    }
435
436    #[test]
437    fn mi_roundtrip() {
438        let text = "the cat sat on the mat the cat sat on the mat";
439        let mi = MiTokenizer::train(text, 30);
440        let encoded = mi.encode(text);
441        let decoded = mi.decode(&encoded);
442        assert_eq!(decoded, text);
443    }
444
445    #[test]
446    fn mi_compression() {
447        let text = "aaaa bbbb aaaa bbbb aaaa bbbb cccc dddd cccc dddd";
448        let mi = MiTokenizer::train(text, 30);
449        let char_len = text.len();
450        let mi_len = mi.encode(text).len();
451        assert!(mi_len < char_len, "MI should compress: {} < {}", mi_len, char_len);
452    }
453
454    #[test]
455    fn mi_merges_high_mi_pairs() {
456        // "th" and "he" should merge because they co-occur far more than chance
457        let text = "the the the the the the the the the the other this that them then";
458        let mi = MiTokenizer::train(text, 50);
459        assert!(
460            mi.vocab_size > 10,
461            "MI should have merged pairs, got vocab={}",
462            mi.vocab_size
463        );
464        let encoded = mi.encode("the");
465        // "the" should be fewer tokens than 3 characters
466        assert!(
467            encoded.len() < 3,
468            "\"the\" should be compressed: {} tokens",
469            encoded.len()
470        );
471    }
472}