Skip to main content

rlx_embed/
tokenizer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HuggingFace tokenizer wrapper for BERT-style text tokenization.
17//!
18//! Handles loading tokenizer files, configuring padding/truncation,
19//! and batch encoding of text inputs.
20
21use std::path::Path;
22
23use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
24
25/// Output of batch tokenization: token IDs, attention masks, and token type IDs.
26pub struct TokenizedBatch {
27    pub input_ids: Vec<Vec<u32>>,
28    pub attention_mask: Vec<Vec<u32>>,
29    pub token_type_ids: Vec<Vec<u32>>,
30    /// Sequence length (max length in this batch after padding).
31    pub seq_len: usize,
32}
33
34/// Wrapper around HuggingFace tokenizer configured for BERT-style encoding.
35pub struct BertTokenizer {
36    inner: Tokenizer,
37}
38
39impl BertTokenizer {
40    /// Load tokenizer from a model directory containing:
41    /// - `tokenizer.json`
42    /// - `config.json`
43    /// - `special_tokens_map.json`
44    /// - `tokenizer_config.json`
45    pub fn from_dir(dir: &Path, max_length: usize) -> anyhow::Result<Self> {
46        let tokenizer_json = std::fs::read(dir.join("tokenizer.json"))?;
47        let config_json = std::fs::read(dir.join("config.json"))?;
48        let special_tokens_map = std::fs::read(dir.join("special_tokens_map.json"))?;
49        let tokenizer_config = std::fs::read(dir.join("tokenizer_config.json"))?;
50
51        Self::from_bytes(
52            &tokenizer_json,
53            &config_json,
54            &special_tokens_map,
55            &tokenizer_config,
56            max_length,
57        )
58    }
59
60    /// Load tokenizer from raw file bytes.
61    pub fn from_bytes(
62        tokenizer_json: &[u8],
63        config_json: &[u8],
64        special_tokens_map_json: &[u8],
65        tokenizer_config_json: &[u8],
66        max_length: usize,
67    ) -> anyhow::Result<Self> {
68        let mut tokenizer = Tokenizer::from_bytes(tokenizer_json)
69            .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
70
71        // Parse config files
72        let config: serde_json::Value = serde_json::from_slice(config_json)?;
73        let tokenizer_config: serde_json::Value = serde_json::from_slice(tokenizer_config_json)?;
74        let special_tokens_map: serde_json::Value =
75            serde_json::from_slice(special_tokens_map_json)?;
76
77        // Determine max length from tokenizer_config
78        let model_max_length = tokenizer_config
79            .get("model_max_length")
80            .and_then(|v| v.as_f64())
81            .map(|v| v.min(1e9) as usize)
82            .unwrap_or(512);
83        let effective_max_length = max_length.min(model_max_length);
84
85        // Determine pad token and id
86        let pad_token = tokenizer_config
87            .get("pad_token")
88            .and_then(|v| v.as_str())
89            .unwrap_or("[PAD]")
90            .to_string();
91        let pad_token_id = config
92            .get("pad_token_id")
93            .and_then(|v| v.as_u64())
94            .unwrap_or(0) as u32;
95
96        // Configure padding: pad to longest in batch
97        tokenizer.with_padding(Some(PaddingParams {
98            strategy: PaddingStrategy::BatchLongest,
99            pad_token: pad_token.clone(),
100            pad_id: pad_token_id,
101            ..PaddingParams::default()
102        }));
103
104        // Configure truncation
105        tokenizer
106            .with_truncation(Some(TruncationParams {
107                max_length: effective_max_length,
108                ..TruncationParams::default()
109            }))
110            .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
111
112        // Add special tokens from special_tokens_map
113        let mut special_tokens = Vec::new();
114        if let Some(map) = special_tokens_map.as_object() {
115            for (_key, value) in map {
116                match value {
117                    serde_json::Value::String(s) => {
118                        special_tokens.push(AddedToken::from(s.clone(), true));
119                    }
120                    serde_json::Value::Object(obj) => {
121                        if let Some(content) = obj.get("content").and_then(|v| v.as_str()) {
122                            special_tokens.push(AddedToken::from(content.to_string(), true));
123                        }
124                    }
125                    serde_json::Value::Array(arr) => {
126                        for item in arr {
127                            match item {
128                                serde_json::Value::String(s) => {
129                                    special_tokens.push(AddedToken::from(s.clone(), true));
130                                }
131                                serde_json::Value::Object(obj) => {
132                                    if let Some(content) =
133                                        obj.get("content").and_then(|v| v.as_str())
134                                    {
135                                        special_tokens
136                                            .push(AddedToken::from(content.to_string(), true));
137                                    }
138                                }
139                                _ => {}
140                            }
141                        }
142                    }
143                    _ => {}
144                }
145            }
146        }
147        if !special_tokens.is_empty() {
148            tokenizer.add_special_tokens(&special_tokens);
149        }
150
151        Ok(Self { inner: tokenizer })
152    }
153
154    /// Tokenize a batch of texts.
155    ///
156    /// Returns input_ids, attention_mask, and token_type_ids for each text,
157    /// all padded to the same length (longest in batch).
158    pub fn encode_batch(&self, texts: &[&str]) -> anyhow::Result<TokenizedBatch> {
159        let encodings = self
160            .inner
161            .encode_batch(texts.to_vec(), true)
162            .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
163
164        let seq_len = encodings
165            .first()
166            .ok_or_else(|| anyhow::anyhow!("empty batch"))?
167            .len();
168
169        let mut input_ids = Vec::with_capacity(texts.len());
170        let mut attention_mask = Vec::with_capacity(texts.len());
171        let mut token_type_ids = Vec::with_capacity(texts.len());
172
173        for enc in &encodings {
174            input_ids.push(enc.get_ids().to_vec());
175            attention_mask.push(enc.get_attention_mask().to_vec());
176            token_type_ids.push(enc.get_type_ids().to_vec());
177        }
178
179        Ok(TokenizedBatch {
180            input_ids,
181            attention_mask,
182            token_type_ids,
183            seq_len,
184        })
185    }
186}