1use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::path::Path;
15
16pub const FORMAT_MAGIC: &str = "oxitokenizer v1";
18
19const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
22
23pub fn base64_encode(bytes: &[u8]) -> String {
25 let mut out = Vec::with_capacity(bytes.len().div_ceil(3) * 4);
26
27 for chunk in bytes.chunks(3) {
28 let b0 = chunk[0] as u32;
29 let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
30 let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
31
32 let combined = (b0 << 16) | (b1 << 8) | b2;
33
34 out.push(BASE64_CHARS[((combined >> 18) & 0x3F) as usize]);
35 out.push(BASE64_CHARS[((combined >> 12) & 0x3F) as usize]);
36
37 if chunk.len() > 1 {
38 out.push(BASE64_CHARS[((combined >> 6) & 0x3F) as usize]);
39 } else {
40 out.push(b'=');
41 }
42
43 if chunk.len() > 2 {
44 out.push(BASE64_CHARS[(combined & 0x3F) as usize]);
45 } else {
46 out.push(b'=');
47 }
48 }
49
50 String::from_utf8(out).unwrap_or_default()
56}
57
58pub fn base64_decode(s: &str) -> Result<Vec<u8>, SerializationError> {
60 let s = s.trim_end_matches('=');
61 let mut out = Vec::with_capacity((s.len() * 3) / 4 + 1);
62
63 let decode_char = |c: u8| -> Option<u32> {
64 match c {
65 b'A'..=b'Z' => Some((c - b'A') as u32),
66 b'a'..=b'z' => Some((c - b'a' + 26) as u32),
67 b'0'..=b'9' => Some((c - b'0' + 52) as u32),
68 b'+' => Some(62),
69 b'/' => Some(63),
70 _ => None,
71 }
72 };
73
74 let chars: Vec<u8> = s.bytes().collect();
75
76 for chunk in chars.chunks(4) {
77 match chunk.len() {
78 4 => {
79 let v0 = decode_char(chunk[0]).ok_or_else(|| {
80 SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
81 })?;
82 let v1 = decode_char(chunk[1]).ok_or_else(|| {
83 SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
84 })?;
85 let v2 = decode_char(chunk[2]).ok_or_else(|| {
86 SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
87 })?;
88 let v3 = decode_char(chunk[3]).ok_or_else(|| {
89 SerializationError::Base64Error(format!("invalid char '{}'", chunk[3] as char))
90 })?;
91 let combined = (v0 << 18) | (v1 << 12) | (v2 << 6) | v3;
92 out.push(((combined >> 16) & 0xFF) as u8);
93 out.push(((combined >> 8) & 0xFF) as u8);
94 out.push((combined & 0xFF) as u8);
95 }
96 3 => {
97 let v0 = decode_char(chunk[0]).ok_or_else(|| {
98 SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
99 })?;
100 let v1 = decode_char(chunk[1]).ok_or_else(|| {
101 SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
102 })?;
103 let v2 = decode_char(chunk[2]).ok_or_else(|| {
104 SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
105 })?;
106 let combined = (v0 << 18) | (v1 << 12) | (v2 << 6);
107 out.push(((combined >> 16) & 0xFF) as u8);
108 out.push(((combined >> 8) & 0xFF) as u8);
109 }
110 2 => {
111 let v0 = decode_char(chunk[0]).ok_or_else(|| {
112 SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
113 })?;
114 let v1 = decode_char(chunk[1]).ok_or_else(|| {
115 SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
116 })?;
117 let combined = (v0 << 18) | (v1 << 12);
118 out.push(((combined >> 16) & 0xFF) as u8);
119 }
120 1 => {
121 return Err(SerializationError::Base64Error(
122 "truncated base64 group of 1 char".to_string(),
123 ));
124 }
125 _ => {}
126 }
127 }
128
129 Ok(out)
130}
131
132#[derive(Debug, thiserror::Error)]
136pub enum SerializationError {
137 #[error("I/O error: {0}")]
138 Io(#[from] std::io::Error),
139
140 #[error("invalid format magic: expected '{expected}', got '{got}'")]
141 InvalidMagic { expected: String, got: String },
142
143 #[error("parse error on line {line}: {msg}")]
144 ParseError { line: usize, msg: String },
145
146 #[error("base64 decode error: {0}")]
147 Base64Error(String),
148
149 #[error("duplicate token id {0}")]
150 DuplicateId(u32),
151}
152
153#[derive(Debug)]
157pub struct TokenizerState {
158 pub vocab: HashMap<u32, String>,
160 pub merges: Vec<(u32, u32, u32)>,
162 pub special_tokens: HashMap<String, u32>,
164}
165
166impl TokenizerState {
167 pub fn new() -> Self {
169 Self {
170 vocab: HashMap::new(),
171 merges: Vec::new(),
172 special_tokens: HashMap::new(),
173 }
174 }
175
176 pub fn from_trained(trained: &crate::trainer::TrainedTokenizer) -> Self {
178 let mut state = Self::new();
179
180 for (&id, token) in &trained.vocab {
181 if token.starts_with('<') && token.ends_with('>') {
182 state.special_tokens.insert(token.clone(), id);
183 }
184 state.vocab.insert(id, token.clone());
185 }
186
187 for rule in &trained.merges {
188 state.merges.push((rule.left, rule.right, rule.merged));
189 }
190
191 state
192 }
193
194 pub fn vocab_size(&self) -> usize {
196 self.vocab.len()
197 }
198
199 pub fn save_to<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
204 writeln!(writer, "{}", FORMAT_MAGIC)?;
206
207 writeln!(writer, "vocab_size {}", self.vocab.len())?;
209
210 writeln!(writer, "merges {}", self.merges.len())?;
212
213 let mut vocab_entries: Vec<(u32, &str)> =
215 self.vocab.iter().map(|(&id, s)| (id, s.as_str())).collect();
216 vocab_entries.sort_by_key(|(id, _)| *id);
217
218 for (id, token) in &vocab_entries {
219 let encoded = base64_encode(token.as_bytes());
220 writeln!(writer, "tok_id {id} {encoded}")?;
221 }
222
223 for &(left, right, merged) in &self.merges {
225 writeln!(writer, "merge {left} {right} {merged}")?;
226 }
227
228 let mut special_entries: Vec<(&str, u32)> = self
230 .special_tokens
231 .iter()
232 .map(|(k, &v)| (k.as_str(), v))
233 .collect();
234 special_entries.sort_by_key(|(name, _)| *name);
235
236 for (name, id) in &special_entries {
237 let encoded = base64_encode(name.as_bytes());
238 writeln!(writer, "special {encoded} {id}")?;
239 }
240
241 Ok(())
242 }
243
244 pub fn load_from<R: BufRead>(reader: &mut R) -> Result<Self, SerializationError> {
246 let mut lines = reader.lines();
247 let mut line_no: usize = 0;
248
249 let mut next_line = |line_no: &mut usize| -> Result<String, SerializationError> {
251 *line_no += 1;
252 match lines.next() {
253 Some(Ok(l)) => Ok(l),
254 Some(Err(e)) => Err(SerializationError::Io(e)),
255 None => Err(SerializationError::ParseError {
256 line: *line_no,
257 msg: "unexpected end of file".to_string(),
258 }),
259 }
260 };
261
262 let magic_line = next_line(&mut line_no)?;
264 if magic_line.trim() != FORMAT_MAGIC {
265 return Err(SerializationError::InvalidMagic {
266 expected: FORMAT_MAGIC.to_string(),
267 got: magic_line.trim().to_string(),
268 });
269 }
270
271 let vocab_size_line = next_line(&mut line_no)?;
273 let vocab_size = parse_count_line(&vocab_size_line, "vocab_size", line_no)?;
274
275 let merges_line = next_line(&mut line_no)?;
277 let merges_count = parse_count_line(&merges_line, "merges", line_no)?;
278
279 let mut vocab: HashMap<u32, String> = HashMap::with_capacity(vocab_size);
281 for _ in 0..vocab_size {
282 let l = next_line(&mut line_no)?;
283 let parts: Vec<&str> = l.trim().splitn(3, ' ').collect();
284 if parts.len() != 3 || parts[0] != "tok_id" {
285 return Err(SerializationError::ParseError {
286 line: line_no,
287 msg: format!("expected 'tok_id <id> <b64>', got '{l}'"),
288 });
289 }
290 let id: u32 = parts[1]
291 .parse()
292 .map_err(|_| SerializationError::ParseError {
293 line: line_no,
294 msg: format!("invalid token id '{}'", parts[1]),
295 })?;
296 let token_bytes = base64_decode(parts[2])?;
297 let token =
298 String::from_utf8(token_bytes).map_err(|_| SerializationError::ParseError {
299 line: line_no,
300 msg: "token text is not valid UTF-8".to_string(),
301 })?;
302 if vocab.contains_key(&id) {
303 return Err(SerializationError::DuplicateId(id));
304 }
305 vocab.insert(id, token);
306 }
307
308 let mut merges: Vec<(u32, u32, u32)> = Vec::with_capacity(merges_count);
310 for _ in 0..merges_count {
311 let l = next_line(&mut line_no)?;
312 let parts: Vec<&str> = l.trim().splitn(4, ' ').collect();
313 if parts.len() != 4 || parts[0] != "merge" {
314 return Err(SerializationError::ParseError {
315 line: line_no,
316 msg: format!("expected 'merge <left> <right> <merged>', got '{l}'"),
317 });
318 }
319 let left: u32 = parts[1]
320 .parse()
321 .map_err(|_| SerializationError::ParseError {
322 line: line_no,
323 msg: format!("invalid merge left id '{}'", parts[1]),
324 })?;
325 let right: u32 = parts[2]
326 .parse()
327 .map_err(|_| SerializationError::ParseError {
328 line: line_no,
329 msg: format!("invalid merge right id '{}'", parts[2]),
330 })?;
331 let merged: u32 = parts[3]
332 .parse()
333 .map_err(|_| SerializationError::ParseError {
334 line: line_no,
335 msg: format!("invalid merge merged id '{}'", parts[3]),
336 })?;
337 merges.push((left, right, merged));
338 }
339
340 let mut special_tokens: HashMap<String, u32> = HashMap::new();
342 for maybe_line in lines {
343 line_no += 1;
344 let l = maybe_line.map_err(SerializationError::Io)?;
345 let l = l.trim();
346 if l.is_empty() {
347 continue;
348 }
349 let parts: Vec<&str> = l.splitn(3, ' ').collect();
350 if parts.len() != 3 || parts[0] != "special" {
351 return Err(SerializationError::ParseError {
352 line: line_no,
353 msg: format!("expected 'special <b64> <id>', got '{l}'"),
354 });
355 }
356 let name_bytes = base64_decode(parts[1])?;
357 let name =
358 String::from_utf8(name_bytes).map_err(|_| SerializationError::ParseError {
359 line: line_no,
360 msg: "special token name is not valid UTF-8".to_string(),
361 })?;
362 let id: u32 = parts[2]
363 .parse()
364 .map_err(|_| SerializationError::ParseError {
365 line: line_no,
366 msg: format!("invalid special token id '{}'", parts[2]),
367 })?;
368 special_tokens.insert(name, id);
369 }
370
371 Ok(TokenizerState {
372 vocab,
373 merges,
374 special_tokens,
375 })
376 }
377
378 pub fn save(&self, path: &Path) -> Result<(), SerializationError> {
380 let file = File::create(path).map_err(SerializationError::Io)?;
381 let mut writer = BufWriter::new(file);
382 self.save_to(&mut writer)?;
383 writer.flush().map_err(SerializationError::Io)?;
384 Ok(())
385 }
386
387 pub fn load(path: &Path) -> Result<Self, SerializationError> {
389 let file = File::open(path).map_err(SerializationError::Io)?;
390 let mut reader = BufReader::new(file);
391 Self::load_from(&mut reader)
392 }
393
394 pub fn to_oxi_tokenizer(&self) -> crate::OxiTokenizer {
396 use crate::{
397 bpe::BpeMerges,
398 tokenizer::{OxiTokenizer, TokenizerConfig},
399 vocab::Vocabulary,
400 };
401
402 let mut vocabulary = Vocabulary::new();
403 for (&id, token) in &self.vocab {
404 if self.special_tokens.contains_key(token.as_str()) {
405 vocabulary.add_special(token, id);
406 } else {
407 vocabulary.insert(token, id);
408 }
409 }
410
411 let mut bpe_merges = BpeMerges::new();
412 for &(left_id, right_id, merged_id) in &self.merges {
413 let left_str = self.vocab.get(&left_id).map(|s| s.as_str()).unwrap_or("");
414 let right_str = self.vocab.get(&right_id).map(|s| s.as_str()).unwrap_or("");
415 bpe_merges.add_merge(left_str, right_str, merged_id);
416 }
417
418 let config = TokenizerConfig::default();
419 OxiTokenizer::new(vocabulary, bpe_merges, config)
420 }
421}
422
423impl Default for TokenizerState {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429fn parse_count_line(
433 line: &str,
434 keyword: &str,
435 line_no: usize,
436) -> Result<usize, SerializationError> {
437 let parts: Vec<&str> = line.trim().splitn(2, ' ').collect();
438 if parts.len() != 2 || parts[0] != keyword {
439 return Err(SerializationError::ParseError {
440 line: line_no,
441 msg: format!("expected '{keyword} <N>', got '{line}'"),
442 });
443 }
444 parts[1]
445 .parse::<usize>()
446 .map_err(|_| SerializationError::ParseError {
447 line: line_no,
448 msg: format!("invalid count value '{}'", parts[1]),
449 })
450}
451
452#[cfg(test)]
455mod inline_tests {
456 use super::*;
457
458 #[test]
459 fn base64_encode_decode_hello() {
460 let original = b"Hello, World!";
461 let encoded = base64_encode(original);
462 let decoded = base64_decode(&encoded).expect("decode should succeed");
463 assert_eq!(decoded, original);
464 }
465
466 #[test]
467 fn base64_empty() {
468 let encoded = base64_encode(b"");
469 assert_eq!(encoded, "");
470 let decoded = base64_decode("").expect("decode empty");
471 assert!(decoded.is_empty());
472 }
473
474 #[test]
475 fn tokenizer_state_roundtrip_basic() {
476 let mut state = TokenizerState::new();
477 state.vocab.insert(0, "<unk>".to_string());
478 state.vocab.insert(1, "a".to_string());
479 state.merges.push((0, 1, 2));
480
481 let mut buf = Vec::new();
482 state.save_to(&mut buf).expect("save should succeed");
483
484 let mut reader = std::io::BufReader::new(buf.as_slice());
485 let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
486
487 assert_eq!(loaded.vocab.get(&0), Some(&"<unk>".to_string()));
488 assert_eq!(loaded.vocab.get(&1), Some(&"a".to_string()));
489 assert_eq!(loaded.merges, vec![(0, 1, 2)]);
490 }
491}