m2m/inference/
tokenizer.rs

1//! Tokenizer infrastructure for Hydra model.
2//!
3//! Provides a unified trait for tokenization with multiple backend implementations:
4//!
5//! - [`Llama3Tokenizer`]: HuggingFace Tokenizers format (Llama 3, Mistral, etc.)
6//! - [`TiktokenTokenizer`]: OpenAI tiktoken format (cl100k, o200k)
7//! - [`FallbackTokenizer`]: Simple byte-level fallback
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use m2m::inference::{HydraTokenizer, Llama3Tokenizer, TokenizerType};
13//!
14//! // Load Llama 3 tokenizer from file
15//! let tokenizer = Llama3Tokenizer::from_file("./models/hydra/tokenizer.json")?;
16//!
17//! // Encode text to token IDs
18//! let tokens = tokenizer.encode("Hello, world!")?;
19//!
20//! // Get vocab size (128K for Llama 3)
21//! assert_eq!(tokenizer.vocab_size(), 128000);
22//! ```
23
24use std::path::Path;
25use std::sync::Arc;
26
27use crate::error::{M2MError, Result};
28
29// Re-export tiktoken for OpenAI tokenizers
30use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
31
32// HuggingFace tokenizers
33use tokenizers::Tokenizer;
34
35/// Maximum sequence length for Hydra input
36pub const MAX_SEQUENCE_LENGTH: usize = 512;
37
38/// Tokenizer type identifier
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TokenizerType {
41    /// Llama 3 tokenizer (128K vocab, HuggingFace format)
42    Llama3,
43    /// OpenAI o200k_base (200K vocab, tiktoken)
44    O200kBase,
45    /// OpenAI cl100k_base (100K vocab, tiktoken)
46    Cl100kBase,
47    /// Fallback byte-level tokenizer
48    Fallback,
49}
50
51impl TokenizerType {
52    /// Get the expected vocabulary size for this tokenizer type
53    #[must_use]
54    pub fn vocab_size(&self) -> usize {
55        match self {
56            Self::Llama3 => 128_000,
57            Self::O200kBase => 200_019,
58            Self::Cl100kBase => 100_256,
59            Self::Fallback => 256, // Byte-level
60        }
61    }
62
63    /// Get display name
64    #[must_use]
65    pub fn name(&self) -> &'static str {
66        match self {
67            Self::Llama3 => "llama3",
68            Self::O200kBase => "o200k_base",
69            Self::Cl100kBase => "cl100k_base",
70            Self::Fallback => "fallback",
71        }
72    }
73}
74
75impl std::fmt::Display for TokenizerType {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", self.name())
78    }
79}
80
81// ============================================================================
82// HydraTokenizer Trait
83// ============================================================================
84
85/// Trait for tokenizers used by Hydra model.
86///
87/// This trait abstracts over different tokenizer implementations,
88/// allowing Hydra to work with Llama 3, OpenAI, or other tokenizers.
89pub trait HydraTokenizer: Send + Sync {
90    /// Encode text to token IDs.
91    ///
92    /// Returns a vector of token IDs representing the input text.
93    /// The encoding should NOT include special tokens (BOS/EOS) unless
94    /// the specific tokenizer requires them for correct operation.
95    fn encode(&self, text: &str) -> Result<Vec<u32>>;
96
97    /// Decode token IDs back to text.
98    ///
99    /// Returns the original text (or approximation) from token IDs.
100    fn decode(&self, tokens: &[u32]) -> Result<String>;
101
102    /// Get the vocabulary size.
103    fn vocab_size(&self) -> usize;
104
105    /// Get the tokenizer type.
106    fn tokenizer_type(&self) -> TokenizerType;
107
108    /// Truncate tokens to maximum length for Hydra.
109    ///
110    /// Default implementation truncates to `MAX_SEQUENCE_LENGTH`.
111    fn truncate(&self, tokens: Vec<u32>) -> Vec<u32> {
112        if tokens.len() > MAX_SEQUENCE_LENGTH {
113            tokens[..MAX_SEQUENCE_LENGTH].to_vec()
114        } else {
115            tokens
116        }
117    }
118
119    /// Encode and truncate for Hydra input.
120    fn encode_for_hydra(&self, text: &str) -> Result<Vec<u32>> {
121        let tokens = self.encode(text)?;
122        Ok(self.truncate(tokens))
123    }
124}
125
126// ============================================================================
127// Llama3Tokenizer - HuggingFace Tokenizers format
128// ============================================================================
129
130/// Llama 3 tokenizer using HuggingFace Tokenizers library.
131///
132/// This is the primary tokenizer for Hydra, supporting the 128K vocabulary
133/// used by Llama 3 and compatible models.
134///
135/// # Example
136///
137/// ```rust,ignore
138/// let tokenizer = Llama3Tokenizer::from_file("./tokenizer.json")?;
139/// let tokens = tokenizer.encode("Hello, world!")?;
140/// ```
141pub struct Llama3Tokenizer {
142    inner: Tokenizer,
143    vocab_size: usize,
144}
145
146impl Llama3Tokenizer {
147    /// Load tokenizer from a `tokenizer.json` file.
148    ///
149    /// # Errors
150    ///
151    /// Returns error if the file cannot be read or parsed.
152    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
153        let inner = Tokenizer::from_file(path.as_ref())
154            .map_err(|e| M2MError::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
155
156        let vocab_size = inner.get_vocab_size(true);
157
158        Ok(Self { inner, vocab_size })
159    }
160
161    /// Load tokenizer from JSON string.
162    ///
163    /// # Errors
164    ///
165    /// Returns error if the JSON is invalid.
166    pub fn from_json(json: &str) -> Result<Self> {
167        Self::from_bytes(json.as_bytes())
168    }
169
170    /// Load tokenizer from bytes.
171    ///
172    /// # Errors
173    ///
174    /// Returns error if the bytes are not valid JSON.
175    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
176        let inner = Tokenizer::from_bytes(bytes)
177            .map_err(|e| M2MError::Tokenizer(format!("Failed to parse tokenizer: {e}")))?;
178
179        let vocab_size = inner.get_vocab_size(true);
180
181        Ok(Self { inner, vocab_size })
182    }
183}
184
185impl HydraTokenizer for Llama3Tokenizer {
186    fn encode(&self, text: &str) -> Result<Vec<u32>> {
187        let encoding = self
188            .inner
189            .encode(text, false)
190            .map_err(|e| M2MError::Tokenizer(format!("Encoding failed: {e}")))?;
191
192        Ok(encoding.get_ids().to_vec())
193    }
194
195    fn decode(&self, tokens: &[u32]) -> Result<String> {
196        self.inner
197            .decode(tokens, true)
198            .map_err(|e| M2MError::Tokenizer(format!("Decoding failed: {e}")))
199    }
200
201    fn vocab_size(&self) -> usize {
202        self.vocab_size
203    }
204
205    fn tokenizer_type(&self) -> TokenizerType {
206        TokenizerType::Llama3
207    }
208}
209
210// ============================================================================
211// TiktokenTokenizer - OpenAI tiktoken format
212// ============================================================================
213
214/// OpenAI tiktoken-based tokenizer.
215///
216/// Supports cl100k_base (GPT-4) and o200k_base (GPT-4o) encodings.
217pub struct TiktokenTokenizer {
218    inner: CoreBPE,
219    tokenizer_type: TokenizerType,
220}
221
222impl TiktokenTokenizer {
223    /// Create cl100k_base tokenizer (GPT-3.5, GPT-4).
224    ///
225    /// # Errors
226    ///
227    /// Returns error if tokenizer initialization fails.
228    pub fn cl100k() -> Result<Self> {
229        let inner = cl100k_base()
230            .map_err(|e| M2MError::Tokenizer(format!("Failed to load cl100k: {e}")))?;
231
232        Ok(Self {
233            inner,
234            tokenizer_type: TokenizerType::Cl100kBase,
235        })
236    }
237
238    /// Create o200k_base tokenizer (GPT-4o, o1, o3).
239    ///
240    /// # Errors
241    ///
242    /// Returns error if tokenizer initialization fails.
243    pub fn o200k() -> Result<Self> {
244        let inner =
245            o200k_base().map_err(|e| M2MError::Tokenizer(format!("Failed to load o200k: {e}")))?;
246
247        Ok(Self {
248            inner,
249            tokenizer_type: TokenizerType::O200kBase,
250        })
251    }
252
253    /// Create tokenizer from type.
254    ///
255    /// # Errors
256    ///
257    /// Returns error if tokenizer initialization fails or type is not tiktoken-based.
258    pub fn from_type(tokenizer_type: TokenizerType) -> Result<Self> {
259        match tokenizer_type {
260            TokenizerType::Cl100kBase => Self::cl100k(),
261            TokenizerType::O200kBase => Self::o200k(),
262            _ => Err(M2MError::Tokenizer(format!(
263                "Tokenizer type {tokenizer_type} is not tiktoken-based"
264            ))),
265        }
266    }
267}
268
269impl HydraTokenizer for TiktokenTokenizer {
270    fn encode(&self, text: &str) -> Result<Vec<u32>> {
271        // tiktoken Rank is u32, so direct collect works
272        Ok(self.inner.encode_with_special_tokens(text))
273    }
274
275    fn decode(&self, tokens: &[u32]) -> Result<String> {
276        // tiktoken Rank is u32, so just convert slice to Vec
277        self.inner
278            .decode(tokens.to_vec())
279            .map_err(|e| M2MError::Tokenizer(format!("Decoding failed: {e}")))
280    }
281
282    fn vocab_size(&self) -> usize {
283        self.tokenizer_type.vocab_size()
284    }
285
286    fn tokenizer_type(&self) -> TokenizerType {
287        self.tokenizer_type
288    }
289}
290
291// ============================================================================
292// HydraByteTokenizer - Byte-level tokenizer matching training
293// ============================================================================
294
295/// Byte-level tokenizer that matches Hydra's training tokenizer.
296///
297/// Uses the same encoding as the Python `SimpleTokenizer`:
298/// - PAD = 0, EOS = 1, BOS = 2
299/// - Byte values 0-255 map to token IDs 3-258
300/// - Sequences are wrapped with BOS and EOS tokens
301#[derive(Debug, Clone)]
302pub struct HydraByteTokenizer {
303    /// Maximum sequence length (default 512)
304    max_length: usize,
305}
306
307impl HydraByteTokenizer {
308    /// PAD token ID
309    pub const PAD_TOKEN_ID: u32 = 0;
310    /// EOS token ID
311    pub const EOS_TOKEN_ID: u32 = 1;
312    /// BOS token ID
313    pub const BOS_TOKEN_ID: u32 = 2;
314    /// Offset for byte values (first 3 IDs reserved for special tokens)
315    pub const BYTE_OFFSET: u32 = 3;
316
317    /// Create new Hydra byte tokenizer with default max length (512).
318    #[must_use]
319    pub fn new() -> Self {
320        Self { max_length: 512 }
321    }
322
323    /// Create tokenizer with custom max length.
324    #[must_use]
325    pub fn with_max_length(max_length: usize) -> Self {
326        Self { max_length }
327    }
328}
329
330impl Default for HydraByteTokenizer {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336impl HydraTokenizer for HydraByteTokenizer {
337    fn encode(&self, text: &str) -> Result<Vec<u32>> {
338        let mut tokens = Vec::with_capacity(self.max_length.min(text.len() + 2));
339
340        // BOS token
341        tokens.push(Self::BOS_TOKEN_ID);
342
343        // Encode bytes with offset (leave room for EOS)
344        let max_content = self.max_length.saturating_sub(2);
345        for byte in text.bytes().take(max_content) {
346            tokens.push((byte as u32) + Self::BYTE_OFFSET);
347        }
348
349        // EOS token
350        tokens.push(Self::EOS_TOKEN_ID);
351
352        Ok(tokens)
353    }
354
355    fn decode(&self, tokens: &[u32]) -> Result<String> {
356        let bytes: Vec<u8> = tokens
357            .iter()
358            .filter_map(|&t| {
359                // Skip special tokens, decode byte tokens
360                if t >= Self::BYTE_OFFSET && t < Self::BYTE_OFFSET + 256 {
361                    Some((t - Self::BYTE_OFFSET) as u8)
362                } else {
363                    None
364                }
365            })
366            .collect();
367
368        String::from_utf8(bytes)
369            .map_err(|e| M2MError::Tokenizer(format!("Invalid UTF-8 in tokens: {e}")))
370    }
371
372    fn vocab_size(&self) -> usize {
373        // 3 special tokens + 256 byte values = 259, but model uses 32000
374        32000
375    }
376
377    fn tokenizer_type(&self) -> TokenizerType {
378        TokenizerType::Fallback // Use same type for compatibility
379    }
380}
381
382// ============================================================================
383// FallbackTokenizer - Simple byte-level tokenizer (legacy)
384// ============================================================================
385
386/// Fallback byte-level tokenizer.
387///
388/// Used when no proper tokenizer is available. Maps bytes directly to token IDs.
389/// This is NOT recommended for production use but ensures Hydra can always run.
390///
391/// **Note**: For Hydra inference, prefer [`HydraByteTokenizer`] which matches
392/// the training tokenizer exactly.
393#[derive(Debug, Clone, Default)]
394pub struct FallbackTokenizer {
395    vocab_size: usize,
396}
397
398impl FallbackTokenizer {
399    /// Create new fallback tokenizer.
400    #[must_use]
401    pub fn new() -> Self {
402        Self { vocab_size: 256 }
403    }
404
405    /// Create fallback tokenizer that maps to a specific vocab size.
406    ///
407    /// Token IDs will be `byte % vocab_size` to ensure they fit within bounds.
408    #[must_use]
409    pub fn with_vocab_size(vocab_size: usize) -> Self {
410        Self { vocab_size }
411    }
412}
413
414impl HydraTokenizer for FallbackTokenizer {
415    fn encode(&self, text: &str) -> Result<Vec<u32>> {
416        Ok(text
417            .bytes()
418            .map(|b| (b as u32) % (self.vocab_size as u32))
419            .collect())
420    }
421
422    fn decode(&self, tokens: &[u32]) -> Result<String> {
423        // Best effort: treat tokens as bytes
424        let bytes: Vec<u8> = tokens
425            .iter()
426            .filter_map(|&t| if t < 256 { Some(t as u8) } else { None })
427            .collect();
428
429        String::from_utf8(bytes)
430            .map_err(|e| M2MError::Tokenizer(format!("Invalid UTF-8 in tokens: {e}")))
431    }
432
433    fn vocab_size(&self) -> usize {
434        self.vocab_size
435    }
436
437    fn tokenizer_type(&self) -> TokenizerType {
438        TokenizerType::Fallback
439    }
440}
441
442// ============================================================================
443// BoxedTokenizer - Type-erased tokenizer for dynamic dispatch
444// ============================================================================
445
446/// Type-erased tokenizer for storing different tokenizer implementations.
447pub type BoxedTokenizer = Arc<dyn HydraTokenizer>;
448
449/// Create a boxed tokenizer from a specific implementation.
450pub fn boxed<T: HydraTokenizer + 'static>(tokenizer: T) -> BoxedTokenizer {
451    Arc::new(tokenizer)
452}
453
454// ============================================================================
455// Tokenizer Loading Utilities
456// ============================================================================
457
458/// Load the best available tokenizer for Hydra.
459///
460/// Attempts to load in order:
461/// 1. Llama 3 tokenizer from the specified path
462/// 2. Fallback tokenizer with specified vocab size
463///
464/// # Arguments
465///
466/// * `tokenizer_path` - Optional path to `tokenizer.json`
467/// * `vocab_size` - Fallback vocab size if no tokenizer found
468///
469/// # Example
470///
471/// ```rust,ignore
472/// let tokenizer = load_tokenizer(Some("./models/hydra/tokenizer.json"), 128000)?;
473/// ```
474pub fn load_tokenizer(tokenizer_path: Option<&Path>, vocab_size: usize) -> Result<BoxedTokenizer> {
475    // Try to load Llama 3 tokenizer if path provided
476    if let Some(path) = tokenizer_path {
477        if path.exists() {
478            match Llama3Tokenizer::from_file(path) {
479                Ok(tokenizer) => {
480                    tracing::info!(
481                        "Loaded Llama 3 tokenizer from {} (vocab: {})",
482                        path.display(),
483                        tokenizer.vocab_size()
484                    );
485                    return Ok(boxed(tokenizer));
486                },
487                Err(e) => {
488                    tracing::warn!("Failed to load tokenizer from {}: {e}", path.display());
489                },
490            }
491        }
492    }
493
494    // Fallback
495    tracing::warn!(
496        "Using fallback byte-level tokenizer (vocab_size: {vocab_size}). \
497         For best results, provide a tokenizer.json file."
498    );
499    Ok(boxed(FallbackTokenizer::with_vocab_size(vocab_size)))
500}
501
502/// Load tokenizer by type.
503///
504/// # Errors
505///
506/// Returns error if the specified tokenizer type cannot be loaded.
507pub fn load_tokenizer_by_type(
508    tokenizer_type: TokenizerType,
509    tokenizer_path: Option<&Path>,
510) -> Result<BoxedTokenizer> {
511    match tokenizer_type {
512        TokenizerType::Llama3 => {
513            let path = tokenizer_path
514                .ok_or_else(|| M2MError::Tokenizer("Llama3 tokenizer requires a path".into()))?;
515            Ok(boxed(Llama3Tokenizer::from_file(path)?))
516        },
517        TokenizerType::O200kBase => Ok(boxed(TiktokenTokenizer::o200k()?)),
518        TokenizerType::Cl100kBase => Ok(boxed(TiktokenTokenizer::cl100k()?)),
519        TokenizerType::Fallback => Ok(boxed(FallbackTokenizer::new())),
520    }
521}
522
523// ============================================================================
524// Tests
525// ============================================================================
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_fallback_tokenizer_encode_decode() {
533        let tokenizer = FallbackTokenizer::new();
534        let text = "Hello";
535
536        let tokens = tokenizer.encode(text).unwrap();
537        assert_eq!(tokens.len(), 5); // 5 bytes
538
539        // Verify byte values
540        assert_eq!(tokens[0], b'H' as u32);
541        assert_eq!(tokens[1], b'e' as u32);
542    }
543
544    #[test]
545    fn test_fallback_tokenizer_vocab_mapping() {
546        let tokenizer = FallbackTokenizer::with_vocab_size(128000);
547        let text = "Test";
548
549        let tokens = tokenizer.encode(text).unwrap();
550
551        // All tokens should be < vocab_size
552        for &t in &tokens {
553            assert!(t < 128000);
554        }
555    }
556
557    #[test]
558    fn test_tiktoken_cl100k() {
559        let tokenizer = TiktokenTokenizer::cl100k().unwrap();
560
561        assert_eq!(tokenizer.tokenizer_type(), TokenizerType::Cl100kBase);
562        assert_eq!(tokenizer.vocab_size(), 100_256);
563
564        let tokens = tokenizer.encode("Hello, world!").unwrap();
565        assert!(!tokens.is_empty());
566
567        let decoded = tokenizer.decode(&tokens).unwrap();
568        assert_eq!(decoded, "Hello, world!");
569    }
570
571    #[test]
572    fn test_tiktoken_o200k() {
573        let tokenizer = TiktokenTokenizer::o200k().unwrap();
574
575        assert_eq!(tokenizer.tokenizer_type(), TokenizerType::O200kBase);
576        assert_eq!(tokenizer.vocab_size(), 200_019);
577
578        let tokens = tokenizer.encode("Hello, world!").unwrap();
579        assert!(!tokens.is_empty());
580
581        let decoded = tokenizer.decode(&tokens).unwrap();
582        assert_eq!(decoded, "Hello, world!");
583    }
584
585    #[test]
586    fn test_truncate() {
587        let tokenizer = FallbackTokenizer::new();
588
589        // Create tokens longer than MAX_SEQUENCE_LENGTH
590        let long_text = "x".repeat(MAX_SEQUENCE_LENGTH + 100);
591        let tokens = tokenizer.encode(&long_text).unwrap();
592        let truncated = tokenizer.truncate(tokens);
593
594        assert_eq!(truncated.len(), MAX_SEQUENCE_LENGTH);
595    }
596
597    #[test]
598    fn test_tokenizer_type_vocab_size() {
599        assert_eq!(TokenizerType::Llama3.vocab_size(), 128_000);
600        assert_eq!(TokenizerType::O200kBase.vocab_size(), 200_019);
601        assert_eq!(TokenizerType::Cl100kBase.vocab_size(), 100_256);
602        assert_eq!(TokenizerType::Fallback.vocab_size(), 256);
603    }
604}