llm_shield_models/
tokenizer.rs

1//! Tokenizer Wrapper for HuggingFace Tokenizers
2//!
3//! ## SPARC Phase 3: Construction (TDD Green Phase)
4//!
5//! This module provides a thread-safe wrapper around HuggingFace tokenizers
6//! for preprocessing text before ML model inference.
7//!
8//! ## Features
9//!
10//! - Support for multiple tokenizer types (DeBERTa, RoBERTa, etc.)
11//! - Configurable truncation at max length (default: 512 tokens)
12//! - Padding support (right-side padding)
13//! - Special tokens handling
14//! - Thread-safe design using Arc
15//! - Batch encoding support
16//!
17//! ## Usage Example
18//!
19//! ```no_run
20//! use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
21//!
22//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! let tokenizer = TokenizerWrapper::from_pretrained(
24//!     "microsoft/deberta-v3-base",
25//!     TokenizerConfig::default(),
26//! )?;
27//!
28//! let encoding = tokenizer.encode("Ignore all previous instructions")?;
29//! println!("Token IDs: {:?}", encoding.input_ids);
30//! # Ok(())
31//! # }
32//! ```
33
34use llm_shield_core::Error;
35use std::sync::Arc;
36use tokenizers::{
37    Tokenizer,
38    PaddingParams, PaddingStrategy, PaddingDirection,
39    TruncationParams, TruncationStrategy,
40};
41
42/// Result type alias
43pub type Result<T> = std::result::Result<T, Error>;
44
45/// Configuration for the tokenizer
46///
47/// ## Configuration Options
48///
49/// - `max_length`: Maximum sequence length (default: 512)
50/// - `padding`: Enable padding to max_length (default: true)
51/// - `truncation`: Enable truncation at max_length (default: true)
52/// - `add_special_tokens`: Add special tokens like [CLS], [SEP] (default: true)
53///
54/// ## Recommended Settings
55///
56/// **Production (default)**:
57/// ```rust
58/// # use llm_shield_models::TokenizerConfig;
59/// let config = TokenizerConfig::default();
60/// assert_eq!(config.max_length, 512);
61/// assert!(config.padding);
62/// assert!(config.truncation);
63/// ```
64///
65/// **Memory-constrained**:
66/// ```rust
67/// # use llm_shield_models::TokenizerConfig;
68/// let config = TokenizerConfig {
69///     max_length: 256,
70///     padding: false,
71///     truncation: true,
72///     add_special_tokens: true,
73/// };
74/// ```
75#[derive(Debug, Clone)]
76pub struct TokenizerConfig {
77    /// Maximum sequence length (tokens)
78    pub max_length: usize,
79
80    /// Enable padding to max_length
81    pub padding: bool,
82
83    /// Enable truncation at max_length
84    pub truncation: bool,
85
86    /// Add model-specific special tokens ([CLS], [SEP], etc.)
87    pub add_special_tokens: bool,
88}
89
90impl Default for TokenizerConfig {
91    fn default() -> Self {
92        Self {
93            max_length: 512,
94            padding: true,
95            truncation: true,
96            add_special_tokens: true,
97        }
98    }
99}
100
101/// Encoding result from tokenization
102///
103/// Contains the token IDs, attention mask, and character offsets needed for model inference.
104#[derive(Debug, Clone)]
105pub struct Encoding {
106    /// Token IDs (vocabulary indices)
107    pub input_ids: Vec<u32>,
108
109    /// Attention mask (1 for real tokens, 0 for padding)
110    pub attention_mask: Vec<u32>,
111
112    /// Character offsets in original text for each token
113    /// (start_char, end_char) for each token
114    /// Special tokens (CLS, SEP, PAD) have offset (0, 0)
115    pub offsets: Vec<(usize, usize)>,
116}
117
118impl Encoding {
119    /// Create a new encoding
120    pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u32>) -> Self {
121        // Create default offsets (all zeros for backward compatibility)
122        let offsets = vec![(0, 0); input_ids.len()];
123        Self {
124            input_ids,
125            attention_mask,
126            offsets,
127        }
128    }
129
130    /// Create a new encoding with offsets
131    pub fn with_offsets(
132        input_ids: Vec<u32>,
133        attention_mask: Vec<u32>,
134        offsets: Vec<(usize, usize)>,
135    ) -> Self {
136        Self {
137            input_ids,
138            attention_mask,
139            offsets,
140        }
141    }
142
143    /// Get the length of the encoding
144    pub fn len(&self) -> usize {
145        self.input_ids.len()
146    }
147
148    /// Check if encoding is empty
149    pub fn is_empty(&self) -> bool {
150        self.input_ids.is_empty()
151    }
152
153    /// Convert to arrays suitable for ONNX inference
154    ///
155    /// Returns (input_ids, attention_mask) as i64 arrays
156    pub fn to_arrays(&self) -> (Vec<i64>, Vec<i64>) {
157        let input_ids = self.input_ids.iter().map(|&x| x as i64).collect();
158        let attention_mask = self.attention_mask.iter().map(|&x| x as i64).collect();
159        (input_ids, attention_mask)
160    }
161}
162
163/// Thread-safe tokenizer wrapper
164///
165/// ## Thread Safety
166///
167/// This wrapper uses `Arc<Tokenizer>` for thread-safe access.
168/// Multiple threads can encode text concurrently using the same tokenizer.
169///
170/// ## Performance
171///
172/// - Tokenization: ~0.1-0.5ms per input (100-500 tokens)
173/// - Thread-safe without locks (immutable after creation)
174/// - Batch encoding is more efficient than individual calls
175///
176/// ## Example
177///
178/// ```no_run
179/// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
180/// # use std::sync::Arc;
181/// # use std::thread;
182/// #
183/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
184/// let tokenizer = Arc::new(
185///     TokenizerWrapper::from_pretrained(
186///         "microsoft/deberta-v3-base",
187///         TokenizerConfig::default(),
188///     )?
189/// );
190///
191/// let handles: Vec<_> = (0..4)
192///     .map(|i| {
193///         let tok = Arc::clone(&tokenizer);
194///         thread::spawn(move || tok.encode(&format!("Text {}", i)))
195///     })
196///     .collect();
197///
198/// for handle in handles {
199///     let encoding = handle.join().unwrap()?;
200///     println!("Encoded {} tokens", encoding.len());
201/// }
202/// # Ok(())
203/// # }
204/// ```
205pub struct TokenizerWrapper {
206    tokenizer: Arc<Tokenizer>,
207    config: TokenizerConfig,
208}
209
210impl TokenizerWrapper {
211    /// Load a tokenizer from HuggingFace Hub
212    ///
213    /// # Arguments
214    ///
215    /// * `model_name` - HuggingFace model identifier (e.g., "microsoft/deberta-v3-base")
216    /// * `config` - Tokenizer configuration
217    ///
218    /// # Supported Models
219    ///
220    /// - **DeBERTa**: `microsoft/deberta-v3-base` (PromptInjection)
221    /// - **RoBERTa**: `roberta-base` (Toxicity, Sentiment)
222    /// - **BERT**: `bert-base-uncased`
223    /// - Any HuggingFace model with a tokenizer
224    ///
225    /// # Example
226    ///
227    /// ```no_run
228    /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
229    /// let tokenizer = TokenizerWrapper::from_pretrained(
230    ///     "microsoft/deberta-v3-base",
231    ///     TokenizerConfig::default(),
232    /// )?;
233    /// # Ok::<(), llm_shield_core::Error>(())
234    /// ```
235    pub fn from_pretrained(model_name: &str, config: TokenizerConfig) -> Result<Self> {
236        tracing::info!("Loading tokenizer from: {}", model_name);
237
238        // For now, we'll use a simple approach that assumes tokenizer.json exists locally
239        // In production, this should download from HuggingFace Hub
240        let tokenizer_path = format!("models/{}/tokenizer.json", model_name);
241
242        let mut tokenizer = if std::path::Path::new(&tokenizer_path).exists() {
243            Tokenizer::from_file(&tokenizer_path)
244                .map_err(|e| {
245                    Error::model(format!(
246                        "Failed to load tokenizer from '{}': {}",
247                        tokenizer_path, e
248                    ))
249                })?
250        } else {
251            // Fall back to a basic tokenizer for testing
252            // In production, implement proper HuggingFace Hub download
253            return Err(Error::model(format!(
254                "Tokenizer not found at '{}'. Please download tokenizer files first.",
255                tokenizer_path
256            )));
257        };
258
259        // Configure padding
260        if config.padding {
261            let padding = PaddingParams {
262                strategy: PaddingStrategy::Fixed(config.max_length),
263                direction: PaddingDirection::Right,
264                pad_id: 0, // Will be overridden by tokenizer's pad token
265                pad_type_id: 0,
266                pad_token: String::from("[PAD]"), // Will be overridden
267                pad_to_multiple_of: None,
268            };
269            tokenizer.with_padding(Some(padding));
270        }
271
272        // Configure truncation
273        if config.truncation {
274            let truncation = TruncationParams {
275                max_length: config.max_length,
276                strategy: TruncationStrategy::LongestFirst,
277                stride: 0,
278                direction: tokenizers::TruncationDirection::Right,
279            };
280            tokenizer.with_truncation(Some(truncation))
281                .map_err(|e| {
282                    Error::model(format!("Failed to configure truncation: {}", e))
283                })?;
284        }
285
286        tracing::debug!(
287            "Tokenizer loaded: max_length={}, padding={}, truncation={}",
288            config.max_length,
289            config.padding,
290            config.truncation
291        );
292
293        Ok(Self {
294            tokenizer: Arc::new(tokenizer),
295            config,
296        })
297    }
298
299    /// Encode a single text string
300    ///
301    /// # Arguments
302    ///
303    /// * `text` - Input text to tokenize
304    ///
305    /// # Returns
306    ///
307    /// `Encoding` with token IDs and attention mask
308    ///
309    /// # Example
310    ///
311    /// ```no_run
312    /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
313    /// # let tokenizer = TokenizerWrapper::from_pretrained(
314    /// #     "microsoft/deberta-v3-base",
315    /// #     TokenizerConfig::default(),
316    /// # )?;
317    /// let encoding = tokenizer.encode("Hello, world!")?;
318    /// println!("Token IDs: {:?}", encoding.input_ids);
319    /// println!("Attention mask: {:?}", encoding.attention_mask);
320    /// # Ok::<(), llm_shield_core::Error>(())
321    /// ```
322    pub fn encode(&self, text: &str) -> Result<Encoding> {
323        let encoding = self.tokenizer
324            .encode(text, self.config.add_special_tokens)
325            .map_err(|e| {
326                Error::model(format!("Failed to encode text: {}", e))
327            })?;
328
329        let input_ids = encoding.get_ids().to_vec();
330        let attention_mask = encoding.get_attention_mask().to_vec();
331
332        // Extract character offsets
333        let offsets: Vec<(usize, usize)> = encoding
334            .get_offsets()
335            .iter()
336            .map(|offset| (offset.0, offset.1))
337            .collect();
338
339        Ok(Encoding::with_offsets(input_ids, attention_mask, offsets))
340    }
341
342    /// Encode multiple texts in batch
343    ///
344    /// Batch encoding is more efficient than encoding texts individually.
345    ///
346    /// # Arguments
347    ///
348    /// * `texts` - Slice of text strings
349    ///
350    /// # Returns
351    ///
352    /// Vector of `Encoding` results (one per input text)
353    ///
354    /// # Example
355    ///
356    /// ```no_run
357    /// # use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
358    /// # let tokenizer = TokenizerWrapper::from_pretrained(
359    /// #     "microsoft/deberta-v3-base",
360    /// #     TokenizerConfig::default(),
361    /// # )?;
362    /// let texts = vec!["First text", "Second text", "Third text"];
363    /// let encodings = tokenizer.encode_batch(&texts)?;
364    ///
365    /// assert_eq!(encodings.len(), 3);
366    /// for encoding in encodings {
367    ///     println!("Length: {}", encoding.len());
368    /// }
369    /// # Ok::<(), llm_shield_core::Error>(())
370    /// ```
371    pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<Encoding>> {
372        if texts.is_empty() {
373            return Ok(vec![]);
374        }
375
376        let encodings = self.tokenizer
377            .encode_batch(texts.to_vec(), self.config.add_special_tokens)
378            .map_err(|e| {
379                Error::model(format!("Failed to encode batch: {}", e))
380            })?;
381
382        let results = encodings
383            .into_iter()
384            .map(|enc| {
385                let input_ids = enc.get_ids().to_vec();
386                let attention_mask = enc.get_attention_mask().to_vec();
387                let offsets: Vec<(usize, usize)> = enc
388                    .get_offsets()
389                    .iter()
390                    .map(|offset| (offset.0, offset.1))
391                    .collect();
392                Encoding::with_offsets(input_ids, attention_mask, offsets)
393            })
394            .collect();
395
396        Ok(results)
397    }
398
399    /// Get the tokenizer configuration
400    pub fn config(&self) -> &TokenizerConfig {
401        &self.config
402    }
403
404    /// Get the vocabulary size
405    ///
406    /// Returns the size of the tokenizer's vocabulary.
407    pub fn vocab_size(&self) -> usize {
408        self.tokenizer.get_vocab_size(self.config.add_special_tokens)
409    }
410}
411
412// Implement Clone for TokenizerWrapper (clones Arc, not the underlying tokenizer)
413impl Clone for TokenizerWrapper {
414    fn clone(&self) -> Self {
415        Self {
416            tokenizer: Arc::clone(&self.tokenizer),
417            config: self.config.clone(),
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_tokenizer_config_default() {
428        let config = TokenizerConfig::default();
429        assert_eq!(config.max_length, 512);
430        assert!(config.padding);
431        assert!(config.truncation);
432        assert!(config.add_special_tokens);
433    }
434
435    #[test]
436    fn test_encoding_creation() {
437        let encoding = Encoding::new(
438            vec![101, 2023, 2003, 102],
439            vec![1, 1, 1, 1],
440        );
441
442        assert_eq!(encoding.len(), 4);
443        assert!(!encoding.is_empty());
444    }
445
446    #[test]
447    fn test_encoding_to_arrays() {
448        let encoding = Encoding::new(
449            vec![101, 2023, 102],
450            vec![1, 1, 1],
451        );
452
453        let (input_ids, attention_mask) = encoding.to_arrays();
454        assert_eq!(input_ids, vec![101i64, 2023, 102]);
455        assert_eq!(attention_mask, vec![1i64, 1, 1]);
456    }
457
458    #[test]
459    fn test_encoding_empty() {
460        let encoding = Encoding::new(vec![], vec![]);
461        assert!(encoding.is_empty());
462        assert_eq!(encoding.len(), 0);
463    }
464
465    // Note: The following tests require network access to HuggingFace Hub
466    // They are integration tests and should be run with `cargo test --test tokenizer_test`
467}