1use image::DynamicImage;
6use std::path::Path;
7
8use crate::det::{DetModel, DetOptions};
9use crate::error::OcrResult;
10use crate::mnn::{Backend, InferenceConfig, PrecisionMode};
11use crate::postprocess::TextBox;
12use crate::rec::{RecModel, RecOptions, RecognitionResult};
13
14#[derive(Debug, Clone)]
16pub struct OcrResult_ {
17 pub text: String,
19 pub confidence: f32,
21 pub bbox: TextBox,
23}
24
25impl OcrResult_ {
26 pub fn new(text: String, confidence: f32, bbox: TextBox) -> Self {
28 Self {
29 text,
30 confidence,
31 bbox,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct OcrEngineConfig {
39 pub backend: Backend,
41 pub thread_count: i32,
43 pub precision_mode: PrecisionMode,
45 pub det_options: DetOptions,
47 pub rec_options: RecOptions,
49 pub enable_parallel: bool,
51 pub min_result_confidence: f32,
53}
54
55impl Default for OcrEngineConfig {
56 fn default() -> Self {
57 Self {
58 backend: Backend::CPU,
59 thread_count: 4,
60 precision_mode: PrecisionMode::Normal,
61 det_options: DetOptions::default(),
62 rec_options: RecOptions::default(),
63 enable_parallel: true,
64 min_result_confidence: 0.5,
65 }
66 }
67}
68
69impl OcrEngineConfig {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_backend(mut self, backend: Backend) -> Self {
77 self.backend = backend;
78 self
79 }
80
81 pub fn with_threads(mut self, threads: i32) -> Self {
83 self.thread_count = threads;
84 self
85 }
86
87 pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
89 self.precision_mode = precision;
90 self
91 }
92
93 pub fn with_det_options(mut self, options: DetOptions) -> Self {
95 self.det_options = options;
96 self
97 }
98
99 pub fn with_rec_options(mut self, options: RecOptions) -> Self {
101 self.rec_options = options;
102 self
103 }
104
105 pub fn with_parallel(mut self, enable: bool) -> Self {
110 self.enable_parallel = enable;
111 self
112 }
113
114 pub fn with_min_result_confidence(mut self, threshold: f32) -> Self {
119 self.min_result_confidence = threshold;
120 self
121 }
122
123 pub fn fast() -> Self {
125 Self {
126 precision_mode: PrecisionMode::Low,
127 det_options: DetOptions::fast(),
128 ..Default::default()
129 }
130 }
131
132 #[cfg(any(target_os = "macos", target_os = "ios"))]
134 pub fn gpu() -> Self {
135 Self {
136 backend: Backend::Metal,
137 ..Default::default()
138 }
139 }
140
141 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
143 pub fn gpu() -> Self {
144 Self {
145 backend: Backend::OpenCL,
146 ..Default::default()
147 }
148 }
149
150 fn to_inference_config(&self) -> InferenceConfig {
151 InferenceConfig {
152 thread_count: self.thread_count,
153 precision_mode: self.precision_mode,
154 backend: self.backend,
155 ..Default::default()
156 }
157 }
158}
159
160pub struct OcrEngine {
186 det_model: DetModel,
187 rec_model: RecModel,
188 config: OcrEngineConfig,
189}
190
191impl OcrEngine {
192 pub fn new(
200 det_model_path: impl AsRef<Path>,
201 rec_model_path: impl AsRef<Path>,
202 charset_path: impl AsRef<Path>,
203 config: Option<OcrEngineConfig>,
204 ) -> OcrResult<Self> {
205 let config = config.unwrap_or_default();
206 let inference_config = config.to_inference_config();
207
208 let det_options = config.det_options.clone();
210 let rec_options = config.rec_options.clone();
211
212 let det_model = DetModel::from_file(det_model_path, Some(inference_config.clone()))?
213 .with_options(det_options);
214
215 let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
216 .with_options(rec_options);
217
218 Ok(Self {
219 det_model,
220 rec_model,
221 config,
222 })
223 }
224
225 pub fn from_bytes(
227 det_model_bytes: &[u8],
228 rec_model_bytes: &[u8],
229 charset_bytes: &[u8],
230 config: Option<OcrEngineConfig>,
231 ) -> OcrResult<Self> {
232 let config = config.unwrap_or_default();
233 let inference_config = config.to_inference_config();
234
235 let det_options = config.det_options.clone();
237 let rec_options = config.rec_options.clone();
238
239 let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
240 .with_options(det_options);
241
242 let rec_model = RecModel::from_bytes_with_charset(
243 rec_model_bytes,
244 charset_bytes,
245 Some(inference_config),
246 )?
247 .with_options(rec_options);
248
249 Ok(Self {
250 det_model,
251 rec_model,
252 config,
253 })
254 }
255
256 pub fn det_only(
258 det_model_path: impl AsRef<Path>,
259 config: Option<OcrEngineConfig>,
260 ) -> OcrResult<DetOnlyEngine> {
261 let config = config.unwrap_or_default();
262 let inference_config = config.to_inference_config();
263
264 let det_model = DetModel::from_file(det_model_path, Some(inference_config))?
265 .with_options(config.det_options);
266
267 Ok(DetOnlyEngine { det_model })
268 }
269
270 pub fn rec_only(
272 rec_model_path: impl AsRef<Path>,
273 charset_path: impl AsRef<Path>,
274 config: Option<OcrEngineConfig>,
275 ) -> OcrResult<RecOnlyEngine> {
276 let config = config.unwrap_or_default();
277 let inference_config = config.to_inference_config();
278
279 let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
280 .with_options(config.rec_options);
281
282 Ok(RecOnlyEngine { rec_model })
283 }
284
285 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<Vec<OcrResult_>> {
293 let detections = self.det_model.detect_and_crop(image)?;
295
296 if detections.is_empty() {
297 return Ok(Vec::new());
298 }
299
300 let (images, boxes): (Vec<&DynamicImage>, Vec<TextBox>) = detections
302 .iter()
303 .map(|(img, bbox)| (img, bbox.clone()))
304 .unzip();
305
306 let rec_results = if self.config.enable_parallel && images.len() > 4 {
307 use rayon::prelude::*;
309 images
310 .par_iter()
311 .map(|img| self.rec_model.recognize(img))
312 .collect::<OcrResult<Vec<_>>>()?
313 } else {
314 self.rec_model.recognize_batch_ref(&images)?
316 };
317
318 let results: Vec<OcrResult_> = rec_results
320 .into_iter()
321 .zip(boxes)
322 .filter(|(rec, _)| {
323 !rec.text.is_empty() && rec.confidence >= self.config.min_result_confidence
324 })
325 .map(|(rec, bbox)| OcrResult_::new(rec.text, rec.confidence, bbox))
326 .collect();
327
328 Ok(results)
329 }
330
331 pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
333 self.det_model.detect(image)
334 }
335
336 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
338 self.rec_model.recognize(image)
339 }
340
341 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
343 self.rec_model.recognize_batch(images)
344 }
345
346 pub fn det_model(&self) -> &DetModel {
348 &self.det_model
349 }
350
351 pub fn rec_model(&self) -> &RecModel {
353 &self.rec_model
354 }
355
356 pub fn config(&self) -> &OcrEngineConfig {
358 &self.config
359 }
360}
361
362pub struct DetOnlyEngine {
364 det_model: DetModel,
365}
366
367impl DetOnlyEngine {
368 pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
370 self.det_model.detect(image)
371 }
372
373 pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
375 self.det_model.detect_and_crop(image)
376 }
377
378 pub fn model(&self) -> &DetModel {
380 &self.det_model
381 }
382}
383
384pub struct RecOnlyEngine {
386 rec_model: RecModel,
387}
388
389impl RecOnlyEngine {
390 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
392 self.rec_model.recognize(image)
393 }
394
395 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
397 self.rec_model.recognize_text(image)
398 }
399
400 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
402 self.rec_model.recognize_batch(images)
403 }
404
405 pub fn model(&self) -> &RecModel {
407 &self.rec_model
408 }
409}
410
411pub fn ocr_file(
424 image_path: impl AsRef<Path>,
425 det_model_path: impl AsRef<Path>,
426 rec_model_path: impl AsRef<Path>,
427 charset_path: impl AsRef<Path>,
428) -> OcrResult<Vec<OcrResult_>> {
429 let image = image::open(image_path)?;
430 let engine = OcrEngine::new(det_model_path, rec_model_path, charset_path, None)?;
431 engine.recognize(&image)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_engine_config() {
440 let config = OcrEngineConfig::default();
441 assert_eq!(config.thread_count, 4);
442 assert_eq!(config.backend, Backend::CPU);
443
444 let config = OcrEngineConfig::fast();
445 assert_eq!(config.precision_mode, PrecisionMode::Low);
446 }
447
448 #[test]
449 fn test_ocr_result() {
450 let bbox = TextBox::new(imageproc::rect::Rect::at(0, 0).of_size(100, 20), 0.9);
451 let result = OcrResult_::new("Hello".to_string(), 0.95, bbox);
452
453 assert_eq!(result.text, "Hello");
454 assert_eq!(result.confidence, 0.95);
455 }
456}