Skip to main content

oxibonsai_tokenizer/
serialization.rs

1//! Tokenizer serialization: save and load tokenizer state to/from text files.
2//!
3//! Format (plain text, UTF-8):
4//! Line 1: "oxitokenizer v1"
5//! Line 2: `"vocab_size <N>"`
6//! Line 3: `"merges <M>"`
7//! Lines 4..(4+N): `"tok_id <id> <token_text_base64>"`
8//! Lines (4+N)..(4+N+M): `"merge <left_id> <right_id> <merged_id>"`
9//! Special tokens (if any): `"special <token_text_base64> <id>"`
10
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::path::Path;
15
16/// Header magic string.
17pub const FORMAT_MAGIC: &str = "oxitokenizer v1";
18
19// ── Base64 implementation ─────────────────────────────────────────────────────
20
21const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
22
23/// Encode bytes to standard base64 (3 bytes → 4 chars, padded with '=').
24pub fn base64_encode(bytes: &[u8]) -> String {
25    let mut out = Vec::with_capacity(bytes.len().div_ceil(3) * 4);
26
27    for chunk in bytes.chunks(3) {
28        let b0 = chunk[0] as u32;
29        let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
30        let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
31
32        let combined = (b0 << 16) | (b1 << 8) | b2;
33
34        out.push(BASE64_CHARS[((combined >> 18) & 0x3F) as usize]);
35        out.push(BASE64_CHARS[((combined >> 12) & 0x3F) as usize]);
36
37        if chunk.len() > 1 {
38            out.push(BASE64_CHARS[((combined >> 6) & 0x3F) as usize]);
39        } else {
40            out.push(b'=');
41        }
42
43        if chunk.len() > 2 {
44            out.push(BASE64_CHARS[(combined & 0x3F) as usize]);
45        } else {
46            out.push(b'=');
47        }
48    }
49
50    // SAFETY: `BASE64_CHARS` is ASCII-only and `=` is ASCII, so every byte we
51    // pushed is a valid ASCII character.  `String::from_utf8` therefore cannot
52    // fail — but rather than `unwrap`, we fall back to the empty string
53    // (keeps the function panic-free, matching the no-unwrap policy).  The
54    // caller's "decode failed" path will flag any such inconsistency cleanly.
55    String::from_utf8(out).unwrap_or_default()
56}
57
58/// Decode a standard base64 string back to bytes.
59pub fn base64_decode(s: &str) -> Result<Vec<u8>, SerializationError> {
60    let s = s.trim_end_matches('=');
61    let mut out = Vec::with_capacity((s.len() * 3) / 4 + 1);
62
63    let decode_char = |c: u8| -> Option<u32> {
64        match c {
65            b'A'..=b'Z' => Some((c - b'A') as u32),
66            b'a'..=b'z' => Some((c - b'a' + 26) as u32),
67            b'0'..=b'9' => Some((c - b'0' + 52) as u32),
68            b'+' => Some(62),
69            b'/' => Some(63),
70            _ => None,
71        }
72    };
73
74    let chars: Vec<u8> = s.bytes().collect();
75
76    for chunk in chars.chunks(4) {
77        match chunk.len() {
78            4 => {
79                let v0 = decode_char(chunk[0]).ok_or_else(|| {
80                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
81                })?;
82                let v1 = decode_char(chunk[1]).ok_or_else(|| {
83                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
84                })?;
85                let v2 = decode_char(chunk[2]).ok_or_else(|| {
86                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
87                })?;
88                let v3 = decode_char(chunk[3]).ok_or_else(|| {
89                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[3] as char))
90                })?;
91                let combined = (v0 << 18) | (v1 << 12) | (v2 << 6) | v3;
92                out.push(((combined >> 16) & 0xFF) as u8);
93                out.push(((combined >> 8) & 0xFF) as u8);
94                out.push((combined & 0xFF) as u8);
95            }
96            3 => {
97                let v0 = decode_char(chunk[0]).ok_or_else(|| {
98                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
99                })?;
100                let v1 = decode_char(chunk[1]).ok_or_else(|| {
101                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
102                })?;
103                let v2 = decode_char(chunk[2]).ok_or_else(|| {
104                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
105                })?;
106                let combined = (v0 << 18) | (v1 << 12) | (v2 << 6);
107                out.push(((combined >> 16) & 0xFF) as u8);
108                out.push(((combined >> 8) & 0xFF) as u8);
109            }
110            2 => {
111                let v0 = decode_char(chunk[0]).ok_or_else(|| {
112                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
113                })?;
114                let v1 = decode_char(chunk[1]).ok_or_else(|| {
115                    SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
116                })?;
117                let combined = (v0 << 18) | (v1 << 12);
118                out.push(((combined >> 16) & 0xFF) as u8);
119            }
120            1 => {
121                return Err(SerializationError::Base64Error(
122                    "truncated base64 group of 1 char".to_string(),
123                ));
124            }
125            _ => {}
126        }
127    }
128
129    Ok(out)
130}
131
132// ── SerializationError ────────────────────────────────────────────────────────
133
134/// Errors that can occur during tokenizer serialization/deserialization.
135#[derive(Debug, thiserror::Error)]
136pub enum SerializationError {
137    #[error("I/O error: {0}")]
138    Io(#[from] std::io::Error),
139
140    #[error("invalid format magic: expected '{expected}', got '{got}'")]
141    InvalidMagic { expected: String, got: String },
142
143    #[error("parse error on line {line}: {msg}")]
144    ParseError { line: usize, msg: String },
145
146    #[error("base64 decode error: {0}")]
147    Base64Error(String),
148
149    #[error("duplicate token id {0}")]
150    DuplicateId(u32),
151}
152
153// ── TokenizerState ────────────────────────────────────────────────────────────
154
155/// A serializable snapshot of a trained tokenizer.
156#[derive(Debug)]
157pub struct TokenizerState {
158    /// id → token string
159    pub vocab: HashMap<u32, String>,
160    /// (left_id, right_id, merged_id)
161    pub merges: Vec<(u32, u32, u32)>,
162    /// special token name → id (e.g. `"<BOS>"` → 1)
163    pub special_tokens: HashMap<String, u32>,
164}
165
166impl TokenizerState {
167    /// Create an empty state.
168    pub fn new() -> Self {
169        Self {
170            vocab: HashMap::new(),
171            merges: Vec::new(),
172            special_tokens: HashMap::new(),
173        }
174    }
175
176    /// Build a `TokenizerState` from a [`crate::trainer::TrainedTokenizer`].
177    pub fn from_trained(trained: &crate::trainer::TrainedTokenizer) -> Self {
178        let mut state = Self::new();
179
180        for (&id, token) in &trained.vocab {
181            if token.starts_with('<') && token.ends_with('>') {
182                state.special_tokens.insert(token.clone(), id);
183            }
184            state.vocab.insert(id, token.clone());
185        }
186
187        for rule in &trained.merges {
188            state.merges.push((rule.left, rule.right, rule.merged));
189        }
190
191        state
192    }
193
194    /// Number of vocabulary entries.
195    pub fn vocab_size(&self) -> usize {
196        self.vocab.len()
197    }
198
199    /// Save to a writer.
200    ///
201    /// The format is deterministic: vocab entries are written sorted by id,
202    /// merges in their original order, special tokens sorted by name.
203    pub fn save_to<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
204        // Line 1: magic
205        writeln!(writer, "{}", FORMAT_MAGIC)?;
206
207        // Line 2: vocab_size
208        writeln!(writer, "vocab_size {}", self.vocab.len())?;
209
210        // Line 3: merges count
211        writeln!(writer, "merges {}", self.merges.len())?;
212
213        // Lines 4..(4+N): tok_id entries sorted by id for determinism
214        let mut vocab_entries: Vec<(u32, &str)> =
215            self.vocab.iter().map(|(&id, s)| (id, s.as_str())).collect();
216        vocab_entries.sort_by_key(|(id, _)| *id);
217
218        for (id, token) in &vocab_entries {
219            let encoded = base64_encode(token.as_bytes());
220            writeln!(writer, "tok_id {id} {encoded}")?;
221        }
222
223        // Merge rules in original order
224        for &(left, right, merged) in &self.merges {
225            writeln!(writer, "merge {left} {right} {merged}")?;
226        }
227
228        // Special tokens sorted by name
229        let mut special_entries: Vec<(&str, u32)> = self
230            .special_tokens
231            .iter()
232            .map(|(k, &v)| (k.as_str(), v))
233            .collect();
234        special_entries.sort_by_key(|(name, _)| *name);
235
236        for (name, id) in &special_entries {
237            let encoded = base64_encode(name.as_bytes());
238            writeln!(writer, "special {encoded} {id}")?;
239        }
240
241        Ok(())
242    }
243
244    /// Load from a reader.
245    pub fn load_from<R: BufRead>(reader: &mut R) -> Result<Self, SerializationError> {
246        let mut lines = reader.lines();
247        let mut line_no: usize = 0;
248
249        // Helper to read the next non-empty line
250        let mut next_line = |line_no: &mut usize| -> Result<String, SerializationError> {
251            *line_no += 1;
252            match lines.next() {
253                Some(Ok(l)) => Ok(l),
254                Some(Err(e)) => Err(SerializationError::Io(e)),
255                None => Err(SerializationError::ParseError {
256                    line: *line_no,
257                    msg: "unexpected end of file".to_string(),
258                }),
259            }
260        };
261
262        // Line 1: magic
263        let magic_line = next_line(&mut line_no)?;
264        if magic_line.trim() != FORMAT_MAGIC {
265            return Err(SerializationError::InvalidMagic {
266                expected: FORMAT_MAGIC.to_string(),
267                got: magic_line.trim().to_string(),
268            });
269        }
270
271        // Line 2: vocab_size <N>
272        let vocab_size_line = next_line(&mut line_no)?;
273        let vocab_size = parse_count_line(&vocab_size_line, "vocab_size", line_no)?;
274
275        // Line 3: merges <M>
276        let merges_line = next_line(&mut line_no)?;
277        let merges_count = parse_count_line(&merges_line, "merges", line_no)?;
278
279        // Read vocab entries
280        let mut vocab: HashMap<u32, String> = HashMap::with_capacity(vocab_size);
281        for _ in 0..vocab_size {
282            let l = next_line(&mut line_no)?;
283            let parts: Vec<&str> = l.trim().splitn(3, ' ').collect();
284            if parts.len() != 3 || parts[0] != "tok_id" {
285                return Err(SerializationError::ParseError {
286                    line: line_no,
287                    msg: format!("expected 'tok_id <id> <b64>', got '{l}'"),
288                });
289            }
290            let id: u32 = parts[1]
291                .parse()
292                .map_err(|_| SerializationError::ParseError {
293                    line: line_no,
294                    msg: format!("invalid token id '{}'", parts[1]),
295                })?;
296            let token_bytes = base64_decode(parts[2])?;
297            let token =
298                String::from_utf8(token_bytes).map_err(|_| SerializationError::ParseError {
299                    line: line_no,
300                    msg: "token text is not valid UTF-8".to_string(),
301                })?;
302            if vocab.contains_key(&id) {
303                return Err(SerializationError::DuplicateId(id));
304            }
305            vocab.insert(id, token);
306        }
307
308        // Read merge rules
309        let mut merges: Vec<(u32, u32, u32)> = Vec::with_capacity(merges_count);
310        for _ in 0..merges_count {
311            let l = next_line(&mut line_no)?;
312            let parts: Vec<&str> = l.trim().splitn(4, ' ').collect();
313            if parts.len() != 4 || parts[0] != "merge" {
314                return Err(SerializationError::ParseError {
315                    line: line_no,
316                    msg: format!("expected 'merge <left> <right> <merged>', got '{l}'"),
317                });
318            }
319            let left: u32 = parts[1]
320                .parse()
321                .map_err(|_| SerializationError::ParseError {
322                    line: line_no,
323                    msg: format!("invalid merge left id '{}'", parts[1]),
324                })?;
325            let right: u32 = parts[2]
326                .parse()
327                .map_err(|_| SerializationError::ParseError {
328                    line: line_no,
329                    msg: format!("invalid merge right id '{}'", parts[2]),
330                })?;
331            let merged: u32 = parts[3]
332                .parse()
333                .map_err(|_| SerializationError::ParseError {
334                    line: line_no,
335                    msg: format!("invalid merge merged id '{}'", parts[3]),
336                })?;
337            merges.push((left, right, merged));
338        }
339
340        // Read remaining lines as special tokens (optional section)
341        let mut special_tokens: HashMap<String, u32> = HashMap::new();
342        for maybe_line in lines {
343            line_no += 1;
344            let l = maybe_line.map_err(SerializationError::Io)?;
345            let l = l.trim();
346            if l.is_empty() {
347                continue;
348            }
349            let parts: Vec<&str> = l.splitn(3, ' ').collect();
350            if parts.len() != 3 || parts[0] != "special" {
351                return Err(SerializationError::ParseError {
352                    line: line_no,
353                    msg: format!("expected 'special <b64> <id>', got '{l}'"),
354                });
355            }
356            let name_bytes = base64_decode(parts[1])?;
357            let name =
358                String::from_utf8(name_bytes).map_err(|_| SerializationError::ParseError {
359                    line: line_no,
360                    msg: "special token name is not valid UTF-8".to_string(),
361                })?;
362            let id: u32 = parts[2]
363                .parse()
364                .map_err(|_| SerializationError::ParseError {
365                    line: line_no,
366                    msg: format!("invalid special token id '{}'", parts[2]),
367                })?;
368            special_tokens.insert(name, id);
369        }
370
371        Ok(TokenizerState {
372            vocab,
373            merges,
374            special_tokens,
375        })
376    }
377
378    /// Save to a file path.
379    pub fn save(&self, path: &Path) -> Result<(), SerializationError> {
380        let file = File::create(path).map_err(SerializationError::Io)?;
381        let mut writer = BufWriter::new(file);
382        self.save_to(&mut writer)?;
383        writer.flush().map_err(SerializationError::Io)?;
384        Ok(())
385    }
386
387    /// Load from a file path.
388    pub fn load(path: &Path) -> Result<Self, SerializationError> {
389        let file = File::open(path).map_err(SerializationError::Io)?;
390        let mut reader = BufReader::new(file);
391        Self::load_from(&mut reader)
392    }
393
394    /// Convert to an [`crate::OxiTokenizer`] (char-level fallback using our vocab).
395    pub fn to_oxi_tokenizer(&self) -> crate::OxiTokenizer {
396        use crate::{
397            bpe::BpeMerges,
398            tokenizer::{OxiTokenizer, TokenizerConfig},
399            vocab::Vocabulary,
400        };
401
402        let mut vocabulary = Vocabulary::new();
403        for (&id, token) in &self.vocab {
404            if self.special_tokens.contains_key(token.as_str()) {
405                vocabulary.add_special(token, id);
406            } else {
407                vocabulary.insert(token, id);
408            }
409        }
410
411        let mut bpe_merges = BpeMerges::new();
412        for &(left_id, right_id, merged_id) in &self.merges {
413            let left_str = self.vocab.get(&left_id).map(|s| s.as_str()).unwrap_or("");
414            let right_str = self.vocab.get(&right_id).map(|s| s.as_str()).unwrap_or("");
415            bpe_merges.add_merge(left_str, right_str, merged_id);
416        }
417
418        let config = TokenizerConfig::default();
419        OxiTokenizer::new(vocabulary, bpe_merges, config)
420    }
421}
422
423impl Default for TokenizerState {
424    fn default() -> Self {
425        Self::new()
426    }
427}
428
429// ── Helpers ───────────────────────────────────────────────────────────────────
430
431/// Parse a line of the form `<keyword> <count>`.
432fn parse_count_line(
433    line: &str,
434    keyword: &str,
435    line_no: usize,
436) -> Result<usize, SerializationError> {
437    let parts: Vec<&str> = line.trim().splitn(2, ' ').collect();
438    if parts.len() != 2 || parts[0] != keyword {
439        return Err(SerializationError::ParseError {
440            line: line_no,
441            msg: format!("expected '{keyword} <N>', got '{line}'"),
442        });
443    }
444    parts[1]
445        .parse::<usize>()
446        .map_err(|_| SerializationError::ParseError {
447            line: line_no,
448            msg: format!("invalid count value '{}'", parts[1]),
449        })
450}
451
452// ── Tests ─────────────────────────────────────────────────────────────────────
453
454#[cfg(test)]
455mod inline_tests {
456    use super::*;
457
458    #[test]
459    fn base64_encode_decode_hello() {
460        let original = b"Hello, World!";
461        let encoded = base64_encode(original);
462        let decoded = base64_decode(&encoded).expect("decode should succeed");
463        assert_eq!(decoded, original);
464    }
465
466    #[test]
467    fn base64_empty() {
468        let encoded = base64_encode(b"");
469        assert_eq!(encoded, "");
470        let decoded = base64_decode("").expect("decode empty");
471        assert!(decoded.is_empty());
472    }
473
474    #[test]
475    fn tokenizer_state_roundtrip_basic() {
476        let mut state = TokenizerState::new();
477        state.vocab.insert(0, "<unk>".to_string());
478        state.vocab.insert(1, "a".to_string());
479        state.merges.push((0, 1, 2));
480
481        let mut buf = Vec::new();
482        state.save_to(&mut buf).expect("save should succeed");
483
484        let mut reader = std::io::BufReader::new(buf.as_slice());
485        let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
486
487        assert_eq!(loaded.vocab.get(&0), Some(&"<unk>".to_string()));
488        assert_eq!(loaded.vocab.get(&1), Some(&"a".to_string()));
489        assert_eq!(loaded.merges, vec![(0, 1, 2)]);
490    }
491}