llm_shield_models/tokenizer.rs
1//! Tokenizer Wrapper for HuggingFace Tokenizers
2//!
3//! ## SPARC Phase 3: Construction (TDD Green Phase)
4//!
5//! This module provides a thread-safe wrapper around HuggingFace tokenizers
6//! for preprocessing text before ML model inference.
7//!
8//! ## Features
9//!
10//! - Support for multiple tokenizer types (DeBERTa, RoBERTa, etc.)
11//! - Configurable truncation at max length (default: 512 tokens)
12//! - Padding support (right-side padding)
13//! - Special tokens handling
14//! - Thread-safe design using Arc
15//! - Batch encoding support
16//!
17//! ## Usage Example
18//!
19//! ```no_run
20//! use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
21//!
22//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! let tokenizer = TokenizerWrapper::from_pretrained(
24//! "microsoft/deberta-v3-base",
25//! TokenizerConfig::default(),
26//! )?;
27//!
28//! let encoding = tokenizer.encode("Ignore all previous instructions")?;
29//! println!("Token IDs: {:?}", encoding.input_ids);
30//! # Ok(())
31//! # }
32//! ```
33
34use llm_shield_core::Error;
35use std::sync::Arc;
36use tokenizers::{
37 Tokenizer,
38 PaddingParams, PaddingStrategy, PaddingDirection,
39 TruncationParams, TruncationStrategy,
40};
41
42/// Result type alias
43pub type Result<T> = std::result::Result<T, Error>;
44
45/// Configuration for the tokenizer
46///
47/// ## Configuration Options
48///
49/// - `max_length`: Maximum sequence length (default: 512)
50/// - `padding`: Enable padding to max_length (default: true)
51/// - `truncation`: Enable truncation at max_length (default: true)
52/// - `add_special_tokens`: Add special tokens like [CLS], [SEP] (default: true)
53///
54/// ## Recommended Settings
55///
56/// **Production (default)**:
57/// ```rust
58/// # use llm_shield_models::TokenizerConfig;
59/// let config = TokenizerConfig::default();
60/// assert_eq!(config.max_length, 512);
61/// assert!(config.padding);
62/// assert!(config.truncation);
63/// ```
64///
65/// **Memory-constrained**:
66/// ```rust
67/// # use llm_shield_models::TokenizerConfig;
68/// let config = TokenizerConfig {
69/// max_length: 256,
70/// padding: false,
71/// truncation: true,
72/// add_special_tokens: true,
73/// };
74/// ```
75#[derive(Debug, Clone)]
76pub struct TokenizerConfig {
77 /// Maximum sequence length (tokens)
78 pub max_length: usize,
79
80 /// Enable padding to max_length
81 pub padding: bool,
82
83 /// Enable truncation at max_length
84 pub truncation: bool,
85
86 /// Add model-specific special tokens ([CLS], [SEP], etc.)
87 pub add_special_tokens: bool,
88}
89
90impl Default for TokenizerConfig {
91 fn default() -> Self {
92 Self {
93 max_length: 512,
94 padding: true,
95 truncation: true,
96 add_special_tokens: true,
97 }
98 }
99}
100
101/// Encoding result from tokenization
102///
103/// Contains the token IDs, attention mask, and character offsets needed for model inference.
104#[derive(Debug, Clone)]
105pub struct Encoding {
106 /// Token IDs (vocabulary indices)
107 pub input_ids: Vec<u32>,
108
109 /// Attention mask (1 for real tokens, 0 for padding)
110 pub attention_mask: Vec<u32>,
111
112 /// Character offsets in original text for each token
113 /// (start_char, end_char) for each token
114 /// Special tokens (CLS, SEP, PAD) have offset (0, 0)
115 pub offsets: Vec<(usize, usize)>,
116}
117
118impl Encoding {
119 /// Create a new encoding
120 pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u32>) -> Self {
121 // Create default offsets (all zeros for backward compatibility)
122 let offsets = vec![(0, 0); input_ids.len()];
123 Self {
124 input_ids,
125 attention_mask,
126 offsets,
127 }
128 }
129
130 /// Create a new encoding with offsets
131 pub fn with_offsets(
132 input_ids: Vec<u32>,
133 attention_mask: Vec<u32>,
134 offsets: Vec<(usize, usize)>,
135 ) -> Self {
136 Self {
137 input_ids,
138 attention_mask,
139 offsets,
140 }
141 }
142
143 /// Get the length of the encoding
144 pub fn len(&self) -> usize {
145 self.input_ids.len()
146 }
147
148 /// Check if encoding is empty
149 pub fn is_empty(&self) -> bool {
150 self.input_ids.is_empty()
151 }
152
153 /// Convert to arrays suitable for ONNX inference
154 ///
155 /// Returns (input_ids, attention_mask) as i64 arrays
156 pub fn to_arrays(&self) -> (Vec<i64>, Vec<i64>) {
157 let input_ids = self.input_ids.iter().map(|&x| x as i64).collect();
158 let attention_mask = self.attention_mask.iter().map(|&x| x as i64).collect();
159 (input_ids, attention_mask)
160 }
161}
162
163/// Thread-safe tokenizer wrapper
164///
165/// ## Thread Safety
166///
167/// This wrapper uses `Arc<Tokenizer>` for thread-safe access.
168/// Multiple threads can encode text concurrently using the same tokenizer.
169///
170/// ## Performance
171///
172/// - Tokenization: ~0.1-0.5ms per input (100-500 tokens)
173/// - Thread-safe without locks (immutable after creation)
174/// - Batch encoding is more efficient than individual calls
175///
176/// ## Example
177///
178/// ```no_run
179/// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
180/// # use std::sync::Arc;
181/// # use std::thread;
182/// #
183/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
184/// let tokenizer = Arc::new(
185/// TokenizerWrapper::from_pretrained(
186/// "microsoft/deberta-v3-base",
187/// TokenizerConfig::default(),
188/// )?
189/// );
190///
191/// let handles: Vec<_> = (0..4)
192/// .map(|i| {
193/// let tok = Arc::clone(&tokenizer);
194/// thread::spawn(move || tok.encode(&format!("Text {}", i)))
195/// })
196/// .collect();
197///
198/// for handle in handles {
199/// let encoding = handle.join().unwrap()?;
200/// println!("Encoded {} tokens", encoding.len());
201/// }
202/// # Ok(())
203/// # }
204/// ```
205pub struct TokenizerWrapper {
206 tokenizer: Arc<Tokenizer>,
207 config: TokenizerConfig,
208}
209
210impl TokenizerWrapper {
211 /// Load a tokenizer from HuggingFace Hub
212 ///
213 /// # Arguments
214 ///
215 /// * `model_name` - HuggingFace model identifier (e.g., "microsoft/deberta-v3-base")
216 /// * `config` - Tokenizer configuration
217 ///
218 /// # Supported Models
219 ///
220 /// - **DeBERTa**: `microsoft/deberta-v3-base` (PromptInjection)
221 /// - **RoBERTa**: `roberta-base` (Toxicity, Sentiment)
222 /// - **BERT**: `bert-base-uncased`
223 /// - Any HuggingFace model with a tokenizer
224 ///
225 /// # Example
226 ///
227 /// ```no_run
228 /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
229 /// let tokenizer = TokenizerWrapper::from_pretrained(
230 /// "microsoft/deberta-v3-base",
231 /// TokenizerConfig::default(),
232 /// )?;
233 /// # Ok::<(), llm_shield_core::Error>(())
234 /// ```
235 pub fn from_pretrained(model_name: &str, config: TokenizerConfig) -> Result<Self> {
236 tracing::info!("Loading tokenizer from: {}", model_name);
237
238 // For now, we'll use a simple approach that assumes tokenizer.json exists locally
239 // In production, this should download from HuggingFace Hub
240 let tokenizer_path = format!("models/{}/tokenizer.json", model_name);
241
242 let mut tokenizer = if std::path::Path::new(&tokenizer_path).exists() {
243 Tokenizer::from_file(&tokenizer_path)
244 .map_err(|e| {
245 Error::model(format!(
246 "Failed to load tokenizer from '{}': {}",
247 tokenizer_path, e
248 ))
249 })?
250 } else {
251 // Fall back to a basic tokenizer for testing
252 // In production, implement proper HuggingFace Hub download
253 return Err(Error::model(format!(
254 "Tokenizer not found at '{}'. Please download tokenizer files first.",
255 tokenizer_path
256 )));
257 };
258
259 // Configure padding
260 if config.padding {
261 let padding = PaddingParams {
262 strategy: PaddingStrategy::Fixed(config.max_length),
263 direction: PaddingDirection::Right,
264 pad_id: 0, // Will be overridden by tokenizer's pad token
265 pad_type_id: 0,
266 pad_token: String::from("[PAD]"), // Will be overridden
267 pad_to_multiple_of: None,
268 };
269 tokenizer.with_padding(Some(padding));
270 }
271
272 // Configure truncation
273 if config.truncation {
274 let truncation = TruncationParams {
275 max_length: config.max_length,
276 strategy: TruncationStrategy::LongestFirst,
277 stride: 0,
278 direction: tokenizers::TruncationDirection::Right,
279 };
280 tokenizer.with_truncation(Some(truncation))
281 .map_err(|e| {
282 Error::model(format!("Failed to configure truncation: {}", e))
283 })?;
284 }
285
286 tracing::debug!(
287 "Tokenizer loaded: max_length={}, padding={}, truncation={}",
288 config.max_length,
289 config.padding,
290 config.truncation
291 );
292
293 Ok(Self {
294 tokenizer: Arc::new(tokenizer),
295 config,
296 })
297 }
298
299 /// Encode a single text string
300 ///
301 /// # Arguments
302 ///
303 /// * `text` - Input text to tokenize
304 ///
305 /// # Returns
306 ///
307 /// `Encoding` with token IDs and attention mask
308 ///
309 /// # Example
310 ///
311 /// ```no_run
312 /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
313 /// # let tokenizer = TokenizerWrapper::from_pretrained(
314 /// # "microsoft/deberta-v3-base",
315 /// # TokenizerConfig::default(),
316 /// # )?;
317 /// let encoding = tokenizer.encode("Hello, world!")?;
318 /// println!("Token IDs: {:?}", encoding.input_ids);
319 /// println!("Attention mask: {:?}", encoding.attention_mask);
320 /// # Ok::<(), llm_shield_core::Error>(())
321 /// ```
322 pub fn encode(&self, text: &str) -> Result<Encoding> {
323 let encoding = self.tokenizer
324 .encode(text, self.config.add_special_tokens)
325 .map_err(|e| {
326 Error::model(format!("Failed to encode text: {}", e))
327 })?;
328
329 let input_ids = encoding.get_ids().to_vec();
330 let attention_mask = encoding.get_attention_mask().to_vec();
331
332 // Extract character offsets
333 let offsets: Vec<(usize, usize)> = encoding
334 .get_offsets()
335 .iter()
336 .map(|offset| (offset.0, offset.1))
337 .collect();
338
339 Ok(Encoding::with_offsets(input_ids, attention_mask, offsets))
340 }
341
342 /// Encode multiple texts in batch
343 ///
344 /// Batch encoding is more efficient than encoding texts individually.
345 ///
346 /// # Arguments
347 ///
348 /// * `texts` - Slice of text strings
349 ///
350 /// # Returns
351 ///
352 /// Vector of `Encoding` results (one per input text)
353 ///
354 /// # Example
355 ///
356 /// ```no_run
357 /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
358 /// # let tokenizer = TokenizerWrapper::from_pretrained(
359 /// # "microsoft/deberta-v3-base",
360 /// # TokenizerConfig::default(),
361 /// # )?;
362 /// let texts = vec!["First text", "Second text", "Third text"];
363 /// let encodings = tokenizer.encode_batch(&texts)?;
364 ///
365 /// assert_eq!(encodings.len(), 3);
366 /// for encoding in encodings {
367 /// println!("Length: {}", encoding.len());
368 /// }
369 /// # Ok::<(), llm_shield_core::Error>(())
370 /// ```
371 pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<Encoding>> {
372 if texts.is_empty() {
373 return Ok(vec![]);
374 }
375
376 let encodings = self.tokenizer
377 .encode_batch(texts.to_vec(), self.config.add_special_tokens)
378 .map_err(|e| {
379 Error::model(format!("Failed to encode batch: {}", e))
380 })?;
381
382 let results = encodings
383 .into_iter()
384 .map(|enc| {
385 let input_ids = enc.get_ids().to_vec();
386 let attention_mask = enc.get_attention_mask().to_vec();
387 let offsets: Vec<(usize, usize)> = enc
388 .get_offsets()
389 .iter()
390 .map(|offset| (offset.0, offset.1))
391 .collect();
392 Encoding::with_offsets(input_ids, attention_mask, offsets)
393 })
394 .collect();
395
396 Ok(results)
397 }
398
399 /// Get the tokenizer configuration
400 pub fn config(&self) -> &TokenizerConfig {
401 &self.config
402 }
403
404 /// Get the vocabulary size
405 ///
406 /// Returns the size of the tokenizer's vocabulary.
407 pub fn vocab_size(&self) -> usize {
408 self.tokenizer.get_vocab_size(self.config.add_special_tokens)
409 }
410}
411
412// Implement Clone for TokenizerWrapper (clones Arc, not the underlying tokenizer)
413impl Clone for TokenizerWrapper {
414 fn clone(&self) -> Self {
415 Self {
416 tokenizer: Arc::clone(&self.tokenizer),
417 config: self.config.clone(),
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_tokenizer_config_default() {
428 let config = TokenizerConfig::default();
429 assert_eq!(config.max_length, 512);
430 assert!(config.padding);
431 assert!(config.truncation);
432 assert!(config.add_special_tokens);
433 }
434
435 #[test]
436 fn test_encoding_creation() {
437 let encoding = Encoding::new(
438 vec![101, 2023, 2003, 102],
439 vec![1, 1, 1, 1],
440 );
441
442 assert_eq!(encoding.len(), 4);
443 assert!(!encoding.is_empty());
444 }
445
446 #[test]
447 fn test_encoding_to_arrays() {
448 let encoding = Encoding::new(
449 vec![101, 2023, 102],
450 vec![1, 1, 1],
451 );
452
453 let (input_ids, attention_mask) = encoding.to_arrays();
454 assert_eq!(input_ids, vec![101i64, 2023, 102]);
455 assert_eq!(attention_mask, vec![1i64, 1, 1]);
456 }
457
458 #[test]
459 fn test_encoding_empty() {
460 let encoding = Encoding::new(vec![], vec![]);
461 assert!(encoding.is_empty());
462 assert_eq!(encoding.len(), 0);
463 }
464
465 // Note: The following tests require network access to HuggingFace Hub
466 // They are integration tests and should be run with `cargo test --test tokenizer_test`
467}