1use std::collections::HashMap;
4
5use anyhow::Result;
6
7use crate::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
8
9pub struct MockTokenizer {
11 vocab: HashMap<String, u32>,
12 reverse_vocab: HashMap<u32, String>,
13 special_tokens: SpecialTokens,
14}
15
16impl Default for MockTokenizer {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl MockTokenizer {
23 pub fn new() -> Self {
24 let mut vocab = HashMap::new();
25 let mut reverse_vocab = HashMap::new();
26
27 let tokens = vec![
29 ("Hello", 1),
30 ("world", 2),
31 ("test", 3),
32 ("token", 4),
33 (" ", 5),
34 (".", 6),
35 ("<eos>", 999),
36 ("<bos>", 1000),
37 ("<|im_start|>", 1001),
38 ("<|im_end|>", 1002),
39 ("<|eot_id|>", 1003),
40 ("system", 7),
41 ("user", 8),
42 ("assistant", 9),
43 ];
44
45 for (token, id) in tokens {
46 vocab.insert(token.to_string(), id);
47 reverse_vocab.insert(id, token.to_string());
48 }
49
50 let special_tokens = SpecialTokens {
51 bos_token: Some("<bos>".to_string()),
52 eos_token: Some("<eos>".to_string()),
53 unk_token: Some("<unk>".to_string()),
54 sep_token: None,
55 pad_token: None,
56 cls_token: None,
57 mask_token: None,
58 additional_special_tokens: vec![],
59 };
60
61 Self {
62 vocab,
63 reverse_vocab,
64 special_tokens,
65 }
66 }
67}
68
69impl Encoder for MockTokenizer {
70 fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
71 let tokens: Vec<u32> = input
74 .split_whitespace()
75 .filter_map(|word| self.vocab.get(word).copied())
76 .collect();
77
78 Ok(Encoding::Plain(tokens))
79 }
80
81 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
82 inputs
83 .iter()
84 .map(|input| self.encode(input, add_special_tokens))
85 .collect()
86 }
87}
88
89impl Decoder for MockTokenizer {
90 fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
91 let tokens: Vec<String> = token_ids
92 .iter()
93 .filter_map(|id| {
94 self.reverse_vocab.get(id).and_then(|token| {
95 if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
96 None
97 } else {
98 Some(token.clone())
99 }
100 })
101 })
102 .collect();
103
104 Ok(tokens.join(" "))
105 }
106}
107
108impl TokenizerTrait for MockTokenizer {
109 fn vocab_size(&self) -> usize {
110 self.vocab.len()
111 }
112
113 fn get_special_tokens(&self) -> &SpecialTokens {
114 &self.special_tokens
115 }
116
117 fn token_to_id(&self, token: &str) -> Option<u32> {
118 self.vocab.get(token).copied()
119 }
120
121 fn id_to_token(&self, id: u32) -> Option<String> {
122 self.reverse_vocab.get(&id).cloned()
123 }
124
125 fn as_any(&self) -> &dyn std::any::Any {
126 self
127 }
128}