Skip to main content

llm_tokenizer/
traits.rs

1use std::{
2    collections::hash_map::DefaultHasher,
3    hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
9
10/// Type alias for token IDs
11pub type TokenIdType = u32;
12
13/// Core encoding trait - separate from decoding for modularity
14pub trait Encoder: Send + Sync {
15    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
16    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
17}
18
19/// Core decoding trait - can be implemented independently
20pub trait Decoder: Send + Sync {
21    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
22
23    /// Incremental decode step — called once per generated token.
24    ///
25    /// Maintains mutable state (`ids`, `prefix`, `prefix_index`) across calls to
26    /// produce incremental text output. The default implementation uses the
27    /// double-decode algorithm (decode prefix, decode prefix+new, diff).
28    ///
29    /// HuggingFace overrides this with the native `step_decode_stream` from the
30    /// `tokenizers` crate, which uses the same algorithm internally but avoids
31    /// trait-method overhead for the two `decode()` calls.
32    fn decode_step(
33        &self,
34        token_id: TokenIdType,
35        ids: &mut Vec<TokenIdType>,
36        prefix: &mut String,
37        prefix_index: &mut usize,
38        skip_special_tokens: bool,
39    ) -> Result<Option<String>> {
40        // Recompute prefix if empty (first call or after incomplete UTF-8)
41        if prefix.is_empty() && !ids.is_empty() {
42            let new_prefix = self.decode(ids, skip_special_tokens)?;
43            if !new_prefix.ends_with('�') {
44                *prefix = new_prefix;
45                *prefix_index = ids.len();
46            }
47        }
48
49        ids.push(token_id);
50        let string = self.decode(ids, skip_special_tokens)?;
51
52        if string.len() > prefix.len() && !string.ends_with('�') {
53            // Find char-safe split point
54            let mut split_at = prefix.len();
55            while !string.is_char_boundary(split_at) && split_at > 0 {
56                split_at -= 1;
57            }
58
59            let new_text = string[split_at..].to_string();
60
61            // Drain consumed tokens and cache new prefix for next call
62            let new_prefix_len = ids.len() - *prefix_index;
63            ids.drain(..*prefix_index);
64            *prefix_index = new_prefix_len;
65            *prefix = self.decode(ids, skip_special_tokens)?;
66
67            Ok(Some(new_text))
68        } else {
69            Ok(None)
70        }
71    }
72}
73
74/// Combined tokenizer trait
75pub trait Tokenizer: Encoder + Decoder {
76    fn vocab_size(&self) -> usize;
77    fn get_special_tokens(&self) -> &SpecialTokens;
78    fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
79    fn id_to_token(&self, id: TokenIdType) -> Option<String>;
80
81    /// Enable downcasting to concrete types
82    fn as_any(&self) -> &dyn std::any::Any;
83
84    /// Apply chat template to messages. Default returns an error for tokenizers without template support.
85    fn apply_chat_template(
86        &self,
87        _messages: &[serde_json::Value],
88        _params: ChatTemplateParams,
89    ) -> Result<String> {
90        Err(anyhow::anyhow!(
91            "Chat template not supported by this tokenizer"
92        ))
93    }
94
95    /// Get the content format expected by the chat template.
96    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
97        ChatTemplateContentFormat::default()
98    }
99
100    /// Set or override the chat template.
101    ///
102    /// Returns an error if the template fails to parse or the tokenizer
103    /// does not support chat templates.
104    fn set_chat_template(&mut self, _template: String) -> Result<()> {
105        Err(anyhow::anyhow!(
106            "set_chat_template is not supported by this tokenizer"
107        ))
108    }
109}
110
111/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
112#[derive(Debug, Clone)]
113pub enum Encoding {
114    /// Hugging Face
115    Hf(Box<tokenizers::tokenizer::Encoding>),
116    /// Plain token ID vector
117    Plain(Vec<TokenIdType>),
118    /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
119    Tiktoken(Vec<TokenIdType>),
120}
121
122impl Encoding {
123    /// Returns a reference to token IDs - zero-copy operation
124    #[inline]
125    pub fn token_ids(&self) -> &[TokenIdType] {
126        match self {
127            Encoding::Hf(inner) => inner.get_ids(),
128            Encoding::Plain(inner) => inner,
129            Encoding::Tiktoken(inner) => inner,
130        }
131    }
132
133    /// Get a hash of the token IDs for caching purposes
134    pub fn get_hash(&self) -> u64 {
135        let mut hasher = DefaultHasher::new();
136        self.hash(&mut hasher);
137        hasher.finish()
138    }
139}
140
141/// Hash implementation for Encoding
142impl Hash for Encoding {
143    fn hash<H: Hasher>(&self, state: &mut H) {
144        match self {
145            Encoding::Hf(inner) => inner.get_ids().hash(state),
146            Encoding::Plain(inner) => inner.hash(state),
147            Encoding::Tiktoken(inner) => inner.hash(state),
148        }
149    }
150}
151
152#[derive(Debug, Clone, Default)]
153pub struct SpecialTokens {
154    pub bos_token: Option<String>,
155    pub eos_token: Option<String>,
156    pub unk_token: Option<String>,
157    pub sep_token: Option<String>,
158    pub pad_token: Option<String>,
159    pub cls_token: Option<String>,
160    pub mask_token: Option<String>,
161    pub additional_special_tokens: Vec<String>,
162}