captcha_engine/
tokenizer.rs1use ndarray::ArrayView;
4
5#[derive(Debug)]
10pub struct Tokenizer {
11 chars: Vec<char>,
13 blank_id: usize,
15}
16
17impl Default for Tokenizer {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl Tokenizer {
24 #[must_use]
25 pub fn new() -> Self {
26 let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
28 let chars: Vec<char> = raw_charset.chars().collect();
29 Self {
30 chars,
31 blank_id: 0, }
33 }
34
35 #[must_use]
37 pub const fn charset_len(&self) -> usize {
38 self.chars.len()
39 }
40
41 #[must_use]
50 pub fn decode<D: ndarray::Dimension>(&self, logits: &ArrayView<f32, D>) -> String {
51 let shape = logits.shape();
52
53 let (seq_len, num_classes) = match shape.len() {
54 3 => (shape[1], shape[2]),
55 2 => (shape[0], shape[1]),
56 _ => return String::new(),
57 };
58
59 let data_vec;
62 let data = if let Some(slice) = logits.as_slice() {
63 slice
64 } else {
65 data_vec = logits.iter().copied().collect::<Vec<_>>();
66 &data_vec
67 };
68
69 let class_stride = num_classes;
71
72 let tokens: Vec<usize> = (0..seq_len)
74 .map(|t| {
75 let base = t * class_stride;
76 (0..num_classes)
77 .max_by(|&a, &b| {
78 let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
79 let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
80 val_a
81 .partial_cmp(&val_b)
82 .unwrap_or(std::cmp::Ordering::Equal)
83 })
84 .unwrap_or(0)
85 })
86 .collect();
87
88 let mut result = String::new();
90 let mut prev = self.blank_id;
91
92 for &token in &tokens {
93 if token != self.blank_id && token != prev {
94 if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
96 result.push(c);
97 }
98 }
99 prev = token;
100 }
101
102 result
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 #![allow(clippy::unwrap_used)]
109 use super::*;
110
111 #[test]
112 fn test_tokenizer_charset_length() {
113 let tokenizer = Tokenizer::new();
114 assert_eq!(
115 tokenizer.charset_len(),
116 36,
117 "Charset should have 36 alphanumeric characters (0-9, a-z)"
118 );
119 }
120
121 #[test]
122 fn test_tokenizer_default() {
123 let tokenizer = Tokenizer::default();
124 assert_eq!(tokenizer.charset_len(), 36);
125 }
126
127 #[test]
128 fn test_decode_repeated_chars() {
129 let tokenizer = Tokenizer::new();
130 let shape = (3, 37); let mut data = vec![0.0f32; 3 * 37];
141
142 data[11] = 1.0;
144 data[37 + 11] = 1.0;
146 data[2 * 37 + 12] = 1.0;
148
149 let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
150 let result = tokenizer.decode(&probs.view());
151 assert_eq!(result, "ab");
152 }
153
154 #[test]
155 fn test_decode_with_blanks() {
156 let tokenizer = Tokenizer::new();
157 let shape = (3, 37);
159 let mut data = vec![0.0f32; 3 * 37];
160
161 data[11] = 1.0;
163 data[37] = 1.0;
165 data[2 * 37 + 11] = 1.0;
167
168 let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
169 let result = tokenizer.decode(&probs.view());
170 assert_eq!(result, "aa");
171 }
172
173 #[test]
174 fn test_decode_empty() {
175 let tokenizer = Tokenizer::new();
176 let shape = (0, 37);
177 let data: Vec<f32> = vec![];
178 let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
179 let result = tokenizer.decode(&probs.view());
180 assert_eq!(result, "");
181 }
182
183 #[test]
184 fn test_decode_all_blanks() {
185 let tokenizer = Tokenizer::new();
186 let shape = (5, 37);
187 let mut data = vec![0.0f32; 5 * 37];
188 for i in 0..5 {
190 data[i * 37] = 1.0;
191 }
192 let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
193 let result = tokenizer.decode(&probs.view());
194 assert_eq!(result, "");
195 }
196
197 #[test]
198 fn test_decode_complex_pattern() {
199 let tokenizer = Tokenizer::new();
200 let shape = (8, 37);
203 let mut data = vec![0.0f32; 8 * 37];
204
205 data[11] = 1.0;
207 data[37 + 11] = 1.0;
209 data[2 * 37] = 1.0;
211 data[3 * 37 + 12] = 1.0;
213 data[4 * 37 + 12] = 1.0;
215 data[5 * 37 + 12] = 1.0;
217 data[6 * 37] = 1.0;
219 data[7 * 37 + 13] = 1.0;
221
222 let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
223 let result = tokenizer.decode(&probs.view());
224 assert_eq!(result, "abc");
225 }
226}