1mod builder;
51mod components;
52mod config;
53mod extensible_integration;
54mod image_processing;
55mod orchestration;
56mod result;
57
58pub use builder::OAROCRBuilder;
59pub use config::{OAROCRConfig, OnnxThreadingConfig, ParallelPolicy};
60pub use extensible_integration::{ExtensibleOAROCR, ExtensibleOAROCRBuilder};
61pub use image_processing::ImageProcessor;
62pub use orchestration::{
63 ImageInputSource, ImageProcessingOrchestrator, PipelineExecutor, PipelineStage,
64 PipelineStageConfig, ProcessingStrategy,
65};
66pub use result::{ErrorMetrics, OAROCRResult, TextRegion};
67
68use crate::core::{OCRError, parse_text_line_orientation, traits::StandardPredictor};
69use crate::pipeline::PipelineStats;
70use crate::pipeline::stages::{
71 CroppingConfig, CroppingStageProcessor, RecognitionConfig, RecognitionStageProcessor,
72};
73use crate::predictor::{
74 DocOrientationClassifier, DoctrRectifierPredictor, TextDetPredictor, TextLineClasPredictor,
75 TextRecPredictor,
76};
77use components::ComponentBuilder;
78use image::RgbImage;
79
80use crate::pipeline::StatsManager;
81use std::path::Path;
82use std::sync::{Arc, Once};
83use tracing::{debug, info, warn};
84
85static THREAD_POOL_INIT: Once = Once::new();
88
89#[derive(Debug, Clone)]
91pub struct SingleImageProcessingParams<'a> {
92 index: usize,
93 input_img_arc: Arc<RgbImage>,
94 current_img: RgbImage,
95 text_boxes: Vec<crate::processors::BoundingBox>,
96 orientation_angle: Option<f32>,
97 rectified_img: Option<Arc<RgbImage>>,
98 image_path: &'a Path,
99}
100
101pub fn configure_thread_pool_once(max_threads: usize) -> crate::core::OcrResult<()> {
115 let mut result = Ok(());
116
117 THREAD_POOL_INIT.call_once(|| {
118 debug!(
119 "Configuring global rayon thread pool with {} threads",
120 max_threads
121 );
122 if let Err(e) = rayon::ThreadPoolBuilder::new()
123 .num_threads(max_threads)
124 .build_global()
125 {
126 result = Err(OCRError::config_error(format!(
127 "Failed to configure global thread pool: {e}"
128 )));
129 }
130 });
131
132 result
133}
134
135pub struct OAROCR {
143 config: OAROCRConfig,
145 doc_orientation_classifier: Option<DocOrientationClassifier>,
147 doc_rectifier: Option<DoctrRectifierPredictor>,
149 text_detector: TextDetPredictor,
151 text_line_classifier: Option<TextLineClasPredictor>,
153 text_recognizer: TextRecPredictor,
155 stats: StatsManager,
157}
158
159impl OAROCR {
160 pub fn new(config: OAROCRConfig) -> crate::core::OcrResult<Self> {
173 info!("Initializing OAROCR pipeline with config: {:?}", config);
174
175 if let Some(max_threads) = config.max_threads() {
177 configure_thread_pool_once(max_threads)?;
178 }
179
180 let doc_orientation_classifier = if config.use_doc_orientation_classify {
182 info!("Initializing document orientation classifier");
183 Some(ComponentBuilder::build_doc_orientation_classifier(&config)?)
184 } else {
185 None
186 };
187
188 let doc_rectifier = if config.use_doc_unwarping {
189 info!("Initializing document rectifier");
190 Some(ComponentBuilder::build_doc_rectifier(&config)?)
191 } else {
192 None
193 };
194
195 let text_line_classifier = if config.use_textline_orientation {
196 info!("Initializing text line classifier");
197 Some(ComponentBuilder::build_text_line_classifier(&config)?)
198 } else {
199 None
200 };
201
202 info!("Initializing text detector");
204 let text_detector = ComponentBuilder::build_text_detector(&config)?;
205
206 info!("Initializing text recognizer");
207 let text_recognizer = ComponentBuilder::build_text_recognizer(&config)?;
208
209 let pipeline = Self {
210 config,
211 doc_orientation_classifier,
212 doc_rectifier,
213 text_detector,
214 text_line_classifier,
215 text_recognizer,
216 stats: StatsManager::new(),
217 };
218
219 info!("OAROCR pipeline initialized successfully");
220 Ok(pipeline)
221 }
222
223 pub fn predict(&self, images: &[RgbImage]) -> crate::core::OcrResult<Vec<OAROCRResult>> {
242 let start_time = std::time::Instant::now();
243
244 info!(
245 "Starting OCR pipeline for {} in-memory image(s)",
246 images.len()
247 );
248
249 let result = self.process_images_from_memory(images);
250
251 let processing_time = start_time.elapsed();
253 let total_time_ms = processing_time.as_millis() as f64;
254
255 match &result {
256 Ok(results) => {
257 self.update_stats(images.len(), results.len(), 0, total_time_ms);
258 }
259 Err(_) => {
260 self.update_stats(images.len(), 0, images.len(), total_time_ms);
261 }
262 }
263
264 result
265 }
266
267 fn process_images_from_memory(
269 &self,
270 images: &[RgbImage],
271 ) -> crate::core::OcrResult<Vec<OAROCRResult>> {
272 let orchestrator = ImageProcessingOrchestrator::new(self);
274
275 let inputs: Vec<(usize, &RgbImage)> = images.iter().enumerate().collect();
277
278 let image_threshold = self.config.image_threshold();
280 let strategy = ProcessingStrategy::Auto(image_threshold);
281 let stage_config = PipelineStageConfig::default(); orchestrator.process_batch(inputs, strategy, stage_config)
284 }
285
286 fn process_single_image_from_detection(
291 &self,
292 params: SingleImageProcessingParams,
293 ) -> crate::core::OcrResult<OAROCRResult> {
294 let SingleImageProcessingParams {
296 index,
297 input_img_arc,
298 current_img,
299 text_boxes,
300 orientation_angle,
301 rectified_img,
302 image_path,
303 } = params;
304
305 let cropping_config = CroppingConfig::default();
307
308 let cropping_stage_result = CroppingStageProcessor::process_single(
309 ¤t_img,
310 &text_boxes,
311 Some(&cropping_config),
312 )?;
313
314 let cropped_images = cropping_stage_result.data.cropped_images;
315 let failed_crops = cropping_stage_result.data.failed_crops;
316
317 let mut text_line_orientations: Vec<Option<f32>> = Vec::new();
322 let mut failed_orientations = 0;
323 if self.config.use_textline_orientation && !text_boxes.is_empty() {
324 if let Some(ref classifier) = self.text_line_classifier {
325 let valid_images: Vec<RgbImage> = cropped_images
326 .iter()
327 .filter_map(|o| o.as_ref().cloned())
328 .collect();
329 let valid_images_count = valid_images.len();
330 if !valid_images.is_empty() {
331 match classifier.predict(valid_images, None) {
332 Ok(result) => {
333 let mut result_idx = 0usize;
334 for cropped_img_opt in &cropped_images {
335 if cropped_img_opt.is_some() {
336 if let (Some(labels), Some(score_list)) = (
337 result.label_names.get(result_idx),
338 result.scores.get(result_idx),
339 ) {
340 if let (Some(label), Some(&score)) =
341 (labels.first(), score_list.first())
342 {
343 let confidence_threshold = self
344 .config
345 .text_line_orientation_stage
346 .as_ref()
347 .and_then(|config| config.confidence_threshold);
348
349 let orientation_result = parse_text_line_orientation(
350 label.as_ref(),
351 score,
352 confidence_threshold,
353 );
354
355 if orientation_result.is_confident {
356 text_line_orientations
357 .push(Some(orientation_result.angle));
358 } else {
359 text_line_orientations.push(None);
360 }
361 } else {
362 text_line_orientations.push(None);
363 }
364 } else {
365 text_line_orientations.push(None);
366 }
367 result_idx += 1;
368 } else {
369 text_line_orientations.push(None);
370 }
371 }
372 }
373 Err(e) => {
374 failed_orientations = valid_images_count;
375 warn!(
376 "Text line orientation classification failed for {} images: {}",
377 valid_images_count, e
378 );
379 text_line_orientations.resize(text_boxes.len(), None);
380 }
381 }
382 } else {
383 text_line_orientations.resize(text_boxes.len(), None);
384 }
385 } else {
386 text_line_orientations.resize(text_boxes.len(), None);
387 }
388 } else {
389 text_line_orientations.resize(text_boxes.len(), None);
390 }
391
392 let recognition_config = RecognitionConfig::from_legacy_config(
394 self.config.use_textline_orientation,
395 self.config.aspect_ratio_bucketing.clone(),
396 );
397
398 let recognition_stage_result = RecognitionStageProcessor::process_single(
399 cropped_images,
400 Some(&text_line_orientations),
401 Some(&self.text_recognizer),
402 Some(&recognition_config),
403 )?;
404
405 let rec_texts = recognition_stage_result.data.rec_texts;
406 let rec_scores = recognition_stage_result.data.rec_scores;
407 let failed_recognitions = recognition_stage_result.data.failed_recognitions;
408
409 let score_thresh = self.config.recognition.score_thresh.unwrap_or(0.0);
411 let mut final_texts: Vec<Option<Arc<str>>> = Vec::new();
412 let mut final_scores: Vec<Option<f32>> = Vec::new();
413 let mut final_orientations: Vec<Option<f32>> = Vec::new();
414 for ((text, score), orientation) in rec_texts
415 .into_iter()
416 .zip(rec_scores)
417 .zip(text_line_orientations.iter().cloned())
418 {
419 if score >= score_thresh {
420 final_texts.push(Some(text));
421 final_scores.push(Some(score));
422 final_orientations.push(orientation);
423 } else {
424 final_texts.push(None);
425 final_scores.push(None);
426 final_orientations.push(orientation);
427 }
428 }
429
430 let error_metrics = ErrorMetrics {
432 failed_crops,
433 failed_recognitions,
434 failed_orientations,
435 total_text_boxes: text_boxes.len(),
436 };
437
438 let text_regions = OAROCRResult::create_text_regions_from_vectors(
440 &text_boxes,
441 &final_texts,
442 &final_scores,
443 &final_orientations,
444 );
445
446 Ok(OAROCRResult {
447 input_path: Arc::from(image_path.to_string_lossy().as_ref()),
448 index,
449 input_img: input_img_arc,
450 text_regions,
451 orientation_angle,
452 rectified_img,
453 error_metrics,
454 })
455 }
456
457 pub fn get_stats(&self) -> PipelineStats {
463 self.stats.get_stats()
464 }
465
466 fn update_stats(
475 &self,
476 processed_count: usize,
477 successful_count: usize,
478 failed_count: usize,
479 inference_time_ms: f64,
480 ) {
481 self.stats.update_stats(
482 processed_count,
483 successful_count,
484 failed_count,
485 inference_time_ms,
486 );
487 }
488
489 pub fn reset_stats(&self) {
491 self.stats.reset_stats();
492 }
493
494 pub fn get_config(&self) -> &OAROCRConfig {
500 &self.config
501 }
502
503 }
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_oarocr_builder_text_rec_score_thresh() {
526 let builder = OAROCRBuilder::new(
528 "dummy_det_path".to_string(),
529 "dummy_rec_path".to_string(),
530 "dummy_dict_path".to_string(),
531 )
532 .text_rec_score_threshold(0.8);
533
534 assert_eq!(builder.get_config().recognition.score_thresh, Some(0.8));
535 }
536
537 #[test]
538 fn test_orchestration_abstraction_imports() {
539 use crate::pipeline::oarocr::{
541 ImageInputSource, ImageProcessingOrchestrator, PipelineExecutor, PipelineStage,
542 PipelineStageConfig, ProcessingStrategy,
543 };
544 use image::RgbImage;
545 use std::path::Path;
546 use std::sync::Arc;
547
548 let strategy = ProcessingStrategy::Sequential;
550 let config = PipelineStageConfig::default();
551
552 let _orientation_stage = PipelineStage::Orientation;
554 let _detection_stage = PipelineStage::Detection;
555
556 let path = Path::new("test.jpg");
558 let _path_source = ImageInputSource::Path(path);
559
560 let img = RgbImage::new(100, 100);
561 let _memory_source = ImageInputSource::Memory(&img);
562
563 let img_arc = Arc::new(img);
564 let _loaded_source = ImageInputSource::LoadedWithPath(img_arc, path);
565
566 assert!(!strategy.should_use_parallel(10));
568 assert_eq!(config.start_from, PipelineStage::Orientation);
569
570 let _orchestrator_type = std::any::type_name::<ImageProcessingOrchestrator>();
573 let _executor_type = std::any::type_name::<PipelineExecutor>();
574 }
575
576 #[test]
577 fn test_processing_strategy_behavior() {
578 let sequential = ProcessingStrategy::Sequential;
580 let parallel = ProcessingStrategy::Parallel;
581 let auto_5 = ProcessingStrategy::Auto(5);
582
583 assert!(!sequential.should_use_parallel(1));
585 assert!(!sequential.should_use_parallel(100));
586
587 assert!(parallel.should_use_parallel(1));
589 assert!(parallel.should_use_parallel(100));
590
591 assert!(!auto_5.should_use_parallel(3));
593 assert!(!auto_5.should_use_parallel(5));
594 assert!(auto_5.should_use_parallel(6));
595 assert!(auto_5.should_use_parallel(10));
596 }
597
598 #[test]
599 fn test_pipeline_stage_config_customization() {
600 let mut config = PipelineStageConfig::default();
601
602 assert_eq!(config.start_from, PipelineStage::Orientation);
604 assert!(config.skip_stages.is_empty());
605 assert!(config.custom_params.is_none());
606
607 config.start_from = PipelineStage::Detection;
609 config.skip_stages.insert(PipelineStage::Recognition);
610
611 assert_eq!(config.start_from, PipelineStage::Detection);
612 assert!(config.skip_stages.contains(&PipelineStage::Recognition));
613 assert!(!config.skip_stages.contains(&PipelineStage::Orientation));
614 }
615
616 #[test]
617 fn test_oarocr_builder_doc_orientation_confidence_threshold() {
618 let builder = OAROCRBuilder::new(
620 "dummy_det_path".to_string(),
621 "dummy_rec_path".to_string(),
622 "dummy_dict_path".to_string(),
623 )
624 .doc_orientation_threshold(0.8);
625
626 assert!(builder.get_config().orientation_stage.is_some());
627 assert_eq!(
628 builder
629 .get_config()
630 .orientation_stage
631 .as_ref()
632 .unwrap()
633 .confidence_threshold,
634 Some(0.8)
635 );
636 }
637
638 #[test]
639 fn test_oarocr_builder_textline_orientation_confidence_threshold() {
640 let builder = OAROCRBuilder::new(
642 "dummy_det_path".to_string(),
643 "dummy_rec_path".to_string(),
644 "dummy_dict_path".to_string(),
645 )
646 .textline_orientation_threshold(0.9);
647
648 assert!(builder.get_config().text_line_orientation_stage.is_some());
649 assert_eq!(
650 builder
651 .get_config()
652 .text_line_orientation_stage
653 .as_ref()
654 .unwrap()
655 .confidence_threshold,
656 Some(0.9)
657 );
658 }
659
660 #[test]
661 fn test_oarocr_result_alignment_preservation() {
662 use crate::processors::BoundingBox;
664 use crate::processors::Point;
665 use image::RgbImage;
666 use std::sync::Arc;
667
668 let text_boxes = vec![
670 BoundingBox {
671 points: vec![
672 Point { x: 0.0, y: 0.0 },
673 Point { x: 10.0, y: 0.0 },
674 Point { x: 10.0, y: 10.0 },
675 Point { x: 0.0, y: 10.0 },
676 ],
677 },
678 BoundingBox {
679 points: vec![
680 Point { x: 20.0, y: 0.0 },
681 Point { x: 30.0, y: 0.0 },
682 Point { x: 30.0, y: 10.0 },
683 Point { x: 20.0, y: 10.0 },
684 ],
685 },
686 BoundingBox {
687 points: vec![
688 Point { x: 40.0, y: 0.0 },
689 Point { x: 50.0, y: 0.0 },
690 Point { x: 50.0, y: 10.0 },
691 Point { x: 40.0, y: 10.0 },
692 ],
693 },
694 ];
695
696 let rec_texts = vec![
698 Some(Arc::from("Hello")),
699 None, Some(Arc::from("World")),
701 ];
702
703 let rec_scores = vec![
704 Some(0.9),
705 None, Some(0.8),
707 ];
708
709 let text_regions = OAROCRResult::create_text_regions_from_vectors(
711 &text_boxes,
712 &rec_texts,
713 &rec_scores,
714 &[None, None, None],
715 );
716
717 let result = OAROCRResult {
718 input_path: Arc::from("test.jpg"),
719 index: 0,
720 input_img: Arc::new(RgbImage::new(100, 100)),
721 text_regions,
722 orientation_angle: None,
723 rectified_img: None,
724 error_metrics: ErrorMetrics::default(),
725 };
726
727 assert_eq!(result.text_regions.len(), 3);
729
730 for (i, region) in result.text_regions.iter().enumerate() {
732 assert!(region.bounding_box.points.len() >= 4);
734
735 match i {
736 0 => {
737 assert!(region.text.is_some());
739 assert!(region.confidence.is_some());
740 assert_eq!(region.text.as_ref().unwrap().as_ref(), "Hello");
741 assert_eq!(region.confidence.unwrap(), 0.9);
742 }
743 1 => {
744 assert!(region.text.is_none());
746 assert!(region.confidence.is_none());
747 }
748 2 => {
749 assert!(region.text.is_some());
751 assert!(region.confidence.is_some());
752 assert_eq!(region.text.as_ref().unwrap().as_ref(), "World");
753 assert_eq!(region.confidence.unwrap(), 0.8);
754 }
755 _ => panic!("Unexpected index"),
756 }
757 }
758 }
759
760 #[test]
761 fn test_thread_pool_configuration_once() {
762 let result1 = configure_thread_pool_once(2);
767 assert!(
768 result1.is_ok(),
769 "First thread pool configuration should succeed"
770 );
771
772 let result2 = configure_thread_pool_once(4);
774 assert!(
775 result2.is_ok(),
776 "Second thread pool configuration should succeed (ignored)"
777 );
778
779 let result3 = configure_thread_pool_once(1);
781 assert!(
782 result3.is_ok(),
783 "Third thread pool configuration should succeed (ignored)"
784 );
785 }
786}