1use std::{
2 collections::hash_map::DefaultHasher,
3 hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
9
10pub type TokenIdType = u32;
12
13pub trait Encoder: Send + Sync {
15 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
16 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
17}
18
19pub trait Decoder: Send + Sync {
21 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
22
23 fn decode_step(
33 &self,
34 token_id: TokenIdType,
35 ids: &mut Vec<TokenIdType>,
36 prefix: &mut String,
37 prefix_index: &mut usize,
38 skip_special_tokens: bool,
39 ) -> Result<Option<String>> {
40 if prefix.is_empty() && !ids.is_empty() {
42 let new_prefix = self.decode(ids, skip_special_tokens)?;
43 if !new_prefix.ends_with('�') {
44 *prefix = new_prefix;
45 *prefix_index = ids.len();
46 }
47 }
48
49 ids.push(token_id);
50 let string = self.decode(ids, skip_special_tokens)?;
51
52 if string.len() > prefix.len() && !string.ends_with('�') {
53 let mut split_at = prefix.len();
55 while !string.is_char_boundary(split_at) && split_at > 0 {
56 split_at -= 1;
57 }
58
59 let new_text = string[split_at..].to_string();
60
61 let new_prefix_len = ids.len() - *prefix_index;
63 ids.drain(..*prefix_index);
64 *prefix_index = new_prefix_len;
65 *prefix = self.decode(ids, skip_special_tokens)?;
66
67 Ok(Some(new_text))
68 } else {
69 Ok(None)
70 }
71 }
72}
73
74pub trait Tokenizer: Encoder + Decoder {
76 fn vocab_size(&self) -> usize;
77 fn get_special_tokens(&self) -> &SpecialTokens;
78 fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
79 fn id_to_token(&self, id: TokenIdType) -> Option<String>;
80
81 fn as_any(&self) -> &dyn std::any::Any;
83
84 fn apply_chat_template(
86 &self,
87 _messages: &[serde_json::Value],
88 _params: ChatTemplateParams,
89 ) -> Result<String> {
90 Err(anyhow::anyhow!(
91 "Chat template not supported by this tokenizer"
92 ))
93 }
94
95 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
97 ChatTemplateContentFormat::default()
98 }
99
100 fn set_chat_template(&mut self, _template: String) -> Result<()> {
105 Err(anyhow::anyhow!(
106 "set_chat_template is not supported by this tokenizer"
107 ))
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum Encoding {
114 Hf(Box<tokenizers::tokenizer::Encoding>),
116 Plain(Vec<TokenIdType>),
118 Tiktoken(Vec<TokenIdType>),
120}
121
122impl Encoding {
123 #[inline]
125 pub fn token_ids(&self) -> &[TokenIdType] {
126 match self {
127 Encoding::Hf(inner) => inner.get_ids(),
128 Encoding::Plain(inner) => inner,
129 Encoding::Tiktoken(inner) => inner,
130 }
131 }
132
133 pub fn get_hash(&self) -> u64 {
135 let mut hasher = DefaultHasher::new();
136 self.hash(&mut hasher);
137 hasher.finish()
138 }
139}
140
141impl Hash for Encoding {
143 fn hash<H: Hasher>(&self, state: &mut H) {
144 match self {
145 Encoding::Hf(inner) => inner.get_ids().hash(state),
146 Encoding::Plain(inner) => inner.hash(state),
147 Encoding::Tiktoken(inner) => inner.hash(state),
148 }
149 }
150}
151
152#[derive(Debug, Clone, Default)]
153pub struct SpecialTokens {
154 pub bos_token: Option<String>,
155 pub eos_token: Option<String>,
156 pub unk_token: Option<String>,
157 pub sep_token: Option<String>,
158 pub pad_token: Option<String>,
159 pub cls_token: Option<String>,
160 pub mask_token: Option<String>,
161 pub additional_special_tokens: Vec<String>,
162}