1use std::{
2 collections::hash_map::DefaultHasher,
3 hash::{Hash, Hasher},
4};
5
6use anyhow::Result;
7
8pub type TokenIdType = u32;
10
11pub trait Encoder: Send + Sync {
13 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding>;
14 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>>;
15}
16
17pub trait Decoder: Send + Sync {
19 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
20}
21
22pub trait Tokenizer: Encoder + Decoder {
24 fn vocab_size(&self) -> usize;
25 fn get_special_tokens(&self) -> &SpecialTokens;
26 fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
27 fn id_to_token(&self, id: TokenIdType) -> Option<String>;
28
29 fn as_any(&self) -> &dyn std::any::Any;
31}
32
33#[derive(Debug, Clone)]
35pub enum Encoding {
36 Hf(Box<tokenizers::tokenizer::Encoding>),
38 Sp(Vec<TokenIdType>),
40 Tiktoken(Vec<TokenIdType>),
42}
43
44impl Encoding {
45 #[inline]
47 pub fn token_ids(&self) -> &[TokenIdType] {
48 match self {
49 Encoding::Hf(inner) => inner.get_ids(),
50 Encoding::Sp(inner) => inner,
51 Encoding::Tiktoken(inner) => inner,
52 }
53 }
54
55 #[deprecated(since = "0.1.0", note = "Use token_ids() instead")]
57 pub fn token_ids_ref(&self) -> &[TokenIdType] {
58 self.token_ids()
59 }
60
61 pub fn get_hash(&self) -> u64 {
63 let mut hasher = DefaultHasher::new();
64 self.hash(&mut hasher);
65 hasher.finish()
66 }
67}
68
69impl Hash for Encoding {
71 fn hash<H: Hasher>(&self, state: &mut H) {
72 match self {
73 Encoding::Hf(inner) => inner.get_ids().hash(state),
74 Encoding::Sp(inner) => inner.hash(state),
75 Encoding::Tiktoken(inner) => inner.hash(state),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
81pub struct SpecialTokens {
82 pub bos_token: Option<String>,
83 pub eos_token: Option<String>,
84 pub unk_token: Option<String>,
85 pub sep_token: Option<String>,
86 pub pad_token: Option<String>,
87 pub cls_token: Option<String>,
88 pub mask_token: Option<String>,
89 pub additional_special_tokens: Vec<String>,
90}