1use std::cell::UnsafeCell;
2use std::io::Cursor;
3use std::path::Path;
4use std::sync::LazyLock;
5
6use anyhow::{Context, Result, bail};
7use image::ImageReader;
8use image::imageops::FilterType;
9use ort::session::Session;
10use ort::value::Tensor;
11use tracing::info;
12
13static CHARSET: LazyLock<Vec<String>> = LazyLock::new(|| {
14 serde_json::from_str(include_str!("./charsets.json")).expect("bundled charsets.json is invalid")
15});
16
17pub struct CaptchaOcr {
18 session: UnsafeCell<Session>,
22}
23
24unsafe impl Send for CaptchaOcr {}
28unsafe impl Sync for CaptchaOcr {}
29
30impl CaptchaOcr {
31 pub fn load(model_dir: &Path) -> Result<Self> {
32 let onnx_path = model_dir.join("common.onnx");
33
34 if !onnx_path.exists() {
35 bail!(
36 "ONNX model not found at {}. \
37 Download from https://github.com/sml2h3/ddddocr/blob/master/ddddocr/common.onnx",
38 onnx_path.display()
39 );
40 }
41
42 let session = Session::builder()
43 .context("failed to create ONNX session builder")?
44 .commit_from_file(&onnx_path)
45 .with_context(|| format!("failed to load ONNX model from {}", onnx_path.display()))?;
46
47 info!(model_path = %onnx_path.display(), "ONNX 加载成功");
48
49 Ok(Self {
50 session: UnsafeCell::new(session),
51 })
52 }
53
54 pub fn recognize(&self, image_bytes: &[u8]) -> Result<String> {
55 let img = ImageReader::new(Cursor::new(image_bytes))
56 .with_guessed_format()
57 .context("failed to guess image format")?
58 .decode()
59 .context("failed to decode captcha image")?;
60
61 let target_height = 64u32;
63 let scale = f64::from(target_height) / f64::from(img.height());
64 let target_width = (f64::from(img.width()) * scale).round() as u32;
65 let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
66
67 let gray = resized.to_luma8();
69
70 let width = gray.width() as usize;
72 let height = gray.height() as usize;
73 let mut data = Vec::with_capacity(height * width);
74 for y in 0..height {
75 for x in 0..width {
76 let pixel = f32::from(gray.get_pixel(x as u32, y as u32).0[0]);
77 data.push((pixel / 255.0 - 0.5) / 0.5);
78 }
79 }
80 let input = Tensor::from_array(([1usize, 1, height, width], data.into_boxed_slice()))?;
81
82 let session = unsafe { &mut *self.session.get() };
85 let outputs = session
86 .run(ort::inputs![input])
87 .context("ONNX inference failed")?;
88
89 let (shape, raw_data) = outputs[0]
90 .try_extract_tensor::<f32>()
91 .context("failed to read output tensor")?;
92
93 let result = ctc_decode(shape, raw_data, &CHARSET);
94
95 Ok(result)
96 }
97}
98
99fn ctc_decode(shape: &[i64], data: &[f32], charset: &[String]) -> String {
102 let (seq_len, num_classes) = if shape.len() == 3 {
103 if shape[0] == 1 {
105 (shape[1] as usize, shape[2] as usize)
106 } else {
107 (shape[0] as usize, shape[2] as usize)
108 }
109 } else if shape.len() == 2 {
110 (shape[0] as usize, shape[1] as usize)
111 } else {
112 return String::new();
113 };
114
115 let mut last_idx: Option<usize> = None;
116 let mut result = String::new();
117
118 for t in 0..seq_len {
119 let offset = t * num_classes;
120 let slice = &data[offset..offset + num_classes];
121 let best_idx = slice
122 .iter()
123 .enumerate()
124 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
125 .map(|(i, _)| i)
126 .unwrap_or(0);
127
128 if best_idx == 0 {
130 last_idx = None;
131
132 continue;
133 }
134
135 if last_idx == Some(best_idx) {
137 continue;
138 }
139
140 last_idx = Some(best_idx);
141
142 if let Some(ch) = charset.get(best_idx)
143 && !ch.is_empty()
144 {
145 result.push_str(ch);
146 }
147 }
148
149 result
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_ctc_decode_basic() {
158 let charset: Vec<String> = vec![String::new(), "a".into(), "b".into(), "c".into()];
159 let data: Vec<f32> = vec![
160 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
164 let shape: &[i64] = &[3, 1, 4];
165 let result = ctc_decode(shape, &data, &charset);
166 assert_eq!(result, "ab");
167 }
168
169 #[test]
170 fn test_ctc_decode_with_blanks() {
171 let charset: Vec<String> = vec![String::new(), "x".into(), "y".into()];
172 let data: Vec<f32> = vec![
173 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
178 let shape: &[i64] = &[4, 1, 3];
179 let result = ctc_decode(shape, &data, &charset);
180 assert_eq!(result, "xx");
181 }
182
183 #[test]
184 fn test_recognize_sample_captcha() {
185 let ocr = CaptchaOcr::load(Path::new("../../models")).unwrap();
186 let image_bytes = include_bytes!("test_captcha.bmp");
187 let result = ocr.recognize(image_bytes).unwrap();
188 assert_eq!(result, "48115");
189 }
190}