Skip to main content

entrenar/tokenizer/
char.rs

1//! Character-level tokenizer implementation.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::config::TokenizerConfig;
8use super::error::{Result, TokenizerError};
9use super::traits::{TokenId, Tokenizer};
10
11/// Character-level tokenizer (simple baseline)
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CharTokenizer {
14    config: TokenizerConfig,
15    vocab: HashMap<char, TokenId>,
16    id_to_char: HashMap<TokenId, char>,
17    trained: bool,
18}
19
20impl CharTokenizer {
21    /// Create a new character tokenizer
22    pub fn new(config: TokenizerConfig) -> Self {
23        Self { config, vocab: HashMap::new(), id_to_char: HashMap::new(), trained: false }
24    }
25}
26
27impl Tokenizer for CharTokenizer {
28    fn train(&mut self, corpus: &[&str]) -> Result<()> {
29        let mut id: TokenId = 0;
30
31        // Count character frequencies
32        let mut char_counts: HashMap<char, usize> = HashMap::new();
33        for text in corpus {
34            let processed =
35                if self.config.lowercase { text.to_lowercase() } else { text.to_string() };
36            for c in processed.chars() {
37                *char_counts.entry(c).or_insert(0) += 1;
38            }
39        }
40
41        // Sort by frequency and take top vocab_size
42        let mut chars: Vec<_> = char_counts.into_iter().collect();
43        chars.sort_by(|a, b| b.1.cmp(&a.1));
44
45        for (c, count) in chars.into_iter().take(self.config.vocab_size) {
46            if count >= self.config.min_frequency {
47                self.vocab.insert(c, id);
48                self.id_to_char.insert(id, c);
49                id += 1;
50            }
51        }
52
53        self.trained = true;
54        Ok(())
55    }
56
57    fn encode(&self, text: &str) -> Result<Vec<TokenId>> {
58        if !self.trained {
59            return Err(TokenizerError::NotTrained);
60        }
61
62        let processed = if self.config.lowercase { text.to_lowercase() } else { text.to_string() };
63
64        let mut ids = Vec::new();
65        for c in processed.chars() {
66            if let Some(&id) = self.vocab.get(&c) {
67                ids.push(id);
68            }
69            // Unknown characters are skipped
70        }
71
72        Ok(ids)
73    }
74
75    fn decode(&self, ids: &[TokenId]) -> Result<String> {
76        if !self.trained {
77            return Err(TokenizerError::NotTrained);
78        }
79
80        let mut result = String::new();
81        for &id in ids {
82            if let Some(&c) = self.id_to_char.get(&id) {
83                result.push(c);
84            }
85        }
86
87        Ok(result)
88    }
89
90    fn vocab_size(&self) -> usize {
91        self.vocab.len()
92    }
93
94    fn is_trained(&self) -> bool {
95        self.trained
96    }
97
98    fn id_to_token(&self, _id: TokenId) -> Option<&str> {
99        // Characters are not stored as strings
100        None
101    }
102
103    fn token_to_id(&self, token: &str) -> Option<TokenId> {
104        if token.len() == 1 {
105            self.vocab
106                .get(&token.chars().next().expect("single-char token must have a char"))
107                .copied()
108        } else {
109            None
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_char_new() {
120        let config = TokenizerConfig::char();
121        let tokenizer = CharTokenizer::new(config);
122        assert!(!tokenizer.is_trained());
123    }
124
125    #[test]
126    fn test_char_train() {
127        let config = TokenizerConfig::char().with_min_frequency(1);
128        let mut tokenizer = CharTokenizer::new(config);
129
130        let corpus = vec!["hello", "world"];
131        tokenizer.train(&corpus).expect("operation should succeed");
132
133        assert!(tokenizer.is_trained());
134        // h, e, l, o, w, r, d = 7 unique chars
135        assert_eq!(tokenizer.vocab_size(), 7);
136    }
137
138    #[test]
139    fn test_char_encode_decode() {
140        let config = TokenizerConfig::char().with_min_frequency(1);
141        let mut tokenizer = CharTokenizer::new(config);
142
143        let corpus = vec!["hello"];
144        tokenizer.train(&corpus).expect("operation should succeed");
145
146        let text = "hello";
147        let encoded = tokenizer.encode(text).expect("encoding should succeed");
148        let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
149
150        assert_eq!(decoded, text);
151    }
152
153    #[test]
154    fn test_char_unknown_chars() {
155        let config = TokenizerConfig::char().with_min_frequency(1);
156        let mut tokenizer = CharTokenizer::new(config);
157
158        let corpus = vec!["abc"];
159        tokenizer.train(&corpus).expect("operation should succeed");
160
161        // 'x' is not in vocabulary, should be skipped
162        let encoded = tokenizer.encode("axbc").expect("encoding should succeed");
163        let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
164
165        assert_eq!(decoded, "abc");
166    }
167
168    #[test]
169    fn test_char_lowercase() {
170        let config = TokenizerConfig::char().with_min_frequency(1).with_lowercase(true);
171        let mut tokenizer = CharTokenizer::new(config);
172
173        let corpus = vec!["Hello"];
174        tokenizer.train(&corpus).expect("operation should succeed");
175
176        let encoded = tokenizer.encode("HELLO").expect("encoding should succeed");
177        let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
178
179        assert_eq!(decoded, "hello");
180    }
181}
182
183#[cfg(test)]
184mod property_tests {
185    use super::*;
186    use proptest::prelude::*;
187
188    proptest! {
189        #![proptest_config(ProptestConfig::with_cases(50))]
190
191        #[test]
192        fn prop_char_roundtrip(text in "[a-z]{1,20}") {
193            let config = TokenizerConfig::char().with_min_frequency(1);
194            let mut tokenizer = CharTokenizer::new(config);
195            tokenizer.train(&[&text]).expect("operation should succeed");
196
197            let encoded = tokenizer.encode(&text).expect("encoding should succeed");
198            let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
199
200            prop_assert_eq!(decoded, text);
201        }
202
203        #[test]
204        fn prop_char_vocab_size_matches_unique_chars(text in "[a-z]{5,30}") {
205            let config = TokenizerConfig::char()
206                .with_min_frequency(1)
207                .with_vocab_size(256);
208            let mut tokenizer = CharTokenizer::new(config);
209            tokenizer.train(&[&text]).expect("operation should succeed");
210
211            let unique_chars: std::collections::HashSet<char> = text.chars().collect();
212            prop_assert_eq!(tokenizer.vocab_size(), unique_chars.len());
213        }
214    }
215}