Skip to main content

captcha_engine/
tokenizer.rs

1//! Tokenizer for decoding captcha model output.
2
3use rten_tensor::Tensor;
4use rten_tensor::prelude::*;
5
6/// Tokenizer for decoding captcha model output.
7///
8/// Uses CTC decoding to decode model logits into text.
9/// The model outputs 37 classes: blank (index 0) + 36 alphanumeric characters (0-9, a-z).
10#[derive(Debug)]
11pub struct Tokenizer {
12    /// Character set (without special tokens)
13    chars: Vec<char>,
14    /// Blank token index (index 0 for CTC)
15    blank_id: usize,
16}
17
18impl Default for Tokenizer {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl Tokenizer {
25    #[must_use]
26    pub fn new() -> Self {
27        // Lowercase charset matching the training script
28        let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
29        let chars: Vec<char> = raw_charset.chars().collect();
30        Self {
31            chars,
32            blank_id: 0, // Blank is first token (CTC convention)
33        }
34    }
35
36    /// Returns the number of characters in the charset.
37    #[must_use]
38    pub const fn charset_len(&self) -> usize {
39        self.chars.len()
40    }
41
42    /// Decode model output logits to text using CTC greedy decoding.
43    /// Collapses repeated characters and removes blanks.
44    ///
45    /// # Arguments
46    /// * `logits` - `Tensor<f32>` with shape [batch, `sequence_length`, `num_classes`] or [`sequence_length`, `num_classes`]
47    ///
48    /// # Returns
49    /// Decoded text string
50    #[must_use]
51    pub fn decode_rten(&self, logits: &Tensor<f32>) -> String {
52        let shape = logits.shape();
53
54        // Check dimensions and handle batch dimension if present
55        let (seq_len, num_classes) = match shape.len() {
56            3 => (shape[1], shape[2]), // [Batch, Seq, NumClasses] - assume batch=1
57            2 => (shape[0], shape[1]), // [Seq, NumClasses]
58            _ => return String::new(),
59        };
60
61        // Get raw data slice
62        // Ensure tensor is contiguous. If not, we might need to handle strides or clone.
63        // For simplicity, we assume generic layout or handle it.
64        // rten::Tensor::to_vec() ensures contiguous, or data() if already contiguous.
65        // We'll use iter() to be safe or try to get slice.
66        let data_vec;
67        let data = if let Some(slice) = logits.data() {
68            slice
69        } else {
70            // If not contiguous, copy to vec
71            data_vec = logits.iter().copied().collect::<Vec<_>>();
72            &data_vec
73        };
74
75        // Calculate stride. For [Seq, NumClasses], inner stride is 1, seq stride is NumClasses.
76        // If data is contiguous row-major (standard for C/rten):
77        let class_stride = num_classes;
78
79        // Greedy decoding: take argmax at each timestep
80        let tokens: Vec<usize> = (0..seq_len)
81            .map(|t| {
82                let base = t * class_stride;
83                (0..num_classes)
84                    .max_by(|&a, &b| {
85                        let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
86                        let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
87                        val_a
88                            .partial_cmp(&val_b)
89                            .unwrap_or(std::cmp::Ordering::Equal)
90                    })
91                    .unwrap_or(0)
92            })
93            .collect();
94
95        // CTC decoding: collapse repeated tokens and remove blanks
96        let mut result = String::new();
97        let mut prev = self.blank_id;
98
99        for &token in &tokens {
100            if token != self.blank_id && token != prev {
101                // Token indices: 0 = blank, 1..=36 = characters
102                if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
103                    result.push(c);
104                }
105            }
106            prev = token;
107        }
108
109        result
110    }
111
112    // Kept for compatibility if we had other usages, but renamed for clarity in model.rs
113    // Alias to allow cleaner API
114    #[must_use]
115    pub fn decode(&self, logits: &Tensor<f32>) -> String {
116        self.decode_rten(logits)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    #![allow(clippy::unwrap_used)]
123    use super::*;
124
125    #[test]
126    fn test_tokenizer_charset_length() {
127        let tokenizer = Tokenizer::new();
128        assert_eq!(
129            tokenizer.charset_len(),
130            36,
131            "Charset should have 36 alphanumeric characters (0-9, a-z)"
132        );
133    }
134
135    #[test]
136    fn test_tokenizer_default() {
137        let tokenizer = Tokenizer::default();
138        assert_eq!(tokenizer.charset_len(), 36);
139    }
140
141    #[test]
142    fn test_decode_repeated_chars() {
143        let tokenizer = Tokenizer::new();
144        // Sequence: "a", "a", "b" -> collapsed to "ab"
145        // 'a'=11 (0=blank, 1-10=0-9, 11=a)
146
147        // Shape: [Seq=3, Classes=37]
148        let shape = [3, 37];
149        let mut data = vec![0.0f32; 3 * 37];
150
151        // t=0: 'a' (11)
152        data[11] = 1.0;
153        // t=1: 'a' (11)
154        data[37 + 11] = 1.0;
155        // t=2: 'b' (12)
156        data[2 * 37 + 12] = 1.0;
157
158        let probs = Tensor::from_data(&shape, data);
159        let result = tokenizer.decode_rten(&probs);
160        assert_eq!(result, "ab");
161    }
162
163    #[test]
164    fn test_decode_with_blanks() {
165        let tokenizer = Tokenizer::new();
166        // Sequence: "a", "blank", "a" -> "aa"
167        let shape = [3, 37];
168        let mut data = vec![0.0f32; 3 * 37];
169
170        // t=0: 'a' (11)
171        data[11] = 1.0;
172        // t=1: blank (0)
173        data[37] = 1.0;
174        // t=2: 'a' (11)
175        data[2 * 37 + 11] = 1.0;
176
177        let probs = Tensor::from_data(&shape, data);
178        let result = tokenizer.decode_rten(&probs);
179        assert_eq!(result, "aa");
180    }
181
182    #[test]
183    fn test_decode_empty() {
184        let tokenizer = Tokenizer::new();
185        let shape = [0, 37];
186        let data: Vec<f32> = vec![];
187        let probs = Tensor::from_data(&shape, data);
188        let result = tokenizer.decode_rten(&probs);
189        assert_eq!(result, "");
190    }
191
192    #[test]
193    fn test_decode_all_blanks() {
194        let tokenizer = Tokenizer::new();
195        let shape = [5, 37];
196        let mut data = vec![0.0f32; 5 * 37];
197        // Set all to blank (index 0)
198        for i in 0..5 {
199            data[i * 37] = 1.0;
200        }
201        let probs = Tensor::from_data(&shape, data);
202        let result = tokenizer.decode_rten(&probs);
203        assert_eq!(result, "");
204    }
205
206    #[test]
207    fn test_decode_complex_pattern() {
208        let tokenizer = Tokenizer::new();
209        // "a", "a", "blank", "b", "b", "b", "blank", "c" -> "abc"
210        // 'a'=11, 'b'=12, 'c'=13
211        let shape = [8, 37];
212        let mut data = vec![0.0f32; 8 * 37];
213
214        // 0: a
215        data[11] = 1.0;
216        // 1: a
217        data[37 + 11] = 1.0;
218        // 2: blank
219        data[2 * 37] = 1.0;
220        // 3: b
221        data[3 * 37 + 12] = 1.0;
222        // 4: b
223        data[4 * 37 + 12] = 1.0;
224        // 5: b
225        data[5 * 37 + 12] = 1.0;
226        // 6: blank
227        data[6 * 37] = 1.0;
228        // 7: c
229        data[7 * 37 + 13] = 1.0;
230
231        let probs = Tensor::from_data(&shape, data);
232        let result = tokenizer.decode_rten(&probs);
233        assert_eq!(result, "abc");
234    }
235}