llm_tokenizer/
traits.rs

1use std::{
2    collections::hash_map::DefaultHasher,
3    hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8/// Type alias for token IDs
9pub type TokenIdType = u32;
10
11/// Core encoding trait - separate from decoding for modularity
12pub trait Encoder: Send + Sync {
13    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
14    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
15}
16
17/// Core decoding trait - can be implemented independently
18pub trait Decoder: Send + Sync {
19    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
20}
21
22/// Combined tokenizer trait
23pub trait Tokenizer: Encoder + Decoder {
24    fn vocab_size(&self) -> usize;
25    fn get_special_tokens(&self) -> &SpecialTokens;
26    fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
27    fn id_to_token(&self, id: TokenIdType) -> Option<String>;
28
29    /// Enable downcasting to concrete types
30    fn as_any(&self) -> &dyn std::any::Any;
31}
32
33/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
34#[derive(Debug, Clone)]
35pub enum Encoding {
36    /// Hugging Face
37    Hf(Box<tokenizers::tokenizer::Encoding>),
38    /// Sentence Piece
39    Sp(Vec<TokenIdType>),
40    /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
41    Tiktoken(Vec<TokenIdType>),
42}
43
44impl Encoding {
45    /// Returns a reference to token IDs - zero-copy operation
46    #[inline]
47    pub fn token_ids(&self) -> &[TokenIdType] {
48        match self {
49            Encoding::Hf(inner) => inner.get_ids(),
50            Encoding::Sp(inner) => inner,
51            Encoding::Tiktoken(inner) => inner,
52        }
53    }
54
55    /// Deprecated: Use token_ids() instead (kept for compatibility)
56    #[deprecated(since = "0.1.0", note = "Use token_ids() instead")]
57    pub fn token_ids_ref(&self) -> &[TokenIdType] {
58        self.token_ids()
59    }
60
61    /// Get a hash of the token IDs for caching purposes
62    pub fn get_hash(&self) -> u64 {
63        let mut hasher = DefaultHasher::new();
64        self.hash(&mut hasher);
65        hasher.finish()
66    }
67}
68
69/// Hash implementation for Encoding
70impl Hash for Encoding {
71    fn hash<H: Hasher>(&self, state: &mut H) {
72        match self {
73            Encoding::Hf(inner) => inner.get_ids().hash(state),
74            Encoding::Sp(inner) => inner.hash(state),
75            Encoding::Tiktoken(inner) => inner.hash(state),
76        }
77    }
78}
79
80#[derive(Debug, Clone)]
81pub struct SpecialTokens {
82    pub bos_token: Option<String>,
83    pub eos_token: Option<String>,
84    pub unk_token: Option<String>,
85    pub sep_token: Option<String>,
86    pub pad_token: Option<String>,
87    pub cls_token: Option<String>,
88    pub mask_token: Option<String>,
89    pub additional_special_tokens: Vec<String>,
90}