mediarium-ocr 0.1.0

ONNX-based OCR helpers for Mediarium captcha recognition
Documentation
use std::cell::UnsafeCell;
use std::io::Cursor;
use std::path::Path;
use std::sync::LazyLock;

use anyhow::{Context, Result, bail};
use image::ImageReader;
use image::imageops::FilterType;
use ort::session::Session;
use ort::value::Tensor;
use tracing::info;

static CHARSET: LazyLock<Vec<String>> = LazyLock::new(|| {
    serde_json::from_str(include_str!("./charsets.json")).expect("bundled charsets.json is invalid")
});

pub struct CaptchaOcr {
    /// Use `UnsafeCell` here for interior mutability
    ///
    /// Avoid `recognize` holding a mutable reference to `self`
    session: UnsafeCell<Session>,
}

// SAFETY: ort::Session::run requires &mut self only because the Rust binding is
// conservative. The underlying ONNX Runtime C API (OrtSession::Run) is
// documented as thread-safe and uses internal synchronization.
unsafe impl Send for CaptchaOcr {}
unsafe impl Sync for CaptchaOcr {}

impl CaptchaOcr {
    pub fn load(model_dir: &Path) -> Result<Self> {
        let onnx_path = model_dir.join("common.onnx");

        if !onnx_path.exists() {
            bail!(
                "ONNX model not found at {}. \
                 Download from https://github.com/sml2h3/ddddocr/blob/master/ddddocr/common.onnx",
                onnx_path.display()
            );
        }

        let session = Session::builder()
            .context("failed to create ONNX session builder")?
            .commit_from_file(&onnx_path)
            .with_context(|| format!("failed to load ONNX model from {}", onnx_path.display()))?;

        info!(model_path = %onnx_path.display(), "ONNX 加载成功");

        Ok(Self {
            session: UnsafeCell::new(session),
        })
    }

    pub fn recognize(&self, image_bytes: &[u8]) -> Result<String> {
        let img = ImageReader::new(Cursor::new(image_bytes))
            .with_guessed_format()
            .context("failed to guess image format")?
            .decode()
            .context("failed to decode captcha image")?;

        // Resize to 64px height, maintaining aspect ratio
        let target_height = 64u32;
        let scale = f64::from(target_height) / f64::from(img.height());
        let target_width = (f64::from(img.width()) * scale).round() as u32;
        let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);

        // Convert to grayscale
        let gray = resized.to_luma8();

        // Build input tensor: [1, 1, 64, width], normalized
        let width = gray.width() as usize;
        let height = gray.height() as usize;
        let mut data = Vec::with_capacity(height * width);
        for y in 0..height {
            for x in 0..width {
                let pixel = f32::from(gray.get_pixel(x as u32, y as u32).0[0]);
                data.push((pixel / 255.0 - 0.5) / 0.5);
            }
        }
        let input = Tensor::from_array(([1usize, 1, height, width], data.into_boxed_slice()))?;

        // SAFETY: see unsafe impl Sync above - ONNX Runtime handles concurrency
        // internally.
        let session = unsafe { &mut *self.session.get() };
        let outputs = session
            .run(ort::inputs![input])
            .context("ONNX inference failed")?;

        let (shape, raw_data) = outputs[0]
            .try_extract_tensor::<f32>()
            .context("failed to read output tensor")?;

        let result = ctc_decode(shape, raw_data, &CHARSET);

        Ok(result)
    }
}

/// CTC greedy decode: take argmax at each timestep, collapse repeats, remove
/// blanks.
fn ctc_decode(shape: &[i64], data: &[f32], charset: &[String]) -> String {
    let (seq_len, num_classes) = if shape.len() == 3 {
        // edge case: if the 3D tensor has batch=1, seq_len is at index 1
        if shape[0] == 1 {
            (shape[1] as usize, shape[2] as usize)
        } else {
            (shape[0] as usize, shape[2] as usize)
        }
    } else if shape.len() == 2 {
        (shape[0] as usize, shape[1] as usize)
    } else {
        return String::new();
    };

    let mut last_idx: Option<usize> = None;
    let mut result = String::new();

    for t in 0..seq_len {
        let offset = t * num_classes;
        let slice = &data[offset..offset + num_classes];
        let best_idx = slice
            .iter()
            .enumerate()
            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
            .map(|(i, _)| i)
            .unwrap_or(0);

        // blank token shape: 1.0, 0.0, 0.0, ...
        if best_idx == 0 {
            last_idx = None;

            continue;
        }

        // Collapse repeated characters
        if last_idx == Some(best_idx) {
            continue;
        }

        last_idx = Some(best_idx);

        if let Some(ch) = charset.get(best_idx)
            && !ch.is_empty()
        {
            result.push_str(ch);
        }
    }

    result
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ctc_decode_basic() {
        let charset: Vec<String> = vec![String::new(), "a".into(), "b".into(), "c".into()];
        let data: Vec<f32> = vec![
            0.0, 1.0, 0.0, 0.0, // t0: 'a'
            0.0, 1.0, 0.0, 0.0, // t1: 'a' (repeat, collapsed)
            0.0, 0.0, 1.0, 0.0, // t2: 'b'
        ];
        let shape: &[i64] = &[3, 1, 4];
        let result = ctc_decode(shape, &data, &charset);
        assert_eq!(result, "ab");
    }

    #[test]
    fn test_ctc_decode_with_blanks() {
        let charset: Vec<String> = vec![String::new(), "x".into(), "y".into()];
        let data: Vec<f32> = vec![
            1.0, 0.0, 0.0, // t0: blank
            0.0, 1.0, 0.0, // t1: 'x'
            1.0, 0.0, 0.0, // t2: blank
            0.0, 1.0, 0.0, // t3: 'x'
        ];
        let shape: &[i64] = &[4, 1, 3];
        let result = ctc_decode(shape, &data, &charset);
        assert_eq!(result, "xx");
    }

    #[test]
    fn test_recognize_sample_captcha() {
        let ocr = CaptchaOcr::load(Path::new("../../models")).unwrap();
        let image_bytes = include_bytes!("test_captcha.bmp");
        let result = ocr.recognize(image_bytes).unwrap();
        assert_eq!(result, "48115");
    }
}