Skip to main content

sapient_tokenizers/
tokenizer.rs

1//! `SapientTokenizer` — wraps the HuggingFace `tokenizers` crate.
2
3use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokenizers::Tokenizer;
7
8// ── TokenizerOptions ──────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone)]
11pub struct TokenizerOptions {
12    /// Add BOS token at the start of every encoding (default: true for LLMs).
13    pub add_bos: bool,
14    /// Add EOS token at the end of every encoding (default: false — let templates control this).
15    pub add_eos: bool,
16    /// Truncate to this many tokens (0 = no limit).
17    pub max_length: usize,
18}
19
20impl Default for TokenizerOptions {
21    fn default() -> Self {
22        Self {
23            add_bos: true,
24            add_eos: false,
25            max_length: 0,
26        }
27    }
28}
29
30// ── SapientTokenizer ──────────────────────────────────────────────────────────
31
32/// A tokenizer loaded from a HuggingFace `tokenizer.json`.
33///
34/// Supports every tokenizer type HF ships: BPE, WordPiece, Unigram (SentencePiece).
35/// Known end-of-turn / end-of-sequence marker tokens across model families.
36/// A model may define several (e.g. Qwen has both `<|endoftext|>` and
37/// `<|im_end|>`); generation must stop on *any* of them.
38const EOS_CANDIDATES: &[&str] = &[
39    "</s>",
40    "<eos>",
41    "<|endoftext|>",
42    "<|end_of_text|>",
43    "<|eot_id|>",
44    "<|im_end|>",
45    "<end_of_turn>",
46    "<|redacted_EOS|>",
47];
48
49pub struct SapientTokenizer {
50    inner: Tokenizer,
51    pub bos_id: Option<u32>,
52    pub eos_id: Option<u32>,
53    /// Every EOS/turn-end token id present in this tokenizer's vocab.
54    pub eos_ids: Vec<u32>,
55    pub pad_id: Option<u32>,
56    opts: TokenizerOptions,
57}
58
59impl SapientTokenizer {
60    /// Load from a `tokenizer.json` file.
61    pub fn from_file(path: &Path, opts: TokenizerOptions) -> Result<Self> {
62        match Tokenizer::from_file(path) {
63            Ok(inner) => Self::from_inner(inner, opts),
64            Err(first_err) => {
65                let normalized = normalize_tokenizer_json(path).with_context(|| {
66                    format!("Failed to load tokenizer and could not normalize it: {first_err}")
67                })?;
68                let inner = Tokenizer::from_bytes(&normalized)
69                    .map_err(|e| anyhow::anyhow!("Failed to load normalized tokenizer: {e}"))?;
70                Self::from_inner(inner, opts)
71            }
72        }
73    }
74
75    /// Load from a HuggingFace model ID string (uses the HF Hub cache).
76    pub fn from_pretrained(model_id: &str) -> Result<Self> {
77        let inner = Tokenizer::from_pretrained(model_id, None)
78            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer for '{model_id}': {e}"))?;
79
80        let bos_id = Self::special_token_id(&inner, &["<s>", "<bos>", "<|begin_of_text|>"]);
81        let eos_ids = Self::all_special_token_ids(&inner, EOS_CANDIDATES);
82        let eos_id = eos_ids.first().copied();
83        let pad_id = Self::special_token_id(&inner, &["<pad>"]);
84
85        Ok(Self {
86            inner,
87            bos_id,
88            eos_id,
89            eos_ids,
90            pad_id,
91            opts: TokenizerOptions::default(),
92        })
93    }
94
95    /// Encode a text string to token IDs.
96    pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
97        let encoding = self
98            .inner
99            .encode(text, true)
100            .map_err(|e| anyhow::anyhow!("Tokenizer encode error: {e}"))?;
101
102        let mut ids = encoding.get_ids().to_vec();
103
104        // Prepend BOS if configured.
105        if self.opts.add_bos {
106            if let Some(bos) = self.bos_id {
107                if ids.first() != Some(&bos) {
108                    ids.insert(0, bos);
109                }
110            }
111        }
112
113        // Append EOS if configured.
114        if self.opts.add_eos {
115            if let Some(eos) = self.eos_id {
116                ids.push(eos);
117            }
118        }
119
120        // Truncate if needed.
121        if self.opts.max_length > 0 && ids.len() > self.opts.max_length {
122            ids.truncate(self.opts.max_length);
123        }
124
125        Ok(ids)
126    }
127
128    /// Decode token IDs back to a string.
129    pub fn decode(&self, ids: &[u32], skip_special: bool) -> Result<String> {
130        self.inner
131            .decode(ids, skip_special)
132            .map_err(|e| anyhow::anyhow!("Tokenizer decode error: {e}"))
133    }
134
135    /// Decode a single token ID to a string (for streaming).
136    pub fn decode_token(&self, id: u32) -> Result<String> {
137        self.decode(&[id], true)
138    }
139
140    /// Vocabulary size.
141    pub fn vocab_size(&self) -> usize {
142        self.inner.get_vocab_size(true)
143    }
144
145    // ── Helpers ───────────────────────────────────────────────────────────────
146
147    /// True if `id` is any of this tokenizer's end-of-sequence markers.
148    pub fn is_eos(&self, id: u32) -> bool {
149        self.eos_ids.contains(&id)
150    }
151
152    fn special_token_id(tok: &Tokenizer, candidates: &[&str]) -> Option<u32> {
153        for c in candidates {
154            if let Some(id) = tok.token_to_id(c) {
155                return Some(id);
156            }
157        }
158        None
159    }
160
161    /// All ids (in candidate order, deduplicated) for tokens present in the vocab.
162    fn all_special_token_ids(tok: &Tokenizer, candidates: &[&str]) -> Vec<u32> {
163        let mut ids = Vec::new();
164        for c in candidates {
165            if let Some(id) = tok.token_to_id(c) {
166                if !ids.contains(&id) {
167                    ids.push(id);
168                }
169            }
170        }
171        ids
172    }
173
174    fn from_inner(inner: Tokenizer, opts: TokenizerOptions) -> Result<Self> {
175        let bos_id =
176            Self::special_token_id(&inner, &["<s>", "<bos>", "<|begin_of_text|>", "[BOS]"]);
177        let eos_ids = Self::all_special_token_ids(&inner, EOS_CANDIDATES);
178        let eos_id = eos_ids.first().copied();
179        let pad_id =
180            Self::special_token_id(&inner, &["<pad>", "<|finetune_right_pad_id|>", "[PAD]"]);
181
182        Ok(Self {
183            inner,
184            bos_id,
185            eos_id,
186            eos_ids,
187            pad_id,
188            opts,
189        })
190    }
191}
192
193/// Normalize newer HuggingFace tokenizer JSON into a format older `tokenizers`
194/// versions can deserialize (e.g. BPE merges stored as `[a, b]` pairs).
195fn normalize_tokenizer_json(path: &Path) -> Result<Vec<u8>> {
196    let text = std::fs::read_to_string(path)?;
197    let mut value: serde_json::Value = serde_json::from_str(&text)?;
198
199    let Some(model) = value.get_mut("model") else {
200        anyhow::bail!("tokenizer.json missing model section");
201    };
202    let Some(merges) = model.get_mut("merges") else {
203        anyhow::bail!("tokenizer.json missing BPE merges");
204    };
205    let Some(arr) = merges.as_array_mut() else {
206        anyhow::bail!("tokenizer merges are not an array");
207    };
208    if arr.is_empty() {
209        return Ok(text.into_bytes());
210    }
211    if !arr[0].is_array() {
212        anyhow::bail!("tokenizer merges already use string format");
213    }
214
215    let normalized: Vec<String> = arr
216        .iter()
217        .filter_map(|entry| {
218            let pair = entry.as_array()?;
219            if pair.len() != 2 {
220                return None;
221            }
222            Some(format!("{} {}", pair[0].as_str()?, pair[1].as_str()?))
223        })
224        .collect();
225
226    if normalized.len() != arr.len() {
227        anyhow::bail!("failed to normalize all BPE merges");
228    }
229
230    *merges = serde_json::Value::Array(
231        normalized
232            .into_iter()
233            .map(serde_json::Value::String)
234            .collect(),
235    );
236
237    Ok(serde_json::to_vec(&value)?)
238}
239
240#[cfg(test)]
241mod tests {
242    // Integration tests require network access to download tokenizer.json.
243    // Run with: cargo test -p sapient-tokenizers -- --ignored
244    // (or point at a local tokenizer.json)
245}