Skip to main content

ferrum_interfaces/
tokenizer.rs

1//! Tokenizer interface for text encoding/decoding
2//!
3//! This module provides tokenizer abstractions that are completely separate
4//! from model implementations, supporting incremental decoding and various
5//! tokenization strategies.
6
7use async_trait::async_trait;
8use ferrum_types::{Result, SpecialTokens, TokenId};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Core tokenizer trait for encoding/decoding operations
13pub trait Tokenizer: Send + Sync {
14    /// Encode text to token IDs
15    fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
16
17    /// Decode token IDs to text
18    fn decode(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
19
20    /// Incremental decode: given previous tokens and new token, return only the new text
21    /// This is crucial for streaming applications to avoid re-decoding all tokens
22    fn decode_incremental(&self, prev: &[TokenId], next: TokenId) -> Result<String>;
23
24    /// Get vocabulary size
25    fn vocab_size(&self) -> usize;
26
27    /// Get special tokens configuration  
28    fn special_tokens(&self) -> &SpecialTokens;
29
30    /// Get token ID for a specific text (if exists in vocabulary)
31    fn token_id(&self, text: &str) -> Option<TokenId>;
32
33    /// Get text for a specific token ID
34    fn token_text(&self, token_id: TokenId) -> Option<&str>;
35
36    /// Check if token is a special token
37    fn is_special_token(&self, token_id: TokenId) -> bool {
38        let special = self.special_tokens();
39        let fallback = TokenId::MAX;
40        token_id == special.bos_token.unwrap_or(fallback)
41            || token_id == special.eos_token.unwrap_or(fallback)
42            || token_id == special.unk_token.unwrap_or(fallback)
43            || token_id == special.pad_token.unwrap_or(fallback)
44    }
45
46    /// Apply chat template if supported
47    fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
48        // Default implementation: just concatenate messages
49        let mut result = String::new();
50        for msg in messages {
51            result.push_str(&format!("{}: {}\n", msg.role, msg.content));
52        }
53        Ok(result.trim_end().to_string())
54    }
55
56    /// Get tokenizer information
57    fn info(&self) -> TokenizerInfo;
58}
59
60/// Asynchronous tokenizer operations for I/O-bound tokenization
61#[async_trait]
62pub trait AsyncTokenizer: Tokenizer {
63    /// Asynchronous encoding (useful for very large texts)
64    async fn encode_async(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>>;
65
66    /// Asynchronous decoding
67    async fn decode_async(&self, tokens: &[TokenId], skip_special: bool) -> Result<String>;
68
69    /// Batch encoding for multiple texts
70    async fn encode_batch(&self, texts: &[&str], add_special: bool) -> Result<Vec<Vec<TokenId>>>;
71
72    /// Batch decoding for multiple token sequences
73    async fn decode_batch(
74        &self,
75        token_sequences: &[&[TokenId]],
76        skip_special: bool,
77    ) -> Result<Vec<String>>;
78}
79
80/// Advanced tokenizer capabilities
81pub trait TokenizerCapabilities: Tokenizer {
82    /// Get token probability/likelihood for text
83    fn token_probability(&self, text: &str, token_id: TokenId) -> Option<f32>;
84
85    /// Get all possible tokens for a prefix
86    fn get_prefix_tokens(&self, prefix: &str) -> Result<Vec<TokenId>>;
87
88    /// Check if sequence can be extended with token
89    fn can_extend(&self, tokens: &[TokenId], next_token: TokenId) -> bool;
90
91    /// Get token type (word, subword, punctuation, etc.)
92    fn token_type(&self, token_id: TokenId) -> TokenType;
93
94    /// Normalize text before tokenization
95    fn normalize_text(&self, text: &str) -> String;
96
97    /// Pre-tokenize text (split into words/subwords)
98    fn pre_tokenize(&self, text: &str) -> Vec<String>;
99}
100
101/// Tokenizer factory for creating tokenizer instances
102#[async_trait]
103pub trait TokenizerFactory: Send + Sync {
104    /// Load tokenizer from file path
105    async fn load_from_file(&self, path: &str) -> Result<Box<dyn Tokenizer>>;
106
107    /// Load tokenizer from bytes
108    async fn load_from_bytes(&self, data: &[u8]) -> Result<Box<dyn Tokenizer>>;
109
110    /// Load tokenizer from Hugging Face Hub
111    async fn load_from_hub(
112        &self,
113        repo_id: &str,
114        revision: Option<&str>,
115    ) -> Result<Box<dyn Tokenizer>>;
116
117    /// Create tokenizer from configuration
118    async fn create_from_config(&self, config: &TokenizerConfig) -> Result<Box<dyn Tokenizer>>;
119
120    /// Get supported tokenizer types
121    fn supported_types(&self) -> Vec<TokenizerType>;
122}
123
124/// Tokenizer information and metadata
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct TokenizerInfo {
127    /// Tokenizer type/algorithm
128    pub tokenizer_type: TokenizerType,
129    /// Vocabulary size
130    pub vocab_size: usize,
131    /// Special tokens
132    pub special_tokens: SpecialTokens,
133    /// Whether tokenizer supports incremental decoding efficiently
134    pub supports_incremental: bool,
135    /// Whether tokenizer supports chat templates
136    pub supports_chat_template: bool,
137    /// Maximum token length
138    pub max_token_length: Option<usize>,
139    /// Model name or identifier this tokenizer was trained for
140    pub model_name: Option<String>,
141}
142
143/// Tokenizer types/algorithms
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145pub enum TokenizerType {
146    /// Byte-Pair Encoding
147    BPE,
148    /// WordPiece (BERT-style)
149    WordPiece,
150    /// SentencePiece
151    SentencePiece,
152    /// Tiktoken (GPT family)
153    Tiktoken,
154    /// Custom implementation
155    Custom,
156}
157
158/// Token types for classification
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
160pub enum TokenType {
161    /// Regular word token
162    Word,
163    /// Subword token  
164    Subword,
165    /// Punctuation token
166    Punctuation,
167    /// Number token
168    Number,
169    /// Special/control token
170    Special,
171    /// Unknown token
172    Unknown,
173}
174
175/// Chat message for template application
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ChatMessage {
178    /// Message role (user, assistant, system, etc.)
179    pub role: String,
180    /// Message content
181    pub content: String,
182    /// Additional metadata
183    pub metadata: HashMap<String, serde_json::Value>,
184}
185
186impl ChatMessage {
187    /// Create user message
188    pub fn user(content: impl Into<String>) -> Self {
189        Self {
190            role: "user".to_string(),
191            content: content.into(),
192            metadata: HashMap::new(),
193        }
194    }
195
196    /// Create assistant message
197    pub fn assistant(content: impl Into<String>) -> Self {
198        Self {
199            role: "assistant".to_string(),
200            content: content.into(),
201            metadata: HashMap::new(),
202        }
203    }
204
205    /// Create system message
206    pub fn system(content: impl Into<String>) -> Self {
207        Self {
208            role: "system".to_string(),
209            content: content.into(),
210            metadata: HashMap::new(),
211        }
212    }
213}
214
215/// Tokenizer configuration
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct TokenizerConfig {
218    /// Tokenizer type
219    pub tokenizer_type: TokenizerType,
220    /// Path to tokenizer files
221    pub path: String,
222    /// Whether to add special tokens during encoding
223    pub add_special_tokens: bool,
224    /// Whether to use fast tokenization (if available)
225    pub use_fast: bool,
226    /// Truncation configuration
227    pub truncation: Option<TruncationConfig>,
228    /// Padding configuration
229    pub padding: Option<PaddingConfig>,
230    /// Chat template (if any)
231    pub chat_template: Option<String>,
232    /// Additional tokenizer-specific options
233    pub extra_options: HashMap<String, serde_json::Value>,
234}
235
236/// Truncation configuration
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct TruncationConfig {
239    /// Maximum sequence length
240    pub max_length: usize,
241    /// Truncation strategy
242    pub strategy: TruncationStrategy,
243    /// Stride for sliding window truncation
244    pub stride: Option<usize>,
245}
246
247/// Truncation strategies
248#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
249pub enum TruncationStrategy {
250    /// Remove tokens from the end
251    TruncateEnd,
252    /// Remove tokens from the beginning  
253    TruncateStart,
254    /// Remove tokens from both ends equally
255    TruncateBoth,
256    /// Sliding window approach
257    SlidingWindow,
258}
259
260/// Padding configuration
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct PaddingConfig {
263    /// Padding strategy
264    pub strategy: PaddingStrategy,
265    /// Padding token ID
266    pub token_id: TokenId,
267    /// Target length (if fixed padding)
268    pub length: Option<usize>,
269    /// Padding direction
270    pub direction: PaddingDirection,
271}
272
273/// Padding strategies
274#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
275pub enum PaddingStrategy {
276    /// No padding
277    None,
278    /// Pad to longest sequence in batch
279    Longest,
280    /// Pad to multiple of specified value
281    MultipleOf(usize),
282    /// Pad to fixed length
283    Fixed,
284}
285
286/// Padding direction
287#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
288pub enum PaddingDirection {
289    /// Pad on the right
290    Right,
291    /// Pad on the left
292    Left,
293}
294
295/// Incremental tokenizer state for streaming
296pub trait IncrementalTokenizer: Tokenizer {
297    /// Tokenizer state for incremental operations
298    type State: Send + Sync;
299
300    /// Create initial state for incremental decoding
301    fn create_state(&self) -> Self::State;
302
303    /// Add token to state and get incremental text
304    fn decode_incremental_with_state(
305        &self,
306        state: &mut Self::State,
307        token: TokenId,
308    ) -> Result<String>;
309
310    /// Reset state to initial condition
311    fn reset_state(&self, state: &mut Self::State);
312
313    /// Get all decoded text from current state
314    fn get_decoded_text(&self, state: &Self::State) -> String;
315}
316
317/// Text processing utilities
318pub trait TextProcessor: Send + Sync {
319    /// Clean and normalize text for tokenization
320    fn preprocess(&self, text: &str) -> String;
321
322    /// Post-process decoded text
323    fn postprocess(&self, text: &str) -> String;
324
325    /// Detect language of text (if supported)
326    fn detect_language(&self, text: &str) -> Option<String>;
327
328    /// Split text into sentences
329    fn sentence_split(&self, text: &str) -> Vec<String>;
330
331    /// Count approximate tokens without full tokenization
332    fn estimate_token_count(&self, text: &str) -> usize;
333}
334
335/// Tokenizer performance statistics
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct TokenizerStats {
338    /// Total encoding operations
339    pub encode_operations: u64,
340    /// Total decoding operations  
341    pub decode_operations: u64,
342    /// Total tokens processed
343    pub tokens_processed: u64,
344    /// Average encoding time per character (microseconds)
345    pub avg_encode_time_per_char_us: f64,
346    /// Average decoding time per token (microseconds)
347    pub avg_decode_time_per_token_us: f64,
348    /// Cache hit rate for incremental decoding
349    pub incremental_cache_hit_rate: f32,
350}
351
352/// Tokenizer registry for managing multiple tokenizers
353pub trait TokenizerRegistry: Send + Sync {
354    /// Register a tokenizer with a name
355    fn register(&mut self, name: &str, tokenizer: Box<dyn Tokenizer>) -> Result<()>;
356
357    /// Get tokenizer by name
358    fn get(&self, name: &str) -> Option<&dyn Tokenizer>;
359
360    /// Remove tokenizer by name
361    fn remove(&mut self, name: &str) -> Option<Box<dyn Tokenizer>>;
362
363    /// List all registered tokenizer names
364    fn list_names(&self) -> Vec<String>;
365
366    /// Check if tokenizer exists
367    fn contains(&self, name: &str) -> bool;
368}