ferrum_testkit/
tokenizer.rs1use ferrum_interfaces::{
4 tokenizer::{TokenizerInfo, TokenizerType},
5 Tokenizer,
6};
7use ferrum_types::{Result, SpecialTokens, TokenId};
8
9pub struct MockTokenizer {
12 vocab_size: usize,
13 special_tokens: SpecialTokens,
14}
15
16impl MockTokenizer {
17 pub fn new(vocab_size: usize) -> Self {
18 let eos = TokenId::new((vocab_size - 1) as u32);
19 let bos = TokenId::new((vocab_size - 2) as u32);
20 Self {
21 vocab_size,
22 special_tokens: SpecialTokens {
23 bos_token: Some(bos),
24 eos_token: Some(eos),
25 unk_token: Some(TokenId::new(0)),
26 pad_token: None,
27 sep_token: None,
28 cls_token: None,
29 mask_token: None,
30 },
31 }
32 }
33}
34
35impl Tokenizer for MockTokenizer {
36 fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>> {
37 let mut tokens = Vec::new();
38 if add_special {
39 if let Some(bos) = self.special_tokens.bos_token {
40 tokens.push(bos);
41 }
42 }
43 for word in text.split_whitespace() {
45 let hash = word
46 .bytes()
47 .fold(0u32, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u32));
48 let id = 1 + (hash % (self.vocab_size as u32 - 3));
49 tokens.push(TokenId::new(id));
50 }
51 if tokens.is_empty() {
52 tokens.push(TokenId::new(1)); }
54 Ok(tokens)
55 }
56
57 fn decode(&self, tokens: &[TokenId], _skip_special: bool) -> Result<String> {
58 Ok(tokens
59 .iter()
60 .map(|t| format!("w{}", t.get()))
61 .collect::<Vec<_>>()
62 .join(" "))
63 }
64
65 fn decode_incremental(&self, _prev: &[TokenId], next: TokenId) -> Result<String> {
66 Ok(format!("w{}", next.get()))
67 }
68
69 fn vocab_size(&self) -> usize {
70 self.vocab_size
71 }
72
73 fn special_tokens(&self) -> &SpecialTokens {
74 &self.special_tokens
75 }
76
77 fn token_id(&self, _text: &str) -> Option<TokenId> {
78 None
79 }
80
81 fn token_text(&self, _token_id: TokenId) -> Option<&str> {
82 None
83 }
84
85 fn info(&self) -> TokenizerInfo {
86 TokenizerInfo {
87 tokenizer_type: TokenizerType::Custom,
88 vocab_size: self.vocab_size,
89 special_tokens: self.special_tokens.clone(),
90 supports_incremental: true,
91 supports_chat_template: false,
92 max_token_length: Some(128),
93 model_name: Some("mock".into()),
94 }
95 }
96}