1use std::{
2 collections::hash_map::DefaultHasher,
3 hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8use crate::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
9
10pub type TokenIdType = u32;
12
13pub trait Encoder: Send + Sync {
15 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
16 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
17}
18
19pub trait Decoder: Send + Sync {
21 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
22}
23
24pub trait Tokenizer: Encoder + Decoder {
26 fn vocab_size(&self) -> usize;
27 fn get_special_tokens(&self) -> &SpecialTokens;
28 fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
29 fn id_to_token(&self, id: TokenIdType) -> Option<String>;
30
31 fn as_any(&self) -> &dyn std::any::Any;
33
34 fn apply_chat_template(
36 &self,
37 _messages: &[serde_json::Value],
38 _params: ChatTemplateParams,
39 ) -> Result<String> {
40 Err(anyhow::anyhow!(
41 "Chat template not supported by this tokenizer"
42 ))
43 }
44
45 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
47 ChatTemplateContentFormat::default()
48 }
49
50 fn set_chat_template(&mut self, _template: String) -> Result<()> {
55 Err(anyhow::anyhow!(
56 "set_chat_template is not supported by this tokenizer"
57 ))
58 }
59}
60
61#[derive(Debug, Clone)]
63pub enum Encoding {
64 Hf(Box<tokenizers::tokenizer::Encoding>),
66 Plain(Vec<TokenIdType>),
68 Tiktoken(Vec<TokenIdType>),
70}
71
72impl Encoding {
73 #[inline]
75 pub fn token_ids(&self) -> &[TokenIdType] {
76 match self {
77 Encoding::Hf(inner) => inner.get_ids(),
78 Encoding::Plain(inner) => inner,
79 Encoding::Tiktoken(inner) => inner,
80 }
81 }
82
83 pub fn get_hash(&self) -> u64 {
85 let mut hasher = DefaultHasher::new();
86 self.hash(&mut hasher);
87 hasher.finish()
88 }
89}
90
91impl Hash for Encoding {
93 fn hash<H: Hasher>(&self, state: &mut H) {
94 match self {
95 Encoding::Hf(inner) => inner.get_ids().hash(state),
96 Encoding::Plain(inner) => inner.hash(state),
97 Encoding::Tiktoken(inner) => inner.hash(state),
98 }
99 }
100}
101
102#[derive(Debug, Clone, Default)]
103pub struct SpecialTokens {
104 pub bos_token: Option<String>,
105 pub eos_token: Option<String>,
106 pub unk_token: Option<String>,
107 pub sep_token: Option<String>,
108 pub pad_token: Option<String>,
109 pub cls_token: Option<String>,
110 pub mask_token: Option<String>,
111 pub additional_special_tokens: Vec<String>,
112}