1use crate::ctc::DecodedSequence;
2use crate::detection::DetInferenceSession;
3use crate::dictionary::{DictionaryError, RecDictionary};
4use crate::postprocessing::{
5 DetPolygonScaler, DetPolygonScalerConfig, DetPolygonUnclipper, DetPolygonUnclipperConfig,
6 DetPostProcessor, DetPostProcessorConfig, DetPostProcessorError,
7};
8use crate::preprocessing::{
9 DetPreProcessor, DetPreProcessorConfig, DetPreProcessorError, RecPreProcessor,
10 RecPreProcessorConfig, RecPreProcessorError, RecTextRegion,
11};
12use crate::recognition::{
13 RecInferenceSession, RecPostProcessor, RecPostProcessorConfig, RecPostProcessorError,
14};
15use geo_types::Polygon;
16use image::{DynamicImage, GenericImageView, ImageError};
17use std::error::Error;
18use std::fmt;
19use std::fs;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tract_onnx::prelude::TractError;
24
25#[derive(Debug)]
27pub enum OcrError {
28 MissingField { field: &'static str },
30 Io {
32 source: std::io::Error,
33 path: PathBuf,
34 },
35 ModelLoad { source: TractError, path: PathBuf },
37 Dictionary { source: DictionaryError },
39 InvalidConfiguration { message: String },
41 ImageDecode { source: ImageError, path: PathBuf },
43 DetectionPreprocess { source: DetPreProcessorError },
45 DetectionInference { source: TractError },
47 DetectionPostProcess { source: DetPostProcessorError },
49 RecognitionPreprocess { source: RecPreProcessorError },
51 RecognitionInference { source: TractError },
53 RecognitionPostProcess { source: RecPostProcessorError },
55 PipelineMismatch {
57 detection_regions: usize,
58 recognition_results: usize,
59 },
60}
61
62impl fmt::Display for OcrError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 match self {
65 OcrError::MissingField { field } => {
66 write!(f, "required builder field `{}` was not provided", field)
67 }
68 OcrError::Io { path, source } => {
69 write!(f, "failed to access resource {:?}: {}", path, source)
70 }
71 OcrError::ModelLoad { path, source } => {
72 write!(f, "failed to load ONNX model {:?}: {}", path, source)
73 }
74 OcrError::Dictionary { source } => write!(f, "failed to load dictionary: {}", source),
75 OcrError::InvalidConfiguration { message } => write!(f, "{}", message),
76 OcrError::ImageDecode { path, source } => {
77 write!(f, "failed to decode image {:?}: {}", path, source)
78 }
79 OcrError::DetectionPreprocess { source } => {
80 write!(f, "detection preprocessing failed: {}", source)
81 }
82 OcrError::DetectionInference { source } => {
83 write!(f, "detection inference failed: {}", source)
84 }
85 OcrError::DetectionPostProcess { source } => {
86 write!(f, "detection post-processing failed: {}", source)
87 }
88 OcrError::RecognitionPreprocess { source } => {
89 write!(f, "recognition preprocessing failed: {}", source)
90 }
91 OcrError::RecognitionInference { source } => {
92 write!(f, "recognition inference failed: {}", source)
93 }
94 OcrError::RecognitionPostProcess { source } => {
95 write!(f, "recognition post-processing failed: {}", source)
96 }
97 OcrError::PipelineMismatch {
98 detection_regions,
99 recognition_results,
100 } => write!(
101 f,
102 "pipeline mismatch: detection produced {} regions but recognition returned {} results",
103 detection_regions, recognition_results
104 ),
105 }
106 }
107}
108
109impl From<DetPreProcessorError> for OcrError {
110 fn from(source: DetPreProcessorError) -> Self {
111 OcrError::DetectionPreprocess { source }
112 }
113}
114
115impl From<DetPostProcessorError> for OcrError {
116 fn from(source: DetPostProcessorError) -> Self {
117 OcrError::DetectionPostProcess { source }
118 }
119}
120
121impl From<RecPreProcessorError> for OcrError {
122 fn from(source: RecPreProcessorError) -> Self {
123 OcrError::RecognitionPreprocess { source }
124 }
125}
126
127impl From<RecPostProcessorError> for OcrError {
128 fn from(source: RecPostProcessorError) -> Self {
129 OcrError::RecognitionPostProcess { source }
130 }
131}
132
133fn polygons_to_text_regions(
134 polygons: &[Polygon<f64>],
135 image_dims: (u32, u32),
136) -> Vec<RecTextRegion> {
137 polygons
138 .iter()
139 .map(|polygon| polygon_to_text_region(polygon, image_dims))
140 .collect()
141}
142
143fn polygon_to_text_region(polygon: &Polygon<f64>, image_dims: (u32, u32)) -> RecTextRegion {
144 let mut min_x = f64::INFINITY;
145 let mut min_y = f64::INFINITY;
146 let mut max_x = f64::NEG_INFINITY;
147 let mut max_y = f64::NEG_INFINITY;
148
149 for point in polygon.exterior().points() {
150 let x = point.x();
151 let y = point.y();
152 if x < min_x {
153 min_x = x;
154 }
155 if x > max_x {
156 max_x = x;
157 }
158 if y < min_y {
159 min_y = y;
160 }
161 if y > max_y {
162 max_y = y;
163 }
164 }
165
166 let image_width = image_dims.0.max(1);
167 let image_height = image_dims.1.max(1);
168 let width_limit = image_width as f64;
169 let height_limit = image_height as f64;
170
171 let mut x1 = min_x.floor().max(0.0);
172 let mut y1 = min_y.floor().max(0.0);
173 let mut x2 = max_x.ceil().min(width_limit);
174 let mut y2 = max_y.ceil().min(height_limit);
175
176 if x2 <= x1 {
177 x2 = (x1 + 1.0).min(width_limit);
178 }
179 if y2 <= y1 {
180 y2 = (y1 + 1.0).min(height_limit);
181 }
182
183 if x2 <= x1 {
184 x1 = (width_limit - 1.0).max(0.0);
185 x2 = width_limit;
186 }
187 if y2 <= y1 {
188 y1 = (height_limit - 1.0).max(0.0);
189 y2 = height_limit;
190 }
191
192 let mut x = x1.floor() as u32;
193 let mut y = y1.floor() as u32;
194 if x >= image_width {
195 x = image_width - 1;
196 }
197 if y >= image_height {
198 y = image_height - 1;
199 }
200
201 let mut width = (x2 - x1).ceil().max(1.0) as u32;
202 let mut height = (y2 - y1).ceil().max(1.0) as u32;
203
204 if x + width > image_width {
205 width = image_width.saturating_sub(x);
206 }
207 if y + height > image_height {
208 height = image_height.saturating_sub(y);
209 }
210
211 if width == 0 {
212 width = 1;
213 }
214 if height == 0 {
215 height = 1;
216 }
217
218 RecTextRegion {
219 x,
220 y,
221 width,
222 height,
223 }
224}
225
226impl Error for OcrError {
227 fn source(&self) -> Option<&(dyn Error + 'static)> {
228 match self {
229 OcrError::MissingField { .. } => None,
230 OcrError::Io { source, .. } => Some(source),
231 OcrError::ModelLoad { .. } => None,
232 OcrError::Dictionary { source } => Some(source),
233 OcrError::InvalidConfiguration { .. } => None,
234 OcrError::ImageDecode { source, .. } => Some(source),
235 OcrError::DetectionPreprocess { source } => Some(source),
236 OcrError::DetectionInference { .. } => None,
237 OcrError::DetectionPostProcess { source } => Some(source),
238 OcrError::RecognitionPreprocess { source } => Some(source),
239 OcrError::RecognitionInference { .. } => None,
240 OcrError::RecognitionPostProcess { source } => Some(source),
241 OcrError::PipelineMismatch { .. } => None,
242 }
243 }
244}
245
246impl From<DictionaryError> for OcrError {
247 fn from(source: DictionaryError) -> Self {
248 Self::Dictionary { source }
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct OcrEngineConfig {
255 pub det_preprocessor: DetPreProcessorConfig,
256 pub det_postprocessor: DetPostProcessorConfig,
257 pub det_unclipper: DetPolygonUnclipperConfig,
258 pub det_polygon_scaler: DetPolygonScalerConfig,
259 pub rec_preprocessor: RecPreProcessorConfig,
260 pub rec_postprocessor: RecPostProcessorConfig,
261 pub rec_batch_size: usize,
262}
263
264impl Default for OcrEngineConfig {
265 fn default() -> Self {
266 Self {
267 det_preprocessor: DetPreProcessorConfig::default(),
268 det_postprocessor: DetPostProcessorConfig::default(),
269 det_unclipper: DetPolygonUnclipperConfig::default(),
270 det_polygon_scaler: DetPolygonScalerConfig::default(),
271 rec_preprocessor: RecPreProcessorConfig::default(),
272 rec_postprocessor: RecPostProcessorConfig::default(),
273 rec_batch_size: 8,
274 }
275 }
276}
277
278#[derive(Debug)]
288pub struct OcrEngine {
289 assets: EngineAssets,
290 detection: DetectionPipeline,
291 recognition: RecognitionPipeline,
292 config: OcrEngineConfig,
293}
294
295#[derive(Debug, Clone)]
297pub struct OcrResult {
298 pub text: String,
299 pub confidence: f32,
300 pub bounding_box: Polygon<f64>,
301}
302
303#[derive(Debug, Clone)]
304pub struct StageTimings {
305 pub preprocess: Duration,
306 pub inference: Duration,
307 pub postprocess: Duration,
308}
309
310impl StageTimings {
311 fn zero() -> Self {
312 Self {
313 preprocess: Duration::ZERO,
314 inference: Duration::ZERO,
315 postprocess: Duration::ZERO,
316 }
317 }
318}
319
320#[derive(Debug, Clone)]
321pub struct OcrTimings {
322 pub total: Duration,
323 pub image_decode: Duration,
324 pub detection: StageTimings,
325 pub recognition: StageTimings,
326}
327
328impl OcrTimings {
329 fn new() -> Self {
330 Self {
331 total: Duration::ZERO,
332 image_decode: Duration::ZERO,
333 detection: StageTimings::zero(),
334 recognition: StageTimings::zero(),
335 }
336 }
337}
338
339#[derive(Debug, Clone)]
340pub struct OcrRunWithMetrics {
341 pub results: Vec<OcrResult>,
342 pub timings: OcrTimings,
343}
344
345impl OcrEngine {
346 fn new(
347 det_model_path: PathBuf,
348 rec_model_path: PathBuf,
349 dictionary_path: PathBuf,
350 det_session: DetInferenceSession,
351 rec_session: RecInferenceSession,
352 dictionary: RecDictionary,
353 config: OcrEngineConfig,
354 ) -> Self {
355 let assets = EngineAssets::new(det_model_path, rec_model_path, dictionary_path);
356
357 let det_session = Arc::new(det_session);
358 let rec_session = Arc::new(rec_session);
359 let dictionary = Arc::new(dictionary);
360
361 let detection = DetectionPipeline::new(
362 Arc::clone(&det_session),
363 config.det_preprocessor,
364 config.det_postprocessor,
365 config.det_unclipper,
366 config.det_polygon_scaler,
367 );
368
369 let recognition = RecognitionPipeline::new(
370 Arc::clone(&rec_session),
371 Arc::clone(&dictionary),
372 config.rec_preprocessor.clone(),
373 config.rec_postprocessor.clone(),
374 );
375
376 Self {
377 assets,
378 detection,
379 recognition,
380 config,
381 }
382 }
383
384 pub fn run_from_path<P: AsRef<Path>>(&self, path: P) -> Result<Vec<OcrResult>, OcrError> {
386 let run = self.run_with_metrics_from_path(path)?;
387 Ok(run.results)
388 }
389
390 pub fn run_with_metrics_from_path<P: AsRef<Path>>(
392 &self,
393 path: P,
394 ) -> Result<OcrRunWithMetrics, OcrError> {
395 let overall_start = Instant::now();
396 let path_ref = path.as_ref();
397 let decode_start = Instant::now();
398 let image = image::open(path_ref).map_err(|source| OcrError::ImageDecode {
399 source,
400 path: path_ref.to_path_buf(),
401 })?;
402 let mut run = self.run_with_metrics_from_image_impl(&image)?;
403 run.timings.image_decode = decode_start.elapsed();
404 run.timings.total = overall_start.elapsed();
405 Ok(run)
406 }
407
408 pub fn config(&self) -> &OcrEngineConfig {
411 &self.config
412 }
413
414 pub fn det_model_path(&self) -> &Path {
416 self.assets.det_model_path()
417 }
418
419 pub fn rec_model_path(&self) -> &Path {
421 self.assets.rec_model_path()
422 }
423
424 pub fn dictionary_path(&self) -> &Path {
426 self.assets.dictionary_path()
427 }
428
429 pub fn rec_batch_size(&self) -> usize {
431 self.config.rec_batch_size
432 }
433
434 pub fn run_from_image(&self, image: &DynamicImage) -> Result<Vec<OcrResult>, OcrError> {
435 let run = self.run_with_metrics_from_image_impl(image)?;
436 Ok(run.results)
437 }
438
439 pub fn run_with_metrics_from_image(
440 &self,
441 image: &DynamicImage,
442 ) -> Result<OcrRunWithMetrics, OcrError> {
443 self.run_with_metrics_from_image_impl(image)
444 }
445
446 fn run_with_metrics_from_image_impl(
447 &self,
448 image: &DynamicImage,
449 ) -> Result<OcrRunWithMetrics, OcrError> {
450 let pipeline_start = Instant::now();
451 let mut timings = OcrTimings::new();
452 let image_dims = image.dimensions();
453
454 let (polygons, detection_timings) = self
455 .detection
456 .detect_polygons_with_timings(image, image_dims)?;
457 timings.detection = detection_timings;
458
459 if polygons.is_empty() {
460 timings.total = pipeline_start.elapsed();
461 return Ok(OcrRunWithMetrics {
462 results: Vec::new(),
463 timings,
464 });
465 }
466
467 let regions = polygons_to_text_regions(&polygons, image_dims);
468 let (sequences, recognition_timings) =
469 self.recognition.run_with_timings(image, ®ions)?;
470 timings.recognition = recognition_timings;
471
472 if sequences.len() != polygons.len() {
473 return Err(OcrError::PipelineMismatch {
474 detection_regions: polygons.len(),
475 recognition_results: sequences.len(),
476 });
477 }
478
479 let results: Vec<OcrResult> = polygons
480 .into_iter()
481 .zip(sequences.into_iter())
482 .map(|(polygon, sequence)| OcrResult {
483 text: sequence.text,
484 confidence: sequence.confidence,
485 bounding_box: polygon,
486 })
487 .collect();
488
489 timings.total = pipeline_start.elapsed();
490
491 Ok(OcrRunWithMetrics { results, timings })
492 }
493}
494
495#[derive(Debug, Clone)]
497pub struct OcrEngineBuilder {
498 det_model_path: Option<PathBuf>,
499 rec_model_path: Option<PathBuf>,
500 dictionary_path: Option<PathBuf>,
501 det_limit_side_len: u32,
502 det_unclip_ratio: f32,
503 rec_batch_size: usize,
504}
505
506impl Default for OcrEngineBuilder {
507 fn default() -> Self {
508 Self {
509 det_model_path: None,
510 rec_model_path: None,
511 dictionary_path: None,
512 det_limit_side_len: DetPreProcessorConfig::default().limit_side_len,
513 det_unclip_ratio: DetPolygonUnclipperConfig::default().unclip_ratio,
514 rec_batch_size: OcrEngineConfig::default().rec_batch_size,
515 }
516 }
517}
518
519impl OcrEngineBuilder {
520 pub fn new() -> Self {
522 Self::default()
523 }
524
525 pub fn det_model_path<P: AsRef<Path>>(mut self, path: P) -> Self {
527 self.det_model_path = Some(path.as_ref().to_path_buf());
528 self
529 }
530
531 pub fn rec_model_path<P: AsRef<Path>>(mut self, path: P) -> Self {
533 self.rec_model_path = Some(path.as_ref().to_path_buf());
534 self
535 }
536
537 pub fn dictionary_path<P: AsRef<Path>>(mut self, path: P) -> Self {
539 self.dictionary_path = Some(path.as_ref().to_path_buf());
540 self
541 }
542
543 pub fn det_limit_side_len(mut self, len: u32) -> Self {
545 self.det_limit_side_len = len;
546 self
547 }
548
549 pub fn det_unclip_ratio(mut self, ratio: f64) -> Self {
551 self.det_unclip_ratio = ratio as f32;
552 self
553 }
554
555 pub fn rec_batch_size(mut self, size: usize) -> Self {
557 self.rec_batch_size = size;
558 self
559 }
560
561 pub fn build(self) -> Result<OcrEngine, OcrError> {
563 let det_model_path = self.det_model_path.ok_or(OcrError::MissingField {
564 field: "det_model_path",
565 })?;
566 let rec_model_path = self.rec_model_path.ok_or(OcrError::MissingField {
567 field: "rec_model_path",
568 })?;
569 let dictionary_path = self.dictionary_path.ok_or(OcrError::MissingField {
570 field: "dictionary_path",
571 })?;
572
573 if self.rec_batch_size == 0 {
574 return Err(OcrError::InvalidConfiguration {
575 message: "rec_batch_size must be greater than zero".to_string(),
576 });
577 }
578
579 verify_file_exists(&det_model_path)?;
580 verify_file_exists(&rec_model_path)?;
581 verify_file_exists(&dictionary_path)?;
582
583 let det_session =
584 DetInferenceSession::load(&det_model_path).map_err(|source| OcrError::ModelLoad {
585 source,
586 path: det_model_path.clone(),
587 })?;
588 let rec_session =
589 RecInferenceSession::load(&rec_model_path).map_err(|source| OcrError::ModelLoad {
590 source,
591 path: rec_model_path.clone(),
592 })?;
593 let dictionary = RecDictionary::from_path(&dictionary_path)?;
594
595 let mut det_unclipper_config = DetPolygonUnclipperConfig::default();
596 det_unclipper_config.unclip_ratio = self.det_unclip_ratio;
597
598 let mut det_preprocessor_config = DetPreProcessorConfig::default();
599 det_preprocessor_config.limit_side_len = self.det_limit_side_len;
600
601 let mut config = OcrEngineConfig::default();
602 config.det_preprocessor = det_preprocessor_config;
603 config.det_unclipper = det_unclipper_config;
604 config.rec_batch_size = self.rec_batch_size;
605 config.rec_postprocessor.blank_id = dictionary.blank_id();
606
607 Ok(OcrEngine::new(
608 det_model_path,
609 rec_model_path,
610 dictionary_path,
611 det_session,
612 rec_session,
613 dictionary,
614 config,
615 ))
616 }
617}
618
619fn verify_file_exists(path: &Path) -> Result<(), OcrError> {
620 if let Err(source) = fs::metadata(path) {
621 return Err(OcrError::Io {
622 source,
623 path: path.to_path_buf(),
624 });
625 }
626 Ok(())
627}
628
629#[derive(Debug)]
630struct EngineAssets {
631 det_model_path: PathBuf,
632 rec_model_path: PathBuf,
633 dictionary_path: PathBuf,
634}
635
636impl EngineAssets {
637 fn new(det_model_path: PathBuf, rec_model_path: PathBuf, dictionary_path: PathBuf) -> Self {
638 Self {
639 det_model_path,
640 rec_model_path,
641 dictionary_path,
642 }
643 }
644
645 fn det_model_path(&self) -> &Path {
646 self.det_model_path.as_path()
647 }
648
649 fn rec_model_path(&self) -> &Path {
650 self.rec_model_path.as_path()
651 }
652
653 fn dictionary_path(&self) -> &Path {
654 self.dictionary_path.as_path()
655 }
656}
657
658#[derive(Debug)]
659struct DetectionPipeline {
660 preprocessor: DetPreProcessor,
661 session: Arc<DetInferenceSession>,
662 postprocessor: DetPostProcessor,
663 unclipper: DetPolygonUnclipper,
664 scaler: DetPolygonScaler,
665}
666
667impl DetectionPipeline {
668 fn new(
669 session: Arc<DetInferenceSession>,
670 preprocessor: DetPreProcessorConfig,
671 postprocessor: DetPostProcessorConfig,
672 unclipper: DetPolygonUnclipperConfig,
673 scaler: DetPolygonScalerConfig,
674 ) -> Self {
675 Self {
676 preprocessor: DetPreProcessor::new(preprocessor),
677 session,
678 postprocessor: DetPostProcessor::new(postprocessor),
679 unclipper: DetPolygonUnclipper::new(unclipper),
680 scaler: DetPolygonScaler::new(scaler),
681 }
682 }
683
684 fn detect_polygons_with_timings(
685 &self,
686 image: &DynamicImage,
687 image_dims: (u32, u32),
688 ) -> Result<(Vec<Polygon<f64>>, StageTimings), OcrError> {
689 let preprocess_start = Instant::now();
690 let preprocessed = self.preprocessor.process(image).map_err(OcrError::from)?;
691 let preprocess_elapsed = preprocess_start.elapsed();
692
693 let inference_start = Instant::now();
694 let inference = self
695 .session
696 .run(&preprocessed)
697 .map_err(|source| OcrError::DetectionInference { source })?;
698 let inference_elapsed = inference_start.elapsed();
699
700 let post_start = Instant::now();
701 let contours = self
702 .postprocessor
703 .process(&inference)
704 .map_err(OcrError::from)?;
705 let unclipped = self.unclipper.unclip_contours(&contours);
706 let scaled = self
707 .scaler
708 .scale_polygons(&unclipped, preprocessed.scale_ratio, image_dims);
709 let post_elapsed = post_start.elapsed();
710
711 let timings = StageTimings {
712 preprocess: preprocess_elapsed,
713 inference: inference_elapsed,
714 postprocess: post_elapsed,
715 };
716
717 Ok((scaled, timings))
718 }
719}
720
721#[derive(Debug)]
722struct RecognitionPipeline {
723 preprocessor: RecPreProcessor,
724 session: Arc<RecInferenceSession>,
725 postprocessor: RecPostProcessor,
726}
727
728impl RecognitionPipeline {
729 fn new(
730 session: Arc<RecInferenceSession>,
731 dictionary: Arc<RecDictionary>,
732 preprocessor: RecPreProcessorConfig,
733 postprocessor: RecPostProcessorConfig,
734 ) -> Self {
735 let postprocessor = RecPostProcessor::new(Arc::clone(&dictionary), postprocessor);
736
737 Self {
738 preprocessor: RecPreProcessor::new(preprocessor),
739 session,
740 postprocessor,
741 }
742 }
743
744 fn run_with_timings(
745 &self,
746 image: &DynamicImage,
747 regions: &[RecTextRegion],
748 ) -> Result<(Vec<DecodedSequence>, StageTimings), OcrError> {
749 let preprocess_start = Instant::now();
750 let batch = self
751 .preprocessor
752 .process(image, regions)
753 .map_err(OcrError::from)?;
754 let preprocess_elapsed = preprocess_start.elapsed();
755
756 let inference_start = Instant::now();
757 let inference = self
758 .session
759 .run(&batch)
760 .map_err(|source| OcrError::RecognitionInference { source })?;
761 let inference_elapsed = inference_start.elapsed();
762
763 let post_start = Instant::now();
764 let sequences = self
765 .postprocessor
766 .process(&inference)
767 .map_err(OcrError::from)?;
768 let post_elapsed = post_start.elapsed();
769
770 let timings = StageTimings {
771 preprocess: preprocess_elapsed,
772 inference: inference_elapsed,
773 postprocess: post_elapsed,
774 };
775
776 Ok((sequences, timings))
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783 use crate::ctc::CtcGreedyDecoderError;
784 use crate::dictionary::RecDictionary;
785 use crate::postprocessing::DetPostProcessorError;
786 use crate::preprocessing::{DetPreProcessorError, RecPreProcessorError};
787 use crate::recognition::RecPostProcessorError;
788 use std::env;
789 use std::path::Path;
790 use std::time::{SystemTime, UNIX_EPOCH};
791
792 fn locate_ppocrv5_asset(file_name: &str) -> Option<PathBuf> {
793 let mut bases: Vec<PathBuf> = Vec::new();
794 if let Some(dir) = env::var_os("PURE_ONNX_OCR_FIXTURE_DIR") {
795 let env_path = PathBuf::from(dir);
796 bases.push(env_path.clone());
797 bases.push(env_path.join("models"));
798 }
799
800 let manifest = Path::new(env!("CARGO_MANIFEST_DIR"));
801 bases.push(manifest.join("tests").join("fixtures").join("models"));
802 bases.push(manifest.join("tests").join("fixtures"));
803 bases.push(manifest.join("models"));
804
805 for base in bases {
806 let ppocr_dir = base.join("ppocrv5");
807 let candidate = ppocr_dir.join(file_name);
808 if candidate.exists() {
809 return Some(candidate);
810 }
811
812 let alt = base.join(file_name);
813 if alt.exists() {
814 return Some(alt);
815 }
816 }
817
818 None
819 }
820
821 fn existing_model_paths() -> Option<(PathBuf, PathBuf, PathBuf)> {
822 let det = locate_ppocrv5_asset("det.onnx")?;
823 let rec = locate_ppocrv5_asset("rec.onnx")?;
824 let dict = locate_ppocrv5_asset("ppocrv5_dict.txt")?;
825 Some((det, rec, dict))
826 }
827
828 fn temp_image_path(prefix: &str) -> PathBuf {
829 let timestamp = SystemTime::now()
830 .duration_since(UNIX_EPOCH)
831 .unwrap()
832 .as_nanos();
833 std::env::temp_dir().join(format!("{}_{}.png", prefix, timestamp))
834 }
835
836 #[test]
837 fn missing_det_model_path_returns_error() {
838 let err = OcrEngineBuilder::new()
839 .rec_model_path("rec.onnx")
840 .dictionary_path("dict.txt")
841 .build()
842 .unwrap_err();
843
844 match err {
845 OcrError::MissingField { field } => assert_eq!(field, "det_model_path"),
846 other => panic!("expected MissingField error, got {:?}", other),
847 }
848 }
849
850 #[test]
851 fn missing_dictionary_path_returns_error() {
852 let err = OcrEngineBuilder::new()
853 .det_model_path("det.onnx")
854 .rec_model_path("rec.onnx")
855 .build()
856 .unwrap_err();
857
858 match err {
859 OcrError::MissingField { field } => assert_eq!(field, "dictionary_path"),
860 other => panic!("expected MissingField error, got {:?}", other),
861 }
862 }
863
864 #[test]
865 fn zero_recognition_batch_size_is_rejected() {
866 let err = OcrEngineBuilder::new()
867 .det_model_path("det.onnx")
868 .rec_model_path("rec.onnx")
869 .dictionary_path("dict.txt")
870 .rec_batch_size(0)
871 .build()
872 .unwrap_err();
873
874 match err {
875 OcrError::InvalidConfiguration { message } => {
876 assert!(message.contains("rec_batch_size"));
877 }
878 other => panic!("expected InvalidConfiguration error, got {:?}", other),
879 }
880 }
881
882 #[test]
883 fn build_succeeds_when_paths_exist() {
884 let (det, rec, dict) = existing_model_paths()
885 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
886
887 let engine = OcrEngineBuilder::new()
888 .det_model_path(&det)
889 .rec_model_path(&rec)
890 .dictionary_path(&dict)
891 .det_limit_side_len(1024)
892 .det_unclip_ratio(2.0)
893 .rec_batch_size(4)
894 .build()
895 .expect("engine should build successfully");
896
897 assert_eq!(engine.config().det_preprocessor.limit_side_len, 1024);
898 assert!((engine.config().det_unclipper.unclip_ratio - 2.0).abs() < f32::EPSILON);
899 assert_eq!(engine.config().rec_batch_size, 4);
900 }
901
902 #[test]
903 fn engine_reports_asset_paths_and_batch_size() {
904 let (det, rec, dict) = existing_model_paths()
905 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
906
907 let engine = OcrEngineBuilder::new()
908 .det_model_path(&det)
909 .rec_model_path(&rec)
910 .dictionary_path(&dict)
911 .rec_batch_size(6)
912 .build()
913 .expect("engine should build successfully");
914
915 assert_eq!(engine.det_model_path(), det.as_path());
916 assert_eq!(engine.rec_model_path(), rec.as_path());
917 assert_eq!(engine.dictionary_path(), dict.as_path());
918 assert_eq!(engine.rec_batch_size(), 6);
919 }
920
921 #[test]
922 fn recognition_blank_id_matches_dictionary_blank_id() {
923 let (det, rec, dict) = existing_model_paths()
924 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
925
926 let dictionary_blank_id = RecDictionary::from_path(&dict)
927 .expect("dictionary should load successfully")
928 .blank_id();
929
930 let engine = OcrEngineBuilder::new()
931 .det_model_path(&det)
932 .rec_model_path(&rec)
933 .dictionary_path(&dict)
934 .build()
935 .expect("engine should build successfully");
936
937 assert_eq!(
938 engine.config().rec_postprocessor.blank_id,
939 dictionary_blank_id
940 );
941 }
942
943 #[test]
944 fn run_from_path_processes_blank_image() -> Result<(), OcrError> {
945 let (det, rec, dict) = existing_model_paths()
946 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
947
948 let engine = OcrEngineBuilder::new()
949 .det_model_path(&det)
950 .rec_model_path(&rec)
951 .dictionary_path(&dict)
952 .build()
953 .expect("engine should build successfully");
954
955 let temp_path = temp_image_path("run_path_blank");
956 let image_buffer = image::ImageBuffer::from_pixel(64, 32, image::Rgb([0, 0, 0]));
957 DynamicImage::ImageRgb8(image_buffer)
958 .save(&temp_path)
959 .expect("failed to save temporary image");
960
961 let results = engine.run_from_path(&temp_path)?;
962 assert!(
963 results.len() <= engine.rec_batch_size(),
964 "number of results should not exceed configured batch size"
965 );
966
967 std::fs::remove_file(&temp_path).ok();
968 Ok(())
969 }
970
971 #[test]
972 fn run_from_image_reuses_pipeline() -> Result<(), OcrError> {
973 let (det, rec, dict) = existing_model_paths()
974 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
975
976 let engine = OcrEngineBuilder::new()
977 .det_model_path(&det)
978 .rec_model_path(&rec)
979 .dictionary_path(&dict)
980 .build()
981 .expect("engine should build successfully");
982
983 let image_buffer = image::ImageBuffer::from_pixel(32, 192, image::Rgb([255, 255, 255]));
984 let dynamic_image = DynamicImage::ImageRgb8(image_buffer);
985 let results = engine.run_from_image(&dynamic_image)?;
986
987 assert!(
988 results.len() <= engine.rec_batch_size(),
989 "number of results should not exceed configured batch size"
990 );
991
992 Ok(())
993 }
994
995 #[test]
996 fn run_with_metrics_reports_timings() -> Result<(), OcrError> {
997 let (det, rec, dict) = existing_model_paths()
998 .expect("expected PP-OCRv5 assets to be present under models/ppocrv5/");
999
1000 let engine = OcrEngineBuilder::new()
1001 .det_model_path(&det)
1002 .rec_model_path(&rec)
1003 .dictionary_path(&dict)
1004 .build()
1005 .expect("engine should build successfully");
1006
1007 let image_buffer = image::ImageBuffer::from_pixel(16, 16, image::Rgb([0, 0, 0]));
1008 let dynamic_image = DynamicImage::ImageRgb8(image_buffer);
1009
1010 let run_with_metrics = engine.run_with_metrics_from_image(&dynamic_image)?;
1011 let baseline_results = engine.run_from_image(&dynamic_image)?;
1012
1013 assert_eq!(run_with_metrics.results.len(), baseline_results.len());
1014 assert!(run_with_metrics.timings.total >= run_with_metrics.timings.detection.preprocess);
1015 assert!(run_with_metrics.timings.recognition.preprocess <= run_with_metrics.timings.total);
1016
1017 Ok(())
1018 }
1019
1020 #[test]
1021 fn component_errors_convert_to_ocr_error_variants() {
1022 match OcrError::from(DetPreProcessorError::EmptyImage) {
1023 OcrError::DetectionPreprocess { .. } => {}
1024 other => panic!("expected DetectionPreprocess variant, got {:?}", other),
1025 }
1026
1027 match OcrError::from(DetPostProcessorError::EmptyProbabilityMap) {
1028 OcrError::DetectionPostProcess { .. } => {}
1029 other => panic!("expected DetectionPostProcess variant, got {:?}", other),
1030 }
1031
1032 match OcrError::from(RecPreProcessorError::EmptyRegions) {
1033 OcrError::RecognitionPreprocess { .. } => {}
1034 other => panic!("expected RecognitionPreprocess variant, got {:?}", other),
1035 }
1036
1037 let rec_post_err = RecPostProcessorError::from(CtcGreedyDecoderError::EmptyBatch);
1038 match OcrError::from(rec_post_err) {
1039 OcrError::RecognitionPostProcess { .. } => {}
1040 other => panic!("expected RecognitionPostProcess variant, got {:?}", other),
1041 }
1042 }
1043}