m2m/inference/
tokenizer.rs1use std::path::Path;
25use std::sync::Arc;
26
27use crate::error::{M2MError, Result};
28
29use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
31
32use tokenizers::Tokenizer;
34
35pub const MAX_SEQUENCE_LENGTH: usize = 512;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TokenizerType {
41 Llama3,
43 O200kBase,
45 Cl100kBase,
47 Fallback,
49}
50
51impl TokenizerType {
52 #[must_use]
54 pub fn vocab_size(&self) -> usize {
55 match self {
56 Self::Llama3 => 128_000,
57 Self::O200kBase => 200_019,
58 Self::Cl100kBase => 100_256,
59 Self::Fallback => 256, }
61 }
62
63 #[must_use]
65 pub fn name(&self) -> &'static str {
66 match self {
67 Self::Llama3 => "llama3",
68 Self::O200kBase => "o200k_base",
69 Self::Cl100kBase => "cl100k_base",
70 Self::Fallback => "fallback",
71 }
72 }
73}
74
75impl std::fmt::Display for TokenizerType {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.name())
78 }
79}
80
81pub trait HydraTokenizer: Send + Sync {
90 fn encode(&self, text: &str) -> Result<Vec<u32>>;
96
97 fn decode(&self, tokens: &[u32]) -> Result<String>;
101
102 fn vocab_size(&self) -> usize;
104
105 fn tokenizer_type(&self) -> TokenizerType;
107
108 fn truncate(&self, tokens: Vec<u32>) -> Vec<u32> {
112 if tokens.len() > MAX_SEQUENCE_LENGTH {
113 tokens[..MAX_SEQUENCE_LENGTH].to_vec()
114 } else {
115 tokens
116 }
117 }
118
119 fn encode_for_hydra(&self, text: &str) -> Result<Vec<u32>> {
121 let tokens = self.encode(text)?;
122 Ok(self.truncate(tokens))
123 }
124}
125
126pub struct Llama3Tokenizer {
142 inner: Tokenizer,
143 vocab_size: usize,
144}
145
146impl Llama3Tokenizer {
147 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
153 let inner = Tokenizer::from_file(path.as_ref())
154 .map_err(|e| M2MError::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
155
156 let vocab_size = inner.get_vocab_size(true);
157
158 Ok(Self { inner, vocab_size })
159 }
160
161 pub fn from_json(json: &str) -> Result<Self> {
167 Self::from_bytes(json.as_bytes())
168 }
169
170 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
176 let inner = Tokenizer::from_bytes(bytes)
177 .map_err(|e| M2MError::Tokenizer(format!("Failed to parse tokenizer: {e}")))?;
178
179 let vocab_size = inner.get_vocab_size(true);
180
181 Ok(Self { inner, vocab_size })
182 }
183}
184
185impl HydraTokenizer for Llama3Tokenizer {
186 fn encode(&self, text: &str) -> Result<Vec<u32>> {
187 let encoding = self
188 .inner
189 .encode(text, false)
190 .map_err(|e| M2MError::Tokenizer(format!("Encoding failed: {e}")))?;
191
192 Ok(encoding.get_ids().to_vec())
193 }
194
195 fn decode(&self, tokens: &[u32]) -> Result<String> {
196 self.inner
197 .decode(tokens, true)
198 .map_err(|e| M2MError::Tokenizer(format!("Decoding failed: {e}")))
199 }
200
201 fn vocab_size(&self) -> usize {
202 self.vocab_size
203 }
204
205 fn tokenizer_type(&self) -> TokenizerType {
206 TokenizerType::Llama3
207 }
208}
209
210pub struct TiktokenTokenizer {
218 inner: CoreBPE,
219 tokenizer_type: TokenizerType,
220}
221
222impl TiktokenTokenizer {
223 pub fn cl100k() -> Result<Self> {
229 let inner = cl100k_base()
230 .map_err(|e| M2MError::Tokenizer(format!("Failed to load cl100k: {e}")))?;
231
232 Ok(Self {
233 inner,
234 tokenizer_type: TokenizerType::Cl100kBase,
235 })
236 }
237
238 pub fn o200k() -> Result<Self> {
244 let inner =
245 o200k_base().map_err(|e| M2MError::Tokenizer(format!("Failed to load o200k: {e}")))?;
246
247 Ok(Self {
248 inner,
249 tokenizer_type: TokenizerType::O200kBase,
250 })
251 }
252
253 pub fn from_type(tokenizer_type: TokenizerType) -> Result<Self> {
259 match tokenizer_type {
260 TokenizerType::Cl100kBase => Self::cl100k(),
261 TokenizerType::O200kBase => Self::o200k(),
262 _ => Err(M2MError::Tokenizer(format!(
263 "Tokenizer type {tokenizer_type} is not tiktoken-based"
264 ))),
265 }
266 }
267}
268
269impl HydraTokenizer for TiktokenTokenizer {
270 fn encode(&self, text: &str) -> Result<Vec<u32>> {
271 Ok(self.inner.encode_with_special_tokens(text))
273 }
274
275 fn decode(&self, tokens: &[u32]) -> Result<String> {
276 self.inner
278 .decode(tokens.to_vec())
279 .map_err(|e| M2MError::Tokenizer(format!("Decoding failed: {e}")))
280 }
281
282 fn vocab_size(&self) -> usize {
283 self.tokenizer_type.vocab_size()
284 }
285
286 fn tokenizer_type(&self) -> TokenizerType {
287 self.tokenizer_type
288 }
289}
290
291#[derive(Debug, Clone)]
302pub struct HydraByteTokenizer {
303 max_length: usize,
305}
306
307impl HydraByteTokenizer {
308 pub const PAD_TOKEN_ID: u32 = 0;
310 pub const EOS_TOKEN_ID: u32 = 1;
312 pub const BOS_TOKEN_ID: u32 = 2;
314 pub const BYTE_OFFSET: u32 = 3;
316
317 #[must_use]
319 pub fn new() -> Self {
320 Self { max_length: 512 }
321 }
322
323 #[must_use]
325 pub fn with_max_length(max_length: usize) -> Self {
326 Self { max_length }
327 }
328}
329
330impl Default for HydraByteTokenizer {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336impl HydraTokenizer for HydraByteTokenizer {
337 fn encode(&self, text: &str) -> Result<Vec<u32>> {
338 let mut tokens = Vec::with_capacity(self.max_length.min(text.len() + 2));
339
340 tokens.push(Self::BOS_TOKEN_ID);
342
343 let max_content = self.max_length.saturating_sub(2);
345 for byte in text.bytes().take(max_content) {
346 tokens.push((byte as u32) + Self::BYTE_OFFSET);
347 }
348
349 tokens.push(Self::EOS_TOKEN_ID);
351
352 Ok(tokens)
353 }
354
355 fn decode(&self, tokens: &[u32]) -> Result<String> {
356 let bytes: Vec<u8> = tokens
357 .iter()
358 .filter_map(|&t| {
359 if t >= Self::BYTE_OFFSET && t < Self::BYTE_OFFSET + 256 {
361 Some((t - Self::BYTE_OFFSET) as u8)
362 } else {
363 None
364 }
365 })
366 .collect();
367
368 String::from_utf8(bytes)
369 .map_err(|e| M2MError::Tokenizer(format!("Invalid UTF-8 in tokens: {e}")))
370 }
371
372 fn vocab_size(&self) -> usize {
373 32000
375 }
376
377 fn tokenizer_type(&self) -> TokenizerType {
378 TokenizerType::Fallback }
380}
381
382#[derive(Debug, Clone, Default)]
394pub struct FallbackTokenizer {
395 vocab_size: usize,
396}
397
398impl FallbackTokenizer {
399 #[must_use]
401 pub fn new() -> Self {
402 Self { vocab_size: 256 }
403 }
404
405 #[must_use]
409 pub fn with_vocab_size(vocab_size: usize) -> Self {
410 Self { vocab_size }
411 }
412}
413
414impl HydraTokenizer for FallbackTokenizer {
415 fn encode(&self, text: &str) -> Result<Vec<u32>> {
416 Ok(text
417 .bytes()
418 .map(|b| (b as u32) % (self.vocab_size as u32))
419 .collect())
420 }
421
422 fn decode(&self, tokens: &[u32]) -> Result<String> {
423 let bytes: Vec<u8> = tokens
425 .iter()
426 .filter_map(|&t| if t < 256 { Some(t as u8) } else { None })
427 .collect();
428
429 String::from_utf8(bytes)
430 .map_err(|e| M2MError::Tokenizer(format!("Invalid UTF-8 in tokens: {e}")))
431 }
432
433 fn vocab_size(&self) -> usize {
434 self.vocab_size
435 }
436
437 fn tokenizer_type(&self) -> TokenizerType {
438 TokenizerType::Fallback
439 }
440}
441
442pub type BoxedTokenizer = Arc<dyn HydraTokenizer>;
448
449pub fn boxed<T: HydraTokenizer + 'static>(tokenizer: T) -> BoxedTokenizer {
451 Arc::new(tokenizer)
452}
453
454pub fn load_tokenizer(tokenizer_path: Option<&Path>, vocab_size: usize) -> Result<BoxedTokenizer> {
475 if let Some(path) = tokenizer_path {
477 if path.exists() {
478 match Llama3Tokenizer::from_file(path) {
479 Ok(tokenizer) => {
480 tracing::info!(
481 "Loaded Llama 3 tokenizer from {} (vocab: {})",
482 path.display(),
483 tokenizer.vocab_size()
484 );
485 return Ok(boxed(tokenizer));
486 },
487 Err(e) => {
488 tracing::warn!("Failed to load tokenizer from {}: {e}", path.display());
489 },
490 }
491 }
492 }
493
494 tracing::warn!(
496 "Using fallback byte-level tokenizer (vocab_size: {vocab_size}). \
497 For best results, provide a tokenizer.json file."
498 );
499 Ok(boxed(FallbackTokenizer::with_vocab_size(vocab_size)))
500}
501
502pub fn load_tokenizer_by_type(
508 tokenizer_type: TokenizerType,
509 tokenizer_path: Option<&Path>,
510) -> Result<BoxedTokenizer> {
511 match tokenizer_type {
512 TokenizerType::Llama3 => {
513 let path = tokenizer_path
514 .ok_or_else(|| M2MError::Tokenizer("Llama3 tokenizer requires a path".into()))?;
515 Ok(boxed(Llama3Tokenizer::from_file(path)?))
516 },
517 TokenizerType::O200kBase => Ok(boxed(TiktokenTokenizer::o200k()?)),
518 TokenizerType::Cl100kBase => Ok(boxed(TiktokenTokenizer::cl100k()?)),
519 TokenizerType::Fallback => Ok(boxed(FallbackTokenizer::new())),
520 }
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_fallback_tokenizer_encode_decode() {
533 let tokenizer = FallbackTokenizer::new();
534 let text = "Hello";
535
536 let tokens = tokenizer.encode(text).unwrap();
537 assert_eq!(tokens.len(), 5); assert_eq!(tokens[0], b'H' as u32);
541 assert_eq!(tokens[1], b'e' as u32);
542 }
543
544 #[test]
545 fn test_fallback_tokenizer_vocab_mapping() {
546 let tokenizer = FallbackTokenizer::with_vocab_size(128000);
547 let text = "Test";
548
549 let tokens = tokenizer.encode(text).unwrap();
550
551 for &t in &tokens {
553 assert!(t < 128000);
554 }
555 }
556
557 #[test]
558 fn test_tiktoken_cl100k() {
559 let tokenizer = TiktokenTokenizer::cl100k().unwrap();
560
561 assert_eq!(tokenizer.tokenizer_type(), TokenizerType::Cl100kBase);
562 assert_eq!(tokenizer.vocab_size(), 100_256);
563
564 let tokens = tokenizer.encode("Hello, world!").unwrap();
565 assert!(!tokens.is_empty());
566
567 let decoded = tokenizer.decode(&tokens).unwrap();
568 assert_eq!(decoded, "Hello, world!");
569 }
570
571 #[test]
572 fn test_tiktoken_o200k() {
573 let tokenizer = TiktokenTokenizer::o200k().unwrap();
574
575 assert_eq!(tokenizer.tokenizer_type(), TokenizerType::O200kBase);
576 assert_eq!(tokenizer.vocab_size(), 200_019);
577
578 let tokens = tokenizer.encode("Hello, world!").unwrap();
579 assert!(!tokens.is_empty());
580
581 let decoded = tokenizer.decode(&tokens).unwrap();
582 assert_eq!(decoded, "Hello, world!");
583 }
584
585 #[test]
586 fn test_truncate() {
587 let tokenizer = FallbackTokenizer::new();
588
589 let long_text = "x".repeat(MAX_SEQUENCE_LENGTH + 100);
591 let tokens = tokenizer.encode(&long_text).unwrap();
592 let truncated = tokenizer.truncate(tokens);
593
594 assert_eq!(truncated.len(), MAX_SEQUENCE_LENGTH);
595 }
596
597 #[test]
598 fn test_tokenizer_type_vocab_size() {
599 assert_eq!(TokenizerType::Llama3.vocab_size(), 128_000);
600 assert_eq!(TokenizerType::O200kBase.vocab_size(), 200_019);
601 assert_eq!(TokenizerType::Cl100kBase.vocab_size(), 100_256);
602 assert_eq!(TokenizerType::Fallback.vocab_size(), 256);
603 }
604}