1use std::{
2 collections::hash_map::DefaultHasher,
3 hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{
9 ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
10};
11
12pub type TokenIdType = u32;
14
15pub trait Encoder: Send + Sync {
17 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
18 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
19}
20
21pub trait Decoder: Send + Sync {
23 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
24
25 fn decode_step(
35 &self,
36 token_id: TokenIdType,
37 ids: &mut Vec<TokenIdType>,
38 prefix: &mut String,
39 prefix_index: &mut usize,
40 skip_special_tokens: bool,
41 ) -> Result<Option<String>> {
42 if prefix.is_empty() && !ids.is_empty() {
44 let new_prefix = self.decode(ids, skip_special_tokens)?;
45 if !new_prefix.ends_with('�') {
46 *prefix = new_prefix;
47 *prefix_index = ids.len();
48 }
49 }
50
51 ids.push(token_id);
52 let string = self.decode(ids, skip_special_tokens)?;
53
54 if string.len() > prefix.len() && !string.ends_with('�') {
55 let mut split_at = prefix.len();
57 while !string.is_char_boundary(split_at) && split_at > 0 {
58 split_at -= 1;
59 }
60
61 let new_text = string[split_at..].to_string();
62
63 let new_prefix_len = ids.len() - *prefix_index;
65 ids.drain(..*prefix_index);
66 *prefix_index = new_prefix_len;
67 *prefix = self.decode(ids, skip_special_tokens)?;
68
69 Ok(Some(new_text))
70 } else {
71 Ok(None)
72 }
73 }
74}
75
76pub trait Tokenizer: Encoder + Decoder {
78 fn vocab_size(&self) -> usize;
79 fn get_special_tokens(&self) -> &SpecialTokens;
80 fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
81 fn id_to_token(&self, id: TokenIdType) -> Option<String>;
82
83 fn as_any(&self) -> &dyn std::any::Any;
85
86 fn apply_chat_template(
88 &self,
89 _messages: &[serde_json::Value],
90 _params: ChatTemplateParams,
91 ) -> Result<String> {
92 Err(anyhow::anyhow!(
93 "Chat template not supported by this tokenizer"
94 ))
95 }
96
97 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
99 ChatTemplateContentFormat::default()
100 }
101
102 fn thinking_toggle(&self) -> ThinkingToggle {
104 ThinkingToggle::None
105 }
106
107 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
109 None
110 }
111
112 fn think_in_prefill(&self) -> bool {
114 false
115 }
116
117 fn set_chat_template(&mut self, _template: String) -> Result<()> {
122 Err(anyhow::anyhow!(
123 "set_chat_template is not supported by this tokenizer"
124 ))
125 }
126}
127
128#[derive(Debug, Clone)]
130pub enum Encoding {
131 Hf(Box<tokenizers::tokenizer::Encoding>),
133 Plain(Vec<TokenIdType>),
135 Tiktoken(Vec<TokenIdType>),
137}
138
139impl Encoding {
140 #[inline]
142 pub fn token_ids(&self) -> &[TokenIdType] {
143 match self {
144 Encoding::Hf(inner) => inner.get_ids(),
145 Encoding::Plain(inner) => inner,
146 Encoding::Tiktoken(inner) => inner,
147 }
148 }
149
150 pub fn get_hash(&self) -> u64 {
152 let mut hasher = DefaultHasher::new();
153 self.hash(&mut hasher);
154 hasher.finish()
155 }
156}
157
158impl Hash for Encoding {
160 fn hash<H: Hasher>(&self, state: &mut H) {
161 match self {
162 Encoding::Hf(inner) => inner.get_ids().hash(state),
163 Encoding::Plain(inner) => inner.hash(state),
164 Encoding::Tiktoken(inner) => inner.hash(state),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Default)]
170pub struct SpecialTokens {
171 pub bos_token: Option<String>,
172 pub eos_token: Option<String>,
173 pub unk_token: Option<String>,
174 pub sep_token: Option<String>,
175 pub pad_token: Option<String>,
176 pub cls_token: Option<String>,
177 pub mask_token: Option<String>,
178 pub additional_special_tokens: Vec<String>,
179}