Skip to main content

mediarium_ocr/
lib.rs

1use std::cell::UnsafeCell;
2use std::io::Cursor;
3use std::path::Path;
4use std::sync::LazyLock;
5
6use anyhow::{Context, Result, bail};
7use image::ImageReader;
8use image::imageops::FilterType;
9use ort::session::Session;
10use ort::value::Tensor;
11use tracing::info;
12
13static CHARSET: LazyLock<Vec<String>> = LazyLock::new(|| {
14    serde_json::from_str(include_str!("./charsets.json")).expect("bundled charsets.json is invalid")
15});
16
17pub struct CaptchaOcr {
18    /// Use `UnsafeCell` here for interior mutability
19    ///
20    /// Avoid `recognize` holding a mutable reference to `self`
21    session: UnsafeCell<Session>,
22}
23
24// SAFETY: ort::Session::run requires &mut self only because the Rust binding is
25// conservative. The underlying ONNX Runtime C API (OrtSession::Run) is
26// documented as thread-safe and uses internal synchronization.
27unsafe impl Send for CaptchaOcr {}
28unsafe impl Sync for CaptchaOcr {}
29
30impl CaptchaOcr {
31    pub fn load(model_dir: &Path) -> Result<Self> {
32        let onnx_path = model_dir.join("common.onnx");
33
34        if !onnx_path.exists() {
35            bail!(
36                "ONNX model not found at {}. \
37                 Download from https://github.com/sml2h3/ddddocr/blob/master/ddddocr/common.onnx",
38                onnx_path.display()
39            );
40        }
41
42        let session = Session::builder()
43            .context("failed to create ONNX session builder")?
44            .commit_from_file(&onnx_path)
45            .with_context(|| format!("failed to load ONNX model from {}", onnx_path.display()))?;
46
47        info!(model_path = %onnx_path.display(), "ONNX 加载成功");
48
49        Ok(Self {
50            session: UnsafeCell::new(session),
51        })
52    }
53
54    pub fn recognize(&self, image_bytes: &[u8]) -> Result<String> {
55        let img = ImageReader::new(Cursor::new(image_bytes))
56            .with_guessed_format()
57            .context("failed to guess image format")?
58            .decode()
59            .context("failed to decode captcha image")?;
60
61        // Resize to 64px height, maintaining aspect ratio
62        let target_height = 64u32;
63        let scale = f64::from(target_height) / f64::from(img.height());
64        let target_width = (f64::from(img.width()) * scale).round() as u32;
65        let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
66
67        // Convert to grayscale
68        let gray = resized.to_luma8();
69
70        // Build input tensor: [1, 1, 64, width], normalized
71        let width = gray.width() as usize;
72        let height = gray.height() as usize;
73        let mut data = Vec::with_capacity(height * width);
74        for y in 0..height {
75            for x in 0..width {
76                let pixel = f32::from(gray.get_pixel(x as u32, y as u32).0[0]);
77                data.push((pixel / 255.0 - 0.5) / 0.5);
78            }
79        }
80        let input = Tensor::from_array(([1usize, 1, height, width], data.into_boxed_slice()))?;
81
82        // SAFETY: see unsafe impl Sync above - ONNX Runtime handles concurrency
83        // internally.
84        let session = unsafe { &mut *self.session.get() };
85        let outputs = session
86            .run(ort::inputs![input])
87            .context("ONNX inference failed")?;
88
89        let (shape, raw_data) = outputs[0]
90            .try_extract_tensor::<f32>()
91            .context("failed to read output tensor")?;
92
93        let result = ctc_decode(shape, raw_data, &CHARSET);
94
95        Ok(result)
96    }
97}
98
99/// CTC greedy decode: take argmax at each timestep, collapse repeats, remove
100/// blanks.
101fn ctc_decode(shape: &[i64], data: &[f32], charset: &[String]) -> String {
102    let (seq_len, num_classes) = if shape.len() == 3 {
103        // edge case: if the 3D tensor has batch=1, seq_len is at index 1
104        if shape[0] == 1 {
105            (shape[1] as usize, shape[2] as usize)
106        } else {
107            (shape[0] as usize, shape[2] as usize)
108        }
109    } else if shape.len() == 2 {
110        (shape[0] as usize, shape[1] as usize)
111    } else {
112        return String::new();
113    };
114
115    let mut last_idx: Option<usize> = None;
116    let mut result = String::new();
117
118    for t in 0..seq_len {
119        let offset = t * num_classes;
120        let slice = &data[offset..offset + num_classes];
121        let best_idx = slice
122            .iter()
123            .enumerate()
124            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
125            .map(|(i, _)| i)
126            .unwrap_or(0);
127
128        // blank token shape: 1.0, 0.0, 0.0, ...
129        if best_idx == 0 {
130            last_idx = None;
131
132            continue;
133        }
134
135        // Collapse repeated characters
136        if last_idx == Some(best_idx) {
137            continue;
138        }
139
140        last_idx = Some(best_idx);
141
142        if let Some(ch) = charset.get(best_idx)
143            && !ch.is_empty()
144        {
145            result.push_str(ch);
146        }
147    }
148
149    result
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_ctc_decode_basic() {
158        let charset: Vec<String> = vec![String::new(), "a".into(), "b".into(), "c".into()];
159        let data: Vec<f32> = vec![
160            0.0, 1.0, 0.0, 0.0, // t0: 'a'
161            0.0, 1.0, 0.0, 0.0, // t1: 'a' (repeat, collapsed)
162            0.0, 0.0, 1.0, 0.0, // t2: 'b'
163        ];
164        let shape: &[i64] = &[3, 1, 4];
165        let result = ctc_decode(shape, &data, &charset);
166        assert_eq!(result, "ab");
167    }
168
169    #[test]
170    fn test_ctc_decode_with_blanks() {
171        let charset: Vec<String> = vec![String::new(), "x".into(), "y".into()];
172        let data: Vec<f32> = vec![
173            1.0, 0.0, 0.0, // t0: blank
174            0.0, 1.0, 0.0, // t1: 'x'
175            1.0, 0.0, 0.0, // t2: blank
176            0.0, 1.0, 0.0, // t3: 'x'
177        ];
178        let shape: &[i64] = &[4, 1, 3];
179        let result = ctc_decode(shape, &data, &charset);
180        assert_eq!(result, "xx");
181    }
182
183    #[test]
184    fn test_recognize_sample_captcha() {
185        let ocr = CaptchaOcr::load(Path::new("../../models")).unwrap();
186        let image_bytes = include_bytes!("test_captcha.bmp");
187        let result = ocr.recognize(image_bytes).unwrap();
188        assert_eq!(result, "48115");
189    }
190}