captcha_engine/
tokenizer.rs1use rten_tensor::Tensor;
4use rten_tensor::prelude::*;
5
6#[derive(Debug)]
11pub struct Tokenizer {
12 chars: Vec<char>,
14 blank_id: usize,
16}
17
18impl Default for Tokenizer {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Tokenizer {
25 #[must_use]
26 pub fn new() -> Self {
27 let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
29 let chars: Vec<char> = raw_charset.chars().collect();
30 Self {
31 chars,
32 blank_id: 0, }
34 }
35
36 #[must_use]
38 pub const fn charset_len(&self) -> usize {
39 self.chars.len()
40 }
41
42 #[must_use]
51 pub fn decode_rten(&self, logits: &Tensor<f32>) -> String {
52 let shape = logits.shape();
53
54 let (seq_len, num_classes) = match shape.len() {
56 3 => (shape[1], shape[2]), 2 => (shape[0], shape[1]), _ => return String::new(),
59 };
60
61 let data_vec;
67 let data = if let Some(slice) = logits.data() {
68 slice
69 } else {
70 data_vec = logits.iter().copied().collect::<Vec<_>>();
72 &data_vec
73 };
74
75 let class_stride = num_classes;
78
79 let tokens: Vec<usize> = (0..seq_len)
81 .map(|t| {
82 let base = t * class_stride;
83 (0..num_classes)
84 .max_by(|&a, &b| {
85 let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
86 let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
87 val_a
88 .partial_cmp(&val_b)
89 .unwrap_or(std::cmp::Ordering::Equal)
90 })
91 .unwrap_or(0)
92 })
93 .collect();
94
95 let mut result = String::new();
97 let mut prev = self.blank_id;
98
99 for &token in &tokens {
100 if token != self.blank_id && token != prev {
101 if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
103 result.push(c);
104 }
105 }
106 prev = token;
107 }
108
109 result
110 }
111
112 #[must_use]
115 pub fn decode(&self, logits: &Tensor<f32>) -> String {
116 self.decode_rten(logits)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 #![allow(clippy::unwrap_used)]
123 use super::*;
124
125 #[test]
126 fn test_tokenizer_charset_length() {
127 let tokenizer = Tokenizer::new();
128 assert_eq!(
129 tokenizer.charset_len(),
130 36,
131 "Charset should have 36 alphanumeric characters (0-9, a-z)"
132 );
133 }
134
135 #[test]
136 fn test_tokenizer_default() {
137 let tokenizer = Tokenizer::default();
138 assert_eq!(tokenizer.charset_len(), 36);
139 }
140
141 #[test]
142 fn test_decode_repeated_chars() {
143 let tokenizer = Tokenizer::new();
144 let shape = [3, 37];
149 let mut data = vec![0.0f32; 3 * 37];
150
151 data[11] = 1.0;
153 data[37 + 11] = 1.0;
155 data[2 * 37 + 12] = 1.0;
157
158 let probs = Tensor::from_data(&shape, data);
159 let result = tokenizer.decode_rten(&probs);
160 assert_eq!(result, "ab");
161 }
162
163 #[test]
164 fn test_decode_with_blanks() {
165 let tokenizer = Tokenizer::new();
166 let shape = [3, 37];
168 let mut data = vec![0.0f32; 3 * 37];
169
170 data[11] = 1.0;
172 data[37] = 1.0;
174 data[2 * 37 + 11] = 1.0;
176
177 let probs = Tensor::from_data(&shape, data);
178 let result = tokenizer.decode_rten(&probs);
179 assert_eq!(result, "aa");
180 }
181
182 #[test]
183 fn test_decode_empty() {
184 let tokenizer = Tokenizer::new();
185 let shape = [0, 37];
186 let data: Vec<f32> = vec![];
187 let probs = Tensor::from_data(&shape, data);
188 let result = tokenizer.decode_rten(&probs);
189 assert_eq!(result, "");
190 }
191
192 #[test]
193 fn test_decode_all_blanks() {
194 let tokenizer = Tokenizer::new();
195 let shape = [5, 37];
196 let mut data = vec![0.0f32; 5 * 37];
197 for i in 0..5 {
199 data[i * 37] = 1.0;
200 }
201 let probs = Tensor::from_data(&shape, data);
202 let result = tokenizer.decode_rten(&probs);
203 assert_eq!(result, "");
204 }
205
206 #[test]
207 fn test_decode_complex_pattern() {
208 let tokenizer = Tokenizer::new();
209 let shape = [8, 37];
212 let mut data = vec![0.0f32; 8 * 37];
213
214 data[11] = 1.0;
216 data[37 + 11] = 1.0;
218 data[2 * 37] = 1.0;
220 data[3 * 37 + 12] = 1.0;
222 data[4 * 37 + 12] = 1.0;
224 data[5 * 37 + 12] = 1.0;
226 data[6 * 37] = 1.0;
228 data[7 * 37 + 13] = 1.0;
230
231 let probs = Tensor::from_data(&shape, data);
232 let result = tokenizer.decode_rten(&probs);
233 assert_eq!(result, "abc");
234 }
235}