Skip to main content

llm_tokenizer/
mock.rs

1//! Mock tokenizer implementation for testing
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6
7use crate::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
8
9/// Mock tokenizer for testing purposes
10pub struct MockTokenizer {
11    vocab: HashMap<String, u32>,
12    reverse_vocab: HashMap<u32, String>,
13    special_tokens: SpecialTokens,
14}
15
16impl Default for MockTokenizer {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl MockTokenizer {
23    pub fn new() -> Self {
24        let mut vocab = HashMap::new();
25        let mut reverse_vocab = HashMap::new();
26
27        // Add some basic tokens
28        let tokens = vec![
29            ("Hello", 1),
30            ("world", 2),
31            ("test", 3),
32            ("token", 4),
33            (" ", 5),
34            (".", 6),
35            ("<eos>", 999),
36            ("<bos>", 1000),
37            ("<|im_start|>", 1001),
38            ("<|im_end|>", 1002),
39            ("<|eot_id|>", 1003),
40            ("system", 7),
41            ("user", 8),
42            ("assistant", 9),
43        ];
44
45        for (token, id) in tokens {
46            vocab.insert(token.to_string(), id);
47            reverse_vocab.insert(id, token.to_string());
48        }
49
50        let special_tokens = SpecialTokens {
51            bos_token: Some("<bos>".to_string()),
52            eos_token: Some("<eos>".to_string()),
53            unk_token: Some("<unk>".to_string()),
54            sep_token: None,
55            pad_token: None,
56            cls_token: None,
57            mask_token: None,
58            additional_special_tokens: vec![],
59        };
60
61        Self {
62            vocab,
63            reverse_vocab,
64            special_tokens,
65        }
66    }
67}
68
69impl Encoder for MockTokenizer {
70    fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
71        // Simple word-based tokenization using the vocab
72        // Split by whitespace and look up each word (decoder adds spaces back)
73        let tokens: Vec<u32> = input
74            .split_whitespace()
75            .filter_map(|word| self.vocab.get(word).copied())
76            .collect();
77
78        Ok(Encoding::Plain(tokens))
79    }
80
81    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
82        inputs
83            .iter()
84            .map(|input| self.encode(input, add_special_tokens))
85            .collect()
86    }
87}
88
89impl Decoder for MockTokenizer {
90    fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
91        let tokens: Vec<String> = token_ids
92            .iter()
93            .filter_map(|id| {
94                self.reverse_vocab.get(id).and_then(|token| {
95                    if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
96                        None
97                    } else {
98                        Some(token.clone())
99                    }
100                })
101            })
102            .collect();
103
104        Ok(tokens.join(" "))
105    }
106}
107
108impl TokenizerTrait for MockTokenizer {
109    fn vocab_size(&self) -> usize {
110        self.vocab.len()
111    }
112
113    fn get_special_tokens(&self) -> &SpecialTokens {
114        &self.special_tokens
115    }
116
117    fn token_to_id(&self, token: &str) -> Option<u32> {
118        self.vocab.get(token).copied()
119    }
120
121    fn id_to_token(&self, id: u32) -> Option<String> {
122        self.reverse_vocab.get(&id).cloned()
123    }
124
125    fn as_any(&self) -> &dyn std::any::Any {
126        self
127    }
128}