ocr_rs/
engine.rs

1//! OCR Engine
2//!
3//! Provides complete OCR pipeline encapsulation, performs detection and recognition in one call
4
5use image::DynamicImage;
6use std::path::Path;
7
8use crate::det::{DetModel, DetOptions};
9use crate::error::OcrResult;
10use crate::mnn::{Backend, InferenceConfig, PrecisionMode};
11use crate::postprocess::TextBox;
12use crate::rec::{RecModel, RecOptions, RecognitionResult};
13
14/// OCR result
15#[derive(Debug, Clone)]
16pub struct OcrResult_ {
17    /// Recognized text
18    pub text: String,
19    /// Confidence score
20    pub confidence: f32,
21    /// Bounding box
22    pub bbox: TextBox,
23}
24
25impl OcrResult_ {
26    /// Create a new OCR result
27    pub fn new(text: String, confidence: f32, bbox: TextBox) -> Self {
28        Self {
29            text,
30            confidence,
31            bbox,
32        }
33    }
34}
35
36/// OCR engine configuration
37#[derive(Debug, Clone)]
38pub struct OcrEngineConfig {
39    /// Inference backend
40    pub backend: Backend,
41    /// Thread count
42    pub thread_count: i32,
43    /// Precision mode
44    pub precision_mode: PrecisionMode,
45    /// Detection options
46    pub det_options: DetOptions,
47    /// Recognition options
48    pub rec_options: RecOptions,
49    /// Whether to enable parallel recognition (use rayon to process multiple text regions in parallel)
50    pub enable_parallel: bool,
51    /// Minimum confidence threshold at result level (recognition results below this value will be filtered)
52    pub min_result_confidence: f32,
53}
54
55impl Default for OcrEngineConfig {
56    fn default() -> Self {
57        Self {
58            backend: Backend::CPU,
59            thread_count: 4,
60            precision_mode: PrecisionMode::Normal,
61            det_options: DetOptions::default(),
62            rec_options: RecOptions::default(),
63            enable_parallel: true,
64            min_result_confidence: 0.5,
65        }
66    }
67}
68
69impl OcrEngineConfig {
70    /// Create new configuration
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    /// Set inference backend
76    pub fn with_backend(mut self, backend: Backend) -> Self {
77        self.backend = backend;
78        self
79    }
80
81    /// Set thread count
82    pub fn with_threads(mut self, threads: i32) -> Self {
83        self.thread_count = threads;
84        self
85    }
86
87    /// Set precision mode
88    pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
89        self.precision_mode = precision;
90        self
91    }
92
93    /// Set detection options
94    pub fn with_det_options(mut self, options: DetOptions) -> Self {
95        self.det_options = options;
96        self
97    }
98
99    /// Set recognition options
100    pub fn with_rec_options(mut self, options: RecOptions) -> Self {
101        self.rec_options = options;
102        self
103    }
104
105    /// Enable/disable parallel processing
106    ///
107    /// Note: When multiple text regions are detected, use rayon for parallel recognition.
108    /// If MNN is already set to multi-threading, enabling this option may cause thread contention.
109    pub fn with_parallel(mut self, enable: bool) -> Self {
110        self.enable_parallel = enable;
111        self
112    }
113
114    /// Set minimum confidence threshold at result level
115    ///
116    /// Recognition results below this threshold will be filtered out.
117    /// Recommended values: 0.5 (lenient), 0.7 (balanced), 0.9 (strict)
118    pub fn with_min_result_confidence(mut self, threshold: f32) -> Self {
119        self.min_result_confidence = threshold;
120        self
121    }
122
123    /// Fast mode preset
124    pub fn fast() -> Self {
125        Self {
126            precision_mode: PrecisionMode::Low,
127            det_options: DetOptions::fast(),
128            ..Default::default()
129        }
130    }
131
132    /// GPU mode preset (Metal)
133    #[cfg(any(target_os = "macos", target_os = "ios"))]
134    pub fn gpu() -> Self {
135        Self {
136            backend: Backend::Metal,
137            ..Default::default()
138        }
139    }
140
141    /// GPU mode preset (OpenCL)
142    #[cfg(not(any(target_os = "macos", target_os = "ios")))]
143    pub fn gpu() -> Self {
144        Self {
145            backend: Backend::OpenCL,
146            ..Default::default()
147        }
148    }
149
150    fn to_inference_config(&self) -> InferenceConfig {
151        InferenceConfig {
152            thread_count: self.thread_count,
153            precision_mode: self.precision_mode,
154            backend: self.backend,
155            ..Default::default()
156        }
157    }
158}
159
160/// OCR engine
161///
162/// Encapsulates complete OCR pipeline, including text detection and recognition
163///
164/// # Example
165///
166/// ```ignore
167/// use ocr_rs::{OcrEngine, OcrEngineConfig};
168///
169/// // Create engine
170/// let engine = OcrEngine::new(
171///     "det_model.mnn",
172///     "rec_model.mnn",
173///     "ppocr_keys.txt",
174///     None,
175/// )?;
176///
177/// // Recognize image
178/// let image = image::open("test.jpg")?;
179/// let results = engine.recognize(&image)?;
180///
181/// for result in results {
182///     println!("{}: {:.2}", result.text, result.confidence);
183/// }
184/// ```
185pub struct OcrEngine {
186    det_model: DetModel,
187    rec_model: RecModel,
188    config: OcrEngineConfig,
189}
190
191impl OcrEngine {
192    /// Create OCR engine from model files
193    ///
194    /// # Parameters
195    /// - `det_model_path`: Detection model file path
196    /// - `rec_model_path`: Recognition model file path
197    /// - `charset_path`: Charset file path
198    /// - `config`: Optional engine configuration
199    pub fn new(
200        det_model_path: impl AsRef<Path>,
201        rec_model_path: impl AsRef<Path>,
202        charset_path: impl AsRef<Path>,
203        config: Option<OcrEngineConfig>,
204    ) -> OcrResult<Self> {
205        let config = config.unwrap_or_default();
206        let inference_config = config.to_inference_config();
207
208        // Optimization: Directly move the configuration to avoid multiple clones
209        let det_options = config.det_options.clone();
210        let rec_options = config.rec_options.clone();
211
212        let det_model = DetModel::from_file(det_model_path, Some(inference_config.clone()))?
213            .with_options(det_options);
214
215        let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
216            .with_options(rec_options);
217
218        Ok(Self {
219            det_model,
220            rec_model,
221            config,
222        })
223    }
224
225    /// Create OCR engine from model bytes
226    pub fn from_bytes(
227        det_model_bytes: &[u8],
228        rec_model_bytes: &[u8],
229        charset_bytes: &[u8],
230        config: Option<OcrEngineConfig>,
231    ) -> OcrResult<Self> {
232        let config = config.unwrap_or_default();
233        let inference_config = config.to_inference_config();
234
235        // Optimization: Directly move the configuration to avoid multiple clones
236        let det_options = config.det_options.clone();
237        let rec_options = config.rec_options.clone();
238
239        let det_model = DetModel::from_bytes(det_model_bytes, Some(inference_config.clone()))?
240            .with_options(det_options);
241
242        let rec_model = RecModel::from_bytes_with_charset(
243            rec_model_bytes,
244            charset_bytes,
245            Some(inference_config),
246        )?
247        .with_options(rec_options);
248
249        Ok(Self {
250            det_model,
251            rec_model,
252            config,
253        })
254    }
255
256    /// Create detection-only engine
257    pub fn det_only(
258        det_model_path: impl AsRef<Path>,
259        config: Option<OcrEngineConfig>,
260    ) -> OcrResult<DetOnlyEngine> {
261        let config = config.unwrap_or_default();
262        let inference_config = config.to_inference_config();
263
264        let det_model = DetModel::from_file(det_model_path, Some(inference_config))?
265            .with_options(config.det_options);
266
267        Ok(DetOnlyEngine { det_model })
268    }
269
270    /// Create recognition-only engine
271    pub fn rec_only(
272        rec_model_path: impl AsRef<Path>,
273        charset_path: impl AsRef<Path>,
274        config: Option<OcrEngineConfig>,
275    ) -> OcrResult<RecOnlyEngine> {
276        let config = config.unwrap_or_default();
277        let inference_config = config.to_inference_config();
278
279        let rec_model = RecModel::from_file(rec_model_path, charset_path, Some(inference_config))?
280            .with_options(config.rec_options);
281
282        Ok(RecOnlyEngine { rec_model })
283    }
284
285    /// Perform complete OCR recognition
286    ///
287    /// # Parameters
288    /// - `image`: Input image
289    ///
290    /// # Returns
291    /// List of OCR results, each result contains text, confidence and bounding box
292    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<Vec<OcrResult_>> {
293        // 1. Detect text regions
294        let detections = self.det_model.detect_and_crop(image)?;
295
296        if detections.is_empty() {
297            return Ok(Vec::new());
298        }
299
300        // 2. Batch recognition (avoid cloning)
301        let (images, boxes): (Vec<&DynamicImage>, Vec<TextBox>) = detections
302            .iter()
303            .map(|(img, bbox)| (img, bbox.clone()))
304            .unzip();
305
306        let rec_results = if self.config.enable_parallel && images.len() > 4 {
307            // Parallel recognition: for multiple text regions, use rayon for parallel processing
308            use rayon::prelude::*;
309            images
310                .par_iter()
311                .map(|img| self.rec_model.recognize(img))
312                .collect::<OcrResult<Vec<_>>>()?
313        } else {
314            // Sequential recognition: use batch inference
315            self.rec_model.recognize_batch_ref(&images)?
316        };
317
318        // 3. Combine results and filter low confidence
319        let results: Vec<OcrResult_> = rec_results
320            .into_iter()
321            .zip(boxes)
322            .filter(|(rec, _)| {
323                !rec.text.is_empty() && rec.confidence >= self.config.min_result_confidence
324            })
325            .map(|(rec, bbox)| OcrResult_::new(rec.text, rec.confidence, bbox))
326            .collect();
327
328        Ok(results)
329    }
330
331    /// Perform detection only
332    pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
333        self.det_model.detect(image)
334    }
335
336    /// Perform recognition only (requires pre-cropped text line images)
337    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
338        self.rec_model.recognize(image)
339    }
340
341    /// Batch recognize text line images
342    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
343        self.rec_model.recognize_batch(images)
344    }
345
346    /// Get detection model reference
347    pub fn det_model(&self) -> &DetModel {
348        &self.det_model
349    }
350
351    /// Get recognition model reference
352    pub fn rec_model(&self) -> &RecModel {
353        &self.rec_model
354    }
355
356    /// Get configuration
357    pub fn config(&self) -> &OcrEngineConfig {
358        &self.config
359    }
360}
361
362/// Detection-only engine
363pub struct DetOnlyEngine {
364    det_model: DetModel,
365}
366
367impl DetOnlyEngine {
368    /// Detect text regions in image
369    pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
370        self.det_model.detect(image)
371    }
372
373    /// Detect and return cropped images
374    pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
375        self.det_model.detect_and_crop(image)
376    }
377
378    /// Get detection model reference
379    pub fn model(&self) -> &DetModel {
380        &self.det_model
381    }
382}
383
384/// Recognition-only engine
385pub struct RecOnlyEngine {
386    rec_model: RecModel,
387}
388
389impl RecOnlyEngine {
390    /// Recognize a single image
391    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
392        self.rec_model.recognize(image)
393    }
394
395    /// Return text only
396    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
397        self.rec_model.recognize_text(image)
398    }
399
400    /// Batch recognition
401    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
402        self.rec_model.recognize_batch(images)
403    }
404
405    /// Get recognition model reference
406    pub fn model(&self) -> &RecModel {
407        &self.rec_model
408    }
409}
410
411/// Convenience function: recognize from file
412///
413/// # Example
414///
415/// ```ignore
416/// let results = ocr_rs::ocr_file(
417///     "test.jpg",
418///     "det_model.mnn",
419///     "rec_model.mnn",
420///     "ppocr_keys.txt",
421/// )?;
422/// ```
423pub fn ocr_file(
424    image_path: impl AsRef<Path>,
425    det_model_path: impl AsRef<Path>,
426    rec_model_path: impl AsRef<Path>,
427    charset_path: impl AsRef<Path>,
428) -> OcrResult<Vec<OcrResult_>> {
429    let image = image::open(image_path)?;
430    let engine = OcrEngine::new(det_model_path, rec_model_path, charset_path, None)?;
431    engine.recognize(&image)
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_engine_config() {
440        let config = OcrEngineConfig::default();
441        assert_eq!(config.thread_count, 4);
442        assert_eq!(config.backend, Backend::CPU);
443
444        let config = OcrEngineConfig::fast();
445        assert_eq!(config.precision_mode, PrecisionMode::Low);
446    }
447
448    #[test]
449    fn test_ocr_result() {
450        let bbox = TextBox::new(imageproc::rect::Rect::at(0, 0).of_size(100, 20), 0.9);
451        let result = OcrResult_::new("Hello".to_string(), 0.95, bbox);
452
453        assert_eq!(result.text, "Hello");
454        assert_eq!(result.confidence, 0.95);
455    }
456}