pure_onnx_ocr/
engine.rs

1use crate::ctc::DecodedSequence;
2use crate::detection::DetInferenceSession;
3use crate::dictionary::{DictionaryError, RecDictionary};
4use crate::postprocessing::{
5    DetPolygonScaler, DetPolygonScalerConfig, DetPolygonUnclipper, DetPolygonUnclipperConfig,
6    DetPostProcessor, DetPostProcessorConfig, DetPostProcessorError,
7};
8use crate::preprocessing::{
9    DetPreProcessor, DetPreProcessorConfig, DetPreProcessorError, RecPreProcessor,
10    RecPreProcessorConfig, RecPreProcessorError, RecTextRegion,
11};
12use crate::recognition::{
13    RecInferenceSession, RecPostProcessor, RecPostProcessorConfig, RecPostProcessorError,
14};
15use geo_types::Polygon;
16use image::{DynamicImage, GenericImageView, ImageError};
17use std::error::Error;
18use std::fmt;
19use std::fs;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tract_onnx::prelude::TractError;
24
25/// Errors that can occur while building or using the OCR engine.
26#[derive(Debug)]
27pub enum OcrError {
28    /// A required builder field was not provided.
29    MissingField { field: &'static str },
30    /// An IO error occurred while accessing a resource.
31    Io {
32        source: std::io::Error,
33        path: PathBuf,
34    },
35    /// Loading an ONNX model failed.
36    ModelLoad { source: TractError, path: PathBuf },
37    /// Loading the recognition dictionary failed.
38    Dictionary { source: DictionaryError },
39    /// The provided configuration contained invalid values.
40    InvalidConfiguration { message: String },
41    /// Failed to decode the input image.
42    ImageDecode { source: ImageError, path: PathBuf },
43    /// Detection preprocessing failed.
44    DetectionPreprocess { source: DetPreProcessorError },
45    /// Detection inference failed.
46    DetectionInference { source: TractError },
47    /// Detection post-processing failed.
48    DetectionPostProcess { source: DetPostProcessorError },
49    /// Recognition preprocessing failed.
50    RecognitionPreprocess { source: RecPreProcessorError },
51    /// Recognition inference failed.
52    RecognitionInference { source: TractError },
53    /// Recognition post-processing failed.
54    RecognitionPostProcess { source: RecPostProcessorError },
55    /// The number of recognition results did not match detected regions.
56    PipelineMismatch {
57        detection_regions: usize,
58        recognition_results: usize,
59    },
60}
61
62impl fmt::Display for OcrError {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        match self {
65            OcrError::MissingField { field } => {
66                write!(f, "required builder field `{}` was not provided", field)
67            }
68            OcrError::Io { path, source } => {
69                write!(f, "failed to access resource {:?}: {}", path, source)
70            }
71            OcrError::ModelLoad { path, source } => {
72                write!(f, "failed to load ONNX model {:?}: {}", path, source)
73            }
74            OcrError::Dictionary { source } => write!(f, "failed to load dictionary: {}", source),
75            OcrError::InvalidConfiguration { message } => write!(f, "{}", message),
76            OcrError::ImageDecode { path, source } => {
77                write!(f, "failed to decode image {:?}: {}", path, source)
78            }
79            OcrError::DetectionPreprocess { source } => {
80                write!(f, "detection preprocessing failed: {}", source)
81            }
82            OcrError::DetectionInference { source } => {
83                write!(f, "detection inference failed: {}", source)
84            }
85            OcrError::DetectionPostProcess { source } => {
86                write!(f, "detection post-processing failed: {}", source)
87            }
88            OcrError::RecognitionPreprocess { source } => {
89                write!(f, "recognition preprocessing failed: {}", source)
90            }
91            OcrError::RecognitionInference { source } => {
92                write!(f, "recognition inference failed: {}", source)
93            }
94            OcrError::RecognitionPostProcess { source } => {
95                write!(f, "recognition post-processing failed: {}", source)
96            }
97            OcrError::PipelineMismatch {
98                detection_regions,
99                recognition_results,
100            } => write!(
101                f,
102                "pipeline mismatch: detection produced {} regions but recognition returned {} results",
103                detection_regions, recognition_results
104            ),
105        }
106    }
107}
108
109impl From<DetPreProcessorError> for OcrError {
110    fn from(source: DetPreProcessorError) -> Self {
111        OcrError::DetectionPreprocess { source }
112    }
113}
114
115impl From<DetPostProcessorError> for OcrError {
116    fn from(source: DetPostProcessorError) -> Self {
117        OcrError::DetectionPostProcess { source }
118    }
119}
120
121impl From<RecPreProcessorError> for OcrError {
122    fn from(source: RecPreProcessorError) -> Self {
123        OcrError::RecognitionPreprocess { source }
124    }
125}
126
127impl From<RecPostProcessorError> for OcrError {
128    fn from(source: RecPostProcessorError) -> Self {
129        OcrError::RecognitionPostProcess { source }
130    }
131}
132
133fn polygons_to_text_regions(
134    polygons: &[Polygon<f64>],
135    image_dims: (u32, u32),
136) -> Vec<RecTextRegion> {
137    polygons
138        .iter()
139        .map(|polygon| polygon_to_text_region(polygon, image_dims))
140        .collect()
141}
142
143fn polygon_to_text_region(polygon: &Polygon<f64>, image_dims: (u32, u32)) -> RecTextRegion {
144    let mut min_x = f64::INFINITY;
145    let mut min_y = f64::INFINITY;
146    let mut max_x = f64::NEG_INFINITY;
147    let mut max_y = f64::NEG_INFINITY;
148
149    for point in polygon.exterior().points() {
150        let x = point.x();
151        let y = point.y();
152        if x < min_x {
153            min_x = x;
154        }
155        if x > max_x {
156            max_x = x;
157        }
158        if y < min_y {
159            min_y = y;
160        }
161        if y > max_y {
162            max_y = y;
163        }
164    }
165
166    let image_width = image_dims.0.max(1);
167    let image_height = image_dims.1.max(1);
168    let width_limit = image_width as f64;
169    let height_limit = image_height as f64;
170
171    let mut x1 = min_x.floor().max(0.0);
172    let mut y1 = min_y.floor().max(0.0);
173    let mut x2 = max_x.ceil().min(width_limit);
174    let mut y2 = max_y.ceil().min(height_limit);
175
176    if x2 <= x1 {
177        x2 = (x1 + 1.0).min(width_limit);
178    }
179    if y2 <= y1 {
180        y2 = (y1 + 1.0).min(height_limit);
181    }
182
183    if x2 <= x1 {
184        x1 = (width_limit - 1.0).max(0.0);
185        x2 = width_limit;
186    }
187    if y2 <= y1 {
188        y1 = (height_limit - 1.0).max(0.0);
189        y2 = height_limit;
190    }
191
192    let mut x = x1.floor() as u32;
193    let mut y = y1.floor() as u32;
194    if x >= image_width {
195        x = image_width - 1;
196    }
197    if y >= image_height {
198        y = image_height - 1;
199    }
200
201    let mut width = (x2 - x1).ceil().max(1.0) as u32;
202    let mut height = (y2 - y1).ceil().max(1.0) as u32;
203
204    if x + width > image_width {
205        width = image_width.saturating_sub(x);
206    }
207    if y + height > image_height {
208        height = image_height.saturating_sub(y);
209    }
210
211    if width == 0 {
212        width = 1;
213    }
214    if height == 0 {
215        height = 1;
216    }
217
218    RecTextRegion {
219        x,
220        y,
221        width,
222        height,
223    }
224}
225
226impl Error for OcrError {
227    fn source(&self) -> Option<&(dyn Error + 'static)> {
228        match self {
229            OcrError::MissingField { .. } => None,
230            OcrError::Io { source, .. } => Some(source),
231            OcrError::ModelLoad { .. } => None,
232            OcrError::Dictionary { source } => Some(source),
233            OcrError::InvalidConfiguration { .. } => None,
234            OcrError::ImageDecode { source, .. } => Some(source),
235            OcrError::DetectionPreprocess { source } => Some(source),
236            OcrError::DetectionInference { .. } => None,
237            OcrError::DetectionPostProcess { source } => Some(source),
238            OcrError::RecognitionPreprocess { source } => Some(source),
239            OcrError::RecognitionInference { .. } => None,
240            OcrError::RecognitionPostProcess { source } => Some(source),
241            OcrError::PipelineMismatch { .. } => None,
242        }
243    }
244}
245
246impl From<DictionaryError> for OcrError {
247    fn from(source: DictionaryError) -> Self {
248        Self::Dictionary { source }
249    }
250}
251
252/// Aggregated configuration used by [`OcrEngine`] during inference.
253#[derive(Debug, Clone)]
254pub struct OcrEngineConfig {
255    pub det_preprocessor: DetPreProcessorConfig,
256    pub det_postprocessor: DetPostProcessorConfig,
257    pub det_unclipper: DetPolygonUnclipperConfig,
258    pub det_polygon_scaler: DetPolygonScalerConfig,
259    pub rec_preprocessor: RecPreProcessorConfig,
260    pub rec_postprocessor: RecPostProcessorConfig,
261    pub rec_batch_size: usize,
262}
263
264impl Default for OcrEngineConfig {
265    fn default() -> Self {
266        Self {
267            det_preprocessor: DetPreProcessorConfig::default(),
268            det_postprocessor: DetPostProcessorConfig::default(),
269            det_unclipper: DetPolygonUnclipperConfig::default(),
270            det_polygon_scaler: DetPolygonScalerConfig::default(),
271            rec_preprocessor: RecPreProcessorConfig::default(),
272            rec_postprocessor: RecPostProcessorConfig::default(),
273            rec_batch_size: 8,
274        }
275    }
276}
277
278/// Fully prepared OCR engine orchestrating the detection and recognition pipelines.
279///
280/// The engine executes inference synchronously: upcoming methods such as
281/// [`OcrEngine::run_from_path`](#method.run_from_path) and
282/// [`OcrEngine::run_from_image`](#method.run_from_image) (implemented in later tasks)
283/// will block the caller until the complete pipeline finishes. Internally, every heavy-weight
284/// component (preprocessors, ONNX sessions, dictionary and post-processors) is wrapped in
285/// `Arc`, allowing callers to share a single engine instance across threads or to clone the
286/// engine for concurrent use when needed.
287#[derive(Debug)]
288pub struct OcrEngine {
289    assets: EngineAssets,
290    detection: DetectionPipeline,
291    recognition: RecognitionPipeline,
292    config: OcrEngineConfig,
293}
294
295/// Result of running the full OCR pipeline for a single detected region.
296#[derive(Debug, Clone)]
297pub struct OcrResult {
298    pub text: String,
299    pub confidence: f32,
300    pub bounding_box: Polygon<f64>,
301}
302
303#[derive(Debug, Clone)]
304pub struct StageTimings {
305    pub preprocess: Duration,
306    pub inference: Duration,
307    pub postprocess: Duration,
308}
309
310impl StageTimings {
311    fn zero() -> Self {
312        Self {
313            preprocess: Duration::ZERO,
314            inference: Duration::ZERO,
315            postprocess: Duration::ZERO,
316        }
317    }
318}
319
320#[derive(Debug, Clone)]
321pub struct OcrTimings {
322    pub total: Duration,
323    pub image_decode: Duration,
324    pub detection: StageTimings,
325    pub recognition: StageTimings,
326}
327
328impl OcrTimings {
329    fn new() -> Self {
330        Self {
331            total: Duration::ZERO,
332            image_decode: Duration::ZERO,
333            detection: StageTimings::zero(),
334            recognition: StageTimings::zero(),
335        }
336    }
337}
338
339#[derive(Debug, Clone)]
340pub struct OcrRunWithMetrics {
341    pub results: Vec<OcrResult>,
342    pub timings: OcrTimings,
343}
344
345impl OcrEngine {
346    fn new(
347        det_model_path: PathBuf,
348        rec_model_path: PathBuf,
349        dictionary_path: PathBuf,
350        det_session: DetInferenceSession,
351        rec_session: RecInferenceSession,
352        dictionary: RecDictionary,
353        config: OcrEngineConfig,
354    ) -> Self {
355        let assets = EngineAssets::new(det_model_path, rec_model_path, dictionary_path);
356
357        let det_session = Arc::new(det_session);
358        let rec_session = Arc::new(rec_session);
359        let dictionary = Arc::new(dictionary);
360
361        let detection = DetectionPipeline::new(
362            Arc::clone(&det_session),
363            config.det_preprocessor,
364            config.det_postprocessor,
365            config.det_unclipper,
366            config.det_polygon_scaler,
367        );
368
369        let recognition = RecognitionPipeline::new(
370            Arc::clone(&rec_session),
371            Arc::clone(&dictionary),
372            config.rec_preprocessor.clone(),
373            config.rec_postprocessor.clone(),
374        );
375
376        Self {
377            assets,
378            detection,
379            recognition,
380            config,
381        }
382    }
383
384    /// Executes the full OCR pipeline on an image located on disk.
385    pub fn run_from_path<P: AsRef<Path>>(&self, path: P) -> Result<Vec<OcrResult>, OcrError> {
386        let run = self.run_with_metrics_from_path(path)?;
387        Ok(run.results)
388    }
389
390    /// Executes the full OCR pipeline on an image located on disk and returns benchmarking data.
391    pub fn run_with_metrics_from_path<P: AsRef<Path>>(
392        &self,
393        path: P,
394    ) -> Result<OcrRunWithMetrics, OcrError> {
395        let overall_start = Instant::now();
396        let path_ref = path.as_ref();
397        let decode_start = Instant::now();
398        let image = image::open(path_ref).map_err(|source| OcrError::ImageDecode {
399            source,
400            path: path_ref.to_path_buf(),
401        })?;
402        let mut run = self.run_with_metrics_from_image_impl(&image)?;
403        run.timings.image_decode = decode_start.elapsed();
404        run.timings.total = overall_start.elapsed();
405        Ok(run)
406    }
407
408    /// Executes the full OCR pipeline on an image already loaded in memory.
409    /// Returns the effective configuration for this engine.
410    pub fn config(&self) -> &OcrEngineConfig {
411        &self.config
412    }
413
414    /// Returns the path used for the detection model.
415    pub fn det_model_path(&self) -> &Path {
416        self.assets.det_model_path()
417    }
418
419    /// Returns the path used for the recognition model.
420    pub fn rec_model_path(&self) -> &Path {
421        self.assets.rec_model_path()
422    }
423
424    /// Returns the path used for the recognition dictionary.
425    pub fn dictionary_path(&self) -> &Path {
426        self.assets.dictionary_path()
427    }
428
429    /// Returns the configured recognition batch size.
430    pub fn rec_batch_size(&self) -> usize {
431        self.config.rec_batch_size
432    }
433
434    pub fn run_from_image(&self, image: &DynamicImage) -> Result<Vec<OcrResult>, OcrError> {
435        let run = self.run_with_metrics_from_image_impl(image)?;
436        Ok(run.results)
437    }
438
439    pub fn run_with_metrics_from_image(
440        &self,
441        image: &DynamicImage,
442    ) -> Result<OcrRunWithMetrics, OcrError> {
443        self.run_with_metrics_from_image_impl(image)
444    }
445
446    fn run_with_metrics_from_image_impl(
447        &self,
448        image: &DynamicImage,
449    ) -> Result<OcrRunWithMetrics, OcrError> {
450        let pipeline_start = Instant::now();
451        let mut timings = OcrTimings::new();
452        let image_dims = image.dimensions();
453
454        let (polygons, detection_timings) = self
455            .detection
456            .detect_polygons_with_timings(image, image_dims)?;
457        timings.detection = detection_timings;
458
459        if polygons.is_empty() {
460            timings.total = pipeline_start.elapsed();
461            return Ok(OcrRunWithMetrics {
462                results: Vec::new(),
463                timings,
464            });
465        }
466
467        let regions = polygons_to_text_regions(&polygons, image_dims);
468        let (sequences, recognition_timings) =
469            self.recognition.run_with_timings(image, &regions)?;
470        timings.recognition = recognition_timings;
471
472        if sequences.len() != polygons.len() {
473            return Err(OcrError::PipelineMismatch {
474                detection_regions: polygons.len(),
475                recognition_results: sequences.len(),
476            });
477        }
478
479        let results: Vec<OcrResult> = polygons
480            .into_iter()
481            .zip(sequences.into_iter())
482            .map(|(polygon, sequence)| OcrResult {
483                text: sequence.text,
484                confidence: sequence.confidence,
485                bounding_box: polygon,
486            })
487            .collect();
488
489        timings.total = pipeline_start.elapsed();
490
491        Ok(OcrRunWithMetrics { results, timings })
492    }
493}
494
495/// Builder for constructing [`OcrEngine`] instances.
496#[derive(Debug, Clone)]
497pub struct OcrEngineBuilder {
498    det_model_path: Option<PathBuf>,
499    rec_model_path: Option<PathBuf>,
500    dictionary_path: Option<PathBuf>,
501    det_limit_side_len: u32,
502    det_unclip_ratio: f32,
503    rec_batch_size: usize,
504}
505
506impl Default for OcrEngineBuilder {
507    fn default() -> Self {
508        Self {
509            det_model_path: None,
510            rec_model_path: None,
511            dictionary_path: None,
512            det_limit_side_len: DetPreProcessorConfig::default().limit_side_len,
513            det_unclip_ratio: DetPolygonUnclipperConfig::default().unclip_ratio,
514            rec_batch_size: OcrEngineConfig::default().rec_batch_size,
515        }
516    }
517}
518
519impl OcrEngineBuilder {
520    /// Creates a new builder instance using default configuration values.
521    pub fn new() -> Self {
522        Self::default()
523    }
524
525    /// Sets the path to the DBNet detection ONNX model.
526    pub fn det_model_path<P: AsRef<Path>>(mut self, path: P) -> Self {
527        self.det_model_path = Some(path.as_ref().to_path_buf());
528        self
529    }
530
531    /// Sets the path to the SVTR recognition ONNX model.
532    pub fn rec_model_path<P: AsRef<Path>>(mut self, path: P) -> Self {
533        self.rec_model_path = Some(path.as_ref().to_path_buf());
534        self
535    }
536
537    /// Sets the path to the recognition dictionary file.
538    pub fn dictionary_path<P: AsRef<Path>>(mut self, path: P) -> Self {
539        self.dictionary_path = Some(path.as_ref().to_path_buf());
540        self
541    }
542
543    /// Sets the maximum side length for detection preprocessing.
544    pub fn det_limit_side_len(mut self, len: u32) -> Self {
545        self.det_limit_side_len = len;
546        self
547    }
548
549    /// Sets the unclip ratio used during polygon offsetting.
550    pub fn det_unclip_ratio(mut self, ratio: f64) -> Self {
551        self.det_unclip_ratio = ratio as f32;
552        self
553    }
554
555    /// Sets the maximum batch size for recognition.
556    pub fn rec_batch_size(mut self, size: usize) -> Self {
557        self.rec_batch_size = size;
558        self
559    }
560
561    /// Consumes the builder and attempts to construct an [`OcrEngine`].
562    pub fn build(self) -> Result<OcrEngine, OcrError> {
563        let det_model_path = self.det_model_path.ok_or(OcrError::MissingField {
564            field: "det_model_path",
565        })?;
566        let rec_model_path = self.rec_model_path.ok_or(OcrError::MissingField {
567            field: "rec_model_path",
568        })?;
569        let dictionary_path = self.dictionary_path.ok_or(OcrError::MissingField {
570            field: "dictionary_path",
571        })?;
572
573        if self.rec_batch_size == 0 {
574            return Err(OcrError::InvalidConfiguration {
575                message: "rec_batch_size must be greater than zero".to_string(),
576            });
577        }
578
579        verify_file_exists(&det_model_path)?;
580        verify_file_exists(&rec_model_path)?;
581        verify_file_exists(&dictionary_path)?;
582
583        let det_session =
584            DetInferenceSession::load(&det_model_path).map_err(|source| OcrError::ModelLoad {
585                source,
586                path: det_model_path.clone(),
587            })?;
588        let rec_session =
589            RecInferenceSession::load(&rec_model_path).map_err(|source| OcrError::ModelLoad {
590                source,
591                path: rec_model_path.clone(),
592            })?;
593        let dictionary = RecDictionary::from_path(&dictionary_path)?;
594
595        let mut det_unclipper_config = DetPolygonUnclipperConfig::default();
596        det_unclipper_config.unclip_ratio = self.det_unclip_ratio;
597
598        let mut det_preprocessor_config = DetPreProcessorConfig::default();
599        det_preprocessor_config.limit_side_len = self.det_limit_side_len;
600
601        let mut config = OcrEngineConfig::default();
602        config.det_preprocessor = det_preprocessor_config;
603        config.det_unclipper = det_unclipper_config;
604        config.rec_batch_size = self.rec_batch_size;
605        config.rec_postprocessor.blank_id = dictionary.blank_id();
606
607        Ok(OcrEngine::new(
608            det_model_path,
609            rec_model_path,
610            dictionary_path,
611            det_session,
612            rec_session,
613            dictionary,
614            config,
615        ))
616    }
617}
618
619fn verify_file_exists(path: &Path) -> Result<(), OcrError> {
620    if let Err(source) = fs::metadata(path) {
621        return Err(OcrError::Io {
622            source,
623            path: path.to_path_buf(),
624        });
625    }
626    Ok(())
627}
628
629#[derive(Debug)]
630struct EngineAssets {
631    det_model_path: PathBuf,
632    rec_model_path: PathBuf,
633    dictionary_path: PathBuf,
634}
635
636impl EngineAssets {
637    fn new(det_model_path: PathBuf, rec_model_path: PathBuf, dictionary_path: PathBuf) -> Self {
638        Self {
639            det_model_path,
640            rec_model_path,
641            dictionary_path,
642        }
643    }
644
645    fn det_model_path(&self) -> &Path {
646        self.det_model_path.as_path()
647    }
648
649    fn rec_model_path(&self) -> &Path {
650        self.rec_model_path.as_path()
651    }
652
653    fn dictionary_path(&self) -> &Path {
654        self.dictionary_path.as_path()
655    }
656}
657
658#[derive(Debug)]
659struct DetectionPipeline {
660    preprocessor: DetPreProcessor,
661    session: Arc<DetInferenceSession>,
662    postprocessor: DetPostProcessor,
663    unclipper: DetPolygonUnclipper,
664    scaler: DetPolygonScaler,
665}
666
667impl DetectionPipeline {
668    fn new(
669        session: Arc<DetInferenceSession>,
670        preprocessor: DetPreProcessorConfig,
671        postprocessor: DetPostProcessorConfig,
672        unclipper: DetPolygonUnclipperConfig,
673        scaler: DetPolygonScalerConfig,
674    ) -> Self {
675        Self {
676            preprocessor: DetPreProcessor::new(preprocessor),
677            session,
678            postprocessor: DetPostProcessor::new(postprocessor),
679            unclipper: DetPolygonUnclipper::new(unclipper),
680            scaler: DetPolygonScaler::new(scaler),
681        }
682    }
683
684    fn detect_polygons_with_timings(
685        &self,
686        image: &DynamicImage,
687        image_dims: (u32, u32),
688    ) -> Result<(Vec<Polygon<f64>>, StageTimings), OcrError> {
689        let preprocess_start = Instant::now();
690        let preprocessed = self.preprocessor.process(image).map_err(OcrError::from)?;
691        let preprocess_elapsed = preprocess_start.elapsed();
692
693        let inference_start = Instant::now();
694        let inference = self
695            .session
696            .run(&preprocessed)
697            .map_err(|source| OcrError::DetectionInference { source })?;
698        let inference_elapsed = inference_start.elapsed();
699
700        let post_start = Instant::now();
701        let contours = self
702            .postprocessor
703            .process(&inference)
704            .map_err(OcrError::from)?;
705        let unclipped = self.unclipper.unclip_contours(&contours);
706        let scaled = self
707            .scaler
708            .scale_polygons(&unclipped, preprocessed.scale_ratio, image_dims);
709        let post_elapsed = post_start.elapsed();
710
711        let timings = StageTimings {
712            preprocess: preprocess_elapsed,
713            inference: inference_elapsed,
714            postprocess: post_elapsed,
715        };
716
717        Ok((scaled, timings))
718    }
719}
720
721#[derive(Debug)]
722struct RecognitionPipeline {
723    preprocessor: RecPreProcessor,
724    session: Arc<RecInferenceSession>,
725    postprocessor: RecPostProcessor,
726}
727
728impl RecognitionPipeline {
729    fn new(
730        session: Arc<RecInferenceSession>,
731        dictionary: Arc<RecDictionary>,
732        preprocessor: RecPreProcessorConfig,
733        postprocessor: RecPostProcessorConfig,
734    ) -> Self {
735        let postprocessor = RecPostProcessor::new(Arc::clone(&dictionary), postprocessor);
736
737        Self {
738            preprocessor: RecPreProcessor::new(preprocessor),
739            session,
740            postprocessor,
741        }
742    }
743
744    fn run_with_timings(
745        &self,
746        image: &DynamicImage,
747        regions: &[RecTextRegion],
748    ) -> Result<(Vec<DecodedSequence>, StageTimings), OcrError> {
749        let preprocess_start = Instant::now();
750        let batch = self
751            .preprocessor
752            .process(image, regions)
753            .map_err(OcrError::from)?;
754        let preprocess_elapsed = preprocess_start.elapsed();
755
756        let inference_start = Instant::now();
757        let inference = self
758            .session
759            .run(&batch)
760            .map_err(|source| OcrError::RecognitionInference { source })?;
761        let inference_elapsed = inference_start.elapsed();
762
763        let post_start = Instant::now();
764        let sequences = self
765            .postprocessor
766            .process(&inference)
767            .map_err(OcrError::from)?;
768        let post_elapsed = post_start.elapsed();
769
770        let timings = StageTimings {
771            preprocess: preprocess_elapsed,
772            inference: inference_elapsed,
773            postprocess: post_elapsed,
774        };
775
776        Ok((sequences, timings))
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783    use crate::ctc::CtcGreedyDecoderError;
784    use crate::dictionary::RecDictionary;
785    use crate::postprocessing::DetPostProcessorError;
786    use crate::preprocessing::{DetPreProcessorError, RecPreProcessorError};
787    use crate::recognition::RecPostProcessorError;
788    use std::env;
789    use std::path::Path;
790    use std::time::{SystemTime, UNIX_EPOCH};
791
792    fn locate_ppocrv5_asset(file_name: &str) -> Option<PathBuf> {
793        let mut bases: Vec<PathBuf> = Vec::new();
794        if let Some(dir) = env::var_os("PURE_ONNX_OCR_FIXTURE_DIR") {
795            let env_path = PathBuf::from(dir);
796            bases.push(env_path.clone());
797            bases.push(env_path.join("models"));
798        }
799
800        let manifest = Path::new(env!("CARGO_MANIFEST_DIR"));
801        bases.push(manifest.join("tests").join("fixtures").join("models"));
802        bases.push(manifest.join("tests").join("fixtures"));
803        bases.push(manifest.join("models"));
804
805        for base in bases {
806            let ppocr_dir = base.join("ppocrv5");
807            let candidate = ppocr_dir.join(file_name);
808            if candidate.exists() {
809                return Some(candidate);
810            }
811
812            let alt = base.join(file_name);
813            if alt.exists() {
814                return Some(alt);
815            }
816        }
817
818        None
819    }
820
821    fn existing_model_paths() -> Option<(PathBuf, PathBuf, PathBuf)> {
822        let det = locate_ppocrv5_asset("det.onnx")?;
823        let rec = locate_ppocrv5_asset("rec.onnx")?;
824        let dict = locate_ppocrv5_asset("ppocrv5_dict.txt")?;
825        Some((det, rec, dict))
826    }
827
828    fn temp_image_path(prefix: &str) -> PathBuf {
829        let timestamp = SystemTime::now()
830            .duration_since(UNIX_EPOCH)
831            .unwrap()
832            .as_nanos();
833        std::env::temp_dir().join(format!("{}_{}.png", prefix, timestamp))
834    }
835
836    #[test]
837    fn missing_det_model_path_returns_error() {
838        let err = OcrEngineBuilder::new()
839            .rec_model_path("rec.onnx")
840            .dictionary_path("dict.txt")
841            .build()
842            .unwrap_err();
843
844        match err {
845            OcrError::MissingField { field } => assert_eq!(field, "det_model_path"),
846            other => panic!("expected MissingField error, got {:?}", other),
847        }
848    }
849
850    #[test]
851    fn missing_dictionary_path_returns_error() {
852        let err = OcrEngineBuilder::new()
853            .det_model_path("det.onnx")
854            .rec_model_path("rec.onnx")
855            .build()
856            .unwrap_err();
857
858        match err {
859            OcrError::MissingField { field } => assert_eq!(field, "dictionary_path"),
860            other => panic!("expected MissingField error, got {:?}", other),
861        }
862    }
863
864    #[test]
865    fn zero_recognition_batch_size_is_rejected() {
866        let err = OcrEngineBuilder::new()
867            .det_model_path("det.onnx")
868            .rec_model_path("rec.onnx")
869            .dictionary_path("dict.txt")
870            .rec_batch_size(0)
871            .build()
872            .unwrap_err();
873
874        match err {
875            OcrError::InvalidConfiguration { message } => {
876                assert!(message.contains("rec_batch_size"));
877            }
878            other => panic!("expected InvalidConfiguration error, got {:?}", other),
879        }
880    }
881
882    #[test]
883    fn build_succeeds_when_paths_exist() {
884        let (det, rec, dict) = existing_model_paths()
885            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
886
887        let engine = OcrEngineBuilder::new()
888            .det_model_path(&det)
889            .rec_model_path(&rec)
890            .dictionary_path(&dict)
891            .det_limit_side_len(1024)
892            .det_unclip_ratio(2.0)
893            .rec_batch_size(4)
894            .build()
895            .expect("engine should build successfully");
896
897        assert_eq!(engine.config().det_preprocessor.limit_side_len, 1024);
898        assert!((engine.config().det_unclipper.unclip_ratio - 2.0).abs() < f32::EPSILON);
899        assert_eq!(engine.config().rec_batch_size, 4);
900    }
901
902    #[test]
903    fn engine_reports_asset_paths_and_batch_size() {
904        let (det, rec, dict) = existing_model_paths()
905            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
906
907        let engine = OcrEngineBuilder::new()
908            .det_model_path(&det)
909            .rec_model_path(&rec)
910            .dictionary_path(&dict)
911            .rec_batch_size(6)
912            .build()
913            .expect("engine should build successfully");
914
915        assert_eq!(engine.det_model_path(), det.as_path());
916        assert_eq!(engine.rec_model_path(), rec.as_path());
917        assert_eq!(engine.dictionary_path(), dict.as_path());
918        assert_eq!(engine.rec_batch_size(), 6);
919    }
920
921    #[test]
922    fn recognition_blank_id_matches_dictionary_blank_id() {
923        let (det, rec, dict) = existing_model_paths()
924            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
925
926        let dictionary_blank_id = RecDictionary::from_path(&dict)
927            .expect("dictionary should load successfully")
928            .blank_id();
929
930        let engine = OcrEngineBuilder::new()
931            .det_model_path(&det)
932            .rec_model_path(&rec)
933            .dictionary_path(&dict)
934            .build()
935            .expect("engine should build successfully");
936
937        assert_eq!(
938            engine.config().rec_postprocessor.blank_id,
939            dictionary_blank_id
940        );
941    }
942
943    #[test]
944    fn run_from_path_processes_blank_image() -> Result<(), OcrError> {
945        let (det, rec, dict) = existing_model_paths()
946            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
947
948        let engine = OcrEngineBuilder::new()
949            .det_model_path(&det)
950            .rec_model_path(&rec)
951            .dictionary_path(&dict)
952            .build()
953            .expect("engine should build successfully");
954
955        let temp_path = temp_image_path("run_path_blank");
956        let image_buffer = image::ImageBuffer::from_pixel(64, 32, image::Rgb([0, 0, 0]));
957        DynamicImage::ImageRgb8(image_buffer)
958            .save(&temp_path)
959            .expect("failed to save temporary image");
960
961        let results = engine.run_from_path(&temp_path)?;
962        assert!(
963            results.len() <= engine.rec_batch_size(),
964            "number of results should not exceed configured batch size"
965        );
966
967        std::fs::remove_file(&temp_path).ok();
968        Ok(())
969    }
970
971    #[test]
972    fn run_from_image_reuses_pipeline() -> Result<(), OcrError> {
973        let (det, rec, dict) = existing_model_paths()
974            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
975
976        let engine = OcrEngineBuilder::new()
977            .det_model_path(&det)
978            .rec_model_path(&rec)
979            .dictionary_path(&dict)
980            .build()
981            .expect("engine should build successfully");
982
983        let image_buffer = image::ImageBuffer::from_pixel(32, 192, image::Rgb([255, 255, 255]));
984        let dynamic_image = DynamicImage::ImageRgb8(image_buffer);
985        let results = engine.run_from_image(&dynamic_image)?;
986
987        assert!(
988            results.len() <= engine.rec_batch_size(),
989            "number of results should not exceed configured batch size"
990        );
991
992        Ok(())
993    }
994
995    #[test]
996    fn run_with_metrics_reports_timings() -> Result<(), OcrError> {
997        let (det, rec, dict) = existing_model_paths()
998            .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
999
1000        let engine = OcrEngineBuilder::new()
1001            .det_model_path(&det)
1002            .rec_model_path(&rec)
1003            .dictionary_path(&dict)
1004            .build()
1005            .expect("engine should build successfully");
1006
1007        let image_buffer = image::ImageBuffer::from_pixel(16, 16, image::Rgb([0, 0, 0]));
1008        let dynamic_image = DynamicImage::ImageRgb8(image_buffer);
1009
1010        let run_with_metrics = engine.run_with_metrics_from_image(&dynamic_image)?;
1011        let baseline_results = engine.run_from_image(&dynamic_image)?;
1012
1013        assert_eq!(run_with_metrics.results.len(), baseline_results.len());
1014        assert!(run_with_metrics.timings.total >= run_with_metrics.timings.detection.preprocess);
1015        assert!(run_with_metrics.timings.recognition.preprocess <= run_with_metrics.timings.total);
1016
1017        Ok(())
1018    }
1019
1020    #[test]
1021    fn component_errors_convert_to_ocr_error_variants() {
1022        match OcrError::from(DetPreProcessorError::EmptyImage) {
1023            OcrError::DetectionPreprocess { .. } => {}
1024            other => panic!("expected DetectionPreprocess variant, got {:?}", other),
1025        }
1026
1027        match OcrError::from(DetPostProcessorError::EmptyProbabilityMap) {
1028            OcrError::DetectionPostProcess { .. } => {}
1029            other => panic!("expected DetectionPostProcess variant, got {:?}", other),
1030        }
1031
1032        match OcrError::from(RecPreProcessorError::EmptyRegions) {
1033            OcrError::RecognitionPreprocess { .. } => {}
1034            other => panic!("expected RecognitionPreprocess variant, got {:?}", other),
1035        }
1036
1037        let rec_post_err = RecPostProcessorError::from(CtcGreedyDecoderError::EmptyBatch);
1038        match OcrError::from(rec_post_err) {
1039            OcrError::RecognitionPostProcess { .. } => {}
1040            other => panic!("expected RecognitionPostProcess variant, got {:?}", other),
1041        }
1042    }
1043}