Skip to main content

llm_tokenizer/
traits.rs

1use std::{
2    collections::hash_map::DefaultHasher,
3    hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{
9    ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
10};
11
12/// Type alias for token IDs
13pub type TokenIdType = u32;
14
15/// Core encoding trait - separate from decoding for modularity
16pub trait Encoder: Send + Sync {
17    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
18    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
19}
20
21/// Core decoding trait - can be implemented independently
22pub trait Decoder: Send + Sync {
23    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
24
25    /// Incremental decode step — called once per generated token.
26    ///
27    /// Maintains mutable state (`ids`, `prefix`, `prefix_index`) across calls to
28    /// produce incremental text output. The default implementation uses the
29    /// double-decode algorithm (decode prefix, decode prefix+new, diff).
30    ///
31    /// HuggingFace overrides this with the native `step_decode_stream` from the
32    /// `tokenizers` crate, which uses the same algorithm internally but avoids
33    /// trait-method overhead for the two `decode()` calls.
34    fn decode_step(
35        &self,
36        token_id: TokenIdType,
37        ids: &mut Vec<TokenIdType>,
38        prefix: &mut String,
39        prefix_index: &mut usize,
40        skip_special_tokens: bool,
41    ) -> Result<Option<String>> {
42        // Recompute prefix if empty (first call or after incomplete UTF-8)
43        if prefix.is_empty() && !ids.is_empty() {
44            let new_prefix = self.decode(ids, skip_special_tokens)?;
45            if !new_prefix.ends_with('�') {
46                *prefix = new_prefix;
47                *prefix_index = ids.len();
48            }
49        }
50
51        ids.push(token_id);
52        let string = self.decode(ids, skip_special_tokens)?;
53
54        if string.len() > prefix.len() && !string.ends_with('�') {
55            // Find char-safe split point
56            let mut split_at = prefix.len();
57            while !string.is_char_boundary(split_at) && split_at > 0 {
58                split_at -= 1;
59            }
60
61            let new_text = string[split_at..].to_string();
62
63            // Drain consumed tokens and cache new prefix for next call
64            let new_prefix_len = ids.len() - *prefix_index;
65            ids.drain(..*prefix_index);
66            *prefix_index = new_prefix_len;
67            *prefix = self.decode(ids, skip_special_tokens)?;
68
69            Ok(Some(new_text))
70        } else {
71            Ok(None)
72        }
73    }
74}
75
76/// Combined tokenizer trait
77pub trait Tokenizer: Encoder + Decoder {
78    fn vocab_size(&self) -> usize;
79    fn get_special_tokens(&self) -> &SpecialTokens;
80    fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
81    fn id_to_token(&self, id: TokenIdType) -> Option<String>;
82
83    /// Enable downcasting to concrete types
84    fn as_any(&self) -> &dyn std::any::Any;
85
86    /// Apply chat template to messages. Default returns an error for tokenizers without template support.
87    fn apply_chat_template(
88        &self,
89        _messages: &[serde_json::Value],
90        _params: ChatTemplateParams,
91    ) -> Result<String> {
92        Err(anyhow::anyhow!(
93            "Chat template not supported by this tokenizer"
94        ))
95    }
96
97    /// Get the content format expected by the chat template.
98    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
99        ChatTemplateContentFormat::default()
100    }
101
102    /// Get the thinking toggle support for this template.
103    fn thinking_toggle(&self) -> ThinkingToggle {
104        ThinkingToggle::None
105    }
106
107    /// The variable name the template uses for the thinking toggle.
108    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
109        None
110    }
111
112    /// Whether the template injects `<think>` in the generation prompt.
113    fn think_in_prefill(&self) -> bool {
114        false
115    }
116
117    /// Set or override the chat template.
118    ///
119    /// Returns an error if the template fails to parse or the tokenizer
120    /// does not support chat templates.
121    fn set_chat_template(&mut self, _template: String) -> Result<()> {
122        Err(anyhow::anyhow!(
123            "set_chat_template is not supported by this tokenizer"
124        ))
125    }
126}
127
128/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
129#[derive(Debug, Clone)]
130pub enum Encoding {
131    /// Hugging Face
132    Hf(Box<tokenizers::tokenizer::Encoding>),
133    /// Plain token ID vector
134    Plain(Vec<TokenIdType>),
135    /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
136    Tiktoken(Vec<TokenIdType>),
137}
138
139impl Encoding {
140    /// Returns a reference to token IDs - zero-copy operation
141    #[inline]
142    pub fn token_ids(&self) -> &[TokenIdType] {
143        match self {
144            Encoding::Hf(inner) => inner.get_ids(),
145            Encoding::Plain(inner) => inner,
146            Encoding::Tiktoken(inner) => inner,
147        }
148    }
149
150    /// Get a hash of the token IDs for caching purposes
151    pub fn get_hash(&self) -> u64 {
152        let mut hasher = DefaultHasher::new();
153        self.hash(&mut hasher);
154        hasher.finish()
155    }
156}
157
158/// Hash implementation for Encoding
159impl Hash for Encoding {
160    fn hash<H: Hasher>(&self, state: &mut H) {
161        match self {
162            Encoding::Hf(inner) => inner.get_ids().hash(state),
163            Encoding::Plain(inner) => inner.hash(state),
164            Encoding::Tiktoken(inner) => inner.hash(state),
165        }
166    }
167}
168
169#[derive(Debug, Clone, Default)]
170pub struct SpecialTokens {
171    pub bos_token: Option<String>,
172    pub eos_token: Option<String>,
173    pub unk_token: Option<String>,
174    pub sep_token: Option<String>,
175    pub pad_token: Option<String>,
176    pub cls_token: Option<String>,
177    pub mask_token: Option<String>,
178    pub additional_special_tokens: Vec<String>,
179}