Skip to main content

captcha_engine/
tokenizer.rs

1//! Tokenizer for decoding captcha model output.
2
3use ndarray::ArrayView;
4
5/// Tokenizer for decoding captcha model output.
6///
7/// Uses CTC decoding to decode model logits into text.
8/// The model outputs 37 classes: blank (index 0) + 36 alphanumeric characters (0-9, a-z).
9#[derive(Debug)]
10pub struct Tokenizer {
11    /// Character set (without special tokens)
12    chars: Vec<char>,
13    /// Blank token index (index 0 for CTC)
14    blank_id: usize,
15}
16
17impl Default for Tokenizer {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl Tokenizer {
24    #[must_use]
25    pub fn new() -> Self {
26        // Lowercase charset matching the training script
27        let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
28        let chars: Vec<char> = raw_charset.chars().collect();
29        Self {
30            chars,
31            blank_id: 0, // Blank is first token (CTC convention)
32        }
33    }
34
35    /// Returns the number of characters in the charset.
36    #[must_use]
37    pub const fn charset_len(&self) -> usize {
38        self.chars.len()
39    }
40
41    /// Decode model output logits to text using CTC greedy decoding.
42    /// Collapses repeated characters and removes blanks.
43    ///
44    /// # Arguments
45    /// * `logits` - `ArrayView` with shape [batch, `sequence_length`, `num_classes`] or [`sequence_length`, `num_classes`]
46    ///
47    /// # Returns
48    /// Decoded text string
49    #[must_use]
50    pub fn decode<D: ndarray::Dimension>(&self, logits: &ArrayView<f32, D>) -> String {
51        let shape = logits.shape();
52
53        let (seq_len, num_classes) = match shape.len() {
54            3 => (shape[1], shape[2]),
55            2 => (shape[0], shape[1]),
56            _ => return String::new(),
57        };
58
59        // Flatten to raw slice for indexing
60        // Handle non-contiguous arrays by cloning if necessary
61        let data_vec;
62        let data = if let Some(slice) = logits.as_slice() {
63            slice
64        } else {
65            data_vec = logits.iter().copied().collect::<Vec<_>>();
66            &data_vec
67        };
68
69        // Calculate stride based on shape
70        let class_stride = num_classes;
71
72        // Greedy decoding: take argmax at each timestep
73        let tokens: Vec<usize> = (0..seq_len)
74            .map(|t| {
75                let base = t * class_stride;
76                (0..num_classes)
77                    .max_by(|&a, &b| {
78                        let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
79                        let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
80                        val_a
81                            .partial_cmp(&val_b)
82                            .unwrap_or(std::cmp::Ordering::Equal)
83                    })
84                    .unwrap_or(0)
85            })
86            .collect();
87
88        // CTC decoding: collapse repeated tokens and remove blanks
89        let mut result = String::new();
90        let mut prev = self.blank_id;
91
92        for &token in &tokens {
93            if token != self.blank_id && token != prev {
94                // Token indices: 0 = blank, 1..=36 = characters
95                if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
96                    result.push(c);
97                }
98            }
99            prev = token;
100        }
101
102        result
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    #![allow(clippy::unwrap_used)]
109    use super::*;
110
111    #[test]
112    fn test_tokenizer_charset_length() {
113        let tokenizer = Tokenizer::new();
114        assert_eq!(
115            tokenizer.charset_len(),
116            36,
117            "Charset should have 36 alphanumeric characters (0-9, a-z)"
118        );
119    }
120
121    #[test]
122    fn test_tokenizer_default() {
123        let tokenizer = Tokenizer::default();
124        assert_eq!(tokenizer.charset_len(), 36);
125    }
126
127    #[test]
128    fn test_decode_repeated_chars() {
129        let tokenizer = Tokenizer::new();
130        // Sequence: "a", "a", "b" -> collapsed to "ab"
131        // Char indices: 'a'=10, 'b'=11 (0-9 are 0-9, a is 10)
132        // 0-9 (10 chars), a=10
133        // indexes: blank=0, 0=1..9=10, a=11, b=12...
134
135        // Wait, let's check char mapping in new():
136        // "0123456789abcdefghijklmnopqrstuvwxyz"
137        // 0->1, 9->10, a->11, b->12
138
139        let shape = (3, 37); // [seq_len, num_classes] (2D input supported)
140        let mut data = vec![0.0f32; 3 * 37];
141
142        // t=0: 'a' (index 11)
143        data[11] = 1.0;
144        // t=1: 'a' (index 11)
145        data[37 + 11] = 1.0;
146        // t=2: 'b' (index 12)
147        data[2 * 37 + 12] = 1.0;
148
149        let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
150        let result = tokenizer.decode(&probs.view());
151        assert_eq!(result, "ab");
152    }
153
154    #[test]
155    fn test_decode_with_blanks() {
156        let tokenizer = Tokenizer::new();
157        // Sequence: "a", "blank", "a" -> "aa"
158        let shape = (3, 37);
159        let mut data = vec![0.0f32; 3 * 37];
160
161        // t=0: 'a' (11)
162        data[11] = 1.0;
163        // t=1: blank (0)
164        data[37] = 1.0;
165        // t=2: 'a' (11)
166        data[2 * 37 + 11] = 1.0;
167
168        let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
169        let result = tokenizer.decode(&probs.view());
170        assert_eq!(result, "aa");
171    }
172
173    #[test]
174    fn test_decode_empty() {
175        let tokenizer = Tokenizer::new();
176        let shape = (0, 37);
177        let data: Vec<f32> = vec![];
178        let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
179        let result = tokenizer.decode(&probs.view());
180        assert_eq!(result, "");
181    }
182
183    #[test]
184    fn test_decode_all_blanks() {
185        let tokenizer = Tokenizer::new();
186        let shape = (5, 37);
187        let mut data = vec![0.0f32; 5 * 37];
188        // Set all to blank (index 0)
189        for i in 0..5 {
190            data[i * 37] = 1.0;
191        }
192        let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
193        let result = tokenizer.decode(&probs.view());
194        assert_eq!(result, "");
195    }
196
197    #[test]
198    fn test_decode_complex_pattern() {
199        let tokenizer = Tokenizer::new();
200        // "a", "a", "blank", "b", "b", "b", "blank", "c" -> "abc"
201        // 'a'=11, 'b'=12, 'c'=13
202        let shape = (8, 37);
203        let mut data = vec![0.0f32; 8 * 37];
204
205        // 0: a
206        data[11] = 1.0;
207        // 1: a
208        data[37 + 11] = 1.0;
209        // 2: blank
210        data[2 * 37] = 1.0;
211        // 3: b
212        data[3 * 37 + 12] = 1.0;
213        // 4: b
214        data[4 * 37 + 12] = 1.0;
215        // 5: b
216        data[5 * 37 + 12] = 1.0;
217        // 6: blank
218        data[6 * 37] = 1.0;
219        // 7: c
220        data[7 * 37 + 13] = 1.0;
221
222        let probs = ndarray::Array2::from_shape_vec(shape, data).unwrap();
223        let result = tokenizer.decode(&probs.view());
224        assert_eq!(result, "abc");
225    }
226}