1use image::DynamicImage;
6use std::path::{Path, PathBuf};
7
8use crate::det::{DetModel, DetOptions};
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{Backend, InferenceConfig, PrecisionMode};
11use crate::ori::{OriModel, OriOptions};
12use crate::postprocess::TextBox;
13use crate::rec::{RecModel, RecOptions, RecognitionResult};
14
15#[derive(Debug, Clone)]
17pub struct OcrResult_ {
18 pub text: String,
20 pub confidence: f32,
22 pub bbox: TextBox,
24}
25
26impl OcrResult_ {
27 pub fn new(text: String, confidence: f32, bbox: TextBox) -> Self {
29 Self {
30 text,
31 confidence,
32 bbox,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct OcrEngineConfig {
40 pub backend: Backend,
42 pub thread_count: i32,
44 pub precision_mode: PrecisionMode,
46 pub det_options: DetOptions,
48 pub rec_options: RecOptions,
50 pub ori_options: OriOptions,
52 pub enable_parallel: bool,
54 pub min_result_confidence: f32,
56 pub ori_min_confidence: f32,
58}
59
60impl Default for OcrEngineConfig {
61 fn default() -> Self {
62 Self {
63 backend: Backend::CPU,
64 thread_count: 4,
65 precision_mode: PrecisionMode::Normal,
66 det_options: DetOptions::default(),
67 rec_options: RecOptions::default(),
68 ori_options: OriOptions::default(),
69 enable_parallel: true,
70 min_result_confidence: 0.5,
71 ori_min_confidence: 0.3,
72 }
73 }
74}
75
76impl OcrEngineConfig {
77 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn with_backend(mut self, backend: Backend) -> Self {
84 self.backend = backend;
85 self
86 }
87
88 pub fn with_threads(mut self, threads: i32) -> Self {
90 self.thread_count = threads;
91 self
92 }
93
94 pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
96 self.precision_mode = precision;
97 self
98 }
99
100 pub fn with_det_options(mut self, options: DetOptions) -> Self {
102 self.det_options = options;
103 self
104 }
105
106 pub fn with_rec_options(mut self, options: RecOptions) -> Self {
108 self.rec_options = options;
109 self
110 }
111
112 pub fn with_ori_options(mut self, options: OriOptions) -> Self {
114 self.ori_options = options;
115 self
116 }
117
118 pub fn with_parallel(mut self, enable: bool) -> Self {
123 self.enable_parallel = enable;
124 self
125 }
126
127 pub fn with_min_result_confidence(mut self, threshold: f32) -> Self {
132 self.min_result_confidence = threshold;
133 self
134 }
135
136 pub fn with_ori_min_confidence(mut self, threshold: f32) -> Self {
138 self.ori_min_confidence = threshold;
139 self
140 }
141
142 pub fn fast() -> Self {
144 Self {
145 precision_mode: PrecisionMode::Low,
146 det_options: DetOptions::fast(),
147 ..Default::default()
148 }
149 }
150
151 #[cfg(any(target_os = "macos", target_os = "ios"))]
153 pub fn gpu() -> Self {
154 Self {
155 backend: Backend::Metal,
156 ..Default::default()
157 }
158 }
159
160 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
162 pub fn gpu() -> Self {
163 Self {
164 backend: Backend::OpenCL,
165 ..Default::default()
166 }
167 }
168
169 fn to_inference_config(&self) -> InferenceConfig {
170 InferenceConfig {
171 thread_count: self.thread_count,
172 precision_mode: self.precision_mode,
173 backend: self.backend,
174 ..Default::default()
175 }
176 }
177}
178
179pub struct OcrEngine {
205 det_model: DetModel,
206 rec_model: RecModel,
207 ori_model: Option<OriModel>,
208 config: OcrEngineConfig,
209}
210
211impl OcrEngine {
212 fn build_with_paths(
213 det_model_path: &Path,
214 rec_model_path: &Path,
215 charset_path: &Path,
216 ori_model_path: Option<&Path>,
217 config: Option<OcrEngineConfig>,
218 ) -> OcrResult<Self> {
219 let config = config.unwrap_or_default();
220 let inference_config = config.to_inference_config();
221
222 let det_options = config.det_options.clone();
224 let rec_options = config.rec_options.clone();
225 let ori_options = config.ori_options.clone();
226
227 let det_model = DetModel::from_file(det_model_path, Some(inference_config.clone()))?
228 .with_options(det_options);
229
230 let rec_model =
231 RecModel::from_file(rec_model_path, charset_path, Some(inference_config.clone()))?
232 .with_options(rec_options);
233
234 let ori_model = match ori_model_path {
235 Some(path) => {
236 Some(OriModel::from_file(path, Some(inference_config))?.with_options(ori_options))
237 }
238 None => None,
239 };
240
241 Ok(Self {
242 det_model,
243 rec_model,
244 ori_model,
245 config,
246 })
247 }
248
249 pub fn new(
257 det_model_path: impl AsRef<Path>,
258 rec_model_path: impl AsRef<Path>,
259 charset_path: impl AsRef<Path>,
260 config: Option<OcrEngineConfig>,
261 ) -> OcrResult<Self> {
262 Self::build_with_paths(
263 det_model_path.as_ref(),
264 rec_model_path.as_ref(),
265 charset_path.as_ref(),
266 None,
267 config,
268 )
269 }
270
271 pub fn new_with_ori(
273 det_model_path: impl AsRef<Path>,
274 rec_model_path: impl AsRef<Path>,
275 charset_path: impl AsRef<Path>,
276 ori_model_path: impl AsRef<Path>,
277 config: Option<OcrEngineConfig>,
278 ) -> OcrResult<Self> {
279 Self::build_with_paths(
280 det_model_path.as_ref(),
281 rec_model_path.as_ref(),
282 charset_path.as_ref(),
283 Some(ori_model_path.as_ref()),
284 config,
285 )
286 }
287
288 pub fn from_bytes(
290 det_model_bytes: &[u8],
291 rec_model_bytes: &[u8],
292 charset_bytes: &[u8],
293 config: Option<OcrEngineConfig>,
294 ) -> OcrResult<Self> {
295 let config = config.unwrap_or_default();
296 let inference_config = config.to_inference_config();
297
298 let det_options = config.det_options.clone();
300 let rec_options = config.rec_options.clone();
301
302 let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
303 .with_options(det_options);
304
305 let rec_model = RecModel::from_bytes_with_charset(
306 rec_model_bytes,
307 charset_bytes,
308 Some(inference_config.clone()),
309 )?
310 .with_options(rec_options);
311
312 Ok(Self {
313 det_model,
314 rec_model,
315 ori_model: None,
316 config,
317 })
318 }
319
320 pub fn from_bytes_with_ori(
322 det_model_bytes: &[u8],
323 rec_model_bytes: &[u8],
324 charset_bytes: &[u8],
325 ori_model_bytes: &[u8],
326 config: Option<OcrEngineConfig>,
327 ) -> OcrResult<Self> {
328 let config = config.unwrap_or_default();
329 let inference_config = config.to_inference_config();
330
331 let det_options = config.det_options.clone();
332 let rec_options = config.rec_options.clone();
333 let ori_options = config.ori_options.clone();
334
335 let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
336 .with_options(det_options);
337
338 let rec_model = RecModel::from_bytes_with_charset(
339 rec_model_bytes,
340 charset_bytes,
341 Some(inference_config.clone()),
342 )?
343 .with_options(rec_options);
344
345 let ori_model = OriModel::from_bytes(ori_model_bytes, Some(inference_config))?
346 .with_options(ori_options);
347
348 Ok(Self {
349 det_model,
350 rec_model,
351 ori_model: Some(ori_model),
352 config,
353 })
354 }
355
356 pub fn det_only(
358 det_model_path: impl AsRef<Path>,
359 config: Option<OcrEngineConfig>,
360 ) -> OcrResult<DetOnlyEngine> {
361 let config = config.unwrap_or_default();
362 let inference_config = config.to_inference_config();
363
364 let det_model = DetModel::from_file(det_model_path, Some(inference_config))?
365 .with_options(config.det_options);
366
367 Ok(DetOnlyEngine { det_model })
368 }
369
370 pub fn rec_only(
372 rec_model_path: impl AsRef<Path>,
373 charset_path: impl AsRef<Path>,
374 config: Option<OcrEngineConfig>,
375 ) -> OcrResult<RecOnlyEngine> {
376 let config = config.unwrap_or_default();
377 let inference_config = config.to_inference_config();
378
379 let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
380 .with_options(config.rec_options);
381
382 Ok(RecOnlyEngine { rec_model })
383 }
384
385 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<Vec<OcrResult_>> {
393 let corrected_image = if let Some(ori_model) = self.ori_model.as_ref() {
395 self.correct_orientation_with_model(ori_model, image.clone())
396 } else {
397 image.clone()
398 };
399
400 let detections = self.det_model.detect_and_crop(&corrected_image)?;
402
403 if detections.is_empty() {
404 return Ok(Vec::new());
405 }
406
407 let (mut images, boxes): (Vec<DynamicImage>, Vec<TextBox>) = detections.into_iter().unzip();
409
410 let rec_results = if self.config.enable_parallel && images.len() > 4 {
411 use rayon::prelude::*;
413 images
414 .par_iter()
415 .map(|img| self.rec_model.recognize(img))
416 .collect::<OcrResult<Vec<_>>>()?
417 } else {
418 self.rec_model.recognize_batch(&images)?
420 };
421
422 let results: Vec<OcrResult_> = rec_results
424 .into_iter()
425 .zip(boxes)
426 .filter(|(rec, _)| {
427 !rec.text.is_empty() && rec.confidence >= self.config.min_result_confidence
428 })
429 .map(|(rec, bbox)| OcrResult_::new(rec.text, rec.confidence, bbox))
430 .collect();
431
432 Ok(results)
433 }
434
435 pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
437 self.det_model.detect(image)
438 }
439
440 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
442 self.rec_model.recognize(image)
443 }
444
445 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
447 self.rec_model.recognize_batch(images)
448 }
449
450 pub fn ori_model(&self) -> Option<&OriModel> {
452 self.ori_model.as_ref()
453 }
454
455 pub fn det_model(&self) -> &DetModel {
457 &self.det_model
458 }
459
460 pub fn rec_model(&self) -> &RecModel {
462 &self.rec_model
463 }
464
465 pub fn config(&self) -> &OcrEngineConfig {
467 &self.config
468 }
469
470 fn correct_orientation_with_model(
471 &self,
472 ori_model: &OriModel,
473 image: DynamicImage,
474 ) -> DynamicImage {
475 let result = match ori_model.classify(&image) {
476 Ok(result) => result,
477 Err(_) => return image,
478 };
479
480 if !result.is_valid(self.config.ori_min_confidence) {
481 return image;
482 }
483
484 if result.angle.rem_euclid(360) == 0 {
485 return image;
486 }
487
488 rotate_by_angle(&image, result.angle)
489 }
490}
491
492pub struct OcrEngineBuilder {
494 det_model_path: Option<PathBuf>,
495 rec_model_path: Option<PathBuf>,
496 charset_path: Option<PathBuf>,
497 ori_model_path: Option<PathBuf>,
498 config: Option<OcrEngineConfig>,
499}
500
501impl OcrEngineBuilder {
502 pub fn new() -> Self {
504 Self {
505 det_model_path: None,
506 rec_model_path: None,
507 charset_path: None,
508 ori_model_path: None,
509 config: None,
510 }
511 }
512
513 pub fn with_det_model_path(mut self, path: impl AsRef<Path>) -> Self {
515 self.det_model_path = Some(path.as_ref().to_path_buf());
516 self
517 }
518
519 pub fn with_rec_model_path(mut self, path: impl AsRef<Path>) -> Self {
521 self.rec_model_path = Some(path.as_ref().to_path_buf());
522 self
523 }
524
525 pub fn with_charset_path(mut self, path: impl AsRef<Path>) -> Self {
527 self.charset_path = Some(path.as_ref().to_path_buf());
528 self
529 }
530
531 pub fn with_ori_model_path(mut self, path: impl AsRef<Path>) -> Self {
533 self.ori_model_path = Some(path.as_ref().to_path_buf());
534 self
535 }
536
537 pub fn with_config(mut self, config: OcrEngineConfig) -> Self {
539 self.config = Some(config);
540 self
541 }
542
543 pub fn build(self) -> OcrResult<OcrEngine> {
545 let det_model_path = self
546 .det_model_path
547 .ok_or_else(|| OcrError::InvalidParameter("Missing det_model_path".to_string()))?;
548 let rec_model_path = self
549 .rec_model_path
550 .ok_or_else(|| OcrError::InvalidParameter("Missing rec_model_path".to_string()))?;
551 let charset_path = self
552 .charset_path
553 .ok_or_else(|| OcrError::InvalidParameter("Missing charset_path".to_string()))?;
554
555 OcrEngine::build_with_paths(
556 det_model_path.as_path(),
557 rec_model_path.as_path(),
558 charset_path.as_path(),
559 self.ori_model_path.as_deref(),
560 self.config,
561 )
562 }
563}
564
565pub struct DetOnlyEngine {
567 det_model: DetModel,
568}
569
570impl DetOnlyEngine {
571 pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
573 self.det_model.detect(image)
574 }
575
576 pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
578 self.det_model.detect_and_crop(image)
579 }
580
581 pub fn model(&self) -> &DetModel {
583 &self.det_model
584 }
585}
586
587pub struct RecOnlyEngine {
589 rec_model: RecModel,
590}
591
592impl RecOnlyEngine {
593 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
595 self.rec_model.recognize(image)
596 }
597
598 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
600 self.rec_model.recognize_text(image)
601 }
602
603 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
605 self.rec_model.recognize_batch(images)
606 }
607
608 pub fn model(&self) -> &RecModel {
610 &self.rec_model
611 }
612}
613
614pub fn ocr_file(
627 image_path: impl AsRef<Path>,
628 det_model_path: impl AsRef<Path>,
629 rec_model_path: impl AsRef<Path>,
630 charset_path: impl AsRef<Path>,
631) -> OcrResult<Vec<OcrResult_>> {
632 let image = image::open(image_path)?;
633 let engine = OcrEngine::new(det_model_path, rec_model_path, charset_path, None)?;
634 engine.recognize(&image)
635}
636
637pub fn ocr_file_with_ori(
639 image_path: impl AsRef<Path>,
640 det_model_path: impl AsRef<Path>,
641 rec_model_path: impl AsRef<Path>,
642 charset_path: impl AsRef<Path>,
643 ori_model_path: impl AsRef<Path>,
644) -> OcrResult<Vec<OcrResult_>> {
645 let image = image::open(image_path)?;
646 let engine = OcrEngine::new_with_ori(
647 det_model_path,
648 rec_model_path,
649 charset_path,
650 ori_model_path,
651 None,
652 )?;
653 engine.recognize(&image)
654}
655
656fn rotate_by_angle(image: &DynamicImage, angle: i32) -> DynamicImage {
657 match angle.rem_euclid(360) {
659 90 => DynamicImage::ImageRgb8(image::imageops::rotate270(&image.to_rgb8())),
660 180 => DynamicImage::ImageRgb8(image::imageops::rotate180(&image.to_rgb8())),
661 270 => DynamicImage::ImageRgb8(image::imageops::rotate90(&image.to_rgb8())),
662 _ => image.clone(),
663 }
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_ocr_result() {
672 let bbox = TextBox::new(imageproc::rect::Rect::at(0, 0).of_size(100, 20), 0.9);
673 let result = OcrResult_::new("Hello".to_string(), 0.95, bbox);
674
675 assert_eq!(result.text, "Hello");
676 assert_eq!(result.confidence, 0.95);
677 }
678}