captcha-engine 0.4.10

ONNX-based captcha recognition engine
Documentation
//! Tokenizer for decoding captcha model output.

use rten_tensor::Tensor;
use rten_tensor::prelude::*;

/// Tokenizer for decoding captcha model output.
///
/// Uses CTC decoding to decode model logits into text.
/// The model outputs 37 classes: blank (index 0) + 36 alphanumeric characters (0-9, a-z).
#[derive(Debug)]
pub struct Tokenizer {
    /// Character set (without special tokens)
    chars: Vec<char>,
    /// Blank token index (index 0 for CTC)
    blank_id: usize,
}

impl Default for Tokenizer {
    fn default() -> Self {
        Self::new()
    }
}

impl Tokenizer {
    #[must_use]
    pub fn new() -> Self {
        // Lowercase charset matching the training script
        let raw_charset = "0123456789abcdefghijklmnopqrstuvwxyz";
        let chars: Vec<char> = raw_charset.chars().collect();
        Self {
            chars,
            blank_id: 0, // Blank is first token (CTC convention)
        }
    }

    /// Returns the number of characters in the charset.
    #[must_use]
    pub const fn charset_len(&self) -> usize {
        self.chars.len()
    }

    /// Decode model output logits to text using CTC greedy decoding.
    /// Collapses repeated characters and removes blanks.
    ///
    /// # Arguments
    /// * `logits` - `Tensor<f32>` with shape [batch, `sequence_length`, `num_classes`] or [`sequence_length`, `num_classes`]
    ///
    /// # Returns
    /// Decoded text string
    #[must_use]
    pub fn decode_rten(&self, logits: &Tensor<f32>) -> String {
        let shape = logits.shape();

        // Check dimensions and handle batch dimension if present
        let (seq_len, num_classes) = match shape.len() {
            3 => (shape[1], shape[2]), // [Batch, Seq, NumClasses] - assume batch=1
            2 => (shape[0], shape[1]), // [Seq, NumClasses]
            _ => return String::new(),
        };

        // Get raw data slice
        // Ensure tensor is contiguous. If not, we might need to handle strides or clone.
        // For simplicity, we assume generic layout or handle it.
        // rten::Tensor::to_vec() ensures contiguous, or data() if already contiguous.
        // We'll use iter() to be safe or try to get slice.
        let data_vec;
        let data = if let Some(slice) = logits.data() {
            slice
        } else {
            // If not contiguous, copy to vec
            data_vec = logits.iter().copied().collect::<Vec<_>>();
            &data_vec
        };

        // Calculate stride. For [Seq, NumClasses], inner stride is 1, seq stride is NumClasses.
        // If data is contiguous row-major (standard for C/rten):
        let class_stride = num_classes;

        // Greedy decoding: take argmax at each timestep
        let tokens: Vec<usize> = (0..seq_len)
            .map(|t| {
                let base = t * class_stride;
                (0..num_classes)
                    .max_by(|&a, &b| {
                        let val_a = data.get(base + a).copied().unwrap_or(f32::NEG_INFINITY);
                        let val_b = data.get(base + b).copied().unwrap_or(f32::NEG_INFINITY);
                        val_a
                            .partial_cmp(&val_b)
                            .unwrap_or(std::cmp::Ordering::Equal)
                    })
                    .unwrap_or(0)
            })
            .collect();

        // CTC decoding: collapse repeated tokens and remove blanks
        let mut result = String::new();
        let mut prev = self.blank_id;

        for &token in &tokens {
            if token != self.blank_id && token != prev {
                // Token indices: 0 = blank, 1..=36 = characters
                if let Some(&c) = self.chars.get(token.saturating_sub(1)) {
                    result.push(c);
                }
            }
            prev = token;
        }

        result
    }

    // Kept for compatibility if we had other usages, but renamed for clarity in model.rs
    // Alias to allow cleaner API
    #[must_use]
    pub fn decode(&self, logits: &Tensor<f32>) -> String {
        self.decode_rten(logits)
    }
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)]
    use super::*;

    #[test]
    fn test_tokenizer_charset_length() {
        let tokenizer = Tokenizer::new();
        assert_eq!(
            tokenizer.charset_len(),
            36,
            "Charset should have 36 alphanumeric characters (0-9, a-z)"
        );
    }

    #[test]
    fn test_tokenizer_default() {
        let tokenizer = Tokenizer::default();
        assert_eq!(tokenizer.charset_len(), 36);
    }

    #[test]
    fn test_decode_repeated_chars() {
        let tokenizer = Tokenizer::new();
        // Sequence: "a", "a", "b" -> collapsed to "ab"
        // 'a'=11 (0=blank, 1-10=0-9, 11=a)

        // Shape: [Seq=3, Classes=37]
        let shape = [3, 37];
        let mut data = vec![0.0f32; 3 * 37];

        // t=0: 'a' (11)
        data[11] = 1.0;
        // t=1: 'a' (11)
        data[37 + 11] = 1.0;
        // t=2: 'b' (12)
        data[2 * 37 + 12] = 1.0;

        let probs = Tensor::from_data(&shape, data);
        let result = tokenizer.decode_rten(&probs);
        assert_eq!(result, "ab");
    }

    #[test]
    fn test_decode_with_blanks() {
        let tokenizer = Tokenizer::new();
        // Sequence: "a", "blank", "a" -> "aa"
        let shape = [3, 37];
        let mut data = vec![0.0f32; 3 * 37];

        // t=0: 'a' (11)
        data[11] = 1.0;
        // t=1: blank (0)
        data[37] = 1.0;
        // t=2: 'a' (11)
        data[2 * 37 + 11] = 1.0;

        let probs = Tensor::from_data(&shape, data);
        let result = tokenizer.decode_rten(&probs);
        assert_eq!(result, "aa");
    }

    #[test]
    fn test_decode_empty() {
        let tokenizer = Tokenizer::new();
        let shape = [0, 37];
        let data: Vec<f32> = vec![];
        let probs = Tensor::from_data(&shape, data);
        let result = tokenizer.decode_rten(&probs);
        assert_eq!(result, "");
    }

    #[test]
    fn test_decode_all_blanks() {
        let tokenizer = Tokenizer::new();
        let shape = [5, 37];
        let mut data = vec![0.0f32; 5 * 37];
        // Set all to blank (index 0)
        for i in 0..5 {
            data[i * 37] = 1.0;
        }
        let probs = Tensor::from_data(&shape, data);
        let result = tokenizer.decode_rten(&probs);
        assert_eq!(result, "");
    }

    #[test]
    fn test_decode_complex_pattern() {
        let tokenizer = Tokenizer::new();
        // "a", "a", "blank", "b", "b", "b", "blank", "c" -> "abc"
        // 'a'=11, 'b'=12, 'c'=13
        let shape = [8, 37];
        let mut data = vec![0.0f32; 8 * 37];

        // 0: a
        data[11] = 1.0;
        // 1: a
        data[37 + 11] = 1.0;
        // 2: blank
        data[2 * 37] = 1.0;
        // 3: b
        data[3 * 37 + 12] = 1.0;
        // 4: b
        data[4 * 37 + 12] = 1.0;
        // 5: b
        data[5 * 37 + 12] = 1.0;
        // 6: blank
        data[6 * 37] = 1.0;
        // 7: c
        data[7 * 37 + 13] = 1.0;

        let probs = Tensor::from_data(&shape, data);
        let result = tokenizer.decode_rten(&probs);
        assert_eq!(result, "abc");
    }
}