Skip to main content

ferrum_testkit/
tokenizer.rs

1//! Mock tokenizer that maps words to sequential token IDs.
2
3use ferrum_interfaces::{
4    tokenizer::{TokenizerInfo, TokenizerType},
5    Tokenizer,
6};
7use ferrum_types::{Result, SpecialTokens, TokenId};
8
9/// Mock tokenizer: splits on whitespace, assigns sequential token IDs.
10/// EOS token is vocab_size - 1.
11pub struct MockTokenizer {
12    vocab_size: usize,
13    special_tokens: SpecialTokens,
14}
15
16impl MockTokenizer {
17    pub fn new(vocab_size: usize) -> Self {
18        let eos = TokenId::new((vocab_size - 1) as u32);
19        let bos = TokenId::new((vocab_size - 2) as u32);
20        Self {
21            vocab_size,
22            special_tokens: SpecialTokens {
23                bos_token: Some(bos),
24                eos_token: Some(eos),
25                unk_token: Some(TokenId::new(0)),
26                pad_token: None,
27                sep_token: None,
28                cls_token: None,
29                mask_token: None,
30            },
31        }
32    }
33}
34
35impl Tokenizer for MockTokenizer {
36    fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>> {
37        let mut tokens = Vec::new();
38        if add_special {
39            if let Some(bos) = self.special_tokens.bos_token {
40                tokens.push(bos);
41            }
42        }
43        // Hash each word to a token ID in range [1, vocab_size - 3]
44        for word in text.split_whitespace() {
45            let hash = word
46                .bytes()
47                .fold(0u32, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u32));
48            let id = 1 + (hash % (self.vocab_size as u32 - 3));
49            tokens.push(TokenId::new(id));
50        }
51        if tokens.is_empty() {
52            tokens.push(TokenId::new(1)); // at least one token
53        }
54        Ok(tokens)
55    }
56
57    fn decode(&self, tokens: &[TokenId], _skip_special: bool) -> Result<String> {
58        Ok(tokens
59            .iter()
60            .map(|t| format!("w{}", t.get()))
61            .collect::<Vec<_>>()
62            .join(" "))
63    }
64
65    fn decode_incremental(&self, _prev: &[TokenId], next: TokenId) -> Result<String> {
66        Ok(format!("w{}", next.get()))
67    }
68
69    fn vocab_size(&self) -> usize {
70        self.vocab_size
71    }
72
73    fn special_tokens(&self) -> &SpecialTokens {
74        &self.special_tokens
75    }
76
77    fn token_id(&self, _text: &str) -> Option<TokenId> {
78        None
79    }
80
81    fn token_text(&self, _token_id: TokenId) -> Option<&str> {
82        None
83    }
84
85    fn info(&self) -> TokenizerInfo {
86        TokenizerInfo {
87            tokenizer_type: TokenizerType::Custom,
88            vocab_size: self.vocab_size,
89            special_tokens: self.special_tokens.clone(),
90            supports_incremental: true,
91            supports_chat_template: false,
92            max_token_length: Some(128),
93            model_name: Some("mock".into()),
94        }
95    }
96}