use std::cell::UnsafeCell;
use std::io::Cursor;
use std::path::Path;
use std::sync::LazyLock;
use anyhow::{Context, Result, bail};
use image::ImageReader;
use image::imageops::FilterType;
use ort::session::Session;
use ort::value::Tensor;
use tracing::info;
static CHARSET: LazyLock<Vec<String>> = LazyLock::new(|| {
serde_json::from_str(include_str!("./charsets.json")).expect("bundled charsets.json is invalid")
});
pub struct CaptchaOcr {
session: UnsafeCell<Session>,
}
unsafe impl Send for CaptchaOcr {}
unsafe impl Sync for CaptchaOcr {}
impl CaptchaOcr {
pub fn load(model_dir: &Path) -> Result<Self> {
let onnx_path = model_dir.join("common.onnx");
if !onnx_path.exists() {
bail!(
"ONNX model not found at {}. \
Download from https://github.com/sml2h3/ddddocr/blob/master/ddddocr/common.onnx",
onnx_path.display()
);
}
let session = Session::builder()
.context("failed to create ONNX session builder")?
.commit_from_file(&onnx_path)
.with_context(|| format!("failed to load ONNX model from {}", onnx_path.display()))?;
info!(model_path = %onnx_path.display(), "ONNX 加载成功");
Ok(Self {
session: UnsafeCell::new(session),
})
}
pub fn recognize(&self, image_bytes: &[u8]) -> Result<String> {
let img = ImageReader::new(Cursor::new(image_bytes))
.with_guessed_format()
.context("failed to guess image format")?
.decode()
.context("failed to decode captcha image")?;
let target_height = 64u32;
let scale = f64::from(target_height) / f64::from(img.height());
let target_width = (f64::from(img.width()) * scale).round() as u32;
let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
let gray = resized.to_luma8();
let width = gray.width() as usize;
let height = gray.height() as usize;
let mut data = Vec::with_capacity(height * width);
for y in 0..height {
for x in 0..width {
let pixel = f32::from(gray.get_pixel(x as u32, y as u32).0[0]);
data.push((pixel / 255.0 - 0.5) / 0.5);
}
}
let input = Tensor::from_array(([1usize, 1, height, width], data.into_boxed_slice()))?;
let session = unsafe { &mut *self.session.get() };
let outputs = session
.run(ort::inputs![input])
.context("ONNX inference failed")?;
let (shape, raw_data) = outputs[0]
.try_extract_tensor::<f32>()
.context("failed to read output tensor")?;
let result = ctc_decode(shape, raw_data, &CHARSET);
Ok(result)
}
}
fn ctc_decode(shape: &[i64], data: &[f32], charset: &[String]) -> String {
let (seq_len, num_classes) = if shape.len() == 3 {
if shape[0] == 1 {
(shape[1] as usize, shape[2] as usize)
} else {
(shape[0] as usize, shape[2] as usize)
}
} else if shape.len() == 2 {
(shape[0] as usize, shape[1] as usize)
} else {
return String::new();
};
let mut last_idx: Option<usize> = None;
let mut result = String::new();
for t in 0..seq_len {
let offset = t * num_classes;
let slice = &data[offset..offset + num_classes];
let best_idx = slice
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
if best_idx == 0 {
last_idx = None;
continue;
}
if last_idx == Some(best_idx) {
continue;
}
last_idx = Some(best_idx);
if let Some(ch) = charset.get(best_idx)
&& !ch.is_empty()
{
result.push_str(ch);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ctc_decode_basic() {
let charset: Vec<String> = vec![String::new(), "a".into(), "b".into(), "c".into()];
let data: Vec<f32> = vec![
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
let shape: &[i64] = &[3, 1, 4];
let result = ctc_decode(shape, &data, &charset);
assert_eq!(result, "ab");
}
#[test]
fn test_ctc_decode_with_blanks() {
let charset: Vec<String> = vec![String::new(), "x".into(), "y".into()];
let data: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
let shape: &[i64] = &[4, 1, 3];
let result = ctc_decode(shape, &data, &charset);
assert_eq!(result, "xx");
}
#[test]
fn test_recognize_sample_captcha() {
let ocr = CaptchaOcr::load(Path::new("../../models")).unwrap();
let image_bytes = include_bytes!("test_captcha.bmp");
let result = ocr.recognize(image_bytes).unwrap();
assert_eq!(result, "48115");
}
}