1use std::path::Path;
22
23use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
24
25pub struct TokenizedBatch {
27 pub input_ids: Vec<Vec<u32>>,
28 pub attention_mask: Vec<Vec<u32>>,
29 pub token_type_ids: Vec<Vec<u32>>,
30 pub seq_len: usize,
32}
33
34pub struct BertTokenizer {
36 inner: Tokenizer,
37}
38
39impl BertTokenizer {
40 pub fn from_dir(dir: &Path, max_length: usize) -> anyhow::Result<Self> {
46 let tokenizer_json = std::fs::read(dir.join("tokenizer.json"))?;
47 let config_json = std::fs::read(dir.join("config.json"))?;
48 let special_tokens_map = std::fs::read(dir.join("special_tokens_map.json"))?;
49 let tokenizer_config = std::fs::read(dir.join("tokenizer_config.json"))?;
50
51 Self::from_bytes(
52 &tokenizer_json,
53 &config_json,
54 &special_tokens_map,
55 &tokenizer_config,
56 max_length,
57 )
58 }
59
60 pub fn from_bytes(
62 tokenizer_json: &[u8],
63 config_json: &[u8],
64 special_tokens_map_json: &[u8],
65 tokenizer_config_json: &[u8],
66 max_length: usize,
67 ) -> anyhow::Result<Self> {
68 let mut tokenizer = Tokenizer::from_bytes(tokenizer_json)
69 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
70
71 let config: serde_json::Value = serde_json::from_slice(config_json)?;
73 let tokenizer_config: serde_json::Value = serde_json::from_slice(tokenizer_config_json)?;
74 let special_tokens_map: serde_json::Value =
75 serde_json::from_slice(special_tokens_map_json)?;
76
77 let model_max_length = tokenizer_config
79 .get("model_max_length")
80 .and_then(|v| v.as_f64())
81 .map(|v| v.min(1e9) as usize)
82 .unwrap_or(512);
83 let effective_max_length = max_length.min(model_max_length);
84
85 let pad_token = tokenizer_config
87 .get("pad_token")
88 .and_then(|v| v.as_str())
89 .unwrap_or("[PAD]")
90 .to_string();
91 let pad_token_id = config
92 .get("pad_token_id")
93 .and_then(|v| v.as_u64())
94 .unwrap_or(0) as u32;
95
96 tokenizer.with_padding(Some(PaddingParams {
98 strategy: PaddingStrategy::BatchLongest,
99 pad_token: pad_token.clone(),
100 pad_id: pad_token_id,
101 ..PaddingParams::default()
102 }));
103
104 tokenizer
106 .with_truncation(Some(TruncationParams {
107 max_length: effective_max_length,
108 ..TruncationParams::default()
109 }))
110 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
111
112 let mut special_tokens = Vec::new();
114 if let Some(map) = special_tokens_map.as_object() {
115 for (_key, value) in map {
116 match value {
117 serde_json::Value::String(s) => {
118 special_tokens.push(AddedToken::from(s.clone(), true));
119 }
120 serde_json::Value::Object(obj) => {
121 if let Some(content) = obj.get("content").and_then(|v| v.as_str()) {
122 special_tokens.push(AddedToken::from(content.to_string(), true));
123 }
124 }
125 serde_json::Value::Array(arr) => {
126 for item in arr {
127 match item {
128 serde_json::Value::String(s) => {
129 special_tokens.push(AddedToken::from(s.clone(), true));
130 }
131 serde_json::Value::Object(obj) => {
132 if let Some(content) =
133 obj.get("content").and_then(|v| v.as_str())
134 {
135 special_tokens
136 .push(AddedToken::from(content.to_string(), true));
137 }
138 }
139 _ => {}
140 }
141 }
142 }
143 _ => {}
144 }
145 }
146 }
147 if !special_tokens.is_empty() {
148 tokenizer.add_special_tokens(&special_tokens);
149 }
150
151 Ok(Self { inner: tokenizer })
152 }
153
154 pub fn encode_batch(&self, texts: &[&str]) -> anyhow::Result<TokenizedBatch> {
159 let encodings = self
160 .inner
161 .encode_batch(texts.to_vec(), true)
162 .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
163
164 let seq_len = encodings
165 .first()
166 .ok_or_else(|| anyhow::anyhow!("empty batch"))?
167 .len();
168
169 let mut input_ids = Vec::with_capacity(texts.len());
170 let mut attention_mask = Vec::with_capacity(texts.len());
171 let mut token_type_ids = Vec::with_capacity(texts.len());
172
173 for enc in &encodings {
174 input_ids.push(enc.get_ids().to_vec());
175 attention_mask.push(enc.get_attention_mask().to_vec());
176 token_type_ids.push(enc.get_type_ids().to_vec());
177 }
178
179 Ok(TokenizedBatch {
180 input_ids,
181 attention_mask,
182 token_type_ids,
183 seq_len,
184 })
185 }
186}