Skip to main content

ocr_rs/
engine.rs

1//! OCR Engine
2//!
3//! Provides complete OCR pipeline encapsulation, performs detection and recognition in one call
4
5use image::DynamicImage;
6use std::path::{Path, PathBuf};
7
8use crate::det::{DetModel, DetOptions};
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{Backend, InferenceConfig, PrecisionMode};
11use crate::ori::{OriModel, OriOptions};
12use crate::postprocess::TextBox;
13use crate::rec::{RecModel, RecOptions, RecognitionResult};
14
15/// OCR result
16#[derive(Debug, Clone)]
17pub struct OcrResult_ {
18    /// Recognized text
19    pub text: String,
20    /// Confidence score
21    pub confidence: f32,
22    /// Bounding box
23    pub bbox: TextBox,
24}
25
26impl OcrResult_ {
27    /// Create a new OCR result
28    pub fn new(text: String, confidence: f32, bbox: TextBox) -> Self {
29        Self {
30            text,
31            confidence,
32            bbox,
33        }
34    }
35}
36
37/// OCR engine configuration
38#[derive(Debug, Clone)]
39pub struct OcrEngineConfig {
40    /// Inference backend
41    pub backend: Backend,
42    /// Thread count
43    pub thread_count: i32,
44    /// Precision mode
45    pub precision_mode: PrecisionMode,
46    /// Detection options
47    pub det_options: DetOptions,
48    /// Recognition options
49    pub rec_options: RecOptions,
50    /// Orientation options (used when orientation model is enabled)
51    pub ori_options: OriOptions,
52    /// Whether to enable parallel recognition (use rayon to process multiple text regions in parallel)
53    pub enable_parallel: bool,
54    /// Minimum confidence threshold at result level (recognition results below this value will be filtered)
55    pub min_result_confidence: f32,
56    /// Minimum confidence threshold for orientation correction
57    pub ori_min_confidence: f32,
58}
59
60impl Default for OcrEngineConfig {
61    fn default() -> Self {
62        Self {
63            backend: Backend::CPU,
64            thread_count: 4,
65            precision_mode: PrecisionMode::Normal,
66            det_options: DetOptions::default(),
67            rec_options: RecOptions::default(),
68            ori_options: OriOptions::default(),
69            enable_parallel: true,
70            min_result_confidence: 0.5,
71            ori_min_confidence: 0.3,
72        }
73    }
74}
75
76impl OcrEngineConfig {
77    /// Create new configuration
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Set inference backend
83    pub fn with_backend(mut self, backend: Backend) -> Self {
84        self.backend = backend;
85        self
86    }
87
88    /// Set thread count
89    pub fn with_threads(mut self, threads: i32) -> Self {
90        self.thread_count = threads;
91        self
92    }
93
94    /// Set precision mode
95    pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
96        self.precision_mode = precision;
97        self
98    }
99
100    /// Set detection options
101    pub fn with_det_options(mut self, options: DetOptions) -> Self {
102        self.det_options = options;
103        self
104    }
105
106    /// Set recognition options
107    pub fn with_rec_options(mut self, options: RecOptions) -> Self {
108        self.rec_options = options;
109        self
110    }
111
112    /// Set orientation options
113    pub fn with_ori_options(mut self, options: OriOptions) -> Self {
114        self.ori_options = options;
115        self
116    }
117
118    /// Enable/disable parallel processing
119    ///
120    /// Note: When multiple text regions are detected, use rayon for parallel recognition.
121    /// If MNN is already set to multi-threading, enabling this option may cause thread contention.
122    pub fn with_parallel(mut self, enable: bool) -> Self {
123        self.enable_parallel = enable;
124        self
125    }
126
127    /// Set minimum confidence threshold at result level
128    ///
129    /// Recognition results below this threshold will be filtered out.
130    /// Recommended values: 0.5 (lenient), 0.7 (balanced), 0.9 (strict)
131    pub fn with_min_result_confidence(mut self, threshold: f32) -> Self {
132        self.min_result_confidence = threshold;
133        self
134    }
135
136    /// Set minimum confidence threshold for orientation correction
137    pub fn with_ori_min_confidence(mut self, threshold: f32) -> Self {
138        self.ori_min_confidence = threshold;
139        self
140    }
141
142    /// Fast mode preset
143    pub fn fast() -> Self {
144        Self {
145            precision_mode: PrecisionMode::Low,
146            det_options: DetOptions::fast(),
147            ..Default::default()
148        }
149    }
150
151    /// GPU mode preset (Metal)
152    #[cfg(any(target_os = "macos", target_os = "ios"))]
153    pub fn gpu() -> Self {
154        Self {
155            backend: Backend::Metal,
156            ..Default::default()
157        }
158    }
159
160    /// GPU mode preset (OpenCL)
161    #[cfg(not(any(target_os = "macos", target_os = "ios")))]
162    pub fn gpu() -> Self {
163        Self {
164            backend: Backend::OpenCL,
165            ..Default::default()
166        }
167    }
168
169    fn to_inference_config(&self) -> InferenceConfig {
170        InferenceConfig {
171            thread_count: self.thread_count,
172            precision_mode: self.precision_mode,
173            backend: self.backend,
174            ..Default::default()
175        }
176    }
177}
178
179/// OCR engine
180///
181/// Encapsulates complete OCR pipeline, including text detection and recognition
182///
183/// # Example
184///
185/// ```ignore
186/// use ocr_rs::{OcrEngine, OcrEngineConfig};
187///
188/// // Create engine
189/// let engine = OcrEngine::new(
190///     "det_model.mnn",
191///     "rec_model.mnn",
192///     "ppocr_keys.txt",
193///     None,
194/// )?;
195///
196/// // Recognize image
197/// let image = image::open("test.jpg")?;
198/// let results = engine.recognize(&image)?;
199///
200/// for result in results {
201///     println!("{}: {:.2}", result.text, result.confidence);
202/// }
203/// ```
204pub struct OcrEngine {
205    det_model: DetModel,
206    rec_model: RecModel,
207    ori_model: Option<OriModel>,
208    config: OcrEngineConfig,
209}
210
211impl OcrEngine {
212    fn build_with_paths(
213        det_model_path: &Path,
214        rec_model_path: &Path,
215        charset_path: &Path,
216        ori_model_path: Option<&Path>,
217        config: Option<OcrEngineConfig>,
218    ) -> OcrResult<Self> {
219        let config = config.unwrap_or_default();
220        let inference_config = config.to_inference_config();
221
222        // Optimization: Directly move the configuration to avoid multiple clones
223        let det_options = config.det_options.clone();
224        let rec_options = config.rec_options.clone();
225        let ori_options = config.ori_options.clone();
226
227        let det_model = DetModel::from_file(det_model_path, Some(inference_config.clone()))?
228            .with_options(det_options);
229
230        let rec_model =
231            RecModel::from_file(rec_model_path, charset_path, Some(inference_config.clone()))?
232                .with_options(rec_options);
233
234        let ori_model = match ori_model_path {
235            Some(path) => {
236                Some(OriModel::from_file(path, Some(inference_config))?.with_options(ori_options))
237            }
238            None => None,
239        };
240
241        Ok(Self {
242            det_model,
243            rec_model,
244            ori_model,
245            config,
246        })
247    }
248
249    /// Create OCR engine from model files
250    ///
251    /// # Parameters
252    /// - `det_model_path`: Detection model file path
253    /// - `rec_model_path`: Recognition model file path
254    /// - `charset_path`: Charset file path
255    /// - `config`: Optional engine configuration
256    pub fn new(
257        det_model_path: impl AsRef<Path>,
258        rec_model_path: impl AsRef<Path>,
259        charset_path: impl AsRef<Path>,
260        config: Option<OcrEngineConfig>,
261    ) -> OcrResult<Self> {
262        Self::build_with_paths(
263            det_model_path.as_ref(),
264            rec_model_path.as_ref(),
265            charset_path.as_ref(),
266            None,
267            config,
268        )
269    }
270
271    /// Create OCR engine from model files with orientation model
272    pub fn new_with_ori(
273        det_model_path: impl AsRef<Path>,
274        rec_model_path: impl AsRef<Path>,
275        charset_path: impl AsRef<Path>,
276        ori_model_path: impl AsRef<Path>,
277        config: Option<OcrEngineConfig>,
278    ) -> OcrResult<Self> {
279        Self::build_with_paths(
280            det_model_path.as_ref(),
281            rec_model_path.as_ref(),
282            charset_path.as_ref(),
283            Some(ori_model_path.as_ref()),
284            config,
285        )
286    }
287
288    /// Create OCR engine from model bytes
289    pub fn from_bytes(
290        det_model_bytes: &[u8],
291        rec_model_bytes: &[u8],
292        charset_bytes: &[u8],
293        config: Option<OcrEngineConfig>,
294    ) -> OcrResult<Self> {
295        let config = config.unwrap_or_default();
296        let inference_config = config.to_inference_config();
297
298        // Optimization: Directly move the configuration to avoid multiple clones
299        let det_options = config.det_options.clone();
300        let rec_options = config.rec_options.clone();
301
302        let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
303            .with_options(det_options);
304
305        let rec_model = RecModel::from_bytes_with_charset(
306            rec_model_bytes,
307            charset_bytes,
308            Some(inference_config.clone()),
309        )?
310        .with_options(rec_options);
311
312        Ok(Self {
313            det_model,
314            rec_model,
315            ori_model: None,
316            config,
317        })
318    }
319
320    /// Create OCR engine from model bytes with orientation model
321    pub fn from_bytes_with_ori(
322        det_model_bytes: &[u8],
323        rec_model_bytes: &[u8],
324        charset_bytes: &[u8],
325        ori_model_bytes: &[u8],
326        config: Option<OcrEngineConfig>,
327    ) -> OcrResult<Self> {
328        let config = config.unwrap_or_default();
329        let inference_config = config.to_inference_config();
330
331        let det_options = config.det_options.clone();
332        let rec_options = config.rec_options.clone();
333        let ori_options = config.ori_options.clone();
334
335        let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
336            .with_options(det_options);
337
338        let rec_model = RecModel::from_bytes_with_charset(
339            rec_model_bytes,
340            charset_bytes,
341            Some(inference_config.clone()),
342        )?
343        .with_options(rec_options);
344
345        let ori_model = OriModel::from_bytes(ori_model_bytes, Some(inference_config))?
346            .with_options(ori_options);
347
348        Ok(Self {
349            det_model,
350            rec_model,
351            ori_model: Some(ori_model),
352            config,
353        })
354    }
355
356    /// Create detection-only engine
357    pub fn det_only(
358        det_model_path: impl AsRef<Path>,
359        config: Option<OcrEngineConfig>,
360    ) -> OcrResult<DetOnlyEngine> {
361        let config = config.unwrap_or_default();
362        let inference_config = config.to_inference_config();
363
364        let det_model = DetModel::from_file(det_model_path, Some(inference_config))?
365            .with_options(config.det_options);
366
367        Ok(DetOnlyEngine { det_model })
368    }
369
370    /// Create recognition-only engine
371    pub fn rec_only(
372        rec_model_path: impl AsRef<Path>,
373        charset_path: impl AsRef<Path>,
374        config: Option<OcrEngineConfig>,
375    ) -> OcrResult<RecOnlyEngine> {
376        let config = config.unwrap_or_default();
377        let inference_config = config.to_inference_config();
378
379        let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
380            .with_options(config.rec_options);
381
382        Ok(RecOnlyEngine { rec_model })
383    }
384
385    /// Perform complete OCR recognition
386    ///
387    /// # Parameters
388    /// - `image`: Input image
389    ///
390    /// # Returns
391    /// List of OCR results, each result contains text, confidence and bounding box
392    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<Vec<OcrResult_>> {
393        // 0. Orientation correction for full image (optional)
394        let corrected_image = if let Some(ori_model) = self.ori_model.as_ref() {
395            self.correct_orientation_with_model(ori_model, image.clone())
396        } else {
397            image.clone()
398        };
399
400        // 1. Detect text regions
401        let detections = self.det_model.detect_and_crop(&corrected_image)?;
402
403        if detections.is_empty() {
404            return Ok(Vec::new());
405        }
406
407        // 2. Batch recognition
408        let (mut images, boxes): (Vec<DynamicImage>, Vec<TextBox>) = detections.into_iter().unzip();
409
410        let rec_results = if self.config.enable_parallel && images.len() > 4 {
411            // Parallel recognition: for multiple text regions, use rayon for parallel processing
412            use rayon::prelude::*;
413            images
414                .par_iter()
415                .map(|img| self.rec_model.recognize(img))
416                .collect::<OcrResult<Vec<_>>>()?
417        } else {
418            // Sequential recognition: use batch inference
419            self.rec_model.recognize_batch(&images)?
420        };
421
422        // 3. Combine results and filter low confidence
423        let results: Vec<OcrResult_> = rec_results
424            .into_iter()
425            .zip(boxes)
426            .filter(|(rec, _)| {
427                !rec.text.is_empty() && rec.confidence >= self.config.min_result_confidence
428            })
429            .map(|(rec, bbox)| OcrResult_::new(rec.text, rec.confidence, bbox))
430            .collect();
431
432        Ok(results)
433    }
434
435    /// Perform detection only
436    pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
437        self.det_model.detect(image)
438    }
439
440    /// Perform recognition only (requires pre-cropped text line images)
441    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
442        self.rec_model.recognize(image)
443    }
444
445    /// Batch recognize text line images
446    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
447        self.rec_model.recognize_batch(images)
448    }
449
450    /// Get orientation model reference (if enabled)
451    pub fn ori_model(&self) -> Option<&OriModel> {
452        self.ori_model.as_ref()
453    }
454
455    /// Get detection model reference
456    pub fn det_model(&self) -> &DetModel {
457        &self.det_model
458    }
459
460    /// Get recognition model reference
461    pub fn rec_model(&self) -> &RecModel {
462        &self.rec_model
463    }
464
465    /// Get configuration
466    pub fn config(&self) -> &OcrEngineConfig {
467        &self.config
468    }
469
470    fn correct_orientation_with_model(
471        &self,
472        ori_model: &OriModel,
473        image: DynamicImage,
474    ) -> DynamicImage {
475        let result = match ori_model.classify(&image) {
476            Ok(result) => result,
477            Err(_) => return image,
478        };
479
480        if !result.is_valid(self.config.ori_min_confidence) {
481            return image;
482        }
483
484        if result.angle.rem_euclid(360) == 0 {
485            return image;
486        }
487
488        rotate_by_angle(&image, result.angle)
489    }
490}
491
492/// Builder for OCR engine
493pub struct OcrEngineBuilder {
494    det_model_path: Option<PathBuf>,
495    rec_model_path: Option<PathBuf>,
496    charset_path: Option<PathBuf>,
497    ori_model_path: Option<PathBuf>,
498    config: Option<OcrEngineConfig>,
499}
500
501impl OcrEngineBuilder {
502    /// Create a new builder
503    pub fn new() -> Self {
504        Self {
505            det_model_path: None,
506            rec_model_path: None,
507            charset_path: None,
508            ori_model_path: None,
509            config: None,
510        }
511    }
512
513    /// Set detection model path
514    pub fn with_det_model_path(mut self, path: impl AsRef<Path>) -> Self {
515        self.det_model_path = Some(path.as_ref().to_path_buf());
516        self
517    }
518
519    /// Set recognition model path
520    pub fn with_rec_model_path(mut self, path: impl AsRef<Path>) -> Self {
521        self.rec_model_path = Some(path.as_ref().to_path_buf());
522        self
523    }
524
525    /// Set charset path
526    pub fn with_charset_path(mut self, path: impl AsRef<Path>) -> Self {
527        self.charset_path = Some(path.as_ref().to_path_buf());
528        self
529    }
530
531    /// Set orientation model path
532    pub fn with_ori_model_path(mut self, path: impl AsRef<Path>) -> Self {
533        self.ori_model_path = Some(path.as_ref().to_path_buf());
534        self
535    }
536
537    /// Set engine configuration
538    pub fn with_config(mut self, config: OcrEngineConfig) -> Self {
539        self.config = Some(config);
540        self
541    }
542
543    /// Build OCR engine
544    pub fn build(self) -> OcrResult<OcrEngine> {
545        let det_model_path = self
546            .det_model_path
547            .ok_or_else(|| OcrError::InvalidParameter("Missing det_model_path".to_string()))?;
548        let rec_model_path = self
549            .rec_model_path
550            .ok_or_else(|| OcrError::InvalidParameter("Missing rec_model_path".to_string()))?;
551        let charset_path = self
552            .charset_path
553            .ok_or_else(|| OcrError::InvalidParameter("Missing charset_path".to_string()))?;
554
555        OcrEngine::build_with_paths(
556            det_model_path.as_path(),
557            rec_model_path.as_path(),
558            charset_path.as_path(),
559            self.ori_model_path.as_deref(),
560            self.config,
561        )
562    }
563}
564
565/// Detection-only engine
566pub struct DetOnlyEngine {
567    det_model: DetModel,
568}
569
570impl DetOnlyEngine {
571    /// Detect text regions in image
572    pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
573        self.det_model.detect(image)
574    }
575
576    /// Detect and return cropped images
577    pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
578        self.det_model.detect_and_crop(image)
579    }
580
581    /// Get detection model reference
582    pub fn model(&self) -> &DetModel {
583        &self.det_model
584    }
585}
586
587/// Recognition-only engine
588pub struct RecOnlyEngine {
589    rec_model: RecModel,
590}
591
592impl RecOnlyEngine {
593    /// Recognize a single image
594    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
595        self.rec_model.recognize(image)
596    }
597
598    /// Return text only
599    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
600        self.rec_model.recognize_text(image)
601    }
602
603    /// Batch recognition
604    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
605        self.rec_model.recognize_batch(images)
606    }
607
608    /// Get recognition model reference
609    pub fn model(&self) -> &RecModel {
610        &self.rec_model
611    }
612}
613
614/// Convenience function: recognize from file
615///
616/// # Example
617///
618/// ```ignore
619/// let results = ocr_rs::ocr_file(
620///     "test.jpg",
621///     "det_model.mnn",
622///     "rec_model.mnn",
623///     "ppocr_keys.txt",
624/// )?;
625/// ```
626pub fn ocr_file(
627    image_path: impl AsRef<Path>,
628    det_model_path: impl AsRef<Path>,
629    rec_model_path: impl AsRef<Path>,
630    charset_path: impl AsRef<Path>,
631) -> OcrResult<Vec<OcrResult_>> {
632    let image = image::open(image_path)?;
633    let engine = OcrEngine::new(det_model_path, rec_model_path, charset_path, None)?;
634    engine.recognize(&image)
635}
636
637/// Convenience function: recognize from file with orientation model
638pub fn ocr_file_with_ori(
639    image_path: impl AsRef<Path>,
640    det_model_path: impl AsRef<Path>,
641    rec_model_path: impl AsRef<Path>,
642    charset_path: impl AsRef<Path>,
643    ori_model_path: impl AsRef<Path>,
644) -> OcrResult<Vec<OcrResult_>> {
645    let image = image::open(image_path)?;
646    let engine = OcrEngine::new_with_ori(
647        det_model_path,
648        rec_model_path,
649        charset_path,
650        ori_model_path,
651        None,
652    )?;
653    engine.recognize(&image)
654}
655
656fn rotate_by_angle(image: &DynamicImage, angle: i32) -> DynamicImage {
657    // The model reports rotation from horizontal; rotate back to correct.
658    match angle.rem_euclid(360) {
659        90 => DynamicImage::ImageRgb8(image::imageops::rotate270(&image.to_rgb8())),
660        180 => DynamicImage::ImageRgb8(image::imageops::rotate180(&image.to_rgb8())),
661        270 => DynamicImage::ImageRgb8(image::imageops::rotate90(&image.to_rgb8())),
662        _ => image.clone(),
663    }
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669
670    #[test]
671    fn test_ocr_result() {
672        let bbox = TextBox::new(imageproc::rect::Rect::at(0, 0).of_size(100, 20), 0.9);
673        let result = OcrResult_::new("Hello".to_string(), 0.95, bbox);
674
675        assert_eq!(result.text, "Hello");
676        assert_eq!(result.confidence, 0.95);
677    }
678}