Skip to main content

llm_tokenizer/
traits.rs

1use std::{
2    collections::hash_map::DefaultHasher,
3    hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
9
10/// Type alias for token IDs
11pub type TokenIdType = u32;
12
13/// Core encoding trait - separate from decoding for modularity
14pub trait Encoder: Send + Sync {
15    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
16    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
17}
18
19/// Core decoding trait - can be implemented independently
20pub trait Decoder: Send + Sync {
21    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
22}
23
24/// Combined tokenizer trait
25pub trait Tokenizer: Encoder + Decoder {
26    fn vocab_size(&self) -> usize;
27    fn get_special_tokens(&self) -> &SpecialTokens;
28    fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
29    fn id_to_token(&self, id: TokenIdType) -> Option<String>;
30
31    /// Enable downcasting to concrete types
32    fn as_any(&self) -> &dyn std::any::Any;
33
34    /// Apply chat template to messages. Default returns an error for tokenizers without template support.
35    fn apply_chat_template(
36        &self,
37        _messages: &[serde_json::Value],
38        _params: ChatTemplateParams,
39    ) -> Result<String> {
40        Err(anyhow::anyhow!(
41            "Chat template not supported by this tokenizer"
42        ))
43    }
44
45    /// Get the content format expected by the chat template.
46    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
47        ChatTemplateContentFormat::default()
48    }
49
50    /// Set or override the chat template.
51    ///
52    /// Returns an error if the template fails to parse or the tokenizer
53    /// does not support chat templates.
54    fn set_chat_template(&mut self, _template: String) -> Result<()> {
55        Err(anyhow::anyhow!(
56            "set_chat_template is not supported by this tokenizer"
57        ))
58    }
59}
60
61/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
62#[derive(Debug, Clone)]
63pub enum Encoding {
64    /// Hugging Face
65    Hf(Box<tokenizers::tokenizer::Encoding>),
66    /// Plain token ID vector
67    Plain(Vec<TokenIdType>),
68    /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
69    Tiktoken(Vec<TokenIdType>),
70}
71
72impl Encoding {
73    /// Returns a reference to token IDs - zero-copy operation
74    #[inline]
75    pub fn token_ids(&self) -> &[TokenIdType] {
76        match self {
77            Encoding::Hf(inner) => inner.get_ids(),
78            Encoding::Plain(inner) => inner,
79            Encoding::Tiktoken(inner) => inner,
80        }
81    }
82
83    /// Get a hash of the token IDs for caching purposes
84    pub fn get_hash(&self) -> u64 {
85        let mut hasher = DefaultHasher::new();
86        self.hash(&mut hasher);
87        hasher.finish()
88    }
89}
90
91/// Hash implementation for Encoding
92impl Hash for Encoding {
93    fn hash<H: Hasher>(&self, state: &mut H) {
94        match self {
95            Encoding::Hf(inner) => inner.get_ids().hash(state),
96            Encoding::Plain(inner) => inner.hash(state),
97            Encoding::Tiktoken(inner) => inner.hash(state),
98        }
99    }
100}
101
102#[derive(Debug, Clone, Default)]
103pub struct SpecialTokens {
104    pub bos_token: Option<String>,
105    pub eos_token: Option<String>,
106    pub unk_token: Option<String>,
107    pub sep_token: Option<String>,
108    pub pad_token: Option<String>,
109    pub cls_token: Option<String>,
110    pub mask_token: Option<String>,
111    pub additional_special_tokens: Vec<String>,
112}